import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.SettableFuture;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup;
import io.netty.util.concurrent.GlobalEventExecutor;
import io.netty.util.concurrent.Promise;
import java.io.IOException;
import java.net.InetSocketAddress;
-import java.net.SocketAddress;
import java.security.PublicKey;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
-import org.apache.sshd.client.channel.ClientChannel;
-import org.apache.sshd.client.future.AuthFuture;
-import org.apache.sshd.client.future.OpenFuture;
-import org.apache.sshd.client.session.ClientSession;
-import org.apache.sshd.common.future.SshFutureListener;
-import org.apache.sshd.common.session.Session;
import org.eclipse.jdt.annotation.Nullable;
import org.opendaylight.netconf.client.NetconfClientSession;
import org.opendaylight.netconf.client.NetconfClientSessionListener;
import org.opendaylight.netconf.client.NetconfClientSessionNegotiatorFactory;
+import org.opendaylight.netconf.nettyutil.handler.ssh.client.AsyncSshHandlerWriter;
+import org.opendaylight.netconf.nettyutil.handler.ssh.client.NetconfClientSessionImpl;
+import org.opendaylight.netconf.shaded.sshd.client.channel.ChannelSubsystem;
+import org.opendaylight.netconf.shaded.sshd.client.channel.ClientChannel;
+import org.opendaylight.netconf.shaded.sshd.client.future.AuthFuture;
+import org.opendaylight.netconf.shaded.sshd.client.future.OpenFuture;
+import org.opendaylight.netconf.shaded.sshd.client.session.ClientSession;
+import org.opendaylight.netconf.shaded.sshd.common.future.SshFutureListener;
+import org.opendaylight.netconf.shaded.sshd.common.session.Session;
+import org.opendaylight.yang.gen.v1.urn.opendaylight.netconf.device.rev231025.connection.parameters.Protocol.Name;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+// Non-final for testing
class CallHomeSessionContext implements CallHomeProtocolSessionContext {
private static final Logger LOG = LoggerFactory.getLogger(CallHomeSessionContext.class);
- static final Session.AttributeKey<CallHomeSessionContext> SESSION_KEY = new Session.AttributeKey<>();
-
private static final String NETCONF = "netconf";
+ @VisibleForTesting
+ static final Session.AttributeKey<CallHomeSessionContext> SESSION_KEY = new Session.AttributeKey<>();
+
private final ClientSession sshSession;
private final CallHomeAuthorization authorization;
private final Factory factory;
private final InetSocketAddress remoteAddress;
private final PublicKey serverKey;
+ @SuppressFBWarnings(value = "MC_OVERRIDABLE_METHOD_CALL_IN_CONSTRUCTOR", justification = "Passing 'this' around")
CallHomeSessionContext(final ClientSession sshSession, final CallHomeAuthorization authorization,
- final SocketAddress remoteAddress, final Factory factory) {
+ final Factory factory) {
this.authorization = requireNonNull(authorization, "authorization");
checkArgument(this.authorization.isServerAllowed(), "Server was not allowed.");
- this.factory = requireNonNull(factory, "factory");
- this.sshSession = requireNonNull(sshSession, "sshSession");
- this.sshSession.setAttribute(SESSION_KEY, this);
- this.remoteAddress = (InetSocketAddress) this.sshSession.getIoSession().getRemoteAddress();
- this.serverKey = this.sshSession.getKex().getServerKey();
+ this.factory = requireNonNull(factory);
+ this.sshSession = requireNonNull(sshSession);
+ remoteAddress = (InetSocketAddress) this.sshSession.getIoSession().getRemoteAddress();
+ serverKey = this.sshSession.getServerKey();
+ }
+
+ final void associate() {
+ sshSession.setAttribute(SESSION_KEY, this);
}
static CallHomeSessionContext getFrom(final ClientSession sshSession) {
void openNetconfChannel() {
LOG.debug("Opening NETCONF Subsystem on {}", sshSession);
try {
- final ClientChannel netconfChannel = sshSession.createSubsystemChannel(NETCONF);
+ final MinaSshNettyChannel nettyChannel = newMinaSshNettyChannel();
+ final ClientChannel netconfChannel =
+ ((NetconfClientSessionImpl) sshSession).createSubsystemChannel(NETCONF, nettyChannel.pipeline());
netconfChannel.setStreaming(ClientChannel.Streaming.Async);
- netconfChannel.open().addListener(newSshFutureListener(netconfChannel));
+ netconfChannel.open().addListener(newSshFutureListener(netconfChannel, nettyChannel));
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
- SshFutureListener<OpenFuture> newSshFutureListener(final ClientChannel netconfChannel) {
+ SshFutureListener<OpenFuture> newSshFutureListener(final ClientChannel netconfChannel,
+ final MinaSshNettyChannel nettyChannel) {
return future -> {
if (future.isOpened()) {
factory.getChannelOpenListener().onNetconfSubsystemOpened(this,
- listener -> doActivate(netconfChannel, listener));
+ listener -> doActivate(netconfChannel, listener, nettyChannel));
} else {
channelOpenFailed(future.getException());
}
removeSelf();
}
+ @Override
+ public Name getTransportType() {
+ return Name.SSH;
+ }
+
private void channelOpenFailed(final Throwable throwable) {
LOG.error("Unable to open netconf subsystem, disconnecting.", throwable);
sshSession.close(false);
}
- private synchronized Promise<NetconfClientSession> doActivate(final ClientChannel netconfChannel,
- final NetconfClientSessionListener listener) {
+ private synchronized ListenableFuture<NetconfClientSession> doActivate(final ClientChannel netconfChannel,
+ final NetconfClientSessionListener listener, final MinaSshNettyChannel nettyChannel) {
if (activated) {
- return newSessionPromise().setFailure(new IllegalStateException("Session already activated."));
+ return Futures.immediateFailedFuture(new IllegalStateException("Session already activated."));
}
-
activated = true;
+ nettyChannel.pipeline().addFirst(new SshWriteAsyncHandlerAdapter(netconfChannel));
LOG.info("Activating Netconf channel for {} with {}", getRemoteAddress(), listener);
- Promise<NetconfClientSession> activationPromise = newSessionPromise();
- final MinaSshNettyChannel nettyChannel = newMinaSshNettyChannel(netconfChannel);
+ final Promise<NetconfClientSession> activationPromise = newSessionPromise();
factory.getChannelInitializer(listener).initialize(nettyChannel, activationPromise);
+ ((ChannelSubsystem) netconfChannel).onClose(nettyChannel::doNettyDisconnect);
factory.getNettyGroup().register(nettyChannel).awaitUninterruptibly(500);
- return activationPromise;
+ final SettableFuture<NetconfClientSession> future = SettableFuture.create();
+ activationPromise.addListener(ignored -> {
+ final var cause = activationPromise.cause();
+ if (cause != null) {
+ future.setException(cause);
+ } else {
+ future.set(activationPromise.getNow());
+ }
+ });
+ return future;
}
- protected MinaSshNettyChannel newMinaSshNettyChannel(final ClientChannel netconfChannel) {
- return new MinaSshNettyChannel(this, sshSession, netconfChannel);
+ @Deprecated(since = "7.0.0", forRemoval = true)
+ @VisibleForTesting
+ MinaSshNettyChannel newMinaSshNettyChannel() {
+ return new MinaSshNettyChannel(this, sshSession);
}
private static Promise<NetconfClientSession> newSessionPromise() {
return serverKey;
}
- @Override
- public String getRemoteServerVersion() {
- return sshSession.getServerVersion();
- }
-
@Override
public InetSocketAddress getRemoteAddress() {
return remoteAddress;
}
@Override
- public String getSessionName() {
+ public String getSessionId() {
return authorization.getSessionName();
}
}
static class Factory {
-
+ private final ConcurrentMap<String, CallHomeSessionContext> sessions = new ConcurrentHashMap<>();
private final EventLoopGroup nettyGroup;
private final NetconfClientSessionNegotiatorFactory negotiatorFactory;
private final CallHomeNetconfSubsystemListener subsystemListener;
- private final ConcurrentMap<String, CallHomeSessionContext> sessions = new ConcurrentHashMap<>();
Factory(final EventLoopGroup nettyGroup, final NetconfClientSessionNegotiatorFactory negotiatorFactory,
final CallHomeNetconfSubsystemListener subsystemListener) {
- this.nettyGroup = requireNonNull(nettyGroup, "nettyGroup");
- this.negotiatorFactory = requireNonNull(negotiatorFactory, "negotiatorFactory");
+ this.nettyGroup = requireNonNull(nettyGroup);
+ this.negotiatorFactory = requireNonNull(negotiatorFactory);
this.subsystemListener = requireNonNull(subsystemListener);
}
- void remove(final CallHomeSessionContext session) {
- sessions.remove(session.getSessionName(), session);
- }
-
ReverseSshChannelInitializer getChannelInitializer(final NetconfClientSessionListener listener) {
return ReverseSshChannelInitializer.create(negotiatorFactory, listener);
}
CallHomeNetconfSubsystemListener getChannelOpenListener() {
- return this.subsystemListener;
+ return subsystemListener;
+ }
+
+ EventLoopGroup getNettyGroup() {
+ return nettyGroup;
}
@Nullable CallHomeSessionContext createIfNotExists(final ClientSession sshSession,
- final CallHomeAuthorization authorization, final SocketAddress remoteAddress) {
- CallHomeSessionContext session = new CallHomeSessionContext(sshSession, authorization,
- remoteAddress, this);
- CallHomeSessionContext preexisting = sessions.putIfAbsent(session.getSessionName(), session);
- // If preexisting is null - session does not exist, so we can safely create new one, otherwise we return
- // null and incoming connection will be rejected.
- return preexisting == null ? session : null;
+ final CallHomeAuthorization authorization) {
+ final var newSession = new CallHomeSessionContext(sshSession, authorization, this);
+ final var existing = sessions.putIfAbsent(newSession.getSessionId(), newSession);
+ if (existing == null) {
+ // There was no mapping, but now there is. Associate the context with the session.
+ newSession.associate();
+ return newSession;
+ }
+
+ // We already have a mapping, do not create a new one. But also check if the current session matches
+ // the one stored in the session. This can happen during re-keying.
+ return existing == CallHomeSessionContext.getFrom(sshSession) ? existing : null;
}
- EventLoopGroup getNettyGroup() {
- return nettyGroup;
+ void remove(final CallHomeSessionContext session) {
+ sessions.remove(session.getSessionId(), session);
+ }
+ }
+
+ static class SshWriteAsyncHandlerAdapter extends ChannelOutboundHandlerAdapter {
+ private final AsyncSshHandlerWriter sshWriteAsyncHandler;
+ private final ClientChannel sshChannel;
+
+ SshWriteAsyncHandlerAdapter(final ClientChannel sshChannel) {
+ this.sshChannel = sshChannel;
+ sshWriteAsyncHandler = new AsyncSshHandlerWriter(sshChannel.getAsyncIn());
+ }
+
+ @Override
+ public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
+ sshWriteAsyncHandler.write(ctx, msg, promise);
+ }
+
+ public ClientChannel getSshChannel() {
+ return sshChannel;
}
}
}