Add AsyncSshHandler.onConnectComplete()
[netconf.git] / netconf / netconf-netty-util / src / test / java / org / opendaylight / netconf / nettyutil / handler / ssh / client / AsyncSshHandlerTest.java
1 /*
2  * Copyright (c) 2014 Cisco Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.netconf.nettyutil.handler.ssh.client;
9
10 import static org.mockito.ArgumentMatchers.any;
11 import static org.mockito.ArgumentMatchers.anyBoolean;
12 import static org.mockito.ArgumentMatchers.eq;
13 import static org.mockito.Mockito.doAnswer;
14 import static org.mockito.Mockito.doNothing;
15 import static org.mockito.Mockito.doReturn;
16 import static org.mockito.Mockito.doThrow;
17 import static org.mockito.Mockito.mock;
18 import static org.mockito.Mockito.spy;
19 import static org.mockito.Mockito.times;
20 import static org.mockito.Mockito.verify;
21 import static org.mockito.Mockito.verifyNoMoreInteractions;
22
23 import com.google.common.util.concurrent.FutureCallback;
24 import com.google.common.util.concurrent.Futures;
25 import com.google.common.util.concurrent.ListenableFuture;
26 import com.google.common.util.concurrent.MoreExecutors;
27 import com.google.common.util.concurrent.SettableFuture;
28 import io.netty.buffer.Unpooled;
29 import io.netty.channel.Channel;
30 import io.netty.channel.ChannelConfig;
31 import io.netty.channel.ChannelFuture;
32 import io.netty.channel.ChannelHandlerContext;
33 import io.netty.channel.ChannelPromise;
34 import io.netty.channel.DefaultChannelPromise;
35 import java.io.IOException;
36 import java.net.SocketAddress;
37 import java.util.concurrent.TimeUnit;
38 import org.junit.After;
39 import org.junit.Before;
40 import org.junit.Ignore;
41 import org.junit.Test;
42 import org.junit.runner.RunWith;
43 import org.mockito.Mock;
44 import org.mockito.junit.MockitoJUnitRunner;
45 import org.opendaylight.netconf.nettyutil.handler.ssh.authentication.AuthenticationHandler;
46 import org.opendaylight.netconf.shaded.sshd.client.channel.ClientChannel;
47 import org.opendaylight.netconf.shaded.sshd.client.future.AuthFuture;
48 import org.opendaylight.netconf.shaded.sshd.client.future.ConnectFuture;
49 import org.opendaylight.netconf.shaded.sshd.client.future.OpenFuture;
50 import org.opendaylight.netconf.shaded.sshd.client.session.ClientSession;
51 import org.opendaylight.netconf.shaded.sshd.common.future.CloseFuture;
52 import org.opendaylight.netconf.shaded.sshd.common.future.SshFuture;
53 import org.opendaylight.netconf.shaded.sshd.common.future.SshFutureListener;
54 import org.opendaylight.netconf.shaded.sshd.common.io.IoInputStream;
55 import org.opendaylight.netconf.shaded.sshd.common.io.IoOutputStream;
56 import org.opendaylight.netconf.shaded.sshd.common.io.IoReadFuture;
57 import org.opendaylight.netconf.shaded.sshd.common.io.IoWriteFuture;
58 import org.opendaylight.netconf.shaded.sshd.common.io.WritePendingException;
59 import org.opendaylight.netconf.shaded.sshd.common.util.buffer.Buffer;
60 import org.opendaylight.netconf.shaded.sshd.common.util.buffer.ByteArrayBuffer;
61
62 @RunWith(MockitoJUnitRunner.StrictStubs.class)
63 public class AsyncSshHandlerTest {
64
65     @Mock
66     private NetconfSshClient sshClient;
67     @Mock
68     private AuthenticationHandler authHandler;
69     @Mock
70     private ChannelHandlerContext ctx;
71     @Mock
72     private Channel channel;
73     @Mock
74     private SocketAddress remoteAddress;
75     @Mock
76     private SocketAddress localAddress;
77     @Mock
78     private ChannelConfig channelConfig;
79
80     private AsyncSshHandler asyncSshHandler;
81
82     private SshFutureListener<ConnectFuture> sshConnectListener;
83     private SshFutureListener<AuthFuture> sshAuthListener;
84     private SshFutureListener<OpenFuture> sshChannelOpenListener;
85     private ChannelPromise promise;
86
87     @Before
88     public void setUp() throws Exception {
89         stubAuth();
90         stubSshClient();
91         stubChannel();
92         stubCtx();
93
94         promise = getMockedPromise();
95
96         asyncSshHandler = new AsyncSshHandler(authHandler, sshClient);
97     }
98
99     @After
100     public void tearDown() throws Exception {
101         sshConnectListener = null;
102         sshAuthListener = null;
103         sshChannelOpenListener = null;
104         promise = null;
105         asyncSshHandler.close(ctx, getMockedPromise());
106     }
107
108     private void stubAuth() throws IOException {
109         doReturn("usr").when(authHandler).getUsername();
110
111         final AuthFuture authFuture = mock(AuthFuture.class);
112         Futures.addCallback(stubAddListener(authFuture), new SuccessFutureListener<AuthFuture>() {
113             @Override
114             public void onSuccess(final SshFutureListener<AuthFuture> result) {
115                 sshAuthListener = result;
116             }
117         }, MoreExecutors.directExecutor());
118         doReturn(authFuture).when(authHandler).authenticate(any(ClientSession.class));
119     }
120
121     @SuppressWarnings("unchecked")
122     private static <T extends SshFuture<T>> ListenableFuture<SshFutureListener<T>> stubAddListener(final T future) {
123         final SettableFuture<SshFutureListener<T>> listenerSettableFuture = SettableFuture.create();
124
125         doAnswer(invocation -> {
126             listenerSettableFuture.set((SshFutureListener<T>) invocation.getArguments()[0]);
127             return null;
128         }).when(future).addListener(any(SshFutureListener.class));
129
130         return listenerSettableFuture;
131     }
132
133     private void stubCtx() {
134         doReturn(channel).when(ctx).channel();
135         doReturn(ctx).when(ctx).fireChannelActive();
136         doReturn(ctx).when(ctx).fireChannelInactive();
137         doReturn(mock(ChannelFuture.class)).when(ctx).disconnect(any(ChannelPromise.class));
138         doReturn(getMockedPromise()).when(ctx).newPromise();
139     }
140
141     private void stubChannel() {
142         doReturn("channel").when(channel).toString();
143     }
144
145     private void stubSshClient() throws IOException {
146         final ConnectFuture connectFuture = mock(ConnectFuture.class);
147         Futures.addCallback(stubAddListener(connectFuture), new SuccessFutureListener<ConnectFuture>() {
148             @Override
149             public void onSuccess(final SshFutureListener<ConnectFuture> result) {
150                 sshConnectListener = result;
151             }
152         }, MoreExecutors.directExecutor());
153         doReturn(connectFuture).when(sshClient).connect("usr", remoteAddress);
154         doReturn(channelConfig).when(channel).config();
155         doReturn(1).when(channelConfig).getConnectTimeoutMillis();
156         doReturn(connectFuture).when(connectFuture).verify(1,TimeUnit.MILLISECONDS);
157     }
158
159     @Test
160     public void testConnectSuccess() throws Exception {
161         asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
162
163         final IoInputStream asyncOut = getMockedIoInputStream();
164         final IoOutputStream asyncIn = getMockedIoOutputStream();
165         final NettyAwareChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
166         final ClientSession sshSession = getMockedSshSession(subsystemChannel);
167         final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
168
169         sshConnectListener.operationComplete(connectFuture);
170         sshAuthListener.operationComplete(getSuccessAuthFuture());
171         sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
172
173         verify(subsystemChannel).setStreaming(ClientChannel.Streaming.Async);
174
175         verify(promise).setSuccess();
176         verify(ctx).fireChannelActive();
177         asyncSshHandler.close(ctx, getMockedPromise());
178         verify(ctx).fireChannelInactive();
179     }
180
181     @Test
182     public void testWrite() throws Exception {
183         asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
184
185         final IoInputStream asyncOut = getMockedIoInputStream();
186         final IoOutputStream asyncIn = getMockedIoOutputStream();
187         final NettyAwareChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
188         final ClientSession sshSession = getMockedSshSession(subsystemChannel);
189         final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
190
191         sshConnectListener.operationComplete(connectFuture);
192         sshAuthListener.operationComplete(getSuccessAuthFuture());
193         sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
194
195         final ChannelPromise writePromise = getMockedPromise();
196         asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0, 1, 2, 3, 4, 5}), writePromise);
197
198         verify(writePromise).setSuccess();
199     }
200
201     @Test
202     public void testWriteClosed() throws Exception {
203         asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
204
205         final IoInputStream asyncOut = getMockedIoInputStream();
206         final IoOutputStream asyncIn = getMockedIoOutputStream();
207
208         final IoWriteFuture ioWriteFuture = asyncIn.writeBuffer(new ByteArrayBuffer());
209
210         Futures.addCallback(stubAddListener(ioWriteFuture), new SuccessFutureListener<IoWriteFuture>() {
211             @Override
212             public void onSuccess(final SshFutureListener<IoWriteFuture> result) {
213                 doReturn(false).when(ioWriteFuture).isWritten();
214                 doReturn(new IllegalStateException()).when(ioWriteFuture).getException();
215                 result.operationComplete(ioWriteFuture);
216             }
217         }, MoreExecutors.directExecutor());
218
219         final NettyAwareChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
220         final ClientSession sshSession = getMockedSshSession(subsystemChannel);
221         final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
222
223         sshConnectListener.operationComplete(connectFuture);
224         sshAuthListener.operationComplete(getSuccessAuthFuture());
225         sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
226
227         final ChannelPromise writePromise = getMockedPromise();
228         asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0,1,2,3,4,5}), writePromise);
229
230         verify(writePromise).setFailure(any(Throwable.class));
231     }
232
233     @Test
234     public void testWritePendingOne() throws Exception {
235         asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
236
237         final IoInputStream asyncOut = getMockedIoInputStream();
238         final IoOutputStream asyncIn = getMockedIoOutputStream();
239         final IoWriteFuture ioWriteFuture = asyncIn.writeBuffer(new ByteArrayBuffer());
240
241         final NettyAwareChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
242         final ClientSession sshSession = getMockedSshSession(subsystemChannel);
243         final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
244
245         sshConnectListener.operationComplete(connectFuture);
246         sshAuthListener.operationComplete(getSuccessAuthFuture());
247         sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
248
249         final ChannelPromise firstWritePromise = getMockedPromise();
250
251         // intercept listener for first write,
252         // so we can invoke successful write later thus simulate pending of the first write
253         final ListenableFuture<SshFutureListener<IoWriteFuture>> firstWriteListenerFuture =
254                 stubAddListener(ioWriteFuture);
255         asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0,1,2,3,4,5}), firstWritePromise);
256         final SshFutureListener<IoWriteFuture> firstWriteListener = firstWriteListenerFuture.get();
257         // intercept second listener,
258         // this is the listener for pending write for the pending write to know when pending state ended
259         final ListenableFuture<SshFutureListener<IoWriteFuture>> pendingListener = stubAddListener(ioWriteFuture);
260
261         final ChannelPromise secondWritePromise = getMockedPromise();
262         asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0, 1, 2, 3, 4, 5}), secondWritePromise);
263
264         doReturn(ioWriteFuture).when(asyncIn).writeBuffer(any(Buffer.class));
265
266         verifyNoMoreInteractions(firstWritePromise, secondWritePromise);
267
268         // make first write stop pending
269         firstWriteListener.operationComplete(ioWriteFuture);
270
271         // notify listener for second write that pending has ended
272         pendingListener.get().operationComplete(ioWriteFuture);
273
274         // verify both write promises successful
275         verify(firstWritePromise).setSuccess();
276         verify(secondWritePromise).setSuccess();
277     }
278
279     @Ignore("Pending queue is not limited")
280     @Test
281     public void testWritePendingMax() throws Exception {
282         asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
283
284         final IoInputStream asyncOut = getMockedIoInputStream();
285         final IoOutputStream asyncIn = getMockedIoOutputStream();
286         final IoWriteFuture ioWriteFuture = asyncIn.writeBuffer(new ByteArrayBuffer());
287
288         final NettyAwareChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
289         final ClientSession sshSession = getMockedSshSession(subsystemChannel);
290         final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
291
292         sshConnectListener.operationComplete(connectFuture);
293         sshAuthListener.operationComplete(getSuccessAuthFuture());
294         sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
295
296         final ChannelPromise firstWritePromise = getMockedPromise();
297
298         // intercept listener for first write,
299         // so we can invoke successful write later thus simulate pending of the first write
300         final ListenableFuture<SshFutureListener<IoWriteFuture>> firstWriteListenerFuture =
301                 stubAddListener(ioWriteFuture);
302         asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0,1,2,3,4,5}), firstWritePromise);
303
304         final ChannelPromise secondWritePromise = getMockedPromise();
305         // now make write throw pending exception
306         doThrow(WritePendingException.class).when(asyncIn).writeBuffer(any(Buffer.class));
307         for (int i = 0; i < 1001; i++) {
308             asyncSshHandler.write(ctx, Unpooled.copiedBuffer(new byte[]{0, 1, 2, 3, 4, 5}), secondWritePromise);
309         }
310
311         verify(secondWritePromise, times(1)).setFailure(any(Throwable.class));
312     }
313
314     @Test
315     public void testDisconnect() throws Exception {
316         asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
317
318         final IoInputStream asyncOut = getMockedIoInputStream();
319         final IoOutputStream asyncIn = getMockedIoOutputStream();
320         final NettyAwareChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
321         final ClientSession sshSession = getMockedSshSession(subsystemChannel);
322         final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
323
324         sshConnectListener.operationComplete(connectFuture);
325         sshAuthListener.operationComplete(getSuccessAuthFuture());
326         sshChannelOpenListener.operationComplete(getSuccessOpenFuture());
327
328         final ChannelPromise disconnectPromise = getMockedPromise();
329         asyncSshHandler.disconnect(ctx, disconnectPromise);
330
331         verify(sshSession).close(anyBoolean());
332         verify(disconnectPromise).setSuccess();
333         //verify(ctx).fireChannelInactive();
334     }
335
336     private static OpenFuture getSuccessOpenFuture() {
337         final OpenFuture failedOpenFuture = mock(OpenFuture.class);
338         doReturn(true).when(failedOpenFuture).isOpened();
339         return failedOpenFuture;
340     }
341
342     private static AuthFuture getSuccessAuthFuture() {
343         final AuthFuture authFuture = mock(AuthFuture.class);
344         doReturn(true).when(authFuture).isSuccess();
345         return authFuture;
346     }
347
348     private static ConnectFuture getSuccessConnectFuture(final ClientSession sshSession) {
349         final ConnectFuture connectFuture = mock(ConnectFuture.class);
350         doReturn(null).when(connectFuture).getException();
351
352         doReturn(sshSession).when(connectFuture).getSession();
353         return connectFuture;
354     }
355
356     private static NettyAwareClientSession getMockedSshSession(final NettyAwareChannelSubsystem subsystemChannel)
357             throws IOException {
358         final NettyAwareClientSession sshSession = mock(NettyAwareClientSession.class);
359
360         doReturn("serverVersion").when(sshSession).getServerVersion();
361         doReturn(false).when(sshSession).isClosed();
362         doReturn(false).when(sshSession).isClosing();
363         final CloseFuture closeFuture = mock(CloseFuture.class);
364         Futures.addCallback(stubAddListener(closeFuture), new SuccessFutureListener<CloseFuture>() {
365             @Override
366             public void onSuccess(final SshFutureListener<CloseFuture> result) {
367                 doReturn(true).when(closeFuture).isClosed();
368                 result.operationComplete(closeFuture);
369             }
370         }, MoreExecutors.directExecutor());
371         doReturn(closeFuture).when(sshSession).close(false);
372
373         doReturn(subsystemChannel).when(sshSession).createSubsystemChannel(eq("netconf"), any());
374
375         return sshSession;
376     }
377
378     private NettyAwareChannelSubsystem getMockedSubsystemChannel(final IoInputStream asyncOut,
379                                                        final IoOutputStream asyncIn) throws IOException {
380         final NettyAwareChannelSubsystem subsystemChannel = mock(NettyAwareChannelSubsystem.class);
381
382         doNothing().when(subsystemChannel).setStreaming(any(ClientChannel.Streaming.class));
383         final OpenFuture openFuture = mock(OpenFuture.class);
384
385         Futures.addCallback(stubAddListener(openFuture), new SuccessFutureListener<OpenFuture>() {
386             @Override
387             public void onSuccess(final SshFutureListener<OpenFuture> result) {
388                 sshChannelOpenListener = result;
389             }
390         }, MoreExecutors.directExecutor());
391
392         doReturn(openFuture).when(subsystemChannel).open();
393         doReturn(asyncIn).when(subsystemChannel).getAsyncIn();
394         doNothing().when(subsystemChannel).onClose(any());
395         doNothing().when(subsystemChannel).close();
396         return subsystemChannel;
397     }
398
399     private static IoOutputStream getMockedIoOutputStream() throws IOException {
400         final IoOutputStream mock = mock(IoOutputStream.class);
401         final IoWriteFuture ioWriteFuture = mock(IoWriteFuture.class);
402         doReturn(true).when(ioWriteFuture).isWritten();
403
404         Futures.addCallback(stubAddListener(ioWriteFuture), new SuccessFutureListener<IoWriteFuture>() {
405             @Override
406             public void onSuccess(final SshFutureListener<IoWriteFuture> result) {
407                 result.operationComplete(ioWriteFuture);
408             }
409         }, MoreExecutors.directExecutor());
410
411         doReturn(ioWriteFuture).when(mock).writeBuffer(any(Buffer.class));
412         doReturn(false).when(mock).isClosed();
413         doReturn(false).when(mock).isClosing();
414         return mock;
415     }
416
417     private static IoInputStream getMockedIoInputStream() {
418         final IoInputStream mock = mock(IoInputStream.class);
419         final IoReadFuture ioReadFuture = mock(IoReadFuture.class);
420         // Always success for read
421         Futures.addCallback(stubAddListener(ioReadFuture), new SuccessFutureListener<IoReadFuture>() {
422             @Override
423             public void onSuccess(final SshFutureListener<IoReadFuture> result) {
424                 result.operationComplete(ioReadFuture);
425             }
426         }, MoreExecutors.directExecutor());
427         return mock;
428     }
429
430     @Test
431     public void testConnectFailOpenChannel() throws Exception {
432         asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
433
434         final IoInputStream asyncOut = getMockedIoInputStream();
435         final IoOutputStream asyncIn = getMockedIoOutputStream();
436         final NettyAwareChannelSubsystem subsystemChannel = getMockedSubsystemChannel(asyncOut, asyncIn);
437         final ClientSession sshSession = getMockedSshSession(subsystemChannel);
438         final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
439
440         sshConnectListener.operationComplete(connectFuture);
441
442         sshAuthListener.operationComplete(getSuccessAuthFuture());
443
444         verify(subsystemChannel).setStreaming(ClientChannel.Streaming.Async);
445
446         sshChannelOpenListener.operationComplete(getFailedOpenFuture());
447         verify(promise).setFailure(any(Throwable.class));
448     }
449
450     @Test
451     public void testConnectFailAuth() throws Exception {
452         asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
453
454         final NettyAwareClientSession sshSession = mock(NettyAwareClientSession.class);
455         doReturn(true).when(sshSession).isClosed();
456         final ConnectFuture connectFuture = getSuccessConnectFuture(sshSession);
457
458         sshConnectListener.operationComplete(connectFuture);
459
460         final AuthFuture authFuture = getFailedAuthFuture();
461
462         sshAuthListener.operationComplete(authFuture);
463         verify(promise).setFailure(any(Throwable.class));
464         asyncSshHandler.close(ctx, getMockedPromise());
465         verify(ctx, times(0)).fireChannelInactive();
466     }
467
468     private static AuthFuture getFailedAuthFuture() {
469         final AuthFuture authFuture = mock(AuthFuture.class);
470         doReturn(false).when(authFuture).isSuccess();
471         doReturn(new IllegalStateException()).when(authFuture).getException();
472         return authFuture;
473     }
474
475     private static OpenFuture getFailedOpenFuture() {
476         final OpenFuture authFuture = mock(OpenFuture.class);
477         doReturn(false).when(authFuture).isOpened();
478         doReturn(new IllegalStateException()).when(authFuture).getException();
479         return authFuture;
480     }
481
482     @Test
483     public void testConnectFail() throws Exception {
484         asyncSshHandler.connect(ctx, remoteAddress, localAddress, promise);
485
486         final ConnectFuture connectFuture = getFailedConnectFuture();
487         sshConnectListener.operationComplete(connectFuture);
488         verify(promise).setFailure(any(Throwable.class));
489     }
490
491     private static ConnectFuture getFailedConnectFuture() {
492         final ConnectFuture connectFuture = mock(ConnectFuture.class);
493         doReturn(new IllegalStateException()).when(connectFuture).getException();
494         return connectFuture;
495     }
496
497     private ChannelPromise getMockedPromise() {
498         return spy(new DefaultChannelPromise(channel));
499     }
500
501     private abstract static class SuccessFutureListener<T extends SshFuture<T>>
502             implements FutureCallback<SshFutureListener<T>> {
503
504         @Override
505         public abstract void onSuccess(SshFutureListener<T> result);
506
507         @Override
508         public void onFailure(final Throwable throwable) {
509             throw new RuntimeException(throwable);
510         }
511     }
512 }