Fix frame encoder for netconf server
[netconf.git] / protocol / netconf-server / src / main / java / org / opendaylight / netconf / server / NetconfSubsystemFactory.java
1 /*
2  * Copyright (c) 2023 PANTHEON.tech s.r.o. and others. All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.netconf.server;
9
10 import static java.util.Objects.requireNonNull;
11
12 import io.netty.buffer.ByteBuf;
13 import io.netty.buffer.Unpooled;
14 import io.netty.channel.ChannelHandlerContext;
15 import io.netty.channel.ChannelInboundHandlerAdapter;
16 import io.netty.channel.ChannelOutboundHandlerAdapter;
17 import io.netty.channel.ChannelPromise;
18 import io.netty.channel.embedded.EmbeddedChannel;
19 import io.netty.util.concurrent.DefaultPromise;
20 import io.netty.util.concurrent.GlobalEventExecutor;
21 import java.io.IOException;
22 import java.net.InetSocketAddress;
23 import java.nio.charset.StandardCharsets;
24 import org.opendaylight.netconf.api.messages.NetconfHelloMessageAdditionalHeader;
25 import org.opendaylight.netconf.shaded.sshd.common.io.IoInputStream;
26 import org.opendaylight.netconf.shaded.sshd.common.io.IoOutputStream;
27 import org.opendaylight.netconf.shaded.sshd.common.util.buffer.ByteArrayBuffer;
28 import org.opendaylight.netconf.shaded.sshd.server.channel.ChannelDataReceiver;
29 import org.opendaylight.netconf.shaded.sshd.server.channel.ChannelSession;
30 import org.opendaylight.netconf.shaded.sshd.server.channel.ChannelSessionAware;
31 import org.opendaylight.netconf.shaded.sshd.server.command.AbstractCommandSupport;
32 import org.opendaylight.netconf.shaded.sshd.server.command.AsyncCommand;
33 import org.opendaylight.netconf.shaded.sshd.server.command.Command;
34 import org.opendaylight.netconf.shaded.sshd.server.subsystem.SubsystemFactory;
35 import org.slf4j.Logger;
36 import org.slf4j.LoggerFactory;
37
38 public final class NetconfSubsystemFactory implements SubsystemFactory {
39     private static final String NETCONF = "netconf";
40
41     private final ServerChannelInitializer channelInitializer;
42
43     public NetconfSubsystemFactory(final ServerChannelInitializer channelInitializer) {
44         this.channelInitializer = requireNonNull(channelInitializer);
45     }
46
47     @Override
48     public String getName() {
49         return NETCONF;
50     }
51
52     @Override
53     public Command createSubsystem(ChannelSession channel) throws IOException {
54         return new NetconfSubsystem(channelInitializer);
55     }
56
57     private static class NetconfSubsystem extends AbstractCommandSupport implements AsyncCommand, ChannelSessionAware {
58         private static final Logger LOG = LoggerFactory.getLogger(NetconfSubsystem.class);
59
60         private final ServerChannelInitializer channelInitializer;
61         private EmbeddedChannel innerChannel;
62         private IoOutputStream ioOutputStream;
63         private ChannelSession channelSession;
64
65         NetconfSubsystem(final ServerChannelInitializer channelInitializer) {
66             super(NETCONF, null);
67             this.channelInitializer = channelInitializer;
68         }
69
70         @Override
71         public void setIoInputStream(final IoInputStream in) {
72            // not used
73         }
74
75         @Override
76         public void setIoOutputStream(final IoOutputStream out) {
77             this.ioOutputStream = out;
78         }
79
80         @Override
81         public void setIoErrorStream(final IoOutputStream err) {
82             // not used
83         }
84
85         @Override
86         public void setChannelSession(final ChannelSession channelSession) {
87             this.channelSession = channelSession;
88         }
89
90         @Override
91         public void run() {
92
93             /*
94              * While NETCONF protocol handlers are designed to operate over Netty channel,
95              * the inner channel is used to serve NETCONF over SSH.
96              */
97
98             this.innerChannel = new EmbeddedChannel();
99
100             // inbound packets handler
101             channelSession.setDataReceiver(new ChannelDataReceiver() {
102                 @Override
103                 public int data(ChannelSession channel, byte[] buf, int start, int len) throws IOException {
104                     innerChannel.writeInbound(Unpooled.copiedBuffer(buf, start, len));
105                     return len;
106                 }
107
108                 @Override
109                 public void close() throws IOException {
110                     innerChannel.close();
111                 }
112             });
113
114             // outbound packet handler, adding fist means it will be invoked last bc of flow direction
115             innerChannel.pipeline().addFirst(
116                 new ChannelOutboundHandlerAdapter() {
117                     @Override
118                     public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
119                         if (msg instanceof ByteBuf byteBuf) {
120                             // redirect channel outgoing packets to output stream linked to transport
121                             final byte[] bytes = new byte[byteBuf.readableBytes()];
122                             byteBuf.readBytes(bytes);
123                             try {
124                                 ioOutputStream.writeBuffer(new ByteArrayBuffer(bytes))
125                                     .addListener(future -> {
126                                         if (future.isWritten()) {
127                                             byteBuf.release(); // report outbound message being handled
128                                             promise.setSuccess();
129                                         } else if (future.getException() != null) {
130                                             LOG.error("Error writing buffer", future.getException());
131                                             promise.setFailure(future.getException());
132                                         }
133                                     });
134                             } catch (IOException e) {
135                                 LOG.error("Error writing buffer", e);
136                             }
137                         }
138                     }
139                 });
140
141             // inner channel termination handler
142             innerChannel.pipeline().addLast(
143                 new ChannelInboundHandlerAdapter() {
144                     @Override
145                     public void channelInactive(ChannelHandlerContext ctx) throws Exception {
146                         onExit(0);
147                     }
148                 }
149             );
150
151             // NETCONF protocol handlers
152             channelInitializer.initialize(innerChannel, new DefaultPromise<>(GlobalEventExecutor.INSTANCE));
153             // trigger negotiation flow
154             innerChannel.pipeline().fireChannelActive();
155             // set additional info for upcoming netconf session
156             innerChannel.writeInbound(Unpooled.wrappedBuffer(getHelloAdditionalMessageBytes()));
157         }
158
159         @Override
160         protected void onExit(int exitValue, String exitMessage) {
161             super.onExit(exitValue, exitMessage);
162             if (innerChannel != null) {
163                 innerChannel.close();
164             }
165         }
166
167         private byte[] getHelloAdditionalMessageBytes() {
168             final var session = getServerSession();
169             final var address = (InetSocketAddress) session.getClientAddress();
170             final var header = new NetconfHelloMessageAdditionalHeader(
171                 session.getUsername(),
172                 address.getAddress().getHostAddress(),
173                 String.valueOf(address.getPort()),
174                 "ssh", "client").toFormattedString();
175             return header.getBytes(StandardCharsets.UTF_8);
176         }
177     }
178 }