--- /dev/null
+/*
+ * Copyright (c) 2018 ZTE Corporation. and others. All rights reserved.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License v1.0 which accompanies this distribution,
+ * and is available at http://www.eclipse.org/legal/epl-v10.html
+ */
+package org.opendaylight.netconf.client;
+
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.util.concurrent.Promise;
+import org.opendaylight.netconf.nettyutil.AbstractChannelInitializer;
+
+final class TlsClientChannelInitializer extends AbstractChannelInitializer<NetconfClientSession> {
+ public static final String CHANNEL_ACTIVE_SENTRY = "channelActiveSentry";
+
+ private final SslHandlerFactory sslHandlerFactory;
+ private final NetconfClientSessionNegotiatorFactory negotiatorFactory;
+ private final NetconfClientSessionListener sessionListener;
+
+ TlsClientChannelInitializer(final SslHandlerFactory sslHandlerFactory,
+ final NetconfClientSessionNegotiatorFactory negotiatorFactory,
+ final NetconfClientSessionListener sessionListener) {
+ this.sslHandlerFactory = sslHandlerFactory;
+ this.negotiatorFactory = negotiatorFactory;
+ this.sessionListener = sessionListener;
+ }
+
+ @Override
+ public void initialize(Channel ch, Promise<NetconfClientSession> promise) {
+ // When ssl handshake fails due to the certificate mismatch, the connection will try again,
+ // then we have a chance to create a new SslHandler using the latest certificates with the
+ // help of the sentry. We will replace the sentry with the new SslHandler once the channel
+ // is active.
+ ch.pipeline().addFirst(CHANNEL_ACTIVE_SENTRY, new ChannelActiveSentry(sslHandlerFactory));
+ super.initialize(ch, promise);
+ }
+
+ @Override
+ protected void initializeSessionNegotiator(Channel ch, Promise<NetconfClientSession> promise) {
+ ch.pipeline().addAfter(NETCONF_MESSAGE_DECODER, AbstractChannelInitializer.NETCONF_SESSION_NEGOTIATOR,
+ negotiatorFactory.getSessionNegotiator(() -> sessionListener, ch, promise));
+ }
+
+ private static final class ChannelActiveSentry extends ChannelInboundHandlerAdapter {
+ private final SslHandlerFactory sslHandlerFactory;
+
+ ChannelActiveSentry(final SslHandlerFactory sslHandlerFactory) {
+ this.sslHandlerFactory = sslHandlerFactory;
+ }
+
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) throws Exception {
+ ctx.pipeline().replace(this, "sslHandler", sslHandlerFactory.createSslHandler())
+ .fireChannelActive();
+ }
+ }
+}
--- /dev/null
+/*
+ * Copyright (c) 2018 ZTE Corporation. and others. All rights reserved.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License v1.0 which accompanies this distribution,
+ * and is available at http://www.eclipse.org/legal/epl-v10.html
+ */
+package org.opendaylight.netconf.client;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelPipeline;
+import io.netty.util.concurrent.Promise;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.opendaylight.protocol.framework.SessionListenerFactory;
+import org.opendaylight.protocol.framework.SessionNegotiator;
+
+public class TlsClientChannelInitializerTest {
+ @Mock
+ private SslHandlerFactory sslHandlerFactory;
+ @Mock
+ private NetconfClientSessionNegotiatorFactory negotiatorFactory;
+ @Mock
+ private NetconfClientSessionListener sessionListener;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testInitialize() throws Exception {
+ SessionNegotiator<?> sessionNegotiator = mock(SessionNegotiator.class);
+ doReturn(sessionNegotiator).when(negotiatorFactory).getSessionNegotiator(any(SessionListenerFactory.class),
+ any(Channel.class), any(Promise.class));
+ ChannelPipeline pipeline = mock(ChannelPipeline.class);
+ doReturn(pipeline).when(pipeline).addAfter(anyString(), anyString(), any(ChannelHandler.class));
+ Channel channel = mock(Channel.class);
+ doReturn(pipeline).when(channel).pipeline();
+
+ doReturn(pipeline).when(pipeline).addFirst(anyString(), any(ChannelHandler.class));
+ doReturn(pipeline).when(pipeline).addLast(anyString(), any(ChannelHandler.class));
+
+ Promise<NetconfClientSession> promise = mock(Promise.class);
+
+ TlsClientChannelInitializer initializer = new TlsClientChannelInitializer(sslHandlerFactory,
+ negotiatorFactory, sessionListener);
+ initializer.initialize(channel, promise);
+ verify(pipeline, times(1)).addFirst(anyString(), any(ChannelHandler.class));
+ }
+}