package org.opendaylight.controller.netconf.it;
import static java.util.Arrays.asList;
+import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
-import io.netty.channel.ChannelFuture;
+import com.google.common.collect.Lists;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.GlobalEventExecutor;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.util.Collection;
import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
import junit.framework.Assert;
import org.junit.After;
import org.junit.Before;
super.initConfigTransactionManagerImpl(new HardcodedModuleFactoriesResolver(mockedContext, getModuleFactories().toArray(
new ModuleFactory[0])));
- NetconfOperationServiceFactoryListenerImpl factoriesListener = new NetconfOperationServiceFactoryListenerImpl();
+ final NetconfOperationServiceFactoryListenerImpl factoriesListener = new NetconfOperationServiceFactoryListenerImpl();
factoriesListener.onAddNetconfOperationServiceFactory(new NetconfOperationServiceFactoryImpl(getYangStore()));
commitNot = new DefaultCommitNotificationProducer(ManagementFactory.getPlatformMBeanServer());
final NetconfServerDispatcher dispatchS = createDispatcher(factoriesListener);
- ChannelFuture s = dispatchS.createLocalServer(NetconfConfigUtil.getNetconfLocalAddress());
- s.await();
- EventLoopGroup bossGroup = new NioEventLoopGroup();
+ dispatchS.createLocalServer(NetconfConfigUtil.getNetconfLocalAddress()).await();
+ final EventLoopGroup bossGroup = new NioEventLoopGroup();
sshServer = NetconfSSHServer.start(tlsAddress.getPort(), NetconfConfigUtil.getNetconfLocalAddress(), getAuthProvider(), bossGroup);
}
- private NetconfServerDispatcher createDispatcher(NetconfOperationServiceFactoryListenerImpl factoriesListener) {
+ private NetconfServerDispatcher createDispatcher(final NetconfOperationServiceFactoryListenerImpl factoriesListener) {
return super.createDispatcher(factoriesListener, NetconfITTest.getNetconfMonitoringListenerService(), commitNot);
}
@After
public void tearDown() throws Exception {
- sshServer.stop();
+ sshServer.close();
commitNot.close();
+ sshServer.join();
}
private HardcodedYangStoreService getYangStore() throws YangStoreException, IOException {
@Test
public void testSecure() throws Exception {
- NetconfClientDispatcher dispatch = new NetconfClientDispatcherImpl(getNettyThreadgroup(), getNettyThreadgroup(), getHashedWheelTimer());
+ final NetconfClientDispatcher dispatch = new NetconfClientDispatcherImpl(getNettyThreadgroup(), getNettyThreadgroup(), getHashedWheelTimer());
try (TestingNetconfClient netconfClient = new TestingNetconfClient("testing-ssh-client", dispatch, getClientConfiguration())) {
NetconfMessage response = netconfClient.sendMessage(getConfig);
Assert.assertFalse("Unexpected error message " + XmlUtil.toString(response.getDocument()),
NetconfMessageUtil.isErrorMessage(response));
- NetconfMessage gs = new NetconfMessage(XmlUtil.readXmlToDocument("<rpc message-id=\"2\"\n" +
+ final NetconfMessage gs = new NetconfMessage(XmlUtil.readXmlToDocument("<rpc message-id=\"2\"\n" +
" xmlns=\"urn:ietf:params:xml:ns:netconf:base:1.0\">\n" +
" <get-schema xmlns=\"urn:ietf:params:xml:ns:yang:ietf-netconf-monitoring\">\n" +
" <identifier>config</identifier>\n" +
}
}
+ /**
+ * Test all requests are handled properly and no mismatch occurs in listener
+ */
+ @Test(timeout = 3*60*1000)
+ public void testSecureStress() throws Exception {
+ final NetconfClientDispatcher dispatch = new NetconfClientDispatcherImpl(getNettyThreadgroup(), getNettyThreadgroup(), getHashedWheelTimer());
+ try (TestingNetconfClient netconfClient = new TestingNetconfClient("testing-ssh-client", dispatch, getClientConfiguration())) {
+
+ final AtomicInteger responseCounter = new AtomicInteger(0);
+ final List<Future<?>> futures = Lists.newArrayList();
+
+ final int requests = 1000;
+ for (int i = 0; i < requests; i++) {
+ final Future<NetconfMessage> netconfMessageFuture = netconfClient.sendRequest(getConfig);
+ futures.add(netconfMessageFuture);
+ netconfMessageFuture.addListener(new GenericFutureListener<Future<? super NetconfMessage>>() {
+ @Override
+ public void operationComplete(final Future<? super NetconfMessage> future) throws Exception {
+ assertTrue("Request unsuccessful " + future.cause(), future.isSuccess());
+ responseCounter.incrementAndGet();
+ }
+ });
+ }
+
+ for (final Future<?> future : futures) {
+ future.await();
+ }
+
+ // Give future listeners some time to finish counter incrementation
+ Thread.sleep(5000);
+
+ org.junit.Assert.assertEquals(requests, responseCounter.get());
+ }
+ }
+
public NetconfClientConfiguration getClientConfiguration() throws IOException {
final NetconfClientConfigurationBuilder b = NetconfClientConfigurationBuilder.create();
b.withAddress(tlsAddress);
}
public AuthProvider getAuthProvider() throws Exception {
- AuthProvider mock = mock(AuthProviderImpl.class);
+ final AuthProvider mock = mock(AuthProviderImpl.class);
doReturn(true).when(mock).authenticated(anyString(), anyString());
doReturn(PEMGenerator.generate().toCharArray()).when(mock).getPEMAsCharArray();
return mock;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import java.io.IOException;
import org.apache.sshd.common.future.CloseFuture;
import org.apache.sshd.common.future.SshFutureListener;
import org.apache.sshd.common.io.IoInputStream;
+import org.apache.sshd.common.io.IoOutputStream;
import org.apache.sshd.common.io.IoReadFuture;
+import org.apache.sshd.common.io.IoWriteFuture;
+import org.apache.sshd.common.io.WritePendingException;
import org.apache.sshd.common.util.Buffer;
import org.opendaylight.controller.netconf.nettyutil.handler.ssh.authentication.AuthenticationHandler;
import org.slf4j.Logger;
private final SshClient sshClient;
private SshReadAsyncListener sshReadAsyncListener;
+ private SshWriteAsyncHandler sshWriteAsyncHandler;
+
private ClientChannel channel;
private ClientSession session;
private ChannelPromise connectPromise;
+
public static AsyncSshHandler createForNetconfSubsystem(final AuthenticationHandler authenticationHandler) throws IOException {
return new AsyncSshHandler(authenticationHandler, DEFAULT_CLIENT);
}
connectPromise.setSuccess();
connectPromise = null;
- ctx.fireChannelActive();
- final IoInputStream asyncOut = channel.getAsyncOut();
- sshReadAsyncListener = new SshReadAsyncListener(ctx, asyncOut);
+ sshReadAsyncListener = new SshReadAsyncListener(ctx, channel.getAsyncOut());
+ sshWriteAsyncHandler = new SshWriteAsyncHandler(this, channel.getAsyncIn());
+
+ ctx.fireChannelActive();
}
private synchronized void handleSshSetupFailure(final ChannelHandlerContext ctx, final Throwable e) {
@Override
public synchronized void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
- try {
- if(channel.getAsyncIn().isClosed() || channel.getAsyncIn().isClosing()) {
- handleSshSessionClosed(ctx);
- } else {
- channel.getAsyncIn().write(toBuffer(msg));
- ((ByteBuf) msg).release();
- }
- } catch (final Exception e) {
- logger.warn("Exception while writing to SSH remote on channel {}", ctx.channel(), e);
- throw new IllegalStateException("Exception while writing to SSH remote on channel " + ctx.channel(),e);
- }
+ sshWriteAsyncHandler.write(ctx, msg, promise);
}
private static void handleSshSessionClosed(final ChannelHandlerContext ctx) {
ctx.fireChannelInactive();
}
- private Buffer toBuffer(final Object msg) {
- // TODO Buffer vs ByteBuf translate, Can we handle that better ?
- Preconditions.checkState(msg instanceof ByteBuf);
- final ByteBuf byteBuf = (ByteBuf) msg;
- final byte[] temp = new byte[byteBuf.readableBytes()];
- byteBuf.readBytes(temp, 0, byteBuf.readableBytes());
- return new Buffer(temp);
- }
-
@Override
public synchronized void connect(final ChannelHandlerContext ctx, final SocketAddress remoteAddress, final SocketAddress localAddress, final ChannelPromise promise) throws Exception {
this.connectPromise = promise;
}
@Override
- public synchronized void disconnect(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception {
+ public synchronized void disconnect(final ChannelHandlerContext ctx, final ChannelPromise promise) {
if(sshReadAsyncListener != null) {
sshReadAsyncListener.close();
}
- session.close(false).addListener(new SshFutureListener<CloseFuture>() {
- @Override
- public void operationComplete(final CloseFuture future) {
- if(future.isClosed() == false) {
- session.close(true);
+ if(sshWriteAsyncHandler != null) {
+ sshWriteAsyncHandler.close();
+ }
+
+ if(session!= null && !session.isClosed() && !session.isClosing()) {
+ session.close(false).addListener(new SshFutureListener<CloseFuture>() {
+ @Override
+ public void operationComplete(final CloseFuture future) {
+ if (future.isClosed() == false) {
+ session.close(true);
+ }
+ session = null;
}
- session = null;
- }
- });
+ });
+ }
channel = null;
+ promise.setSuccess();
+
+ handleSshSessionClosed(ctx);
}
/**
}
@Override
- public synchronized void close() throws Exception {
+ public synchronized void close() {
// Remove self as listener on close to prevent reading from closed input
if(currentReadFuture != null) {
currentReadFuture.removeListener(this);
asyncOut = null;
}
}
+
+ private static final class SshWriteAsyncHandler implements AutoCloseable {
+ public static final int MAX_PENDING_WRITES = 100;
+
+ private final ChannelOutboundHandler channelHandler;
+ private IoOutputStream asyncIn;
+
+ // Counter that holds the amount of pending write messages
+ // Pending write can occur in case remote window is full
+ // In such case, we need to wait for the pending write to finish
+ private int pendingWriteCounter;
+ // Last write future, that can be pending
+ private IoWriteFuture lastWriteFuture;
+
+ public SshWriteAsyncHandler(final ChannelOutboundHandler channelHandler, final IoOutputStream asyncIn) {
+ this.channelHandler = channelHandler;
+ this.asyncIn = asyncIn;
+ }
+
+ public synchronized void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
+ try {
+ if(asyncIn.isClosed() || asyncIn.isClosing()) {
+ handleSshSessionClosed(ctx);
+ } else {
+ lastWriteFuture = asyncIn.write(toBuffer(msg));
+ lastWriteFuture.addListener(new SshFutureListener<IoWriteFuture>() {
+
+ @Override
+ public void operationComplete(final IoWriteFuture future) {
+ ((ByteBuf) msg).release();
+
+ // Notify success or failure
+ if (future.isWritten()) {
+ promise.setSuccess();
+ }
+ promise.setFailure(future.getException());
+
+ // Reset last pending future
+ synchronized (SshWriteAsyncHandler.this) {
+ lastWriteFuture = null;
+ }
+ }
+ });
+ }
+ } catch (final WritePendingException e) {
+ // Check limit for pending writes
+ pendingWriteCounter++;
+ if(pendingWriteCounter > MAX_PENDING_WRITES) {
+ handlePendingFailed(ctx, new IllegalStateException("Too much pending writes(" + MAX_PENDING_WRITES + ") on channel: " + ctx.channel() +
+ ", remote window is not getting read or is too small"));
+ }
+
+ logger.debug("Write pending to SSH remote on channel: {}, current pending count: {}", ctx.channel(), pendingWriteCounter);
+
+ // In case of pending, re-invoke write after pending is finished
+ lastWriteFuture.addListener(new SshFutureListener<IoWriteFuture>() {
+ @Override
+ public void operationComplete(final IoWriteFuture future) {
+ if(future.isWritten()) {
+ synchronized (SshWriteAsyncHandler.this) {
+ // Pending done, decrease counter
+ pendingWriteCounter--;
+ }
+ write(ctx, msg, promise);
+ } else {
+ // Cannot reschedule pending, fail
+ handlePendingFailed(ctx, e);
+ }
+ }
+
+ });
+ }
+ }
+
+ private void handlePendingFailed(final ChannelHandlerContext ctx, final Exception e) {
+ logger.warn("Exception while writing to SSH remote on channel {}", ctx.channel(), e);
+ try {
+ channelHandler.disconnect(ctx, ctx.newPromise());
+ } catch (final Exception ex) {
+ // This should not happen
+ throw new IllegalStateException(ex);
+ }
+ }
+
+ @Override
+ public void close() {
+ asyncIn = null;
+ }
+
+ private Buffer toBuffer(final Object msg) {
+ // TODO Buffer vs ByteBuf translate, Can we handle that better ?
+ Preconditions.checkState(msg instanceof ByteBuf);
+ final ByteBuf byteBuf = (ByteBuf) msg;
+ final byte[] temp = new byte[byteBuf.readableBytes()];
+ byteBuf.readBytes(temp, 0, byteBuf.readableBytes());
+ return new Buffer(temp);
+ }
+
+ }
}