Bind OutboundChannelHandler to ChannelAsyncOutputStream
[netconf.git] / transport / transport-ssh / src / main / java / org / opendaylight / netconf / transport / ssh / OutboundChannelHandler.java
index c7db99a9f9e3004aa9e915614c85aae50618a901..cceef609b0d59a6311de828676621a39ced6ba2b 100644 (file)
@@ -11,9 +11,12 @@ import static java.util.Objects.requireNonNull;
 
 import io.netty.buffer.ByteBuf;
 import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandler;
 import io.netty.channel.ChannelOutboundHandlerAdapter;
 import io.netty.channel.ChannelPromise;
 import java.io.IOException;
+import java.util.ArrayDeque;
+import org.opendaylight.netconf.shaded.sshd.common.channel.ChannelAsyncOutputStream;
 import org.opendaylight.netconf.shaded.sshd.common.io.IoOutputStream;
 import org.opendaylight.netconf.shaded.sshd.common.io.IoWriteFuture;
 import org.opendaylight.netconf.shaded.sshd.common.util.buffer.ByteArrayBuffer;
@@ -21,45 +24,111 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * A ChannelOutboundHandler responsible for redirecting whatever bytes need to be written out on the Netty channel so
- * that they pass into SSHD's output.
+ * A {@link ChannelOutboundHandler} responsible for redirecting whatever bytes need to be written out on the Netty
+ * channel so that they pass into SSHD's output.
+ *
+ * <p>
+ * This class is specialized for {@link ChannelAsyncOutputStream} on purpose, as this handler is invoked from the Netty
+ * thread and we do not want to block those. We therefore rely on {@link ChannelAsyncOutputStream}'s single-async-write
+ * promise and perform queueing here.
  */
 final class OutboundChannelHandler extends ChannelOutboundHandlerAdapter {
+    // A write enqueued in pending queue
+    private record Write(ByteBuf buf, ChannelPromise promise) {
+        Write {
+            requireNonNull(buf);
+            requireNonNull(promise);
+        }
+    }
+
     private static final Logger LOG = LoggerFactory.getLogger(OutboundChannelHandler.class);
 
     private final IoOutputStream out;
 
-    OutboundChannelHandler(final IoOutputStream out) {
+    // write requests that need to be sent once currently-outstanding write completes
+    private ArrayDeque<Write> pending;
+    // indicates we have an outstanding write
+    private boolean writePending;
+
+    OutboundChannelHandler(final ChannelAsyncOutputStream out) {
         this.out = requireNonNull(out);
     }
 
     @Override
     public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
         // redirect channel outgoing packets to output stream linked to transport
-        if (!(msg instanceof ByteBuf byteBuf)) {
+        if (msg instanceof ByteBuf buf) {
+            write(buf, promise);
+        } else {
             LOG.trace("Ignoring unrecognized {}", msg == null ? null : msg.getClass());
-            return;
         }
+    }
+
+    private void write(final ByteBuf buf, final ChannelPromise promise) {
+        if (writePending) {
+            LOG.trace("A write is already pending, delaying write");
+            delayWrite(buf, promise);
+        } else {
+            LOG.trace("Issuing immediate write");
+            startWrite(buf, promise);
+        }
+    }
+
+    private void delayWrite(final ByteBuf buf, final ChannelPromise promise) {
+        if (pending == null) {
+            // these are per-session, hence we want to start out small
+            pending = new ArrayDeque<>(1);
+        }
+        pending.addLast(new Write(buf, promise));
+    }
 
-        final var sshBuf = toSshBuffer(byteBuf);
+    private void startWrite(final ByteBuf buf, final ChannelPromise promise) {
+        final var sshBuf = toSshBuffer(buf);
         final IoWriteFuture writeFuture;
         try {
             writeFuture = out.writeBuffer(sshBuf);
         } catch (IOException e) {
-            LOG.error("Error writing buffer", e);
-            promise.setFailure(e);
+            failWrites(promise, e);
             return;
         }
 
-        writeFuture.addListener(future -> {
-            if (future.isWritten()) {
-                // report outbound message being handled
-                promise.setSuccess();
-            } else if (future.getException() != null) {
-                LOG.error("Error writing buffer", future.getException());
-                promise.setFailure(future.getException());
+        writePending = true;
+        writeFuture.addListener(future -> finishWrite(future, promise));
+    }
+
+    private void finishWrite(final IoWriteFuture future, final ChannelPromise promise) {
+        writePending = false;
+
+        if (future.isWritten()) {
+            // report outbound message being handled
+            promise.setSuccess();
+
+            if (pending != null) {
+                // TODO: here we could be coalescing multiple ByteBufs into a single Buffer
+                final var next = pending.pollFirst();
+                if (next != null) {
+                    LOG.trace("Issuing next write");
+                    startWrite(next.buf, next.promise);
+                }
             }
-        });
+            return;
+        }
+
+        final var cause = future.getException();
+        if (cause != null) {
+            failWrites(promise, cause);
+        }
+    }
+
+    private void failWrites(final ChannelPromise promise, final Throwable cause) {
+        LOG.error("Error writing buffer", cause);
+        promise.setFailure(cause);
+
+        // Cascade to all delayed messages
+        if (pending != null) {
+            pending.forEach(msg -> msg.promise.setFailure(cause));
+            pending = null;
+        }
     }
 
     // TODO: This can amount to a lot of copying around. Is it worth our while to create a ByteBufBuffer, which