BUG 624 - Make netconf TCP port optional.
[controller.git] / opendaylight / netconf / netconf-ssh / src / main / java / org / opendaylight / controller / netconf / ssh / threads / Handshaker.java
diff --git a/opendaylight/netconf/netconf-ssh/src/main/java/org/opendaylight/controller/netconf/ssh/threads/Handshaker.java b/opendaylight/netconf/netconf-ssh/src/main/java/org/opendaylight/controller/netconf/ssh/threads/Handshaker.java
new file mode 100644 (file)
index 0000000..d999d37
--- /dev/null
@@ -0,0 +1,406 @@
+/*
+ * Copyright (c) 2013 Cisco Systems, Inc. and others.  All rights reserved.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License v1.0 which accompanies this distribution,
+ * and is available at http://www.eclipse.org/legal/epl-v10.html
+ */
+package org.opendaylight.controller.netconf.ssh.threads;
+
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
+
+import ch.ethz.ssh2.AuthenticationResult;
+import ch.ethz.ssh2.PtySettings;
+import ch.ethz.ssh2.ServerAuthenticationCallback;
+import ch.ethz.ssh2.ServerConnection;
+import ch.ethz.ssh2.ServerConnectionCallback;
+import ch.ethz.ssh2.ServerSession;
+import ch.ethz.ssh2.ServerSessionCallback;
+import ch.ethz.ssh2.SimpleServerSessionCallback;
+import com.google.common.base.Supplier;
+import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufProcessor;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.local.LocalAddress;
+import io.netty.channel.local.LocalChannel;
+import io.netty.handler.stream.ChunkedStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.Socket;
+import javax.annotation.concurrent.NotThreadSafe;
+import javax.annotation.concurrent.ThreadSafe;
+import org.opendaylight.controller.netconf.ssh.authentication.AuthProvider;
+import org.opendaylight.controller.netconf.util.messages.NetconfHelloMessageAdditionalHeader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * One instance represents per connection, responsible for ssh handshake.
+ * Once auth succeeds and correct subsystem is chosen, backend connection with
+ * netty netconf server is made. This task finishes right after negotiation is done.
+ */
+@ThreadSafe
+public class Handshaker implements Runnable {
+    private static final Logger logger = LoggerFactory.getLogger(Handshaker.class);
+
+    private final ServerConnection ganymedConnection;
+    private final String session;
+
+
+    public Handshaker(Socket socket, LocalAddress localAddress, long sessionId, AuthProvider authProvider,
+                      EventLoopGroup bossGroup) throws IOException {
+
+        this.session = "Session " + sessionId;
+
+        String remoteAddressWithPort = socket.getRemoteSocketAddress().toString().replace("/", "");
+        logger.debug("{} started with {}", session, remoteAddressWithPort);
+        String remoteAddress, remotePort;
+        if (remoteAddressWithPort.contains(":")) {
+            String[] split = remoteAddressWithPort.split(":");
+            remoteAddress = split[0];
+            remotePort = split[1];
+        } else {
+            remoteAddress = remoteAddressWithPort;
+            remotePort = "";
+        }
+        ServerAuthenticationCallbackImpl serverAuthenticationCallback = new ServerAuthenticationCallbackImpl(
+                authProvider, session);
+
+        ganymedConnection = new ServerConnection(socket);
+
+        ServerConnectionCallbackImpl serverConnectionCallback = new ServerConnectionCallbackImpl(
+                serverAuthenticationCallback, remoteAddress, remotePort, session,
+                getGanymedAutoCloseable(ganymedConnection), localAddress, bossGroup);
+
+        // initialize ganymed
+        ganymedConnection.setPEMHostKey(authProvider.getPEMAsCharArray(), null);
+        ganymedConnection.setAuthenticationCallback(serverAuthenticationCallback);
+        ganymedConnection.setServerConnectionCallback(serverConnectionCallback);
+    }
+
+
+    private static AutoCloseable getGanymedAutoCloseable(final ServerConnection ganymedConnection) {
+        return new AutoCloseable() {
+            @Override
+            public void close() throws Exception {
+                ganymedConnection.close();
+            }
+        };
+    }
+
+    @Override
+    public void run() {
+        // let ganymed process handshake
+        logger.trace("{} SocketThread is started", session);
+        try {
+            // TODO this should be guarded with a timer to prevent resource exhaustion
+            ganymedConnection.connect();
+        } catch (IOException e) {
+            logger.warn("{} SocketThread error ", session, e);
+        }
+        logger.trace("{} SocketThread is exiting", session);
+    }
+}
+
+/**
+ * Netty client handler that forwards bytes from backed server to supplied output stream.
+ * When backend server closes the connection, remoteConnection.close() is called to tear
+ * down ssh connection.
+ */
+class SSHClientHandler extends ChannelInboundHandlerAdapter {
+    private static final Logger logger = LoggerFactory.getLogger(SSHClientHandler.class);
+    private final AutoCloseable remoteConnection;
+    private final OutputStream remoteOutputStream;
+    private final String session;
+    private ChannelHandlerContext channelHandlerContext;
+
+    public SSHClientHandler(AutoCloseable remoteConnection, OutputStream remoteOutputStream,
+                            String session) {
+        this.remoteConnection = remoteConnection;
+        this.remoteOutputStream = remoteOutputStream;
+        this.session = session;
+    }
+
+    @Override
+    public void channelActive(ChannelHandlerContext ctx) {
+        this.channelHandlerContext = ctx;
+        logger.debug("{} Client active", session);
+    }
+
+    @Override
+    public void channelRead(ChannelHandlerContext ctx, Object msg) {
+        ByteBuf bb = (ByteBuf) msg;
+        // we can block the server here so that slow client does not cause memory pressure
+        try {
+            bb.forEachByte(new ByteBufProcessor() {
+                @Override
+                public boolean process(byte value) throws Exception {
+                    remoteOutputStream.write(value);
+                    return true;
+                }
+            });
+        } finally {
+            bb.release();
+        }
+    }
+
+    @Override
+    public void channelReadComplete(ChannelHandlerContext ctx) throws IOException {
+        logger.trace("{} Flushing", session);
+        remoteOutputStream.flush();
+    }
+
+    @Override
+    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+        // Close the connection when an exception is raised.
+        logger.warn("{} Unexpected exception from downstream", session, cause);
+        ctx.close();
+    }
+
+    @Override
+    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+        logger.trace("{} channelInactive() called, closing remote client ctx", session);
+        remoteConnection.close();//this should close socket and all threads created for this client
+        this.channelHandlerContext = null;
+    }
+
+    public ChannelHandlerContext getChannelHandlerContext() {
+        return checkNotNull(channelHandlerContext, "Channel is not active");
+    }
+}
+
+/**
+ * Ganymed handler that gets unencrypted input and output streams, connects them to netty.
+ * Checks that 'netconf' subsystem is chosen by user.
+ * Launches new ClientInputStreamPoolingThread thread once session is established.
+ * Writes custom header to netty server, to inform it about IP address and username.
+ */
+class ServerConnectionCallbackImpl implements ServerConnectionCallback {
+    private static final Logger logger = LoggerFactory.getLogger(ServerConnectionCallbackImpl.class);
+    public static final String NETCONF_SUBSYSTEM = "netconf";
+
+    private final Supplier<String> currentUserSupplier;
+    private final String remoteAddress;
+    private final String remotePort;
+    private final String session;
+    private final AutoCloseable ganymedConnection;
+    private final LocalAddress localAddress;
+    private final EventLoopGroup bossGroup;
+
+    ServerConnectionCallbackImpl(Supplier<String> currentUserSupplier, String remoteAddress, String remotePort, String session,
+                                 AutoCloseable ganymedConnection, LocalAddress localAddress, EventLoopGroup bossGroup) {
+        this.currentUserSupplier = currentUserSupplier;
+        this.remoteAddress = remoteAddress;
+        this.remotePort = remotePort;
+        this.session = session;
+        this.ganymedConnection = ganymedConnection;
+        // initialize netty local connection
+        this.localAddress = localAddress;
+        this.bossGroup = bossGroup;
+    }
+
+    private static ChannelFuture initializeNettyConnection(LocalAddress localAddress, EventLoopGroup bossGroup,
+                                                           final SSHClientHandler sshClientHandler) {
+        Bootstrap clientBootstrap = new Bootstrap();
+        clientBootstrap.group(bossGroup).channel(LocalChannel.class);
+
+        clientBootstrap.handler(new ChannelInitializer<LocalChannel>() {
+            @Override
+            public void initChannel(LocalChannel ch) throws Exception {
+                ch.pipeline().addLast(sshClientHandler);
+            }
+        });
+        // asynchronously initialize local connection to netconf server
+        return clientBootstrap.connect(localAddress);
+    }
+
+    @Override
+    public ServerSessionCallback acceptSession(final ServerSession serverSession) {
+        String currentUser = currentUserSupplier.get();
+        final String additionalHeader = new NetconfHelloMessageAdditionalHeader(currentUser, remoteAddress,
+                remotePort, "ssh", "client").toFormattedString();
+
+
+        return new SimpleServerSessionCallback() {
+            @Override
+            public Runnable requestSubsystem(final ServerSession ss, final String subsystem) throws IOException {
+                return new Runnable() {
+                    @Override
+                    public void run() {
+                        if (NETCONF_SUBSYSTEM.equals(subsystem)) {
+                            // connect
+                            final SSHClientHandler sshClientHandler = new SSHClientHandler(ganymedConnection, ss.getStdin(), session);
+                            ChannelFuture clientChannelFuture = initializeNettyConnection(localAddress, bossGroup, sshClientHandler);
+                            // get channel
+                            final Channel channel = clientChannelFuture.awaitUninterruptibly().channel();
+                            new ClientInputStreamPoolingThread(session, ss.getStdout(), channel, new AutoCloseable() {
+                                @Override
+                                public void close() throws Exception {
+                                    logger.trace("Closing both ganymed and local connection");
+                                    try {
+                                        ganymedConnection.close();
+                                    } catch (Exception e) {
+                                        logger.warn("Ignoring exception while closing ganymed", e);
+                                    }
+                                    try {
+                                        channel.close();
+                                    } catch (Exception e) {
+                                        logger.warn("Ignoring exception while closing channel", e);
+                                    }
+                                }
+                            }, sshClientHandler.getChannelHandlerContext()).start();
+
+                            // write additional header
+                            channel.writeAndFlush(Unpooled.copiedBuffer(additionalHeader.getBytes()));
+                        } else {
+                            logger.debug("{} Wrong subsystem requested:'{}', closing ssh session", serverSession, subsystem);
+                            String reason = "Only netconf subsystem is supported, requested:" + subsystem;
+                            closeSession(ss, reason);
+                        }
+                    }
+                };
+            }
+
+            public void closeSession(ServerSession ss, String reason) {
+                logger.trace("{} Closing session - {}", serverSession, reason);
+                try {
+                    ss.getStdin().write(reason.getBytes());
+                } catch (IOException e) {
+                    logger.warn("{} Exception while closing session", serverSession, e);
+                }
+                ss.close();
+            }
+
+            @Override
+            public Runnable requestPtyReq(final ServerSession ss, final PtySettings pty) throws IOException {
+                return new Runnable() {
+                    @Override
+                    public void run() {
+                        closeSession(ss, "PTY request not supported");
+                    }
+                };
+            }
+
+            @Override
+            public Runnable requestShell(final ServerSession ss) throws IOException {
+                return new Runnable() {
+                    @Override
+                    public void run() {
+                        closeSession(ss, "Shell not supported");
+                    }
+                };
+            }
+        };
+    }
+}
+
+/**
+ * Only thread that is required during ssh session, forwards client's input to netty.
+ * When user closes connection, onEndOfInput.close() is called to tear down the local channel.
+ */
+class ClientInputStreamPoolingThread extends Thread {
+    private static final Logger logger = LoggerFactory.getLogger(ClientInputStreamPoolingThread.class);
+
+    private final InputStream fromClientIS;
+    private final Channel serverChannel;
+    private final AutoCloseable onEndOfInput;
+    private final ChannelHandlerContext channelHandlerContext;
+
+    ClientInputStreamPoolingThread(String session, InputStream fromClientIS, Channel serverChannel, AutoCloseable onEndOfInput,
+                                   ChannelHandlerContext channelHandlerContext) {
+        super(ClientInputStreamPoolingThread.class.getSimpleName() + " " + session);
+        this.fromClientIS = fromClientIS;
+        this.serverChannel = serverChannel;
+        this.onEndOfInput = onEndOfInput;
+        this.channelHandlerContext = channelHandlerContext;
+    }
+
+    @Override
+    public void run() {
+        ChunkedStream chunkedStream = new ChunkedStream(fromClientIS);
+        try {
+            ByteBuf byteBuf;
+            while ((byteBuf = chunkedStream.readChunk(channelHandlerContext/*only needed for ByteBuf alloc */)) != null) {
+                serverChannel.writeAndFlush(byteBuf);
+            }
+        } catch (Exception e) {
+            logger.warn("Exception", e);
+        } finally {
+            logger.trace("End of input");
+            // tear down connection
+            try {
+                onEndOfInput.close();
+            } catch (Exception e) {
+                logger.warn("Ignoring exception while closing socket", e);
+            }
+        }
+    }
+}
+
+/**
+ * Authentication handler for ganymed.
+ * Provides current user name after authenticating using supplied AuthProvider.
+ */
+@NotThreadSafe
+class ServerAuthenticationCallbackImpl implements ServerAuthenticationCallback, Supplier<String> {
+    private static final Logger logger = LoggerFactory.getLogger(ServerAuthenticationCallbackImpl.class);
+    private final AuthProvider authProvider;
+    private final String session;
+    private String currentUser;
+
+    ServerAuthenticationCallbackImpl(AuthProvider authProvider, String session) {
+        this.authProvider = authProvider;
+        this.session = session;
+    }
+
+    @Override
+    public String initAuthentication(ServerConnection sc) {
+        logger.trace("{} Established connection", session);
+        return "Established connection" + "\r\n";
+    }
+
+    @Override
+    public String[] getRemainingAuthMethods(ServerConnection sc) {
+        return new String[]{ServerAuthenticationCallback.METHOD_PASSWORD};
+    }
+
+    @Override
+    public AuthenticationResult authenticateWithNone(ServerConnection sc, String username) {
+        return AuthenticationResult.FAILURE;
+    }
+
+    @Override
+    public AuthenticationResult authenticateWithPassword(ServerConnection sc, String username, String password) {
+        checkState(currentUser == null);
+        try {
+            if (authProvider.authenticated(username, password)) {
+                currentUser = username;
+                logger.trace("{} user {} authenticated", session, currentUser);
+                return AuthenticationResult.SUCCESS;
+            }
+        } catch (Exception e) {
+            logger.warn("{} Authentication failed", session, e);
+        }
+        return AuthenticationResult.FAILURE;
+    }
+
+    @Override
+    public AuthenticationResult authenticateWithPublicKey(ServerConnection sc, String username, String algorithm,
+                                                          byte[] publicKey, byte[] signature) {
+        return AuthenticationResult.FAILURE;
+    }
+
+    @Override
+    public String get() {
+        return currentUser;
+    }
+}