2 * Copyright (c) 2014 Cisco Systems, Inc. and others. All rights reserved.
4 * This program and the accompanying materials are made available under the
5 * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6 * and is available at http://www.eclipse.org/legal/epl-v10.html
8 package org.opendaylight.netconf.nettyutil.handler.ssh.client;
10 import static org.mockito.ArgumentMatchers.any;
11 import static org.mockito.ArgumentMatchers.anyBoolean;
12 import static org.mockito.ArgumentMatchers.eq;
13 import static org.mockito.Mockito.doAnswer;
14 import static org.mockito.Mockito.doNothing;
15 import static org.mockito.Mockito.doReturn;
16 import static org.mockito.Mockito.doThrow;
17 import static org.mockito.Mockito.mock;
18 import static org.mockito.Mockito.spy;
19 import static org.mockito.Mockito.times;
20 import static org.mockito.Mockito.verify;
21 import static org.mockito.Mockito.verifyNoMoreInteractions;
23 import com.google.common.util.concurrent.FutureCallback;
24 import com.google.common.util.concurrent.Futures;
25 import com.google.common.util.concurrent.ListenableFuture;
26 import com.google.common.util.concurrent.MoreExecutors;
27 import com.google.common.util.concurrent.SettableFuture;
28 import io.netty.buffer.Unpooled;
29 import io.netty.channel.Channel;
30 import io.netty.channel.ChannelConfig;
31 import io.netty.channel.ChannelFuture;
32 import io.netty.channel.ChannelHandlerContext;
33 import io.netty.channel.ChannelPromise;
34 import io.netty.channel.DefaultChannelPromise;
35 import io.netty.util.concurrent.EventExecutor;
36 import java.io.IOException;
37 import java.net.SocketAddress;
38 import java.util.concurrent.TimeUnit;
39 import org.junit.After;
40 import org.junit.Before;
41 import org.junit.Ignore;
42 import org.junit.Test;
43 import org.junit.runner.RunWith;
44 import org.mockito.Mock;
45 import org.mockito.junit.MockitoJUnitRunner;
46 import org.opendaylight.netconf.nettyutil.handler.ssh.authentication.AuthenticationHandler;
47 import org.opendaylight.netconf.shaded.sshd.client.channel.ChannelSubsystem;
48 import org.opendaylight.netconf.shaded.sshd.client.channel.ClientChannel;
49 import org.opendaylight.netconf.shaded.sshd.client.future.AuthFuture;
50 import org.opendaylight.netconf.shaded.sshd.client.future.ConnectFuture;
51 import org.opendaylight.netconf.shaded.sshd.client.future.OpenFuture;
52 import org.opendaylight.netconf.shaded.sshd.client.session.ClientSession;
53 import org.opendaylight.netconf.shaded.sshd.common.future.CloseFuture;
54 import org.opendaylight.netconf.shaded.sshd.common.future.SshFuture;
55 import org.opendaylight.netconf.shaded.sshd.common.future.SshFutureListener;
56 import org.opendaylight.netconf.shaded.sshd.common.io.IoInputStream;
57 import org.opendaylight.netconf.shaded.sshd.common.io.IoOutputStream;
58 import org.opendaylight.netconf.shaded.sshd.common.io.IoReadFuture;
59 import org.opendaylight.netconf.shaded.sshd.common.io.IoWriteFuture;
60 import org.opendaylight.netconf.shaded.sshd.common.io.WritePendingException;
61 import org.opendaylight.netconf.shaded.sshd.common.util.buffer.Buffer;
62 import org.opendaylight.netconf.shaded.sshd.common.util.buffer.ByteArrayBuffer;
64 @RunWith(MockitoJUnitRunner.StrictStubs.class)
65 public class AsyncSshHandlerTest {
68 private NetconfSshClient sshClient;
70 private AuthenticationHandler authHandler;
72 private ChannelHandlerContext ctx;
74 private Channel channel;
76 private SocketAddress remoteAddress;
78 private SocketAddress localAddress;
80 private ChannelConfig channelConfig;
82 private EventExecutor executor;
84 private AsyncSshHandler asyncSshHandler;
86 private SshFutureListener<ConnectFuture> sshConnectListener;
87 private SshFutureListener<AuthFuture> sshAuthListener;
88 private SshFutureListener<OpenFuture> sshChannelOpenListener;
89 private ChannelPromise promise;
92 public void setUp() throws Exception {
98 promise = getMockedPromise();
100 asyncSshHandler = new AsyncSshHandler(authHandler, sshClient);
104 public void tearDown() throws Exception {
105 sshConnectListener = null;
106 sshAuthListener = null;
107 sshChannelOpenListener = null;
109 asyncSshHandler.close(ctx, getMockedPromise());
112 private void stubAuth() throws IOException {
113 doReturn("usr").when(authHandler).getUsername();
115 final AuthFuture authFuture = mock(AuthFuture.class);
116 Futures.addCallback(stubAddListener(authFuture), new SuccessFutureListener<AuthFuture>() {
118 public void onSuccess(final SshFutureListener<AuthFuture> result) {
119 sshAuthListener = result;
121 }, MoreExecutors.directExecutor());
122 doReturn(authFuture).when(authHandler).authenticate(any(ClientSession.class));
125 @SuppressWarnings("unchecked")
126 private static <T extends SshFuture<T>> ListenableFuture<SshFutureListener<T>> stubAddListener(final T future) {
127 final SettableFuture<SshFutureListener<T>> listenerSettableFuture = SettableFuture.create();
129 doAnswer(invocation -> {
130 listenerSettableFuture.set((SshFutureListener<T>) invocation.getArguments()[0]);
132 }).when(future).addListener(any(SshFutureListener.class));
134 return listenerSettableFuture;
137 private void stubCtx() {
138 doReturn(channel).when(ctx).channel();
139 doReturn(ctx).when(ctx).fireChannelActive();
140 doReturn(ctx).when(ctx).fireChannelInactive();
141 doReturn(mock(ChannelFuture.class)).when(ctx).disconnect(any(ChannelPromise.class));
142 doReturn(getMockedPromise()).when(ctx).newPromise();
143 doReturn(executor).when(ctx).executor();
144 doAnswer(invocation -> {
145 invocation.getArgument(0, Runnable.class).run();
147 }).when(executor).execute(any());
150 private void stubChannel() {
151 doReturn("channel").when(channel).toString();
154 private void stubSshClient() throws IOException {
155 final ConnectFuture connectFuture = mock(ConnectFuture.class);
156 Futures.addCallback(stubAddListener(connectFuture), new SuccessFutureListener<ConnectFuture>() {
158 public void onSuccess(final SshFutureListener<ConnectFuture> result) {
159 sshConnectListener = result;
161 }, MoreExecutors.directExecutor());
162 doReturn(connectFuture).when(sshClient).connect("usr", remoteAddress);
163 doReturn(channelConfig).when(channel).config();
164 doReturn(1).when(channelConfig).getConnectTimeoutMillis();
165 doReturn(connectFuture).when(connectFuture).verify(1,TimeUnit.MILLISECONDS);
169 public void testConnectSuccess() throws Exception {
170 asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
172 final IoInputStream asyncOut = getMockedIoInputStream();
173 final IoOutputStream asyncIn = getMockedIoOutputStream();
174 final ChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
175 final ClientSession sshSession = getMockedSshSession(subsystemChannel);
176 final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
178 sshConnectListener.operationComplete(connectFuture);
179 sshAuthListener.operationComplete(getSuccessAuthFuture());
180 sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
182 verify(subsystemChannel).setStreaming(ClientChannel.Streaming.Async);
184 verify(promise).setSuccess();
185 verify(ctx).fireChannelActive();
186 asyncSshHandler.close(ctx, getMockedPromise());
187 verify(ctx).fireChannelInactive();
191 public void testWrite() throws Exception {
192 asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
194 final IoInputStream asyncOut = getMockedIoInputStream();
195 final IoOutputStream asyncIn = getMockedIoOutputStream();
196 final ChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
197 final ClientSession sshSession = getMockedSshSession(subsystemChannel);
198 final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
200 sshConnectListener.operationComplete(connectFuture);
201 sshAuthListener.operationComplete(getSuccessAuthFuture());
202 sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
204 final ChannelPromise writePromise = getMockedPromise();
205 asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0, 1, 2, 3, 4, 5}), writePromise);
207 verify(writePromise).setSuccess();
211 public void testWriteClosed() throws Exception {
212 asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
214 final IoInputStream asyncOut = getMockedIoInputStream();
215 final IoOutputStream asyncIn = getMockedIoOutputStream();
217 final IoWriteFuture ioWriteFuture = asyncIn.writeBuffer(new ByteArrayBuffer());
219 Futures.addCallback(stubAddListener(ioWriteFuture), new SuccessFutureListener<IoWriteFuture>() {
221 public void onSuccess(final SshFutureListener<IoWriteFuture> result) {
222 doReturn(new IllegalStateException()).when(ioWriteFuture).getException();
223 result.operationComplete(ioWriteFuture);
225 }, MoreExecutors.directExecutor());
227 final ChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
228 final ClientSession sshSession = getMockedSshSession(subsystemChannel);
229 final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
231 sshConnectListener.operationComplete(connectFuture);
232 sshAuthListener.operationComplete(getSuccessAuthFuture());
233 sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
235 final ChannelPromise writePromise = getMockedPromise();
236 asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0,1,2,3,4,5}), writePromise);
238 verify(writePromise).setFailure(any(Throwable.class));
242 public void testWritePendingOne() throws Exception {
243 asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
245 final IoInputStream asyncOut = getMockedIoInputStream();
246 final IoOutputStream asyncIn = getMockedIoOutputStream();
247 final IoWriteFuture ioWriteFuture = asyncIn.writeBuffer(new ByteArrayBuffer());
249 final ChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
250 final ClientSession sshSession = getMockedSshSession(subsystemChannel);
251 final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
253 sshConnectListener.operationComplete(connectFuture);
254 sshAuthListener.operationComplete(getSuccessAuthFuture());
255 sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
257 final ChannelPromise firstWritePromise = getMockedPromise();
259 // intercept listener for first write,
260 // so we can invoke successful write later thus simulate pending of the first write
261 final ListenableFuture<SshFutureListener<IoWriteFuture>> firstWriteListenerFuture =
262 stubAddListener(ioWriteFuture);
263 asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0,1,2,3,4,5}), firstWritePromise);
264 final SshFutureListener<IoWriteFuture> firstWriteListener = firstWriteListenerFuture.get();
265 // intercept second listener,
266 // this is the listener for pending write for the pending write to know when pending state ended
267 final ListenableFuture<SshFutureListener<IoWriteFuture>> pendingListener = stubAddListener(ioWriteFuture);
269 final ChannelPromise secondWritePromise = getMockedPromise();
270 asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0, 1, 2, 3, 4, 5}), secondWritePromise);
272 doReturn(ioWriteFuture).when(asyncIn).writeBuffer(any(Buffer.class));
274 verifyNoMoreInteractions(firstWritePromise, secondWritePromise);
276 // make first write stop pending
277 firstWriteListener.operationComplete(ioWriteFuture);
279 // notify listener for second write that pending has ended
280 pendingListener.get().operationComplete(ioWriteFuture);
282 // verify both write promises successful
283 verify(firstWritePromise).setSuccess();
284 verify(secondWritePromise).setSuccess();
287 @Ignore("Pending queue is not limited")
289 public void testWritePendingMax() throws Exception {
290 asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
292 final IoInputStream asyncOut = getMockedIoInputStream();
293 final IoOutputStream asyncIn = getMockedIoOutputStream();
294 final IoWriteFuture ioWriteFuture = asyncIn.writeBuffer(new ByteArrayBuffer());
296 final ChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
297 final ClientSession sshSession = getMockedSshSession(subsystemChannel);
298 final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
300 sshConnectListener.operationComplete(connectFuture);
301 sshAuthListener.operationComplete(getSuccessAuthFuture());
302 sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
304 final ChannelPromise firstWritePromise = getMockedPromise();
306 // intercept listener for first write,
307 // so we can invoke successful write later thus simulate pending of the first write
308 final ListenableFuture<SshFutureListener<IoWriteFuture>> firstWriteListenerFuture =
309 stubAddListener(ioWriteFuture);
310 asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0,1,2,3,4,5}), firstWritePromise);
312 final ChannelPromise secondWritePromise = getMockedPromise();
313 // now make write throw pending exception
314 doThrow(WritePendingException.class).when(asyncIn).writeBuffer(any(Buffer.class));
315 for (int i = 0; i < 1001; i++) {
316 asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0, 1, 2, 3, 4, 5}), secondWritePromise);
319 verify(secondWritePromise, times(1)).setFailure(any(Throwable.class));
323 public void testDisconnect() throws Exception {
324 asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
326 final IoInputStream asyncOut = getMockedIoInputStream();
327 final IoOutputStream asyncIn = getMockedIoOutputStream();
328 final ChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
329 final ClientSession sshSession = getMockedSshSession(subsystemChannel);
330 final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
332 sshConnectListener.operationComplete(connectFuture);
333 sshAuthListener.operationComplete(getSuccessAuthFuture());
334 sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
336 final ChannelPromise disconnectPromise = getMockedPromise();
337 asyncSshHandler.disconnect(ctx, disconnectPromise);
339 verify(sshSession).close(anyBoolean());
340 verify(disconnectPromise).setSuccess();
341 //verify(ctx).fireChannelInactive();
344 private static OpenFuture getSuccessOpenFuture() {
345 final OpenFuture openFuture = mock(OpenFuture.class);
346 doReturn(null).when(openFuture).getException();
350 private static AuthFuture getSuccessAuthFuture() {
351 final AuthFuture authFuture = mock(AuthFuture.class);
352 doReturn(null).when(authFuture).getException();
356 private static ConnectFuture getSuccessConnectFuture(final ClientSession sshSession) {
357 final ConnectFuture connectFuture = mock(ConnectFuture.class);
358 doReturn(null).when(connectFuture).getException();
360 doReturn(sshSession).when(connectFuture).getSession();
361 return connectFuture;
364 private static NettyAwareClientSession getMockedSshSession(final ChannelSubsystem subsystemChannel)
366 final NettyAwareClientSession sshSession = mock(NettyAwareClientSession.class);
368 doReturn("serverVersion").when(sshSession).getServerVersion();
369 doReturn(false).when(sshSession).isClosed();
370 doReturn(false).when(sshSession).isClosing();
371 final CloseFuture closeFuture = mock(CloseFuture.class);
372 Futures.addCallback(stubAddListener(closeFuture), new SuccessFutureListener<>() {
374 public void onSuccess(final SshFutureListener<CloseFuture> result) {
375 doReturn(true).when(closeFuture).isClosed();
376 result.operationComplete(closeFuture);
378 }, MoreExecutors.directExecutor());
379 doReturn(closeFuture).when(sshSession).close(false);
381 doReturn(subsystemChannel).when(sshSession).createSubsystemChannel(eq("netconf"),
382 any(ChannelHandlerContext.class));
387 private ChannelSubsystem getMockedSubsystemChannel(final IoInputStream asyncOut,
388 final IoOutputStream asyncIn) throws IOException {
389 final ChannelSubsystem subsystemChannel = mock(ChannelSubsystem.class);
391 doNothing().when(subsystemChannel).setStreaming(any(ClientChannel.Streaming.class));
392 final OpenFuture openFuture = mock(OpenFuture.class);
394 Futures.addCallback(stubAddListener(openFuture), new SuccessFutureListener<OpenFuture>() {
396 public void onSuccess(final SshFutureListener<OpenFuture> result) {
397 sshChannelOpenListener = result;
399 }, MoreExecutors.directExecutor());
401 doReturn(openFuture).when(subsystemChannel).open();
402 doReturn(asyncIn).when(subsystemChannel).getAsyncIn();
403 doNothing().when(subsystemChannel).onClose(any());
404 doReturn(null).when(subsystemChannel).close(false);
405 return subsystemChannel;
408 private static IoOutputStream getMockedIoOutputStream() throws IOException {
409 final IoOutputStream mock = mock(IoOutputStream.class);
410 final IoWriteFuture ioWriteFuture = mock(IoWriteFuture.class);
411 doReturn(null).when(ioWriteFuture).getException();
413 Futures.addCallback(stubAddListener(ioWriteFuture), new SuccessFutureListener<IoWriteFuture>() {
415 public void onSuccess(final SshFutureListener<IoWriteFuture> result) {
416 result.operationComplete(ioWriteFuture);
418 }, MoreExecutors.directExecutor());
420 doReturn(ioWriteFuture).when(mock).writeBuffer(any(Buffer.class));
421 doReturn(false).when(mock).isClosed();
422 doReturn(false).when(mock).isClosing();
426 private static IoInputStream getMockedIoInputStream() {
427 final IoInputStream mock = mock(IoInputStream.class);
428 final IoReadFuture ioReadFuture = mock(IoReadFuture.class);
429 // Always success for read
430 Futures.addCallback(stubAddListener(ioReadFuture), new SuccessFutureListener<IoReadFuture>() {
432 public void onSuccess(final SshFutureListener<IoReadFuture> result) {
433 result.operationComplete(ioReadFuture);
435 }, MoreExecutors.directExecutor());
440 public void testConnectFailOpenChannel() throws Exception {
441 asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
443 final IoInputStream asyncOut = getMockedIoInputStream();
444 final IoOutputStream asyncIn = getMockedIoOutputStream();
445 final ChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
446 final ClientSession sshSession = getMockedSshSession(subsystemChannel);
447 final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
449 sshConnectListener.operationComplete(connectFuture);
451 sshAuthListener.operationComplete(getSuccessAuthFuture());
453 verify(subsystemChannel).setStreaming(ClientChannel.Streaming.Async);
455 sshChannelOpenListener.operationComplete(getFailedOpenFuture());
456 verify(promise).setFailure(any(Throwable.class));
460 public void testConnectFailAuth() throws Exception {
461 asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
463 final NettyAwareClientSession sshSession = mock(NettyAwareClientSession.class);
464 doReturn(true).when(sshSession).isClosed();
465 final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
467 sshConnectListener.operationComplete(connectFuture);
469 final AuthFuture authFuture = getFailedAuthFuture();
471 sshAuthListener.operationComplete(authFuture);
472 verify(promise).setFailure(any(Throwable.class));
473 asyncSshHandler.close(ctx, getMockedPromise());
474 verify(ctx, times(0)).fireChannelInactive();
477 private static AuthFuture getFailedAuthFuture() {
478 final AuthFuture authFuture = mock(AuthFuture.class);
479 doReturn(new IllegalStateException()).when(authFuture).getException();
483 private static OpenFuture getFailedOpenFuture() {
484 final OpenFuture openFuture = mock(OpenFuture.class);
485 doReturn(new IllegalStateException()).when(openFuture).getException();
490 public void testConnectFail() throws Exception {
491 asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
493 final ConnectFuture connectFuture = getFailedConnectFuture();
494 sshConnectListener.operationComplete(connectFuture);
495 verify(promise).setFailure(any(Throwable.class));
498 private static ConnectFuture getFailedConnectFuture() {
499 final ConnectFuture connectFuture = mock(ConnectFuture.class);
500 doReturn(new IllegalStateException()).when(connectFuture).getException();
501 return connectFuture;
504 private ChannelPromise getMockedPromise() {
505 return spy(new DefaultChannelPromise(channel));
508 private abstract static class SuccessFutureListener<T extends SshFuture<T>>
509 implements FutureCallback<SshFutureListener<T>> {
512 public abstract void onSuccess(SshFutureListener<T> result);
515 public void onFailure(final Throwable throwable) {
516 throw new RuntimeException(throwable);