Refactor FramingMechanismHandlerFactory
[netconf.git] / netconf / netconf-netty-util / src / test / java / org / opendaylight / netconf / nettyutil / AbstractNetconfSessionNegotiatorTest.java
1 /*
2  * Copyright (c) 2016 Cisco Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.netconf.nettyutil;
9
10 import static org.hamcrest.CoreMatchers.instanceOf;
11 import static org.hamcrest.MatcherAssert.assertThat;
12 import static org.junit.Assert.assertEquals;
13 import static org.junit.Assert.assertNotNull;
14 import static org.mockito.ArgumentMatchers.any;
15 import static org.mockito.ArgumentMatchers.eq;
16 import static org.mockito.Mockito.doNothing;
17 import static org.mockito.Mockito.doReturn;
18 import static org.mockito.Mockito.mock;
19 import static org.mockito.Mockito.spy;
20 import static org.mockito.Mockito.times;
21 import static org.mockito.Mockito.verify;
22 import static org.opendaylight.netconf.nettyutil.AbstractChannelInitializer.NETCONF_MESSAGE_AGGREGATOR;
23 import static org.opendaylight.netconf.nettyutil.AbstractChannelInitializer.NETCONF_MESSAGE_FRAME_ENCODER;
24
25 import io.netty.buffer.ByteBuf;
26 import io.netty.buffer.Unpooled;
27 import io.netty.channel.ChannelHandlerContext;
28 import io.netty.channel.ChannelInboundHandlerAdapter;
29 import io.netty.channel.ChannelOutboundHandler;
30 import io.netty.channel.ChannelOutboundHandlerAdapter;
31 import io.netty.channel.ChannelPromise;
32 import io.netty.channel.embedded.EmbeddedChannel;
33 import io.netty.handler.ssl.SslHandler;
34 import io.netty.util.Timeout;
35 import io.netty.util.Timer;
36 import io.netty.util.TimerTask;
37 import io.netty.util.concurrent.Future;
38 import io.netty.util.concurrent.Promise;
39 import java.util.ArrayList;
40 import java.util.List;
41 import java.util.Optional;
42 import java.util.Set;
43 import java.util.concurrent.TimeUnit;
44 import org.junit.Before;
45 import org.junit.Test;
46 import org.junit.runner.RunWith;
47 import org.mockito.ArgumentCaptor;
48 import org.mockito.Mock;
49 import org.mockito.junit.MockitoJUnitRunner;
50 import org.opendaylight.netconf.api.CapabilityURN;
51 import org.opendaylight.netconf.api.NetconfSessionListener;
52 import org.opendaylight.netconf.api.messages.HelloMessage;
53 import org.opendaylight.netconf.api.xml.XmlUtil;
54 import org.opendaylight.netconf.nettyutil.handler.ChunkedFramingMechanismEncoder;
55 import org.opendaylight.netconf.nettyutil.handler.EOMFramingMechanismEncoder;
56 import org.opendaylight.netconf.nettyutil.handler.NetconfChunkAggregator;
57 import org.opendaylight.netconf.nettyutil.handler.NetconfEOMAggregator;
58 import org.opendaylight.netconf.nettyutil.handler.NetconfXMLToHelloMessageDecoder;
59
60 @RunWith(MockitoJUnitRunner.StrictStubs.class)
61 public class AbstractNetconfSessionNegotiatorTest {
62     @Mock
63     private NetconfSessionListener<TestingNetconfSession> listener;
64     @Mock
65     private Promise<TestingNetconfSession> promise;
66     @Mock
67     private SslHandler sslHandler;
68     @Mock
69     private Timer timer;
70     @Mock
71     private Timeout timeout;
72     private EmbeddedChannel channel;
73     private TestSessionNegotiator negotiator;
74     private HelloMessage hello;
75     private HelloMessage helloBase11;
76     private NetconfXMLToHelloMessageDecoder xmlToHello;
77
78     @Before
79     public void setUp() {
80         channel = new EmbeddedChannel();
81         xmlToHello = new NetconfXMLToHelloMessageDecoder();
82         channel.pipeline().addLast(AbstractChannelInitializer.NETCONF_MESSAGE_ENCODER,
83                 new ChannelInboundHandlerAdapter());
84         channel.pipeline().addLast(AbstractChannelInitializer.NETCONF_MESSAGE_DECODER, xmlToHello);
85         channel.pipeline().addLast(NETCONF_MESSAGE_FRAME_ENCODER, new EOMFramingMechanismEncoder());
86         channel.pipeline().addLast(NETCONF_MESSAGE_AGGREGATOR, new NetconfEOMAggregator());
87         hello = HelloMessage.createClientHello(Set.of(), Optional.empty());
88         helloBase11 = HelloMessage.createClientHello(Set.of(CapabilityURN.BASE_1_1), Optional.empty());
89         doReturn(promise).when(promise).setFailure(any());
90         negotiator = new TestSessionNegotiator(helloBase11, promise, channel, timer, listener, 100L);
91     }
92
93     @Test
94     public void testStartNegotiation() {
95         enableTimerTask();
96         negotiator.startNegotiation();
97         assertEquals(helloBase11, channel.readOutbound());
98     }
99
100     @Test
101     public void testStartNegotiationSsl() throws Exception {
102         doReturn(true).when(sslHandler).isSharable();
103         doNothing().when(sslHandler).handlerAdded(any());
104         doNothing().when(sslHandler).write(any(), any(), any());
105         final Future<EmbeddedChannel> handshakeFuture = channel.eventLoop().newSucceededFuture(channel);
106         doReturn(handshakeFuture).when(sslHandler).handshakeFuture();
107         doNothing().when(sslHandler).flush(any());
108         channel.pipeline().addLast(sslHandler);
109
110         enableTimerTask();
111         negotiator.startNegotiation();
112         verify(sslHandler).write(any(), eq(helloBase11), any());
113     }
114
115     @Test
116     public void testStartNegotiationNotEstablished() throws Exception {
117         final ChannelOutboundHandler closedDetector = spy(new CloseDetector());
118         channel.pipeline().addLast("closedDetector", closedDetector);
119         doReturn(false).when(promise).isDone();
120         doReturn(false).when(promise).isCancelled();
121
122         final ArgumentCaptor<TimerTask> captor = ArgumentCaptor.forClass(TimerTask.class);
123         doReturn(timeout).when(timer).newTimeout(captor.capture(), eq(100L), eq(TimeUnit.MILLISECONDS));
124         negotiator.startNegotiation();
125
126         captor.getValue().run(timeout);
127         channel.runPendingTasks();
128         verify(closedDetector).close(any(), any());
129     }
130
131     @Test
132     public void testGetSessionForHelloMessage() throws Exception {
133         enableTimerTask();
134         negotiator.startNegotiation();
135         final TestingNetconfSession session = negotiator.getSessionForHelloMessage(hello);
136         assertNotNull(session);
137         assertThat(channel.pipeline().get(NETCONF_MESSAGE_AGGREGATOR), instanceOf(NetconfEOMAggregator.class));
138         assertThat(channel.pipeline().get(NETCONF_MESSAGE_FRAME_ENCODER), instanceOf(EOMFramingMechanismEncoder.class));
139     }
140
141     @Test
142     public void testGetSessionForHelloMessageBase11() throws Exception {
143         enableTimerTask();
144         negotiator.startNegotiation();
145         final TestingNetconfSession session = negotiator.getSessionForHelloMessage(helloBase11);
146         assertNotNull(session);
147         assertThat(channel.pipeline().get(NETCONF_MESSAGE_AGGREGATOR), instanceOf(NetconfChunkAggregator.class));
148         assertThat(channel.pipeline().get(NETCONF_MESSAGE_FRAME_ENCODER),
149             instanceOf(ChunkedFramingMechanismEncoder.class));
150     }
151
152     @Test
153     public void testReplaceHelloMessageInboundHandler() throws Exception {
154         final List<Object> out = new ArrayList<>();
155         final byte[] msg = "<rpc/>".getBytes();
156         final ByteBuf msgBuf = Unpooled.wrappedBuffer(msg);
157         final ByteBuf helloBuf = Unpooled.wrappedBuffer(XmlUtil.toString(hello.getDocument()).getBytes());
158
159         enableTimerTask();
160         negotiator.startNegotiation();
161
162         xmlToHello.decode(null, helloBuf, out);
163         xmlToHello.decode(null, msgBuf, out);
164         final TestingNetconfSession session = mock(TestingNetconfSession.class);
165         doNothing().when(session).handleMessage(any());
166         negotiator.replaceHelloMessageInboundHandler(session);
167         verify(session, times(1)).handleMessage(any());
168     }
169
170     @Test
171     public void testNegotiationFail() {
172         enableTimerTask();
173         doReturn(true).when(timeout).cancel();
174         negotiator.startNegotiation();
175         final RuntimeException cause = new RuntimeException("failure cause");
176         channel.pipeline().fireExceptionCaught(cause);
177         verify(promise).setFailure(cause);
178     }
179
180     private void enableTimerTask() {
181         doReturn(timeout).when(timer).newTimeout(any(), eq(100L), eq(TimeUnit.MILLISECONDS));
182     }
183
184     private static final class CloseDetector extends ChannelOutboundHandlerAdapter {
185         @Override
186         public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) {
187             // Override needed so @Skip from superclass is not effective
188         }
189     }
190 }