Merge "BUG-625: migrate InventoryAndReadAdapter"
[controller.git] / opendaylight / netconf / netconf-ssh / src / main / java / org / opendaylight / controller / netconf / ssh / threads / Handshaker.java
1 /*
2  * Copyright (c) 2013 Cisco Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.controller.netconf.ssh.threads;
9
10 import static com.google.common.base.Preconditions.checkNotNull;
11 import static com.google.common.base.Preconditions.checkState;
12
13 import ch.ethz.ssh2.AuthenticationResult;
14 import ch.ethz.ssh2.PtySettings;
15 import ch.ethz.ssh2.ServerAuthenticationCallback;
16 import ch.ethz.ssh2.ServerConnection;
17 import ch.ethz.ssh2.ServerConnectionCallback;
18 import ch.ethz.ssh2.ServerSession;
19 import ch.ethz.ssh2.ServerSessionCallback;
20 import ch.ethz.ssh2.SimpleServerSessionCallback;
21 import com.google.common.base.Supplier;
22 import io.netty.bootstrap.Bootstrap;
23 import io.netty.buffer.ByteBuf;
24 import io.netty.buffer.ByteBufProcessor;
25 import io.netty.buffer.Unpooled;
26 import io.netty.channel.Channel;
27 import io.netty.channel.ChannelFuture;
28 import io.netty.channel.ChannelHandlerContext;
29 import io.netty.channel.ChannelInboundHandlerAdapter;
30 import io.netty.channel.ChannelInitializer;
31 import io.netty.channel.EventLoopGroup;
32 import io.netty.channel.local.LocalAddress;
33 import io.netty.channel.local.LocalChannel;
34 import io.netty.handler.stream.ChunkedStream;
35 import java.io.IOException;
36 import java.io.InputStream;
37 import java.io.OutputStream;
38 import java.net.Socket;
39 import javax.annotation.concurrent.NotThreadSafe;
40 import javax.annotation.concurrent.ThreadSafe;
41 import org.opendaylight.controller.netconf.ssh.authentication.AuthProvider;
42 import org.opendaylight.controller.netconf.util.messages.NetconfHelloMessageAdditionalHeader;
43 import org.slf4j.Logger;
44 import org.slf4j.LoggerFactory;
45
46 /**
47  * One instance represents per connection, responsible for ssh handshake.
48  * Once auth succeeds and correct subsystem is chosen, backend connection with
49  * netty netconf server is made. This task finishes right after negotiation is done.
50  */
51 @ThreadSafe
52 public class Handshaker implements Runnable {
53     private static final Logger logger = LoggerFactory.getLogger(Handshaker.class);
54
55     private final ServerConnection ganymedConnection;
56     private final String session;
57
58
59     public Handshaker(Socket socket, LocalAddress localAddress, long sessionId, AuthProvider authProvider,
60                       EventLoopGroup bossGroup) throws IOException {
61
62         this.session = "Session " + sessionId;
63
64         String remoteAddressWithPort = socket.getRemoteSocketAddress().toString().replace("/", "");
65         logger.debug("{} started with {}", session, remoteAddressWithPort);
66         String remoteAddress, remotePort;
67         if (remoteAddressWithPort.contains(":")) {
68             String[] split = remoteAddressWithPort.split(":");
69             remoteAddress = split[0];
70             remotePort = split[1];
71         } else {
72             remoteAddress = remoteAddressWithPort;
73             remotePort = "";
74         }
75         ServerAuthenticationCallbackImpl serverAuthenticationCallback = new ServerAuthenticationCallbackImpl(
76                 authProvider, session);
77
78         ganymedConnection = new ServerConnection(socket);
79
80         ServerConnectionCallbackImpl serverConnectionCallback = new ServerConnectionCallbackImpl(
81                 serverAuthenticationCallback, remoteAddress, remotePort, session,
82                 getGanymedAutoCloseable(ganymedConnection), localAddress, bossGroup);
83
84         // initialize ganymed
85         ganymedConnection.setPEMHostKey(authProvider.getPEMAsCharArray(), null);
86         ganymedConnection.setAuthenticationCallback(serverAuthenticationCallback);
87         ganymedConnection.setServerConnectionCallback(serverConnectionCallback);
88     }
89
90
91     private static AutoCloseable getGanymedAutoCloseable(final ServerConnection ganymedConnection) {
92         return new AutoCloseable() {
93             @Override
94             public void close() throws Exception {
95                 ganymedConnection.close();
96             }
97         };
98     }
99
100     @Override
101     public void run() {
102         // let ganymed process handshake
103         logger.trace("{} SocketThread is started", session);
104         try {
105             // TODO this should be guarded with a timer to prevent resource exhaustion
106             ganymedConnection.connect();
107         } catch (IOException e) {
108             logger.warn("{} SocketThread error ", session, e);
109         }
110         logger.trace("{} SocketThread is exiting", session);
111     }
112 }
113
114 /**
115  * Netty client handler that forwards bytes from backed server to supplied output stream.
116  * When backend server closes the connection, remoteConnection.close() is called to tear
117  * down ssh connection.
118  */
119 class SSHClientHandler extends ChannelInboundHandlerAdapter {
120     private static final Logger logger = LoggerFactory.getLogger(SSHClientHandler.class);
121     private final AutoCloseable remoteConnection;
122     private final OutputStream remoteOutputStream;
123     private final String session;
124     private ChannelHandlerContext channelHandlerContext;
125
126     public SSHClientHandler(AutoCloseable remoteConnection, OutputStream remoteOutputStream,
127                             String session) {
128         this.remoteConnection = remoteConnection;
129         this.remoteOutputStream = remoteOutputStream;
130         this.session = session;
131     }
132
133     @Override
134     public void channelActive(ChannelHandlerContext ctx) {
135         this.channelHandlerContext = ctx;
136         logger.debug("{} Client active", session);
137     }
138
139     @Override
140     public void channelRead(ChannelHandlerContext ctx, Object msg) {
141         ByteBuf bb = (ByteBuf) msg;
142         // we can block the server here so that slow client does not cause memory pressure
143         try {
144             bb.forEachByte(new ByteBufProcessor() {
145                 @Override
146                 public boolean process(byte value) throws Exception {
147                     remoteOutputStream.write(value);
148                     return true;
149                 }
150             });
151         } finally {
152             bb.release();
153         }
154     }
155
156     @Override
157     public void channelReadComplete(ChannelHandlerContext ctx) throws IOException {
158         logger.trace("{} Flushing", session);
159         remoteOutputStream.flush();
160     }
161
162     @Override
163     public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
164         // Close the connection when an exception is raised.
165         logger.warn("{} Unexpected exception from downstream", session, cause);
166         ctx.close();
167     }
168
169     @Override
170     public void channelInactive(ChannelHandlerContext ctx) throws Exception {
171         logger.trace("{} channelInactive() called, closing remote client ctx", session);
172         remoteConnection.close();//this should close socket and all threads created for this client
173         this.channelHandlerContext = null;
174     }
175
176     public ChannelHandlerContext getChannelHandlerContext() {
177         return checkNotNull(channelHandlerContext, "Channel is not active");
178     }
179 }
180
181 /**
182  * Ganymed handler that gets unencrypted input and output streams, connects them to netty.
183  * Checks that 'netconf' subsystem is chosen by user.
184  * Launches new ClientInputStreamPoolingThread thread once session is established.
185  * Writes custom header to netty server, to inform it about IP address and username.
186  */
187 class ServerConnectionCallbackImpl implements ServerConnectionCallback {
188     private static final Logger logger = LoggerFactory.getLogger(ServerConnectionCallbackImpl.class);
189     public static final String NETCONF_SUBSYSTEM = "netconf";
190
191     private final Supplier<String> currentUserSupplier;
192     private final String remoteAddress;
193     private final String remotePort;
194     private final String session;
195     private final AutoCloseable ganymedConnection;
196     private final LocalAddress localAddress;
197     private final EventLoopGroup bossGroup;
198
199     ServerConnectionCallbackImpl(Supplier<String> currentUserSupplier, String remoteAddress, String remotePort, String session,
200                                  AutoCloseable ganymedConnection, LocalAddress localAddress, EventLoopGroup bossGroup) {
201         this.currentUserSupplier = currentUserSupplier;
202         this.remoteAddress = remoteAddress;
203         this.remotePort = remotePort;
204         this.session = session;
205         this.ganymedConnection = ganymedConnection;
206         // initialize netty local connection
207         this.localAddress = localAddress;
208         this.bossGroup = bossGroup;
209     }
210
211     private static ChannelFuture initializeNettyConnection(LocalAddress localAddress, EventLoopGroup bossGroup,
212                                                            final SSHClientHandler sshClientHandler) {
213         Bootstrap clientBootstrap = new Bootstrap();
214         clientBootstrap.group(bossGroup).channel(LocalChannel.class);
215
216         clientBootstrap.handler(new ChannelInitializer<LocalChannel>() {
217             @Override
218             public void initChannel(LocalChannel ch) throws Exception {
219                 ch.pipeline().addLast(sshClientHandler);
220             }
221         });
222         // asynchronously initialize local connection to netconf server
223         return clientBootstrap.connect(localAddress);
224     }
225
226     @Override
227     public ServerSessionCallback acceptSession(final ServerSession serverSession) {
228         String currentUser = currentUserSupplier.get();
229         final String additionalHeader = new NetconfHelloMessageAdditionalHeader(currentUser, remoteAddress,
230                 remotePort, "ssh", "client").toFormattedString();
231
232
233         return new SimpleServerSessionCallback() {
234             @Override
235             public Runnable requestSubsystem(final ServerSession ss, final String subsystem) throws IOException {
236                 return new Runnable() {
237                     @Override
238                     public void run() {
239                         if (NETCONF_SUBSYSTEM.equals(subsystem)) {
240                             // connect
241                             final SSHClientHandler sshClientHandler = new SSHClientHandler(ganymedConnection, ss.getStdin(), session);
242                             ChannelFuture clientChannelFuture = initializeNettyConnection(localAddress, bossGroup, sshClientHandler);
243                             // get channel
244                             final Channel channel = clientChannelFuture.awaitUninterruptibly().channel();
245                             new ClientInputStreamPoolingThread(session, ss.getStdout(), channel, new AutoCloseable() {
246                                 @Override
247                                 public void close() throws Exception {
248                                     logger.trace("Closing both ganymed and local connection");
249                                     try {
250                                         ganymedConnection.close();
251                                     } catch (Exception e) {
252                                         logger.warn("Ignoring exception while closing ganymed", e);
253                                     }
254                                     try {
255                                         channel.close();
256                                     } catch (Exception e) {
257                                         logger.warn("Ignoring exception while closing channel", e);
258                                     }
259                                 }
260                             }, sshClientHandler.getChannelHandlerContext()).start();
261
262                             // write additional header
263                             channel.writeAndFlush(Unpooled.copiedBuffer(additionalHeader.getBytes()));
264                         } else {
265                             logger.debug("{} Wrong subsystem requested:'{}', closing ssh session", serverSession, subsystem);
266                             String reason = "Only netconf subsystem is supported, requested:" + subsystem;
267                             closeSession(ss, reason);
268                         }
269                     }
270                 };
271             }
272
273             public void closeSession(ServerSession ss, String reason) {
274                 logger.trace("{} Closing session - {}", serverSession, reason);
275                 try {
276                     ss.getStdin().write(reason.getBytes());
277                 } catch (IOException e) {
278                     logger.warn("{} Exception while closing session", serverSession, e);
279                 }
280                 ss.close();
281             }
282
283             @Override
284             public Runnable requestPtyReq(final ServerSession ss, final PtySettings pty) throws IOException {
285                 return new Runnable() {
286                     @Override
287                     public void run() {
288                         closeSession(ss, "PTY request not supported");
289                     }
290                 };
291             }
292
293             @Override
294             public Runnable requestShell(final ServerSession ss) throws IOException {
295                 return new Runnable() {
296                     @Override
297                     public void run() {
298                         closeSession(ss, "Shell not supported");
299                     }
300                 };
301             }
302         };
303     }
304 }
305
306 /**
307  * Only thread that is required during ssh session, forwards client's input to netty.
308  * When user closes connection, onEndOfInput.close() is called to tear down the local channel.
309  */
310 class ClientInputStreamPoolingThread extends Thread {
311     private static final Logger logger = LoggerFactory.getLogger(ClientInputStreamPoolingThread.class);
312
313     private final InputStream fromClientIS;
314     private final Channel serverChannel;
315     private final AutoCloseable onEndOfInput;
316     private final ChannelHandlerContext channelHandlerContext;
317
318     ClientInputStreamPoolingThread(String session, InputStream fromClientIS, Channel serverChannel, AutoCloseable onEndOfInput,
319                                    ChannelHandlerContext channelHandlerContext) {
320         super(ClientInputStreamPoolingThread.class.getSimpleName() + " " + session);
321         this.fromClientIS = fromClientIS;
322         this.serverChannel = serverChannel;
323         this.onEndOfInput = onEndOfInput;
324         this.channelHandlerContext = channelHandlerContext;
325     }
326
327     @Override
328     public void run() {
329         ChunkedStream chunkedStream = new ChunkedStream(fromClientIS);
330         try {
331             ByteBuf byteBuf;
332             while ((byteBuf = chunkedStream.readChunk(channelHandlerContext/*only needed for ByteBuf alloc */)) != null) {
333                 serverChannel.writeAndFlush(byteBuf);
334             }
335         } catch (Exception e) {
336             logger.warn("Exception", e);
337         } finally {
338             logger.trace("End of input");
339             // tear down connection
340             try {
341                 onEndOfInput.close();
342             } catch (Exception e) {
343                 logger.warn("Ignoring exception while closing socket", e);
344             }
345         }
346     }
347 }
348
349 /**
350  * Authentication handler for ganymed.
351  * Provides current user name after authenticating using supplied AuthProvider.
352  */
353 @NotThreadSafe
354 class ServerAuthenticationCallbackImpl implements ServerAuthenticationCallback, Supplier<String> {
355     private static final Logger logger = LoggerFactory.getLogger(ServerAuthenticationCallbackImpl.class);
356     private final AuthProvider authProvider;
357     private final String session;
358     private String currentUser;
359
360     ServerAuthenticationCallbackImpl(AuthProvider authProvider, String session) {
361         this.authProvider = authProvider;
362         this.session = session;
363     }
364
365     @Override
366     public String initAuthentication(ServerConnection sc) {
367         logger.trace("{} Established connection", session);
368         return "Established connection" + "\r\n";
369     }
370
371     @Override
372     public String[] getRemainingAuthMethods(ServerConnection sc) {
373         return new String[]{ServerAuthenticationCallback.METHOD_PASSWORD};
374     }
375
376     @Override
377     public AuthenticationResult authenticateWithNone(ServerConnection sc, String username) {
378         return AuthenticationResult.FAILURE;
379     }
380
381     @Override
382     public AuthenticationResult authenticateWithPassword(ServerConnection sc, String username, String password) {
383         checkState(currentUser == null);
384         try {
385             if (authProvider.authenticated(username, password)) {
386                 currentUser = username;
387                 logger.trace("{} user {} authenticated", session, currentUser);
388                 return AuthenticationResult.SUCCESS;
389             }
390         } catch (Exception e) {
391             logger.warn("{} Authentication failed", session, e);
392         }
393         return AuthenticationResult.FAILURE;
394     }
395
396     @Override
397     public AuthenticationResult authenticateWithPublicKey(ServerConnection sc, String username, String algorithm,
398                                                           byte[] publicKey, byte[] signature) {
399         return AuthenticationResult.FAILURE;
400     }
401
402     @Override
403     public String get() {
404         return currentUser;
405     }
406 }