Fix periodic NETCONF Call Home connection dropping
[netconf.git] / netconf / callhome-protocol / src / main / java / org / opendaylight / netconf / callhome / protocol / CallHomeSessionContext.java
index 298763768a67d9bc8bcce76792c049dd624b22b0..83a62f4fcdf9f9c2c3de563665a705c7aeb3c0c4 100644 (file)
@@ -10,6 +10,7 @@ package org.opendaylight.netconf.callhome.protocol;
 import static com.google.common.base.Preconditions.checkArgument;
 import static java.util.Objects.requireNonNull;
 
+import com.google.common.annotations.VisibleForTesting;
 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
 import io.netty.channel.EventLoopGroup;
 import io.netty.util.concurrent.GlobalEventExecutor;
@@ -37,10 +38,11 @@ import org.slf4j.LoggerFactory;
 class CallHomeSessionContext implements CallHomeProtocolSessionContext {
 
     private static final Logger LOG = LoggerFactory.getLogger(CallHomeSessionContext.class);
-    static final Session.AttributeKey<CallHomeSessionContext> SESSION_KEY = new Session.AttributeKey<>();
-
     private static final String NETCONF = "netconf";
 
+    @VisibleForTesting
+    static final Session.AttributeKey<CallHomeSessionContext> SESSION_KEY = new Session.AttributeKey<>();
+
     private final ClientSession sshSession;
     private final CallHomeAuthorization authorization;
     private final Factory factory;
@@ -57,11 +59,14 @@ class CallHomeSessionContext implements CallHomeProtocolSessionContext {
         checkArgument(this.authorization.isServerAllowed(), "Server was not allowed.");
         this.factory = requireNonNull(factory);
         this.sshSession = requireNonNull(sshSession);
-        this.sshSession.setAttribute(SESSION_KEY, this);
         this.remoteAddress = (InetSocketAddress) this.sshSession.getIoSession().getRemoteAddress();
         serverKey = this.sshSession.getServerKey();
     }
 
+    final void associate() {
+        sshSession.setAttribute(SESSION_KEY, this);
+    }
+
     static CallHomeSessionContext getFrom(final ClientSession sshSession) {
         return sshSession.getAttribute(SESSION_KEY);
     }
@@ -152,23 +157,18 @@ class CallHomeSessionContext implements CallHomeProtocolSessionContext {
     }
 
     static class Factory {
-
+        private final ConcurrentMap<String, CallHomeSessionContext> sessions = new ConcurrentHashMap<>();
         private final EventLoopGroup nettyGroup;
         private final NetconfClientSessionNegotiatorFactory negotiatorFactory;
         private final CallHomeNetconfSubsystemListener subsystemListener;
-        private final ConcurrentMap<String, CallHomeSessionContext> sessions = new ConcurrentHashMap<>();
 
         Factory(final EventLoopGroup nettyGroup, final NetconfClientSessionNegotiatorFactory negotiatorFactory,
                 final CallHomeNetconfSubsystemListener subsystemListener) {
-            this.nettyGroup = requireNonNull(nettyGroup, "nettyGroup");
-            this.negotiatorFactory = requireNonNull(negotiatorFactory, "negotiatorFactory");
+            this.nettyGroup = requireNonNull(nettyGroup);
+            this.negotiatorFactory = requireNonNull(negotiatorFactory);
             this.subsystemListener = requireNonNull(subsystemListener);
         }
 
-        void remove(final CallHomeSessionContext session) {
-            sessions.remove(session.getSessionId(), session);
-        }
-
         ReverseSshChannelInitializer getChannelInitializer(final NetconfClientSessionListener listener) {
             return ReverseSshChannelInitializer.create(negotiatorFactory, listener);
         }
@@ -177,18 +177,27 @@ class CallHomeSessionContext implements CallHomeProtocolSessionContext {
             return subsystemListener;
         }
 
+        EventLoopGroup getNettyGroup() {
+            return nettyGroup;
+        }
+
         @Nullable CallHomeSessionContext createIfNotExists(final ClientSession sshSession,
                 final CallHomeAuthorization authorization, final SocketAddress remoteAddress) {
-            CallHomeSessionContext session = new CallHomeSessionContext(sshSession, authorization,
-                    remoteAddress, this);
-            CallHomeSessionContext preexisting = sessions.putIfAbsent(session.getSessionId(), session);
-            // If preexisting is null - session does not exist, so we can safely create new one, otherwise we return
-            // null and incoming connection will be rejected.
-            return preexisting == null ? session : null;
+            final var newSession = new CallHomeSessionContext(sshSession, authorization, remoteAddress, this);
+            final var existing = sessions.putIfAbsent(newSession.getSessionId(), newSession);
+            if (existing == null) {
+                // There was no mapping, but now there is. Associate the the context with the session.
+                newSession.associate();
+                return newSession;
+            }
+
+            // We already have a mapping, do not create a new one. But also check if the current session matches
+            // the one stored in the session. This can happen during rekeying.
+            return existing == CallHomeSessionContext.getFrom(sshSession) ? existing : null;
         }
 
-        EventLoopGroup getNettyGroup() {
-            return nettyGroup;
+        void remove(final CallHomeSessionContext session) {
+            sessions.remove(session.getSessionId(), session);
         }
     }
 }