Merge "Fix for possible NPE if Bundle is stopped."
[controller.git] / opendaylight / netconf / netconf-ssh / src / main / java / org / opendaylight / controller / netconf / ssh / threads / Handshaker.java
1 /*
2  * Copyright (c) 2013 Cisco Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.controller.netconf.ssh.threads;
9
10 import static com.google.common.base.Preconditions.checkNotNull;
11 import static com.google.common.base.Preconditions.checkState;
12
13 import ch.ethz.ssh2.AuthenticationResult;
14 import ch.ethz.ssh2.PtySettings;
15 import ch.ethz.ssh2.ServerAuthenticationCallback;
16 import ch.ethz.ssh2.ServerConnection;
17 import ch.ethz.ssh2.ServerConnectionCallback;
18 import ch.ethz.ssh2.ServerSession;
19 import ch.ethz.ssh2.ServerSessionCallback;
20 import ch.ethz.ssh2.SimpleServerSessionCallback;
21 import com.google.common.base.Supplier;
22 import io.netty.bootstrap.Bootstrap;
23 import io.netty.buffer.ByteBuf;
24 import io.netty.buffer.ByteBufProcessor;
25 import io.netty.buffer.Unpooled;
26 import io.netty.channel.Channel;
27 import io.netty.channel.ChannelFuture;
28 import io.netty.channel.ChannelHandlerContext;
29 import io.netty.channel.ChannelInboundHandlerAdapter;
30 import io.netty.channel.ChannelInitializer;
31 import io.netty.channel.EventLoopGroup;
32 import io.netty.channel.local.LocalAddress;
33 import io.netty.channel.local.LocalChannel;
34 import io.netty.handler.stream.ChunkedStream;
35 import java.io.BufferedOutputStream;
36 import java.io.IOException;
37 import java.io.InputStream;
38 import java.io.OutputStream;
39 import java.net.Socket;
40 import javax.annotation.concurrent.NotThreadSafe;
41 import javax.annotation.concurrent.ThreadSafe;
42 import org.opendaylight.controller.netconf.ssh.authentication.AuthProvider;
43 import org.opendaylight.controller.netconf.util.messages.NetconfHelloMessageAdditionalHeader;
44 import org.slf4j.Logger;
45 import org.slf4j.LoggerFactory;
46
47 /**
48  * One instance represents per connection, responsible for ssh handshake.
49  * Once auth succeeds and correct subsystem is chosen, backend connection with
50  * netty netconf server is made. This task finishes right after negotiation is done.
51  */
52 @ThreadSafe
53 public class Handshaker implements Runnable {
54     private static final Logger logger = LoggerFactory.getLogger(Handshaker.class);
55
56     private final ServerConnection ganymedConnection;
57     private final String session;
58
59
60     public Handshaker(Socket socket, LocalAddress localAddress, long sessionId, AuthProvider authProvider,
61                       EventLoopGroup bossGroup) throws IOException {
62
63         this.session = "Session " + sessionId;
64
65         String remoteAddressWithPort = socket.getRemoteSocketAddress().toString().replace("/", "");
66         logger.debug("{} started with {}", session, remoteAddressWithPort);
67         String remoteAddress, remotePort;
68         if (remoteAddressWithPort.contains(":")) {
69             String[] split = remoteAddressWithPort.split(":");
70             remoteAddress = split[0];
71             remotePort = split[1];
72         } else {
73             remoteAddress = remoteAddressWithPort;
74             remotePort = "";
75         }
76         ServerAuthenticationCallbackImpl serverAuthenticationCallback = new ServerAuthenticationCallbackImpl(
77                 authProvider, session);
78
79         ganymedConnection = new ServerConnection(socket);
80
81         ServerConnectionCallbackImpl serverConnectionCallback = new ServerConnectionCallbackImpl(
82                 serverAuthenticationCallback, remoteAddress, remotePort, session,
83                 getGanymedAutoCloseable(ganymedConnection), localAddress, bossGroup);
84
85         // initialize ganymed
86         ganymedConnection.setPEMHostKey(authProvider.getPEMAsCharArray(), null);
87         ganymedConnection.setAuthenticationCallback(serverAuthenticationCallback);
88         ganymedConnection.setServerConnectionCallback(serverConnectionCallback);
89     }
90
91
92     private static AutoCloseable getGanymedAutoCloseable(final ServerConnection ganymedConnection) {
93         return new AutoCloseable() {
94             @Override
95             public void close() throws Exception {
96                 ganymedConnection.close();
97             }
98         };
99     }
100
101     @Override
102     public void run() {
103         // let ganymed process handshake
104         logger.trace("{} is started", session);
105         try {
106             // TODO this should be guarded with a timer to prevent resource exhaustion
107             ganymedConnection.connect();
108         } catch (IOException e) {
109             logger.debug("{} connection error", session, e);
110         }
111         logger.trace("{} is exiting", session);
112     }
113 }
114
115 /**
116  * Netty client handler that forwards bytes from backed server to supplied output stream.
117  * When backend server closes the connection, remoteConnection.close() is called to tear
118  * down ssh connection.
119  */
120 class SSHClientHandler extends ChannelInboundHandlerAdapter {
121     private static final Logger logger = LoggerFactory.getLogger(SSHClientHandler.class);
122     private final AutoCloseable remoteConnection;
123     private final BufferedOutputStream remoteOutputStream;
124     private final String session;
125     private ChannelHandlerContext channelHandlerContext;
126
127     public SSHClientHandler(AutoCloseable remoteConnection, OutputStream remoteOutputStream,
128                             String session) {
129         this.remoteConnection = remoteConnection;
130         this.remoteOutputStream = new BufferedOutputStream(remoteOutputStream);
131         this.session = session;
132     }
133
134     @Override
135     public void channelActive(ChannelHandlerContext ctx) {
136         this.channelHandlerContext = ctx;
137         logger.debug("{} Client active", session);
138     }
139
140     @Override
141     public void channelRead(ChannelHandlerContext ctx, Object msg) throws IOException {
142         ByteBuf bb = (ByteBuf) msg;
143         // we can block the server here so that slow client does not cause memory pressure
144         try {
145             bb.forEachByte(new ByteBufProcessor() {
146                 @Override
147                 public boolean process(byte value) throws Exception {
148                     remoteOutputStream.write(value);
149                     return true;
150                 }
151             });
152         } finally {
153             bb.release();
154         }
155     }
156
157     @Override
158     public void channelReadComplete(ChannelHandlerContext ctx) throws IOException {
159         logger.trace("{} Flushing", session);
160         remoteOutputStream.flush();
161     }
162
163     @Override
164     public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
165         // Close the connection when an exception is raised.
166         logger.warn("{} Unexpected exception from downstream", session, cause);
167         ctx.close();
168     }
169
170     @Override
171     public void channelInactive(ChannelHandlerContext ctx) throws Exception {
172         logger.trace("{} channelInactive() called, closing remote client ctx", session);
173         remoteConnection.close();//this should close socket and all threads created for this client
174         this.channelHandlerContext = null;
175     }
176
177     public ChannelHandlerContext getChannelHandlerContext() {
178         return checkNotNull(channelHandlerContext, "Channel is not active");
179     }
180 }
181
182 /**
183  * Ganymed handler that gets unencrypted input and output streams, connects them to netty.
184  * Checks that 'netconf' subsystem is chosen by user.
185  * Launches new ClientInputStreamPoolingThread thread once session is established.
186  * Writes custom header to netty server, to inform it about IP address and username.
187  */
188 class ServerConnectionCallbackImpl implements ServerConnectionCallback {
189     private static final Logger logger = LoggerFactory.getLogger(ServerConnectionCallbackImpl.class);
190     public static final String NETCONF_SUBSYSTEM = "netconf";
191
192     private final Supplier<String> currentUserSupplier;
193     private final String remoteAddress;
194     private final String remotePort;
195     private final String session;
196     private final AutoCloseable ganymedConnection;
197     private final LocalAddress localAddress;
198     private final EventLoopGroup bossGroup;
199
200     ServerConnectionCallbackImpl(Supplier<String> currentUserSupplier, String remoteAddress, String remotePort, String session,
201                                  AutoCloseable ganymedConnection, LocalAddress localAddress, EventLoopGroup bossGroup) {
202         this.currentUserSupplier = currentUserSupplier;
203         this.remoteAddress = remoteAddress;
204         this.remotePort = remotePort;
205         this.session = session;
206         this.ganymedConnection = ganymedConnection;
207         // initialize netty local connection
208         this.localAddress = localAddress;
209         this.bossGroup = bossGroup;
210     }
211
212     private static ChannelFuture initializeNettyConnection(LocalAddress localAddress, EventLoopGroup bossGroup,
213                                                            final SSHClientHandler sshClientHandler) {
214         Bootstrap clientBootstrap = new Bootstrap();
215         clientBootstrap.group(bossGroup).channel(LocalChannel.class);
216
217         clientBootstrap.handler(new ChannelInitializer<LocalChannel>() {
218             @Override
219             public void initChannel(LocalChannel ch) throws Exception {
220                 ch.pipeline().addLast(sshClientHandler);
221             }
222         });
223         // asynchronously initialize local connection to netconf server
224         return clientBootstrap.connect(localAddress);
225     }
226
227     @Override
228     public ServerSessionCallback acceptSession(final ServerSession serverSession) {
229         String currentUser = currentUserSupplier.get();
230         final String additionalHeader = new NetconfHelloMessageAdditionalHeader(currentUser, remoteAddress,
231                 remotePort, "ssh", "client").toFormattedString();
232
233
234         return new SimpleServerSessionCallback() {
235             @Override
236             public Runnable requestSubsystem(final ServerSession ss, final String subsystem) throws IOException {
237                 return new Runnable() {
238                     @Override
239                     public void run() {
240                         if (NETCONF_SUBSYSTEM.equals(subsystem)) {
241                             // connect
242                             final SSHClientHandler sshClientHandler = new SSHClientHandler(ganymedConnection, ss.getStdin(), session);
243                             ChannelFuture clientChannelFuture = initializeNettyConnection(localAddress, bossGroup, sshClientHandler);
244                             // get channel
245                             final Channel channel = clientChannelFuture.awaitUninterruptibly().channel();
246
247                             // write additional header before polling thread is started
248                             // polling thread could process and forward data before additional header is written
249                             // This will result into unexpected state:  hello message without additional header and the next message with additional header
250                             channel.writeAndFlush(Unpooled.copiedBuffer(additionalHeader.getBytes()));
251
252                             new ClientInputStreamPoolingThread(session, ss.getStdout(), channel, new AutoCloseable() {
253                                 @Override
254                                 public void close() throws Exception {
255                                     logger.trace("Closing both ganymed and local connection");
256                                     try {
257                                         ganymedConnection.close();
258                                     } catch (Exception e) {
259                                         logger.warn("Ignoring exception while closing ganymed", e);
260                                     }
261                                     try {
262                                         channel.close();
263                                     } catch (Exception e) {
264                                         logger.warn("Ignoring exception while closing channel", e);
265                                     }
266                                 }
267                             }, sshClientHandler.getChannelHandlerContext()).start();
268                         } else {
269                             logger.debug("{} Wrong subsystem requested:'{}', closing ssh session", serverSession, subsystem);
270                             String reason = "Only netconf subsystem is supported, requested:" + subsystem;
271                             closeSession(ss, reason);
272                         }
273                     }
274                 };
275             }
276
277             public void closeSession(ServerSession ss, String reason) {
278                 logger.trace("{} Closing session - {}", serverSession, reason);
279                 try {
280                     ss.getStdin().write(reason.getBytes());
281                 } catch (IOException e) {
282                     logger.warn("{} Exception while closing session", serverSession, e);
283                 }
284                 ss.close();
285             }
286
287             @Override
288             public Runnable requestPtyReq(final ServerSession ss, final PtySettings pty) throws IOException {
289                 return new Runnable() {
290                     @Override
291                     public void run() {
292                         closeSession(ss, "PTY request not supported");
293                     }
294                 };
295             }
296
297             @Override
298             public Runnable requestShell(final ServerSession ss) throws IOException {
299                 return new Runnable() {
300                     @Override
301                     public void run() {
302                         closeSession(ss, "Shell not supported");
303                     }
304                 };
305             }
306         };
307     }
308 }
309
310 /**
311  * Only thread that is required during ssh session, forwards client's input to netty.
312  * When user closes connection, onEndOfInput.close() is called to tear down the local channel.
313  */
314 class ClientInputStreamPoolingThread extends Thread {
315     private static final Logger logger = LoggerFactory.getLogger(ClientInputStreamPoolingThread.class);
316
317     private final InputStream fromClientIS;
318     private final Channel serverChannel;
319     private final AutoCloseable onEndOfInput;
320     private final ChannelHandlerContext channelHandlerContext;
321
322     ClientInputStreamPoolingThread(String session, InputStream fromClientIS, Channel serverChannel, AutoCloseable onEndOfInput,
323                                    ChannelHandlerContext channelHandlerContext) {
324         super(ClientInputStreamPoolingThread.class.getSimpleName() + " " + session);
325         this.fromClientIS = fromClientIS;
326         this.serverChannel = serverChannel;
327         this.onEndOfInput = onEndOfInput;
328         this.channelHandlerContext = channelHandlerContext;
329     }
330
331     @Override
332     public void run() {
333         ChunkedStream chunkedStream = new ChunkedStream(fromClientIS);
334         try {
335             ByteBuf byteBuf;
336             while ((byteBuf = chunkedStream.readChunk(channelHandlerContext/*only needed for ByteBuf alloc */)) != null) {
337                 serverChannel.writeAndFlush(byteBuf);
338             }
339         } catch (Exception e) {
340             logger.warn("Exception", e);
341         } finally {
342             logger.trace("End of input");
343             // tear down connection
344             try {
345                 onEndOfInput.close();
346             } catch (Exception e) {
347                 logger.warn("Ignoring exception while closing socket", e);
348             }
349         }
350     }
351 }
352
353 /**
354  * Authentication handler for ganymed.
355  * Provides current user name after authenticating using supplied AuthProvider.
356  */
357 @NotThreadSafe
358 class ServerAuthenticationCallbackImpl implements ServerAuthenticationCallback, Supplier<String> {
359     private static final Logger logger = LoggerFactory.getLogger(ServerAuthenticationCallbackImpl.class);
360     private final AuthProvider authProvider;
361     private final String session;
362     private String currentUser;
363
364     ServerAuthenticationCallbackImpl(AuthProvider authProvider, String session) {
365         this.authProvider = authProvider;
366         this.session = session;
367     }
368
369     @Override
370     public String initAuthentication(ServerConnection sc) {
371         logger.trace("{} Established connection", session);
372         return "Established connection" + "\r\n";
373     }
374
375     @Override
376     public String[] getRemainingAuthMethods(ServerConnection sc) {
377         return new String[]{ServerAuthenticationCallback.METHOD_PASSWORD};
378     }
379
380     @Override
381     public AuthenticationResult authenticateWithNone(ServerConnection sc, String username) {
382         return AuthenticationResult.FAILURE;
383     }
384
385     @Override
386     public AuthenticationResult authenticateWithPassword(ServerConnection sc, String username, String password) {
387         checkState(currentUser == null);
388         try {
389             if (authProvider.authenticated(username, password)) {
390                 currentUser = username;
391                 logger.trace("{} user {} authenticated", session, currentUser);
392                 return AuthenticationResult.SUCCESS;
393             }
394         } catch (Exception e) {
395             logger.warn("{} Authentication failed", session, e);
396         }
397         return AuthenticationResult.FAILURE;
398     }
399
400     @Override
401     public AuthenticationResult authenticateWithPublicKey(ServerConnection sc, String username, String algorithm,
402                                                           byte[] publicKey, byte[] signature) {
403         return AuthenticationResult.FAILURE;
404     }
405
406     @Override
407     public String get() {
408         return currentUser;
409     }
410 }