Merge "Bug 1362: New AsyncWriteTransaction#submit method"
[controller.git] / opendaylight / netconf / netconf-ssh / src / main / java / org / opendaylight / controller / netconf / ssh / threads / Handshaker.java
1 /*
2  * Copyright (c) 2013 Cisco Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.controller.netconf.ssh.threads;
9
10 import static com.google.common.base.Preconditions.checkNotNull;
11 import static com.google.common.base.Preconditions.checkState;
12
13 import ch.ethz.ssh2.AuthenticationResult;
14 import ch.ethz.ssh2.PtySettings;
15 import ch.ethz.ssh2.ServerAuthenticationCallback;
16 import ch.ethz.ssh2.ServerConnection;
17 import ch.ethz.ssh2.ServerConnectionCallback;
18 import ch.ethz.ssh2.ServerSession;
19 import ch.ethz.ssh2.ServerSessionCallback;
20 import ch.ethz.ssh2.SimpleServerSessionCallback;
21 import com.google.common.base.Supplier;
22 import io.netty.bootstrap.Bootstrap;
23 import io.netty.buffer.ByteBuf;
24 import io.netty.buffer.ByteBufProcessor;
25 import io.netty.buffer.Unpooled;
26 import io.netty.channel.Channel;
27 import io.netty.channel.ChannelFuture;
28 import io.netty.channel.ChannelHandlerContext;
29 import io.netty.channel.ChannelInboundHandlerAdapter;
30 import io.netty.channel.ChannelInitializer;
31 import io.netty.channel.EventLoopGroup;
32 import io.netty.channel.local.LocalAddress;
33 import io.netty.channel.local.LocalChannel;
34 import io.netty.handler.stream.ChunkedStream;
35 import java.io.BufferedOutputStream;
36 import java.io.IOException;
37 import java.io.InputStream;
38 import java.io.OutputStream;
39 import java.net.Socket;
40 import javax.annotation.concurrent.NotThreadSafe;
41 import javax.annotation.concurrent.ThreadSafe;
42 import org.opendaylight.controller.netconf.ssh.authentication.AuthProvider;
43 import org.opendaylight.controller.netconf.util.messages.NetconfHelloMessageAdditionalHeader;
44 import org.slf4j.Logger;
45 import org.slf4j.LoggerFactory;
46
47 /**
48  * One instance represents per connection, responsible for ssh handshake.
49  * Once auth succeeds and correct subsystem is chosen, backend connection with
50  * netty netconf server is made. This task finishes right after negotiation is done.
51  */
52 @ThreadSafe
53 public class Handshaker implements Runnable {
54     private static final Logger logger = LoggerFactory.getLogger(Handshaker.class);
55
56     private final ServerConnection ganymedConnection;
57     private final String session;
58
59
60     public Handshaker(Socket socket, LocalAddress localAddress, long sessionId, AuthProvider authProvider,
61                       EventLoopGroup bossGroup) throws IOException {
62
63         this.session = "Session " + sessionId;
64
65         String remoteAddressWithPort = socket.getRemoteSocketAddress().toString().replace("/", "");
66         logger.debug("{} started with {}", session, remoteAddressWithPort);
67         String remoteAddress, remotePort;
68         if (remoteAddressWithPort.contains(":")) {
69             String[] split = remoteAddressWithPort.split(":");
70             remoteAddress = split[0];
71             remotePort = split[1];
72         } else {
73             remoteAddress = remoteAddressWithPort;
74             remotePort = "";
75         }
76         ServerAuthenticationCallbackImpl serverAuthenticationCallback = new ServerAuthenticationCallbackImpl(
77                 authProvider, session);
78
79         ganymedConnection = new ServerConnection(socket);
80
81         ServerConnectionCallbackImpl serverConnectionCallback = new ServerConnectionCallbackImpl(
82                 serverAuthenticationCallback, remoteAddress, remotePort, session,
83                 getGanymedAutoCloseable(ganymedConnection), localAddress, bossGroup);
84
85         // initialize ganymed
86         ganymedConnection.setPEMHostKey(authProvider.getPEMAsCharArray(), null);
87         ganymedConnection.setAuthenticationCallback(serverAuthenticationCallback);
88         ganymedConnection.setServerConnectionCallback(serverConnectionCallback);
89     }
90
91
92     private static AutoCloseable getGanymedAutoCloseable(final ServerConnection ganymedConnection) {
93         return new AutoCloseable() {
94             @Override
95             public void close() throws Exception {
96                 ganymedConnection.close();
97             }
98         };
99     }
100
101     @Override
102     public void run() {
103         // let ganymed process handshake
104         logger.trace("{} is started", session);
105         try {
106             // TODO this should be guarded with a timer to prevent resource exhaustion
107             ganymedConnection.connect();
108         } catch (IOException e) {
109             logger.debug("{} connection error", session, e);
110         }
111         logger.trace("{} is exiting", session);
112     }
113 }
114
115 /**
116  * Netty client handler that forwards bytes from backed server to supplied output stream.
117  * When backend server closes the connection, remoteConnection.close() is called to tear
118  * down ssh connection.
119  */
120 class SSHClientHandler extends ChannelInboundHandlerAdapter {
121     private static final Logger logger = LoggerFactory.getLogger(SSHClientHandler.class);
122     private final AutoCloseable remoteConnection;
123     private final BufferedOutputStream remoteOutputStream;
124     private final String session;
125     private ChannelHandlerContext channelHandlerContext;
126
127     public SSHClientHandler(AutoCloseable remoteConnection, OutputStream remoteOutputStream,
128                             String session) {
129         this.remoteConnection = remoteConnection;
130         this.remoteOutputStream = new BufferedOutputStream(remoteOutputStream);
131         this.session = session;
132     }
133
134     @Override
135     public void channelActive(ChannelHandlerContext ctx) {
136         this.channelHandlerContext = ctx;
137         logger.debug("{} Client active", session);
138     }
139
140     @Override
141     public void channelRead(ChannelHandlerContext ctx, Object msg) throws IOException {
142         ByteBuf bb = (ByteBuf) msg;
143         // we can block the server here so that slow client does not cause memory pressure
144         try {
145             bb.forEachByte(new ByteBufProcessor() {
146                 @Override
147                 public boolean process(byte value) throws Exception {
148                     remoteOutputStream.write(value);
149                     return true;
150                 }
151             });
152         } finally {
153             bb.release();
154         }
155     }
156
157     @Override
158     public void channelReadComplete(ChannelHandlerContext ctx) throws IOException {
159         logger.trace("{} Flushing", session);
160         remoteOutputStream.flush();
161     }
162
163     @Override
164     public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
165         // Close the connection when an exception is raised.
166         logger.warn("{} Unexpected exception from downstream", session, cause);
167         ctx.close();
168     }
169
170     @Override
171     public void channelInactive(ChannelHandlerContext ctx) throws Exception {
172         logger.trace("{} channelInactive() called, closing remote client ctx", session);
173         remoteConnection.close();//this should close socket and all threads created for this client
174         this.channelHandlerContext = null;
175     }
176
177     public ChannelHandlerContext getChannelHandlerContext() {
178         return checkNotNull(channelHandlerContext, "Channel is not active");
179     }
180 }
181
182 /**
183  * Ganymed handler that gets unencrypted input and output streams, connects them to netty.
184  * Checks that 'netconf' subsystem is chosen by user.
185  * Launches new ClientInputStreamPoolingThread thread once session is established.
186  * Writes custom header to netty server, to inform it about IP address and username.
187  */
188 class ServerConnectionCallbackImpl implements ServerConnectionCallback {
189     private static final Logger logger = LoggerFactory.getLogger(ServerConnectionCallbackImpl.class);
190     public static final String NETCONF_SUBSYSTEM = "netconf";
191
192     private final Supplier<String> currentUserSupplier;
193     private final String remoteAddress;
194     private final String remotePort;
195     private final String session;
196     private final AutoCloseable ganymedConnection;
197     private final LocalAddress localAddress;
198     private final EventLoopGroup bossGroup;
199
200     ServerConnectionCallbackImpl(Supplier<String> currentUserSupplier, String remoteAddress, String remotePort, String session,
201                                  AutoCloseable ganymedConnection, LocalAddress localAddress, EventLoopGroup bossGroup) {
202         this.currentUserSupplier = currentUserSupplier;
203         this.remoteAddress = remoteAddress;
204         this.remotePort = remotePort;
205         this.session = session;
206         this.ganymedConnection = ganymedConnection;
207         // initialize netty local connection
208         this.localAddress = localAddress;
209         this.bossGroup = bossGroup;
210     }
211
212     private static ChannelFuture initializeNettyConnection(LocalAddress localAddress, EventLoopGroup bossGroup,
213                                                            final SSHClientHandler sshClientHandler) {
214         Bootstrap clientBootstrap = new Bootstrap();
215         clientBootstrap.group(bossGroup).channel(LocalChannel.class);
216
217         clientBootstrap.handler(new ChannelInitializer<LocalChannel>() {
218             @Override
219             public void initChannel(LocalChannel ch) throws Exception {
220                 ch.pipeline().addLast(sshClientHandler);
221             }
222         });
223         // asynchronously initialize local connection to netconf server
224         return clientBootstrap.connect(localAddress);
225     }
226
227     @Override
228     public ServerSessionCallback acceptSession(final ServerSession serverSession) {
229         String currentUser = currentUserSupplier.get();
230         final String additionalHeader = new NetconfHelloMessageAdditionalHeader(currentUser, remoteAddress,
231                 remotePort, "ssh", "client").toFormattedString();
232
233
234         return new SimpleServerSessionCallback() {
235             @Override
236             public Runnable requestSubsystem(final ServerSession ss, final String subsystem) throws IOException {
237                 return new Runnable() {
238                     @Override
239                     public void run() {
240                         if (NETCONF_SUBSYSTEM.equals(subsystem)) {
241                             // connect
242                             final SSHClientHandler sshClientHandler = new SSHClientHandler(ganymedConnection, ss.getStdin(), session);
243                             ChannelFuture clientChannelFuture = initializeNettyConnection(localAddress, bossGroup, sshClientHandler);
244                             // get channel
245                             final Channel channel = clientChannelFuture.awaitUninterruptibly().channel();
246                             new ClientInputStreamPoolingThread(session, ss.getStdout(), channel, new AutoCloseable() {
247                                 @Override
248                                 public void close() throws Exception {
249                                     logger.trace("Closing both ganymed and local connection");
250                                     try {
251                                         ganymedConnection.close();
252                                     } catch (Exception e) {
253                                         logger.warn("Ignoring exception while closing ganymed", e);
254                                     }
255                                     try {
256                                         channel.close();
257                                     } catch (Exception e) {
258                                         logger.warn("Ignoring exception while closing channel", e);
259                                     }
260                                 }
261                             }, sshClientHandler.getChannelHandlerContext()).start();
262
263                             // write additional header
264                             channel.writeAndFlush(Unpooled.copiedBuffer(additionalHeader.getBytes()));
265                         } else {
266                             logger.debug("{} Wrong subsystem requested:'{}', closing ssh session", serverSession, subsystem);
267                             String reason = "Only netconf subsystem is supported, requested:" + subsystem;
268                             closeSession(ss, reason);
269                         }
270                     }
271                 };
272             }
273
274             public void closeSession(ServerSession ss, String reason) {
275                 logger.trace("{} Closing session - {}", serverSession, reason);
276                 try {
277                     ss.getStdin().write(reason.getBytes());
278                 } catch (IOException e) {
279                     logger.warn("{} Exception while closing session", serverSession, e);
280                 }
281                 ss.close();
282             }
283
284             @Override
285             public Runnable requestPtyReq(final ServerSession ss, final PtySettings pty) throws IOException {
286                 return new Runnable() {
287                     @Override
288                     public void run() {
289                         closeSession(ss, "PTY request not supported");
290                     }
291                 };
292             }
293
294             @Override
295             public Runnable requestShell(final ServerSession ss) throws IOException {
296                 return new Runnable() {
297                     @Override
298                     public void run() {
299                         closeSession(ss, "Shell not supported");
300                     }
301                 };
302             }
303         };
304     }
305 }
306
307 /**
308  * Only thread that is required during ssh session, forwards client's input to netty.
309  * When user closes connection, onEndOfInput.close() is called to tear down the local channel.
310  */
311 class ClientInputStreamPoolingThread extends Thread {
312     private static final Logger logger = LoggerFactory.getLogger(ClientInputStreamPoolingThread.class);
313
314     private final InputStream fromClientIS;
315     private final Channel serverChannel;
316     private final AutoCloseable onEndOfInput;
317     private final ChannelHandlerContext channelHandlerContext;
318
319     ClientInputStreamPoolingThread(String session, InputStream fromClientIS, Channel serverChannel, AutoCloseable onEndOfInput,
320                                    ChannelHandlerContext channelHandlerContext) {
321         super(ClientInputStreamPoolingThread.class.getSimpleName() + " " + session);
322         this.fromClientIS = fromClientIS;
323         this.serverChannel = serverChannel;
324         this.onEndOfInput = onEndOfInput;
325         this.channelHandlerContext = channelHandlerContext;
326     }
327
328     @Override
329     public void run() {
330         ChunkedStream chunkedStream = new ChunkedStream(fromClientIS);
331         try {
332             ByteBuf byteBuf;
333             while ((byteBuf = chunkedStream.readChunk(channelHandlerContext/*only needed for ByteBuf alloc */)) != null) {
334                 serverChannel.writeAndFlush(byteBuf);
335             }
336         } catch (Exception e) {
337             logger.warn("Exception", e);
338         } finally {
339             logger.trace("End of input");
340             // tear down connection
341             try {
342                 onEndOfInput.close();
343             } catch (Exception e) {
344                 logger.warn("Ignoring exception while closing socket", e);
345             }
346         }
347     }
348 }
349
350 /**
351  * Authentication handler for ganymed.
352  * Provides current user name after authenticating using supplied AuthProvider.
353  */
354 @NotThreadSafe
355 class ServerAuthenticationCallbackImpl implements ServerAuthenticationCallback, Supplier<String> {
356     private static final Logger logger = LoggerFactory.getLogger(ServerAuthenticationCallbackImpl.class);
357     private final AuthProvider authProvider;
358     private final String session;
359     private String currentUser;
360
361     ServerAuthenticationCallbackImpl(AuthProvider authProvider, String session) {
362         this.authProvider = authProvider;
363         this.session = session;
364     }
365
366     @Override
367     public String initAuthentication(ServerConnection sc) {
368         logger.trace("{} Established connection", session);
369         return "Established connection" + "\r\n";
370     }
371
372     @Override
373     public String[] getRemainingAuthMethods(ServerConnection sc) {
374         return new String[]{ServerAuthenticationCallback.METHOD_PASSWORD};
375     }
376
377     @Override
378     public AuthenticationResult authenticateWithNone(ServerConnection sc, String username) {
379         return AuthenticationResult.FAILURE;
380     }
381
382     @Override
383     public AuthenticationResult authenticateWithPassword(ServerConnection sc, String username, String password) {
384         checkState(currentUser == null);
385         try {
386             if (authProvider.authenticated(username, password)) {
387                 currentUser = username;
388                 logger.trace("{} user {} authenticated", session, currentUser);
389                 return AuthenticationResult.SUCCESS;
390             }
391         } catch (Exception e) {
392             logger.warn("{} Authentication failed", session, e);
393         }
394         return AuthenticationResult.FAILURE;
395     }
396
397     @Override
398     public AuthenticationResult authenticateWithPublicKey(ServerConnection sc, String username, String algorithm,
399                                                           byte[] publicKey, byte[] signature) {
400         return AuthenticationResult.FAILURE;
401     }
402
403     @Override
404     public String get() {
405         return currentUser;
406     }
407 }