BUG-1618 Handle pending writes in ssh netconfclient 03/10303/2
authorMaros Marsalek <mmarsale@cisco.com>
Tue, 26 Aug 2014 11:46:32 +0000 (13:46 +0200)
committerMaros Marsalek <mmarsale@cisco.com>
Tue, 26 Aug 2014 12:24:08 +0000 (14:24 +0200)
Change-Id: If4371860e81cf4153c4baaa8a9b0d3c45334ab5c
Signed-off-by: Maros Marsalek <mmarsale@cisco.com>
opendaylight/netconf/netconf-it/src/test/java/org/opendaylight/controller/netconf/it/NetconfITSecureTest.java
opendaylight/netconf/netconf-netty-util/src/main/java/org/opendaylight/controller/netconf/nettyutil/handler/ssh/client/AsyncSshHandler.java
opendaylight/netconf/netconf-ssh/src/test/java/org/opendaylight/controller/netconf/netty/SSHTest.java

index 56f674bc34f287d74224ad269a1274c657109e68..a9e8dbe86b06b76087c902521022ff06e239b870 100644 (file)
@@ -9,13 +9,16 @@
 package org.opendaylight.controller.netconf.it;
 
 import static java.util.Arrays.asList;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 
-import io.netty.channel.ChannelFuture;
+import com.google.common.collect.Lists;
 import io.netty.channel.EventLoopGroup;
 import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
 import io.netty.util.concurrent.GlobalEventExecutor;
 import java.io.IOException;
 import java.io.InputStream;
@@ -23,6 +26,7 @@ import java.lang.management.ManagementFactory;
 import java.net.InetSocketAddress;
 import java.util.Collection;
 import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
 import junit.framework.Assert;
 import org.junit.After;
 import org.junit.Before;
@@ -68,27 +72,27 @@ public class NetconfITSecureTest extends AbstractNetconfConfigTest {
         super.initConfigTransactionManagerImpl(new HardcodedModuleFactoriesResolver(mockedContext, getModuleFactories().toArray(
                 new ModuleFactory[0])));
 
-        NetconfOperationServiceFactoryListenerImpl factoriesListener = new NetconfOperationServiceFactoryListenerImpl();
+        final NetconfOperationServiceFactoryListenerImpl factoriesListener = new NetconfOperationServiceFactoryListenerImpl();
         factoriesListener.onAddNetconfOperationServiceFactory(new NetconfOperationServiceFactoryImpl(getYangStore()));
 
         commitNot = new DefaultCommitNotificationProducer(ManagementFactory.getPlatformMBeanServer());
 
 
         final NetconfServerDispatcher dispatchS = createDispatcher(factoriesListener);
-        ChannelFuture s = dispatchS.createLocalServer(NetconfConfigUtil.getNetconfLocalAddress());
-        s.await();
-        EventLoopGroup bossGroup  = new NioEventLoopGroup();
+        dispatchS.createLocalServer(NetconfConfigUtil.getNetconfLocalAddress()).await();
+        final EventLoopGroup bossGroup  = new NioEventLoopGroup();
         sshServer = NetconfSSHServer.start(tlsAddress.getPort(), NetconfConfigUtil.getNetconfLocalAddress(), getAuthProvider(), bossGroup);
     }
 
-    private NetconfServerDispatcher createDispatcher(NetconfOperationServiceFactoryListenerImpl factoriesListener) {
+    private NetconfServerDispatcher createDispatcher(final NetconfOperationServiceFactoryListenerImpl factoriesListener) {
         return super.createDispatcher(factoriesListener, NetconfITTest.getNetconfMonitoringListenerService(), commitNot);
     }
 
     @After
     public void tearDown() throws Exception {
-        sshServer.stop();
+        sshServer.close();
         commitNot.close();
+        sshServer.join();
     }
 
     private HardcodedYangStoreService getYangStore() throws YangStoreException, IOException {
@@ -102,13 +106,13 @@ public class NetconfITSecureTest extends AbstractNetconfConfigTest {
 
     @Test
     public void testSecure() throws Exception {
-        NetconfClientDispatcher dispatch = new NetconfClientDispatcherImpl(getNettyThreadgroup(), getNettyThreadgroup(), getHashedWheelTimer());
+        final NetconfClientDispatcher dispatch = new NetconfClientDispatcherImpl(getNettyThreadgroup(), getNettyThreadgroup(), getHashedWheelTimer());
         try (TestingNetconfClient netconfClient = new TestingNetconfClient("testing-ssh-client", dispatch, getClientConfiguration())) {
             NetconfMessage response = netconfClient.sendMessage(getConfig);
             Assert.assertFalse("Unexpected error message " + XmlUtil.toString(response.getDocument()),
                     NetconfMessageUtil.isErrorMessage(response));
 
-            NetconfMessage gs = new NetconfMessage(XmlUtil.readXmlToDocument("<rpc message-id=\"2\"\n" +
+            final NetconfMessage gs = new NetconfMessage(XmlUtil.readXmlToDocument("<rpc message-id=\"2\"\n" +
                     "     xmlns=\"urn:ietf:params:xml:ns:netconf:base:1.0\">\n" +
                     "    <get-schema xmlns=\"urn:ietf:params:xml:ns:yang:ietf-netconf-monitoring\">\n" +
                     "        <identifier>config</identifier>\n" +
@@ -121,6 +125,41 @@ public class NetconfITSecureTest extends AbstractNetconfConfigTest {
         }
     }
 
+    /**
+     * Test all requests are handled properly and no mismatch occurs in listener
+     */
+    @Test(timeout = 3*60*1000)
+    public void testSecureStress() throws Exception {
+        final NetconfClientDispatcher dispatch = new NetconfClientDispatcherImpl(getNettyThreadgroup(), getNettyThreadgroup(), getHashedWheelTimer());
+        try (TestingNetconfClient netconfClient = new TestingNetconfClient("testing-ssh-client", dispatch, getClientConfiguration())) {
+
+            final AtomicInteger responseCounter = new AtomicInteger(0);
+            final List<Future<?>> futures = Lists.newArrayList();
+
+            final int requests = 1000;
+            for (int i = 0; i < requests; i++) {
+                final Future<NetconfMessage> netconfMessageFuture = netconfClient.sendRequest(getConfig);
+                futures.add(netconfMessageFuture);
+                netconfMessageFuture.addListener(new GenericFutureListener<Future<? super NetconfMessage>>() {
+                    @Override
+                    public void operationComplete(final Future<? super NetconfMessage> future) throws Exception {
+                        assertTrue("Request unsuccessful " + future.cause(), future.isSuccess());
+                        responseCounter.incrementAndGet();
+                    }
+                });
+            }
+
+            for (final Future<?> future : futures) {
+                future.await();
+            }
+
+            // Give future listeners some time to finish counter incrementation
+            Thread.sleep(5000);
+
+            org.junit.Assert.assertEquals(requests, responseCounter.get());
+        }
+    }
+
     public NetconfClientConfiguration getClientConfiguration() throws IOException {
         final NetconfClientConfigurationBuilder b = NetconfClientConfigurationBuilder.create();
         b.withAddress(tlsAddress);
@@ -133,7 +172,7 @@ public class NetconfITSecureTest extends AbstractNetconfConfigTest {
     }
 
     public AuthProvider getAuthProvider() throws Exception {
-        AuthProvider mock = mock(AuthProviderImpl.class);
+        final AuthProvider mock = mock(AuthProviderImpl.class);
         doReturn(true).when(mock).authenticated(anyString(), anyString());
         doReturn(PEMGenerator.generate().toCharArray()).when(mock).getPEMAsCharArray();
         return mock;
index 2761a45d03bedc8730cb69036748c266cdc7f412..935cb8dcd06ca966e6c560560d2030150cced460 100644 (file)
@@ -12,6 +12,7 @@ import com.google.common.base.Preconditions;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandler;
 import io.netty.channel.ChannelOutboundHandlerAdapter;
 import io.netty.channel.ChannelPromise;
 import java.io.IOException;
@@ -25,7 +26,10 @@ import org.apache.sshd.client.future.OpenFuture;
 import org.apache.sshd.common.future.CloseFuture;
 import org.apache.sshd.common.future.SshFutureListener;
 import org.apache.sshd.common.io.IoInputStream;
+import org.apache.sshd.common.io.IoOutputStream;
 import org.apache.sshd.common.io.IoReadFuture;
+import org.apache.sshd.common.io.IoWriteFuture;
+import org.apache.sshd.common.io.WritePendingException;
 import org.apache.sshd.common.util.Buffer;
 import org.opendaylight.controller.netconf.nettyutil.handler.ssh.authentication.AuthenticationHandler;
 import org.slf4j.Logger;
@@ -53,10 +57,13 @@ public class AsyncSshHandler extends ChannelOutboundHandlerAdapter {
     private final SshClient sshClient;
 
     private SshReadAsyncListener sshReadAsyncListener;
+    private SshWriteAsyncHandler sshWriteAsyncHandler;
+
     private ClientChannel channel;
     private ClientSession session;
     private ChannelPromise connectPromise;
 
+
     public static AsyncSshHandler createForNetconfSubsystem(final AuthenticationHandler authenticationHandler) throws IOException {
         return new AsyncSshHandler(authenticationHandler, DEFAULT_CLIENT);
     }
@@ -139,10 +146,11 @@ public class AsyncSshHandler extends ChannelOutboundHandlerAdapter {
 
         connectPromise.setSuccess();
         connectPromise = null;
-        ctx.fireChannelActive();
 
-        final IoInputStream asyncOut = channel.getAsyncOut();
-        sshReadAsyncListener = new SshReadAsyncListener(ctx, asyncOut);
+        sshReadAsyncListener = new SshReadAsyncListener(ctx, channel.getAsyncOut());
+        sshWriteAsyncHandler = new SshWriteAsyncHandler(this, channel.getAsyncIn());
+
+        ctx.fireChannelActive();
     }
 
     private synchronized void handleSshSetupFailure(final ChannelHandlerContext ctx, final Throwable e) {
@@ -154,17 +162,7 @@ public class AsyncSshHandler extends ChannelOutboundHandlerAdapter {
 
     @Override
     public synchronized void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
-        try {
-            if(channel.getAsyncIn().isClosed() || channel.getAsyncIn().isClosing()) {
-                handleSshSessionClosed(ctx);
-            } else {
-                channel.getAsyncIn().write(toBuffer(msg));
-                ((ByteBuf) msg).release();
-            }
-        } catch (final Exception e) {
-            logger.warn("Exception while writing to SSH remote on channel {}", ctx.channel(), e);
-            throw new IllegalStateException("Exception while writing to SSH remote on channel " + ctx.channel(),e);
-        }
+        sshWriteAsyncHandler.write(ctx, msg, promise);
     }
 
     private static void handleSshSessionClosed(final ChannelHandlerContext ctx) {
@@ -172,15 +170,6 @@ public class AsyncSshHandler extends ChannelOutboundHandlerAdapter {
         ctx.fireChannelInactive();
     }
 
-    private Buffer toBuffer(final Object msg) {
-        // TODO Buffer vs ByteBuf translate, Can we handle that better ?
-        Preconditions.checkState(msg instanceof ByteBuf);
-        final ByteBuf byteBuf = (ByteBuf) msg;
-        final byte[] temp = new byte[byteBuf.readableBytes()];
-        byteBuf.readBytes(temp, 0, byteBuf.readableBytes());
-        return new Buffer(temp);
-    }
-
     @Override
     public synchronized void connect(final ChannelHandlerContext ctx, final SocketAddress remoteAddress, final SocketAddress localAddress, final ChannelPromise promise) throws Exception {
         this.connectPromise = promise;
@@ -193,22 +182,31 @@ public class AsyncSshHandler extends ChannelOutboundHandlerAdapter {
     }
 
     @Override
-    public synchronized void disconnect(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception {
+    public synchronized void disconnect(final ChannelHandlerContext ctx, final ChannelPromise promise) {
         if(sshReadAsyncListener != null) {
             sshReadAsyncListener.close();
         }
 
-        session.close(false).addListener(new SshFutureListener<CloseFuture>() {
-            @Override
-            public void operationComplete(final CloseFuture future) {
-                if(future.isClosed() == false) {
-                    session.close(true);
+        if(sshWriteAsyncHandler != null) {
+            sshWriteAsyncHandler.close();
+        }
+
+        if(session!= null && !session.isClosed() && !session.isClosing()) {
+            session.close(false).addListener(new SshFutureListener<CloseFuture>() {
+                @Override
+                public void operationComplete(final CloseFuture future) {
+                    if (future.isClosed() == false) {
+                        session.close(true);
+                    }
+                    session = null;
                 }
-                session = null;
-            }
-        });
+            });
+        }
 
         channel = null;
+        promise.setSuccess();
+
+        handleSshSessionClosed(ctx);
     }
 
     /**
@@ -255,7 +253,7 @@ public class AsyncSshHandler extends ChannelOutboundHandlerAdapter {
         }
 
         @Override
-        public synchronized void close() throws Exception {
+        public synchronized void close() {
             // Remove self as listener on close to prevent reading from closed input
             if(currentReadFuture != null) {
                 currentReadFuture.removeListener(this);
@@ -264,4 +262,103 @@ public class AsyncSshHandler extends ChannelOutboundHandlerAdapter {
             asyncOut = null;
         }
     }
+
+    private static final class SshWriteAsyncHandler implements AutoCloseable {
+        public static final int MAX_PENDING_WRITES = 100;
+
+        private final ChannelOutboundHandler channelHandler;
+        private IoOutputStream asyncIn;
+
+        // Counter that holds the amount of pending write messages
+        // Pending write can occur in case remote window is full
+        // In such case, we need to wait for the pending write to finish
+        private int pendingWriteCounter;
+        // Last write future, that can be pending
+        private IoWriteFuture lastWriteFuture;
+
+        public SshWriteAsyncHandler(final ChannelOutboundHandler channelHandler, final IoOutputStream asyncIn) {
+            this.channelHandler = channelHandler;
+            this.asyncIn = asyncIn;
+        }
+
+        public synchronized void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
+            try {
+                if(asyncIn.isClosed() || asyncIn.isClosing()) {
+                    handleSshSessionClosed(ctx);
+                } else {
+                    lastWriteFuture = asyncIn.write(toBuffer(msg));
+                    lastWriteFuture.addListener(new SshFutureListener<IoWriteFuture>() {
+
+                        @Override
+                        public void operationComplete(final IoWriteFuture future) {
+                            ((ByteBuf) msg).release();
+
+                            // Notify success or failure
+                            if (future.isWritten()) {
+                                promise.setSuccess();
+                            }
+                            promise.setFailure(future.getException());
+
+                            // Reset last pending future
+                            synchronized (SshWriteAsyncHandler.this) {
+                                lastWriteFuture = null;
+                            }
+                        }
+                    });
+                }
+            } catch (final WritePendingException e) {
+                // Check limit for pending writes
+                pendingWriteCounter++;
+                if(pendingWriteCounter > MAX_PENDING_WRITES) {
+                    handlePendingFailed(ctx, new IllegalStateException("Too much pending writes(" + MAX_PENDING_WRITES + ") on channel: " + ctx.channel() +
+                            ", remote window is not getting read or is too small"));
+                }
+
+                logger.debug("Write pending to SSH remote on channel: {}, current pending count: {}", ctx.channel(), pendingWriteCounter);
+
+                // In case of pending, re-invoke write after pending is finished
+                lastWriteFuture.addListener(new SshFutureListener<IoWriteFuture>() {
+                    @Override
+                    public void operationComplete(final IoWriteFuture future) {
+                        if(future.isWritten()) {
+                            synchronized (SshWriteAsyncHandler.this) {
+                                // Pending done, decrease counter
+                                pendingWriteCounter--;
+                            }
+                            write(ctx, msg, promise);
+                        } else {
+                            // Cannot reschedule pending, fail
+                            handlePendingFailed(ctx, e);
+                        }
+                    }
+
+                });
+            }
+        }
+
+        private void handlePendingFailed(final ChannelHandlerContext ctx, final Exception e) {
+            logger.warn("Exception while writing to SSH remote on channel {}", ctx.channel(), e);
+            try {
+                channelHandler.disconnect(ctx, ctx.newPromise());
+            } catch (final Exception ex) {
+                // This should not happen
+                throw new IllegalStateException(ex);
+            }
+        }
+
+        @Override
+        public void close() {
+            asyncIn = null;
+        }
+
+        private Buffer toBuffer(final Object msg) {
+            // TODO Buffer vs ByteBuf translate, Can we handle that better ?
+            Preconditions.checkState(msg instanceof ByteBuf);
+            final ByteBuf byteBuf = (ByteBuf) msg;
+            final byte[] temp = new byte[byteBuf.readableBytes()];
+            byteBuf.readBytes(temp, 0, byteBuf.readableBytes());
+            return new Buffer(temp);
+        }
+
+    }
 }
index 1b2201170a2426dabce3174515bfe3784e374ced..b32e880537e06d44571875a3dda56c4a36dddf6f 100644 (file)
@@ -118,7 +118,7 @@ public class SSHTest {
             Thread.sleep(100);
         }
         assertFalse(echoClientHandler.isConnected());
-        assertEquals(State.FAILED_TO_CONNECT, echoClientHandler.getState());
+        assertEquals(State.CONNECTION_CLOSED, echoClientHandler.getState());
     }
 
 }