Use ConcurrentHashMap in TesttoolNegotiationFactory
[netconf.git] / netconf / netconf-netty-util / src / main / java / org / opendaylight / netconf / nettyutil / AbstractNetconfSessionNegotiator.java
index afd77d17b17ea1c59b5d970cac86a24d5d1f6fd2..a5d65a37c9959fdf58558a935ad5ced2d44d2de4 100644 (file)
@@ -19,6 +19,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
 import io.netty.handler.ssl.SslHandler;
 import io.netty.util.Timeout;
 import io.netty.util.Timer;
+import io.netty.util.concurrent.Future;
 import io.netty.util.concurrent.Promise;
 import java.util.concurrent.TimeUnit;
 import org.checkerframework.checker.index.qual.NonNegative;
@@ -26,17 +27,18 @@ import org.checkerframework.checker.lock.qual.GuardedBy;
 import org.checkerframework.checker.lock.qual.Holding;
 import org.eclipse.jdt.annotation.NonNull;
 import org.eclipse.jdt.annotation.Nullable;
+import org.opendaylight.netconf.api.CapabilityURN;
+import org.opendaylight.netconf.api.NamespaceURN;
 import org.opendaylight.netconf.api.NetconfDocumentedException;
-import org.opendaylight.netconf.api.NetconfMessage;
 import org.opendaylight.netconf.api.NetconfSessionListener;
-import org.opendaylight.netconf.api.messages.NetconfHelloMessage;
+import org.opendaylight.netconf.api.messages.HelloMessage;
+import org.opendaylight.netconf.api.messages.NetconfMessage;
 import org.opendaylight.netconf.api.xml.XmlNetconfConstants;
-import org.opendaylight.netconf.nettyutil.handler.FramingMechanismHandlerFactory;
+import org.opendaylight.netconf.nettyutil.handler.ChunkedFramingMechanismEncoder;
 import org.opendaylight.netconf.nettyutil.handler.NetconfChunkAggregator;
 import org.opendaylight.netconf.nettyutil.handler.NetconfMessageToXMLEncoder;
 import org.opendaylight.netconf.nettyutil.handler.NetconfXMLToHelloMessageDecoder;
 import org.opendaylight.netconf.nettyutil.handler.NetconfXMLToMessageDecoder;
-import org.opendaylight.netconf.util.messages.FramingMechanism;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.w3c.dom.Document;
@@ -76,7 +78,7 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
         LOG.debug("Default maximum incoming NETCONF chunk size is {} bytes", DEFAULT_MAXIMUM_INCOMING_CHUNK_SIZE);
     }
 
-    private final @NonNull NetconfHelloMessage localHello;
+    private final @NonNull HelloMessage localHello;
     protected final Channel channel;
 
     private final @NonNegative int maximumIncomingChunkSize;
@@ -90,11 +92,11 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
     @GuardedBy("this")
     private State state = State.IDLE;
 
-    protected AbstractNetconfSessionNegotiator(final NetconfHelloMessage hello, final Promise<S> promise,
+    protected AbstractNetconfSessionNegotiator(final HelloMessage hello, final Promise<S> promise,
                                                final Channel channel, final Timer timer, final L sessionListener,
                                                final long connectionTimeoutMillis,
                                                final @NonNegative int maximumIncomingChunkSize) {
-        this.localHello = requireNonNull(hello);
+        localHello = requireNonNull(hello);
         this.promise = requireNonNull(promise);
         this.channel = requireNonNull(channel);
         this.timer = timer;
@@ -104,15 +106,7 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
         checkArgument(maximumIncomingChunkSize > 0, "Invalid maximum incoming chunk size %s", maximumIncomingChunkSize);
     }
 
-    @Deprecated(since = "4.0.1", forRemoval = true)
-    protected AbstractNetconfSessionNegotiator(final NetconfHelloMessage hello, final Promise<S> promise,
-                                               final Channel channel, final Timer timer,
-                                               final L sessionListener, final long connectionTimeoutMillis) {
-        this(hello, promise, channel, timer, sessionListener, connectionTimeoutMillis,
-            DEFAULT_MAXIMUM_INCOMING_CHUNK_SIZE);
-    }
-
-    protected final @NonNull NetconfHelloMessage localHello() {
+    protected final @NonNull HelloMessage localHello() {
         return localHello;
     }
 
@@ -133,9 +127,13 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
         }
     }
 
-    protected final synchronized boolean ifNegotiatedAlready() {
+    protected final boolean ifNegotiatedAlready() {
         // Indicates whether negotiation already started
-        return this.state != State.IDLE;
+        return state() != State.IDLE;
+    }
+
+    private synchronized State state() {
+        return state;
     }
 
     private static @Nullable SslHandler getSslHandler(final Channel channel) {
@@ -143,13 +141,37 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
     }
 
     private void start() {
-        LOG.debug("Session negotiation started with hello message {} on channel {}", localHello, channel);
+        LOG.debug("Sending negotiation proposal {} on channel {}", localHello, channel);
 
-        channel.pipeline().addLast(NAME_OF_EXCEPTION_HANDLER, new ExceptionHandlingInboundChannelHandler());
+        // Send the message out, but to not run listeners just yet, as we have some more state transitions to go through
+        final var helloFuture = channel.writeAndFlush(localHello);
+
+        // Quick check: if the future has already failed we call it quits before negotiation even started
+        final var helloCause = helloFuture.cause();
+        if (helloCause != null) {
+            LOG.warn("Failed to send negotiation proposal on channel {}", channel, helloCause);
+            failAndClose();
+            return;
+        }
 
-        sendMessage(localHello);
+        // Catch any exceptions from this point on. Use a named class to ease debugging.
+        final class ExceptionHandlingInboundChannelHandler extends ChannelInboundHandlerAdapter {
+            @Override
+            public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
+                LOG.warn("An exception occurred during negotiation with {} on channel {}",
+                        channel.remoteAddress(), channel, cause);
+                // FIXME: this is quite suspect as it is competing with timeoutExpired() without synchronization
+                cancelTimeout();
+                negotiationFailed(cause);
+                changeState(State.FAILED);
+            }
+        }
+
+        channel.pipeline().addLast(NAME_OF_EXCEPTION_HANDLER, new ExceptionHandlingInboundChannelHandler());
 
-        replaceHelloMessageOutboundHandler();
+        // Remove special outbound handler for hello message. Insert regular netconf xml message (en|de)coders.
+        replaceChannelHandler(channel, AbstractChannelInitializer.NETCONF_MESSAGE_ENCODER,
+            new NetconfMessageToXMLEncoder());
 
         synchronized (this) {
             lockedChangeState(State.OPEN_WAIT);
@@ -158,6 +180,21 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
             timeoutTask = timer.newTimeout(unused -> channel.eventLoop().execute(this::timeoutExpired),
                 connectionTimeoutMillis, TimeUnit.MILLISECONDS);
         }
+
+        LOG.debug("Session negotiation started on channel {}", channel);
+
+        // State transition completed, now run any additional processing
+        helloFuture.addListener(this::onHelloWriteComplete);
+    }
+
+    private void onHelloWriteComplete(final Future<?> future) {
+        final var cause = future.cause();
+        if (cause != null) {
+            LOG.info("Failed to send message {} on channel {}", localHello, channel, cause);
+            negotiationFailed(cause);
+        } else {
+            LOG.trace("Message {} sent to socket on channel {}", localHello, channel);
+        }
     }
 
     private synchronized void timeoutExpired() {
@@ -176,22 +213,27 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
             if (!promise.isDone() && !promise.isCancelled()) {
                 LOG.warn("Netconf session backed by channel {} was not established after {}", channel,
                     connectionTimeoutMillis);
-                changeState(State.FAILED);
-
-                channel.close().addListener(future -> {
-                    final var cause = future.cause();
-                    if (cause != null) {
-                        LOG.warn("Channel {} closed: fail", channel, cause);
-                    } else {
-                        LOG.debug("Channel {} closed: success", channel);
-                    }
-                });
+                failAndClose();
             }
         } else if (channel.isOpen()) {
             channel.pipeline().remove(NAME_OF_EXCEPTION_HANDLER);
         }
     }
 
+    private void failAndClose() {
+        changeState(State.FAILED);
+        channel.close().addListener(this::onChannelClosed);
+    }
+
+    private void onChannelClosed(final Future<?> future) {
+        final var cause = future.cause();
+        if (cause != null) {
+            LOG.warn("Channel {} closed: fail", channel, cause);
+        } else {
+            LOG.debug("Channel {} closed: success", channel);
+        }
+    }
+
     private synchronized void cancelTimeout() {
         if (timeoutTask != null && !timeoutTask.cancel()) {
             // Late-coming cancel: make sure the task does not actually run
@@ -199,7 +241,7 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
         }
     }
 
-    protected final S getSessionForHelloMessage(final NetconfHelloMessage netconfMessage)
+    protected final S getSessionForHelloMessage(final HelloMessage netconfMessage)
             throws NetconfDocumentedException {
         final Document doc = netconfMessage.getDocument();
 
@@ -211,7 +253,7 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
         return getSession(sessionListener, channel, netconfMessage);
     }
 
-    protected abstract S getSession(L sessionListener, Channel channel, NetconfHelloMessage message)
+    protected abstract S getSession(L sessionListener, Channel channel, HelloMessage message)
         throws NetconfDocumentedException;
 
     /**
@@ -219,9 +261,9 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
      */
     private void insertChunkFramingToPipeline() {
         replaceChannelHandler(channel, AbstractChannelInitializer.NETCONF_MESSAGE_FRAME_ENCODER,
-                FramingMechanismHandlerFactory.createHandler(FramingMechanism.CHUNK));
+            new ChunkedFramingMechanismEncoder());
         replaceChannelHandler(channel, AbstractChannelInitializer.NETCONF_MESSAGE_AGGREGATOR,
-                new NetconfChunkAggregator(maximumIncomingChunkSize));
+            new NetconfChunkAggregator(maximumIncomingChunkSize));
     }
 
     private boolean shouldUseChunkFraming(final Document doc) {
@@ -253,14 +295,6 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
         }
     }
 
-    /**
-     * Remove special outbound handler for hello message. Insert regular netconf xml message (en|de)coders.
-     */
-    private void replaceHelloMessageOutboundHandler() {
-        replaceChannelHandler(channel, AbstractChannelInitializer.NETCONF_MESSAGE_ENCODER,
-                new NetconfMessageToXMLEncoder());
-    }
-
     private static ChannelHandler replaceChannelHandler(final Channel channel, final String handlerKey,
                                                         final ChannelHandler decoder) {
         return channel.pipeline().replace(handlerKey, handlerKey, decoder);
@@ -275,15 +309,13 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
         LOG.debug("Changing state from : {} to : {} for channel: {}", state, newState, channel);
         checkState(isStateChangePermitted(state, newState),
                 "Cannot change state from %s to %s for channel %s", state, newState, channel);
-        this.state = newState;
+        state = newState;
     }
 
     private static boolean containsBase11Capability(final Document doc) {
-        final NodeList nList = doc.getElementsByTagNameNS(
-            XmlNetconfConstants.URN_IETF_PARAMS_XML_NS_NETCONF_BASE_1_0,
-            XmlNetconfConstants.CAPABILITY);
+        final NodeList nList = doc.getElementsByTagNameNS(NamespaceURN.BASE, XmlNetconfConstants.CAPABILITY);
         for (int i = 0; i < nList.getLength(); i++) {
-            if (nList.item(i).getTextContent().contains(XmlNetconfConstants.URN_IETF_PARAMS_NETCONF_BASE_1_1)) {
+            if (nList.item(i).getTextContent().contains(CapabilityURN.BASE_1_1)) {
                 return true;
             }
         }
@@ -291,34 +323,16 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
     }
 
     private static boolean isStateChangePermitted(final State state, final State newState) {
-        if (state == State.IDLE && newState == State.OPEN_WAIT) {
-            return true;
-        }
-        if (state == State.OPEN_WAIT && newState == State.ESTABLISHED) {
+        if (state == State.IDLE && (newState == State.OPEN_WAIT || newState == State.FAILED)) {
             return true;
         }
-        if (state == State.OPEN_WAIT && newState == State.FAILED) {
+        if (state == State.OPEN_WAIT && (newState == State.ESTABLISHED || newState == State.FAILED)) {
             return true;
         }
         LOG.debug("Transition from {} to {} is not allowed", state, newState);
         return false;
     }
 
-    /**
-     * Handler to catch exceptions in pipeline during negotiation.
-     */
-    private final class ExceptionHandlingInboundChannelHandler extends ChannelInboundHandlerAdapter {
-        @Override
-        public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
-            LOG.warn("An exception occurred during negotiation with {} on channel {}",
-                    channel.remoteAddress(), channel, cause);
-            // FIXME: this is quite suspect as it is competing with timeoutExpired() without synchronization
-            cancelTimeout();
-            negotiationFailed(cause);
-            changeState(State.FAILED);
-        }
-    }
-
     protected final void negotiationSuccessful(final S session) {
         LOG.debug("Negotiation on channel {} successful with session {}", channel, session);
         channel.pipeline().replace(this, "session", session);
@@ -331,24 +345,6 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
         promise.setFailure(cause);
     }
 
-    /**
-     * Send a message to peer and fail negotiation if it does not reach
-     * the peer.
-     *
-     * @param msg Message which should be sent.
-     */
-    protected void sendMessage(final NetconfMessage msg) {
-        channel.writeAndFlush(msg).addListener(f -> {
-            final var cause = f.cause();
-            if (cause != null) {
-                LOG.info("Failed to send message {} on channel {}", msg, channel, cause);
-                negotiationFailed(cause);
-            } else {
-                LOG.trace("Message {} sent to socket on channel {}", msg, channel);
-            }
-        });
-    }
-
     @Override
     @SuppressWarnings("checkstyle:illegalCatch")
     public final void channelActive(final ChannelHandlerContext ctx) {
@@ -364,9 +360,14 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
     @Override
     @SuppressWarnings("checkstyle:illegalCatch")
     public final void channelRead(final ChannelHandlerContext ctx, final Object msg) {
+        if (state() == State.FAILED) {
+            // We have already failed -- do not process any more messages
+            return;
+        }
+
         LOG.debug("Negotiation read invoked on channel {}", channel);
         try {
-            handleMessage((NetconfHelloMessage) msg);
+            handleMessage((HelloMessage) msg);
         } catch (final Exception e) {
             LOG.debug("Unexpected error while handling negotiation message {} on channel {}", msg, channel, e);
             negotiationFailed(e);
@@ -379,5 +380,5 @@ public abstract class AbstractNetconfSessionNegotiator<S extends AbstractNetconf
         negotiationFailed(cause);
     }
 
-    protected abstract void handleMessage(NetconfHelloMessage msg) throws Exception;
+    protected abstract void handleMessage(HelloMessage msg) throws Exception;
 }