Migrate netconf/netconf-netty-util
[netconf.git] / netconf / netconf-netty-util / src / test / java / org / opendaylight / netconf / nettyutil / AbstractNetconfSessionNegotiatorTest.java
1 /*
2  * Copyright (c) 2016 Cisco Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.netconf.nettyutil;
9
10 import static org.junit.jupiter.api.Assertions.assertEquals;
11 import static org.junit.jupiter.api.Assertions.assertInstanceOf;
12 import static org.junit.jupiter.api.Assertions.assertNotNull;
13 import static org.mockito.ArgumentMatchers.any;
14 import static org.mockito.ArgumentMatchers.eq;
15 import static org.mockito.Mockito.doNothing;
16 import static org.mockito.Mockito.doReturn;
17 import static org.mockito.Mockito.mock;
18 import static org.mockito.Mockito.spy;
19 import static org.mockito.Mockito.times;
20 import static org.mockito.Mockito.verify;
21 import static org.opendaylight.netconf.nettyutil.AbstractChannelInitializer.NETCONF_MESSAGE_AGGREGATOR;
22 import static org.opendaylight.netconf.nettyutil.AbstractChannelInitializer.NETCONF_MESSAGE_FRAME_ENCODER;
23
24 import io.netty.buffer.ByteBuf;
25 import io.netty.buffer.Unpooled;
26 import io.netty.channel.ChannelHandlerContext;
27 import io.netty.channel.ChannelInboundHandlerAdapter;
28 import io.netty.channel.ChannelOutboundHandler;
29 import io.netty.channel.ChannelOutboundHandlerAdapter;
30 import io.netty.channel.ChannelPromise;
31 import io.netty.channel.embedded.EmbeddedChannel;
32 import io.netty.handler.ssl.SslHandler;
33 import io.netty.util.Timeout;
34 import io.netty.util.TimerTask;
35 import io.netty.util.concurrent.Future;
36 import io.netty.util.concurrent.Promise;
37 import java.util.ArrayList;
38 import java.util.List;
39 import java.util.Optional;
40 import java.util.Set;
41 import java.util.concurrent.TimeUnit;
42 import org.junit.jupiter.api.BeforeEach;
43 import org.junit.jupiter.api.Test;
44 import org.junit.jupiter.api.extension.ExtendWith;
45 import org.mockito.ArgumentCaptor;
46 import org.mockito.Mock;
47 import org.mockito.junit.jupiter.MockitoExtension;
48 import org.opendaylight.netconf.api.CapabilityURN;
49 import org.opendaylight.netconf.api.NetconfSessionListener;
50 import org.opendaylight.netconf.api.messages.HelloMessage;
51 import org.opendaylight.netconf.api.xml.XmlUtil;
52 import org.opendaylight.netconf.common.NetconfTimer;
53 import org.opendaylight.netconf.nettyutil.handler.ChunkedFramingMechanismEncoder;
54 import org.opendaylight.netconf.nettyutil.handler.EOMFramingMechanismEncoder;
55 import org.opendaylight.netconf.nettyutil.handler.NetconfChunkAggregator;
56 import org.opendaylight.netconf.nettyutil.handler.NetconfEOMAggregator;
57 import org.opendaylight.netconf.nettyutil.handler.NetconfXMLToHelloMessageDecoder;
58
59 @ExtendWith(MockitoExtension.class)
60 class AbstractNetconfSessionNegotiatorTest {
61     @Mock
62     private NetconfSessionListener<TestingNetconfSession> listener;
63     @Mock
64     private Promise<TestingNetconfSession> promise;
65     @Mock
66     private SslHandler sslHandler;
67     @Mock
68     private NetconfTimer timer;
69     @Mock
70     private Timeout timeout;
71     private EmbeddedChannel channel;
72     private TestSessionNegotiator negotiator;
73     private HelloMessage hello;
74     private HelloMessage helloBase11;
75     private NetconfXMLToHelloMessageDecoder xmlToHello;
76
77     @BeforeEach
78     void setUp() {
79         channel = new EmbeddedChannel();
80         xmlToHello = new NetconfXMLToHelloMessageDecoder();
81         channel.pipeline().addLast(AbstractChannelInitializer.NETCONF_MESSAGE_ENCODER,
82                 new ChannelInboundHandlerAdapter());
83         channel.pipeline().addLast(AbstractChannelInitializer.NETCONF_MESSAGE_DECODER, xmlToHello);
84         channel.pipeline().addLast(NETCONF_MESSAGE_FRAME_ENCODER, new EOMFramingMechanismEncoder());
85         channel.pipeline().addLast(NETCONF_MESSAGE_AGGREGATOR, new NetconfEOMAggregator());
86         hello = HelloMessage.createClientHello(Set.of(), Optional.empty());
87         helloBase11 = HelloMessage.createClientHello(Set.of(CapabilityURN.BASE_1_1), Optional.empty());
88         negotiator = new TestSessionNegotiator(helloBase11, promise, channel, timer, listener, 100L);
89     }
90
91     @Test
92     void testStartNegotiation() {
93         enableTimerTask();
94         negotiator.startNegotiation();
95         assertEquals(helloBase11, channel.readOutbound());
96     }
97
98     @Test
99     void testStartNegotiationSsl() throws Exception {
100         doReturn(true).when(sslHandler).isSharable();
101         doNothing().when(sslHandler).handlerAdded(any());
102         doNothing().when(sslHandler).write(any(), any(), any());
103         final Future<EmbeddedChannel> handshakeFuture = channel.eventLoop().newSucceededFuture(channel);
104         doReturn(handshakeFuture).when(sslHandler).handshakeFuture();
105         doNothing().when(sslHandler).flush(any());
106         channel.pipeline().addLast(sslHandler);
107
108         enableTimerTask();
109         negotiator.startNegotiation();
110         verify(sslHandler).write(any(), eq(helloBase11), any());
111     }
112
113     @Test
114     void testStartNegotiationNotEstablished() throws Exception {
115         final ChannelOutboundHandler closedDetector = spy(new CloseDetector());
116         channel.pipeline().addLast("closedDetector", closedDetector);
117         doReturn(false).when(promise).isDone();
118         doReturn(false).when(promise).isCancelled();
119
120         final ArgumentCaptor<TimerTask> captor = ArgumentCaptor.forClass(TimerTask.class);
121         doReturn(timeout).when(timer).newTimeout(captor.capture(), eq(100L), eq(TimeUnit.MILLISECONDS));
122         negotiator.startNegotiation();
123
124         captor.getValue().run(timeout);
125         channel.runPendingTasks();
126         verify(closedDetector).close(any(), any());
127     }
128
129     @Test
130     void testGetSessionForHelloMessage() throws Exception {
131         enableTimerTask();
132         negotiator.startNegotiation();
133         final TestingNetconfSession session = negotiator.getSessionForHelloMessage(hello);
134         assertNotNull(session);
135         assertInstanceOf(NetconfEOMAggregator.class, channel.pipeline().get(NETCONF_MESSAGE_AGGREGATOR));
136         assertInstanceOf(EOMFramingMechanismEncoder.class, channel.pipeline().get(NETCONF_MESSAGE_FRAME_ENCODER));
137     }
138
139     @Test
140     void testGetSessionForHelloMessageBase11() throws Exception {
141         enableTimerTask();
142         negotiator.startNegotiation();
143         final TestingNetconfSession session = negotiator.getSessionForHelloMessage(helloBase11);
144         assertNotNull(session);
145         assertInstanceOf(NetconfChunkAggregator.class, channel.pipeline().get(NETCONF_MESSAGE_AGGREGATOR));
146         assertInstanceOf(ChunkedFramingMechanismEncoder.class, channel.pipeline().get(NETCONF_MESSAGE_FRAME_ENCODER));
147     }
148
149     @Test
150     void testReplaceHelloMessageInboundHandler() throws Exception {
151         final List<Object> out = new ArrayList<>();
152         final byte[] msg = "<rpc/>".getBytes();
153         final ByteBuf msgBuf = Unpooled.wrappedBuffer(msg);
154         final ByteBuf helloBuf = Unpooled.wrappedBuffer(XmlUtil.toString(hello.getDocument()).getBytes());
155
156         enableTimerTask();
157         negotiator.startNegotiation();
158
159         xmlToHello.decode(null, helloBuf, out);
160         xmlToHello.decode(null, msgBuf, out);
161         final TestingNetconfSession session = mock(TestingNetconfSession.class);
162         doNothing().when(session).handleMessage(any());
163         negotiator.replaceHelloMessageInboundHandler(session);
164         verify(session, times(1)).handleMessage(any());
165     }
166
167     @Test
168     void testNegotiationFail() {
169         doReturn(promise).when(promise).setFailure(any());
170
171         enableTimerTask();
172         doReturn(true).when(timeout).cancel();
173         negotiator.startNegotiation();
174         final RuntimeException cause = new RuntimeException("failure cause");
175         channel.pipeline().fireExceptionCaught(cause);
176         verify(promise).setFailure(cause);
177     }
178
179     private void enableTimerTask() {
180         doReturn(timeout).when(timer).newTimeout(any(), eq(100L), eq(TimeUnit.MILLISECONDS));
181     }
182
183     private static final class CloseDetector extends ChannelOutboundHandlerAdapter {
184         @Override
185         public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) {
186             // Override needed so @Skip from superclass is not effective
187         }
188     }
189 }