Migrate netconf-topology to new transport
[netconf.git] / netconf / callhome-protocol / src / main / java / org / opendaylight / netconf / callhome / protocol / CallHomeSessionContext.java
index 047342a6ab94e57cdfe814adb553e625c826271f..b2e41e7eb0fc5f26a0885de075d257033564b99f 100644 (file)
@@ -10,35 +10,48 @@ package org.opendaylight.netconf.callhome.protocol;
 import static com.google.common.base.Preconditions.checkArgument;
 import static java.util.Objects.requireNonNull;
 
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.SettableFuture;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
 import io.netty.channel.EventLoopGroup;
 import io.netty.util.concurrent.GlobalEventExecutor;
 import io.netty.util.concurrent.Promise;
 import java.io.IOException;
 import java.net.InetSocketAddress;
-import java.net.SocketAddress;
 import java.security.PublicKey;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
-import org.apache.sshd.client.channel.ClientChannel;
-import org.apache.sshd.client.future.AuthFuture;
-import org.apache.sshd.client.future.OpenFuture;
-import org.apache.sshd.client.session.ClientSession;
-import org.apache.sshd.common.future.SshFutureListener;
-import org.apache.sshd.common.session.Session;
 import org.eclipse.jdt.annotation.Nullable;
 import org.opendaylight.netconf.client.NetconfClientSession;
 import org.opendaylight.netconf.client.NetconfClientSessionListener;
 import org.opendaylight.netconf.client.NetconfClientSessionNegotiatorFactory;
+import org.opendaylight.netconf.nettyutil.handler.ssh.client.AsyncSshHandlerWriter;
+import org.opendaylight.netconf.nettyutil.handler.ssh.client.NetconfClientSessionImpl;
+import org.opendaylight.netconf.shaded.sshd.client.channel.ChannelSubsystem;
+import org.opendaylight.netconf.shaded.sshd.client.channel.ClientChannel;
+import org.opendaylight.netconf.shaded.sshd.client.future.AuthFuture;
+import org.opendaylight.netconf.shaded.sshd.client.future.OpenFuture;
+import org.opendaylight.netconf.shaded.sshd.client.session.ClientSession;
+import org.opendaylight.netconf.shaded.sshd.common.future.SshFutureListener;
+import org.opendaylight.netconf.shaded.sshd.common.session.Session;
+import org.opendaylight.yang.gen.v1.urn.opendaylight.netconf.device.rev231025.connection.parameters.Protocol.Name;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+// Non-final for testing
 class CallHomeSessionContext implements CallHomeProtocolSessionContext {
 
     private static final Logger LOG = LoggerFactory.getLogger(CallHomeSessionContext.class);
-    static final Session.AttributeKey<CallHomeSessionContext> SESSION_KEY = new Session.AttributeKey<>();
-
     private static final String NETCONF = "netconf";
 
+    @VisibleForTesting
+    static final Session.AttributeKey<CallHomeSessionContext> SESSION_KEY = new Session.AttributeKey<>();
+
     private final ClientSession sshSession;
     private final CallHomeAuthorization authorization;
     private final Factory factory;
@@ -48,15 +61,19 @@ class CallHomeSessionContext implements CallHomeProtocolSessionContext {
     private final InetSocketAddress remoteAddress;
     private final PublicKey serverKey;
 
+    @SuppressFBWarnings(value = "MC_OVERRIDABLE_METHOD_CALL_IN_CONSTRUCTOR", justification = "Passing 'this' around")
     CallHomeSessionContext(final ClientSession sshSession, final CallHomeAuthorization authorization,
-                           final SocketAddress remoteAddress, final Factory factory) {
+                           final Factory factory) {
         this.authorization = requireNonNull(authorization, "authorization");
         checkArgument(this.authorization.isServerAllowed(), "Server was not allowed.");
-        this.factory = requireNonNull(factory, "factory");
-        this.sshSession = requireNonNull(sshSession, "sshSession");
-        this.sshSession.setAttribute(SESSION_KEY, this);
-        this.remoteAddress = (InetSocketAddress) this.sshSession.getIoSession().getRemoteAddress();
-        this.serverKey = this.sshSession.getKex().getServerKey();
+        this.factory = requireNonNull(factory);
+        this.sshSession = requireNonNull(sshSession);
+        remoteAddress = (InetSocketAddress) this.sshSession.getIoSession().getRemoteAddress();
+        serverKey = this.sshSession.getServerKey();
+    }
+
+    final void associate() {
+        sshSession.setAttribute(SESSION_KEY, this);
     }
 
     static CallHomeSessionContext getFrom(final ClientSession sshSession) {
@@ -71,19 +88,22 @@ class CallHomeSessionContext implements CallHomeProtocolSessionContext {
     void openNetconfChannel() {
         LOG.debug("Opening NETCONF Subsystem on {}", sshSession);
         try {
-            final ClientChannel netconfChannel = sshSession.createSubsystemChannel(NETCONF);
+            final MinaSshNettyChannel nettyChannel = newMinaSshNettyChannel();
+            final ClientChannel netconfChannel =
+                    ((NetconfClientSessionImpl) sshSession).createSubsystemChannel(NETCONF, nettyChannel.pipeline());
             netconfChannel.setStreaming(ClientChannel.Streaming.Async);
-            netconfChannel.open().addListener(newSshFutureListener(netconfChannel));
+            netconfChannel.open().addListener(newSshFutureListener(netconfChannel, nettyChannel));
         } catch (IOException e) {
             throw new IllegalStateException(e);
         }
     }
 
-    SshFutureListener<OpenFuture> newSshFutureListener(final ClientChannel netconfChannel) {
+    SshFutureListener<OpenFuture> newSshFutureListener(final ClientChannel netconfChannel,
+            final MinaSshNettyChannel nettyChannel) {
         return future -> {
             if (future.isOpened()) {
                 factory.getChannelOpenListener().onNetconfSubsystemOpened(this,
-                    listener -> doActivate(netconfChannel, listener));
+                    listener -> doActivate(netconfChannel, listener, nettyChannel));
             } else {
                 channelOpenFailed(future.getException());
             }
@@ -96,28 +116,44 @@ class CallHomeSessionContext implements CallHomeProtocolSessionContext {
         removeSelf();
     }
 
+    @Override
+    public Name getTransportType() {
+        return Name.SSH;
+    }
+
     private void channelOpenFailed(final Throwable throwable) {
         LOG.error("Unable to open netconf subsystem, disconnecting.", throwable);
         sshSession.close(false);
     }
 
-    private synchronized Promise<NetconfClientSession> doActivate(final ClientChannel netconfChannel,
-            final NetconfClientSessionListener listener) {
+    private synchronized ListenableFuture<NetconfClientSession> doActivate(final ClientChannel netconfChannel,
+            final NetconfClientSessionListener listener, final MinaSshNettyChannel nettyChannel) {
         if (activated) {
-            return newSessionPromise().setFailure(new IllegalStateException("Session already activated."));
+            return Futures.immediateFailedFuture(new IllegalStateException("Session already activated."));
         }
-
         activated = true;
+        nettyChannel.pipeline().addFirst(new SshWriteAsyncHandlerAdapter(netconfChannel));
         LOG.info("Activating Netconf channel for {} with {}", getRemoteAddress(), listener);
-        Promise<NetconfClientSession> activationPromise = newSessionPromise();
-        final MinaSshNettyChannel nettyChannel = newMinaSshNettyChannel(netconfChannel);
+        final Promise<NetconfClientSession> activationPromise = newSessionPromise();
         factory.getChannelInitializer(listener).initialize(nettyChannel, activationPromise);
+        ((ChannelSubsystem) netconfChannel).onClose(nettyChannel::doNettyDisconnect);
         factory.getNettyGroup().register(nettyChannel).awaitUninterruptibly(500);
-        return activationPromise;
+        final SettableFuture<NetconfClientSession> future = SettableFuture.create();
+        activationPromise.addListener(ignored -> {
+            final var cause = activationPromise.cause();
+            if (cause != null) {
+                future.setException(cause);
+            } else {
+                future.set(activationPromise.getNow());
+            }
+        });
+        return future;
     }
 
-    protected MinaSshNettyChannel newMinaSshNettyChannel(final ClientChannel netconfChannel) {
-        return new MinaSshNettyChannel(this, sshSession, netconfChannel);
+    @Deprecated(since = "7.0.0", forRemoval = true)
+    @VisibleForTesting
+    MinaSshNettyChannel newMinaSshNettyChannel() {
+        return new MinaSshNettyChannel(this, sshSession);
     }
 
     private static Promise<NetconfClientSession> newSessionPromise() {
@@ -129,18 +165,13 @@ class CallHomeSessionContext implements CallHomeProtocolSessionContext {
         return serverKey;
     }
 
-    @Override
-    public String getRemoteServerVersion() {
-        return sshSession.getServerVersion();
-    }
-
     @Override
     public InetSocketAddress getRemoteAddress() {
         return remoteAddress;
     }
 
     @Override
-    public String getSessionName() {
+    public String getSessionId() {
         return authorization.getSessionName();
     }
 
@@ -149,43 +180,66 @@ class CallHomeSessionContext implements CallHomeProtocolSessionContext {
     }
 
     static class Factory {
-
+        private final ConcurrentMap<String, CallHomeSessionContext> sessions = new ConcurrentHashMap<>();
         private final EventLoopGroup nettyGroup;
         private final NetconfClientSessionNegotiatorFactory negotiatorFactory;
         private final CallHomeNetconfSubsystemListener subsystemListener;
-        private final ConcurrentMap<String, CallHomeSessionContext> sessions = new ConcurrentHashMap<>();
 
         Factory(final EventLoopGroup nettyGroup, final NetconfClientSessionNegotiatorFactory negotiatorFactory,
                 final CallHomeNetconfSubsystemListener subsystemListener) {
-            this.nettyGroup = requireNonNull(nettyGroup, "nettyGroup");
-            this.negotiatorFactory = requireNonNull(negotiatorFactory, "negotiatorFactory");
+            this.nettyGroup = requireNonNull(nettyGroup);
+            this.negotiatorFactory = requireNonNull(negotiatorFactory);
             this.subsystemListener = requireNonNull(subsystemListener);
         }
 
-        void remove(final CallHomeSessionContext session) {
-            sessions.remove(session.getSessionName(), session);
-        }
-
         ReverseSshChannelInitializer getChannelInitializer(final NetconfClientSessionListener listener) {
             return ReverseSshChannelInitializer.create(negotiatorFactory, listener);
         }
 
         CallHomeNetconfSubsystemListener getChannelOpenListener() {
-            return this.subsystemListener;
+            return subsystemListener;
+        }
+
+        EventLoopGroup getNettyGroup() {
+            return nettyGroup;
         }
 
         @Nullable CallHomeSessionContext createIfNotExists(final ClientSession sshSession,
-                final CallHomeAuthorization authorization, final SocketAddress remoteAddress) {
-            CallHomeSessionContext session = new CallHomeSessionContext(sshSession, authorization,
-                    remoteAddress, this);
-            CallHomeSessionContext preexisting = sessions.putIfAbsent(session.getSessionName(), session);
-            // If preexisting is null - session does not exist, so we can safely create new one, otherwise we return
-            // null and incoming connection will be rejected.
-            return preexisting == null ? session : null;
+                                                           final CallHomeAuthorization authorization) {
+            final var newSession = new CallHomeSessionContext(sshSession, authorization, this);
+            final var existing = sessions.putIfAbsent(newSession.getSessionId(), newSession);
+            if (existing == null) {
+                // There was no mapping, but now there is. Associate the context with the session.
+                newSession.associate();
+                return newSession;
+            }
+
+            // We already have a mapping, do not create a new one. But also check if the current session matches
+            // the one stored in the session. This can happen during re-keying.
+            return existing == CallHomeSessionContext.getFrom(sshSession) ? existing : null;
         }
 
-        EventLoopGroup getNettyGroup() {
-            return nettyGroup;
+        void remove(final CallHomeSessionContext session) {
+            sessions.remove(session.getSessionId(), session);
+        }
+    }
+
+    static class SshWriteAsyncHandlerAdapter extends ChannelOutboundHandlerAdapter {
+        private final AsyncSshHandlerWriter sshWriteAsyncHandler;
+        private final ClientChannel sshChannel;
+
+        SshWriteAsyncHandlerAdapter(final ClientChannel sshChannel) {
+            this.sshChannel = sshChannel;
+            sshWriteAsyncHandler = new AsyncSshHandlerWriter(sshChannel.getAsyncIn());
+        }
+
+        @Override
+        public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
+            sshWriteAsyncHandler.write(ctx, msg, promise);
+        }
+
+        public ClientChannel getSshChannel() {
+            return sshChannel;
         }
     }
 }