Refactor ClientFactoryManagerConfigurator
[netconf.git] / transport / transport-ssh / src / test / java / org / opendaylight / netconf / transport / ssh / SshClientServerTest.java
1 /*
2  * Copyright (c) 2023 PANTHEON.tech s.r.o. All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.netconf.transport.ssh;
9
10 import static org.junit.jupiter.api.Assertions.assertEquals;
11 import static org.junit.jupiter.api.Assertions.assertInstanceOf;
12 import static org.junit.jupiter.api.Assertions.assertNotNull;
13 import static org.junit.jupiter.api.Assertions.assertTrue;
14 import static org.mockito.ArgumentMatchers.any;
15 import static org.mockito.Mockito.timeout;
16 import static org.mockito.Mockito.verify;
17 import static org.mockito.Mockito.when;
18 import static org.opendaylight.netconf.transport.ssh.TestUtils.buildClientAuthHostBased;
19 import static org.opendaylight.netconf.transport.ssh.TestUtils.buildClientAuthWithPassword;
20 import static org.opendaylight.netconf.transport.ssh.TestUtils.buildClientAuthWithPublicKey;
21 import static org.opendaylight.netconf.transport.ssh.TestUtils.buildClientIdentityHostBased;
22 import static org.opendaylight.netconf.transport.ssh.TestUtils.buildClientIdentityWithPassword;
23 import static org.opendaylight.netconf.transport.ssh.TestUtils.buildClientIdentityWithPublicKey;
24 import static org.opendaylight.netconf.transport.ssh.TestUtils.buildServerAuthWithCertificate;
25 import static org.opendaylight.netconf.transport.ssh.TestUtils.buildServerAuthWithPublicKey;
26 import static org.opendaylight.netconf.transport.ssh.TestUtils.buildServerIdentityWithCertificate;
27 import static org.opendaylight.netconf.transport.ssh.TestUtils.buildServerIdentityWithKeyPair;
28 import static org.opendaylight.netconf.transport.ssh.TestUtils.generateKeyPairWithCertificate;
29
30 import com.google.common.util.concurrent.ListenableFuture;
31 import com.google.common.util.concurrent.SettableFuture;
32 import io.netty.channel.Channel;
33 import io.netty.channel.ChannelHandlerContext;
34 import io.netty.channel.ChannelInboundHandlerAdapter;
35 import java.io.IOException;
36 import java.net.InetAddress;
37 import java.net.ServerSocket;
38 import java.util.Collection;
39 import java.util.List;
40 import java.util.concurrent.TimeUnit;
41 import java.util.concurrent.atomic.AtomicInteger;
42 import java.util.concurrent.atomic.AtomicReference;
43 import java.util.stream.Stream;
44 import org.apache.commons.codec.digest.Crypt;
45 import org.junit.jupiter.api.AfterAll;
46 import org.junit.jupiter.api.BeforeAll;
47 import org.junit.jupiter.api.BeforeEach;
48 import org.junit.jupiter.api.DisplayName;
49 import org.junit.jupiter.api.Test;
50 import org.junit.jupiter.api.extension.ExtendWith;
51 import org.junit.jupiter.params.ParameterizedTest;
52 import org.junit.jupiter.params.provider.Arguments;
53 import org.junit.jupiter.params.provider.MethodSource;
54 import org.mockito.ArgumentCaptor;
55 import org.mockito.Captor;
56 import org.mockito.Mock;
57 import org.mockito.junit.jupiter.MockitoExtension;
58 import org.opendaylight.netconf.shaded.sshd.client.ClientFactoryManager;
59 import org.opendaylight.netconf.shaded.sshd.client.auth.password.PasswordIdentityProvider;
60 import org.opendaylight.netconf.shaded.sshd.client.session.ClientSession;
61 import org.opendaylight.netconf.shaded.sshd.common.session.Session;
62 import org.opendaylight.netconf.shaded.sshd.server.auth.password.UserAuthPasswordFactory;
63 import org.opendaylight.netconf.shaded.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider;
64 import org.opendaylight.netconf.shaded.sshd.server.session.ServerSession;
65 import org.opendaylight.netconf.transport.api.TransportChannel;
66 import org.opendaylight.netconf.transport.api.TransportChannelListener;
67 import org.opendaylight.netconf.transport.api.UnsupportedConfigurationException;
68 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.inet.types.rev130715.Host;
69 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.inet.types.rev130715.IetfInetUtil;
70 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.inet.types.rev130715.PortNumber;
71 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.ssh.client.rev240208.SshClientGrouping;
72 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.ssh.client.rev240208.ssh.client.grouping.ClientIdentity;
73 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.ssh.client.rev240208.ssh.client.grouping.ClientIdentityBuilder;
74 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.ssh.client.rev240208.ssh.client.grouping.ServerAuthentication;
75 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.ssh.server.rev240208.SshServerGrouping;
76 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.ssh.server.rev240208.ssh.server.grouping.ClientAuthentication;
77 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.ssh.server.rev240208.ssh.server.grouping.ServerIdentity;
78 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.tcp.client.rev240208.TcpClientGrouping;
79 import org.opendaylight.yang.gen.v1.urn.ietf.params.xml.ns.yang.ietf.tcp.server.rev240208.TcpServerGrouping;
80 import org.opendaylight.yangtools.yang.common.Uint16;
81
82 @ExtendWith(MockitoExtension.class)
83 public class SshClientServerTest {
84     private static final String RSA = "RSA";
85     private static final String EC = "EC";
86     private static final String USER = "user";
87     private static final String PASSWORD = "pa$$w0rd";
88     private static final String SUBSYSTEM = "subsystem";
89     private static final AtomicInteger COUNTER = new AtomicInteger(0);
90     private static final AtomicReference<String> USERNAME = new AtomicReference<>(USER);
91
92     private static SSHTransportStackFactory FACTORY;
93
94     @Mock
95     private TcpClientGrouping tcpClientConfig;
96     @Mock
97     private SshClientGrouping sshClientConfig;
98     @Mock
99     private TransportChannelListener clientListener;
100     @Mock
101     private TcpServerGrouping tcpServerConfig;
102     @Mock
103     private SshServerGrouping sshServerConfig;
104     @Mock
105     private TransportChannelListener serverListener;
106
107     @Captor
108     ArgumentCaptor<TransportChannel> clientTransportChannelCaptor;
109     @Captor
110     ArgumentCaptor<TransportChannel> serverTransportChannelCaptor;
111
112     private ServerSocket socket;
113
114     @BeforeAll
115     static void beforeAll() {
116         FACTORY = new SSHTransportStackFactory("IntegrationTest", 0);
117     }
118
119     @AfterAll
120     static void afterAll() {
121         FACTORY.close();
122     }
123
124     @BeforeEach
125     void beforeEach() throws IOException {
126
127         // create temp socket to get available port for test
128         socket = new ServerSocket(0);
129         final var localAddress = IetfInetUtil.ipAddressFor(InetAddress.getLoopbackAddress());
130         final var localPort = new PortNumber(Uint16.valueOf(socket.getLocalPort()));
131         socket.close();
132
133         when(tcpServerConfig.getLocalAddress()).thenReturn(localAddress);
134         when(tcpServerConfig.requireLocalAddress()).thenCallRealMethod();
135         when(tcpServerConfig.getLocalPort()).thenReturn(localPort);
136         when(tcpServerConfig.requireLocalPort()).thenCallRealMethod();
137
138         when(tcpClientConfig.getRemoteAddress()).thenReturn(new Host(localAddress));
139         when(tcpClientConfig.requireRemoteAddress()).thenCallRealMethod();
140         when(tcpClientConfig.getRemotePort()).thenReturn(localPort);
141         when(tcpClientConfig.requireRemotePort()).thenCallRealMethod();
142     }
143
144     @ParameterizedTest(name = "SSH Server Host Key Verification -- {0}")
145     @MethodSource("itServerKeyVerifyArgs")
146     void itServerKeyVerify(final String testDesc, final ServerIdentity serverIdentity,
147             final ServerAuthentication serverAuth) throws Exception {
148         final var clientIdentity = buildClientIdentityWithPassword(getUsername(), PASSWORD);
149         final var clientAuth = buildClientAuthWithPassword(getUsernameAndUpdate(), "$0$" + PASSWORD);
150         when(sshClientConfig.getClientIdentity()).thenReturn(clientIdentity);
151         when(sshClientConfig.getServerAuthentication()).thenReturn(serverAuth);
152         when(sshServerConfig.getServerIdentity()).thenReturn(serverIdentity);
153         when(sshServerConfig.getClientAuthentication()).thenReturn(clientAuth);
154         integrationTest(
155             () -> FACTORY.listenServer(SUBSYSTEM, serverListener, tcpServerConfig, sshServerConfig),
156             () -> FACTORY.connectClient(SUBSYSTEM, clientListener, tcpClientConfig, sshClientConfig));
157     }
158
159     private static Stream<Arguments> itServerKeyVerifyArgs() throws Exception {
160         final var rsaKeyData = generateKeyPairWithCertificate(RSA);
161         final var ecKeyData = generateKeyPairWithCertificate(EC);
162         return Stream.of(
163                 Arguments.of("RSA public key",
164                         buildServerIdentityWithKeyPair(rsaKeyData), buildServerAuthWithPublicKey(rsaKeyData)),
165                 Arguments.of("EC public key",
166                         buildServerIdentityWithKeyPair(ecKeyData), buildServerAuthWithPublicKey(ecKeyData)),
167                 Arguments.of("RSA certificate",
168                         buildServerIdentityWithCertificate(rsaKeyData), buildServerAuthWithCertificate(rsaKeyData)),
169                 Arguments.of("EC certificate",
170                         buildServerIdentityWithCertificate(ecKeyData), buildServerAuthWithCertificate(ecKeyData))
171         );
172     }
173
174     @ParameterizedTest(name = "SSH User Auth using {0}")
175     @MethodSource("itUserAuthArgs")
176     void itUserAuth(final String testDesc, final ClientIdentity clientIdentity, final ClientAuthentication clientAuth)
177             throws Exception {
178         final var serverIdentity = buildServerIdentityWithKeyPair(generateKeyPairWithCertificate(RSA)); // required
179         when(sshClientConfig.getClientIdentity()).thenReturn(clientIdentity);
180         when(sshClientConfig.getServerAuthentication()).thenReturn(null); // Accept all keys
181         when(sshServerConfig.getServerIdentity()).thenReturn(serverIdentity);
182         when(sshServerConfig.getClientAuthentication()).thenReturn(clientAuth);
183         integrationTest(
184             () -> FACTORY.listenServer(SUBSYSTEM, serverListener, tcpServerConfig, sshServerConfig),
185             () -> FACTORY.connectClient(SUBSYSTEM, clientListener, tcpClientConfig, sshClientConfig));
186     }
187
188     private static Stream<Arguments> itUserAuthArgs() throws Exception {
189         final var rsaKeyData = generateKeyPairWithCertificate(RSA);
190         final var ecKeyData = generateKeyPairWithCertificate(EC);
191         return Stream.of(
192                 Arguments.of("Password -- clear text ",
193                         buildClientIdentityWithPassword(getUsername(), PASSWORD),
194                         buildClientAuthWithPassword(getUsernameAndUpdate(), "$0$" + PASSWORD)),
195                 Arguments.of("Password -- MD5",
196                         buildClientIdentityWithPassword(getUsername(), PASSWORD),
197                         buildClientAuthWithPassword(getUsernameAndUpdate(), Crypt.crypt(PASSWORD, "$1$md5salt"))),
198                 Arguments.of("Password -- SHA-256",
199                         buildClientIdentityWithPassword(getUsername(), PASSWORD),
200                         buildClientAuthWithPassword(getUsernameAndUpdate(),
201                                 Crypt.crypt(PASSWORD, "$5$sha256salt"))),
202                 Arguments.of("Password -- SHA-512 with rounds",
203                         buildClientIdentityWithPassword(getUsername(), PASSWORD),
204                         buildClientAuthWithPassword(getUsernameAndUpdate(),
205                                 Crypt.crypt(PASSWORD, "$6$rounds=4500$sha512salt"))),
206                 Arguments.of("HostBased -- RSA keys",
207                         buildClientIdentityHostBased(getUsername(), rsaKeyData),
208                         buildClientAuthHostBased(getUsernameAndUpdate(), rsaKeyData)),
209                 Arguments.of("HostBased -- EC keys",
210                         buildClientIdentityHostBased(getUsername(), ecKeyData),
211                         buildClientAuthHostBased(getUsernameAndUpdate(), ecKeyData)),
212                 Arguments.of("PublicKey -- RSA keys",
213                         buildClientIdentityWithPublicKey(getUsername(), rsaKeyData),
214                         buildClientAuthWithPublicKey(getUsernameAndUpdate(), rsaKeyData)),
215                 Arguments.of("PublicKey -- EC keys",
216                         buildClientIdentityWithPublicKey(getUsername(), ecKeyData),
217                         buildClientAuthWithPublicKey(getUsernameAndUpdate(), ecKeyData))
218         );
219     }
220
221     private static String getUsername() {
222         return USERNAME.get();
223     }
224
225     /**
226      * Update username for next test.
227      */
228     private static String getUsernameAndUpdate() {
229         return USERNAME.getAndSet(USER + COUNTER.incrementAndGet());
230     }
231
232     private void integrationTest(final Builder<SSHServer> serverBuilder,
233             final Builder<SSHClient> clientBuilder) throws Exception {
234         // start server
235         final var server = serverBuilder.build().get(2, TimeUnit.SECONDS);
236         try {
237             // connect with client
238             final var client = clientBuilder.build().get(2, TimeUnit.SECONDS);
239             try {
240                 verify(serverListener, timeout(10_000))
241                         .onTransportChannelEstablished(serverTransportChannelCaptor.capture());
242                 verify(clientListener, timeout(10_000))
243                         .onTransportChannelEstablished(clientTransportChannelCaptor.capture());
244                 // validate channels are in expected state
245                 var serverChannel = assertChannel(serverTransportChannelCaptor.getAllValues());
246                 var clientChannel = assertChannel(clientTransportChannelCaptor.getAllValues());
247                 // validate channels are connecting same sockets
248                 assertEquals(serverChannel.remoteAddress(), clientChannel.localAddress());
249                 assertEquals(serverChannel.localAddress(), clientChannel.remoteAddress());
250                 // validate sessions are authenticated
251                 assertSession(ServerSession.class, server.getSessions());
252                 assertSession(ClientSession.class, client.getSessions());
253
254             } finally {
255                 client.shutdown().get(2, TimeUnit.SECONDS);
256             }
257         } finally {
258             server.shutdown().get(2, TimeUnit.SECONDS);
259         }
260     }
261
262     @Test
263     @DisplayName("External service integration")
264     void externalServiceIntegration() throws Exception {
265         final var username = getUsernameAndUpdate();
266         when(sshClientConfig.getClientIdentity()).thenReturn(usernameOnlyIdentity(username));
267         when(sshClientConfig.getServerAuthentication()).thenReturn(null);
268         integrationTest(
269             () -> FACTORY.listenServer(SUBSYSTEM, serverListener, tcpServerConfig, null, serverConfigurator(username)),
270             () -> FACTORY.connectClient(SUBSYSTEM, clientListener, tcpClientConfig, sshClientConfig,
271                 clientConfigurator(username)));
272     }
273
274     @Test
275     @DisplayName("Call-home protocol support with services integration")
276     void callHome() throws Exception {
277         final var username = getUsernameAndUpdate();
278         when(sshClientConfig.getClientIdentity()).thenReturn(usernameOnlyIdentity(username));
279         when(sshClientConfig.getServerAuthentication()).thenReturn(null);
280
281         // start call-home client first, accepting inbound tcp connections
282         final var client = FACTORY.listenClient(SUBSYSTEM, clientListener, tcpServerConfig, sshClientConfig,
283                 clientConfigurator(username)).get(2, TimeUnit.SECONDS);
284         try {
285             // start a call-home server, init connection
286             final var server = FACTORY.connectServer(SUBSYSTEM, serverListener, tcpClientConfig, null,
287                     serverConfigurator(username)).get(2, TimeUnit.SECONDS);
288             try {
289                 verify(serverListener, timeout(10_000))
290                     .onTransportChannelEstablished(serverTransportChannelCaptor.capture());
291                 verify(clientListener, timeout(10_000))
292                     .onTransportChannelEstablished(clientTransportChannelCaptor.capture());
293                 // validate channels are in expected state
294                 var serverChannel = assertChannel(serverTransportChannelCaptor.getAllValues());
295                 var clientChannel = assertChannel(clientTransportChannelCaptor.getAllValues());
296                 // validate channels are connecting same sockets
297                 assertEquals(serverChannel.remoteAddress(), clientChannel.localAddress());
298                 assertEquals(serverChannel.localAddress(), clientChannel.remoteAddress());
299                 // validate sessions are authenticated
300                 assertSession(ClientSession.class, client.getSessions());
301                 assertSession(ServerSession.class, server.getSessions());
302
303             } finally {
304                 server.shutdown().get(2, TimeUnit.SECONDS);
305             }
306         } finally {
307             client.shutdown().get(2, TimeUnit.SECONDS);
308         }
309     }
310
311     private static Channel assertChannel(final List<TransportChannel> transportChannels) {
312         assertNotNull(transportChannels);
313         assertEquals(1, transportChannels.size());
314         final var channel = assertInstanceOf(SSHTransportChannel.class, transportChannels.get(0)).channel();
315         assertNotNull(channel);
316         assertTrue(channel.isOpen());
317         return channel;
318     }
319
320     private static <T extends Session> void assertSession(final Class<T> type, final Collection<Session> sessions) {
321         assertNotNull(sessions);
322         assertEquals(1, sessions.size());
323         final T session = assertInstanceOf(type, sessions.iterator().next());
324         assertTrue(session.isAuthenticated());
325     }
326
327     private static ClientIdentity usernameOnlyIdentity(final String username) {
328         return new ClientIdentityBuilder().setUsername(username).build();
329     }
330
331     private static ServerFactoryManagerConfigurator serverConfigurator(final String username) {
332         return factoryManager -> {
333             // authenticate user by credentials and generate host key
334             factoryManager.setUserAuthFactories(List.of(new UserAuthPasswordFactory()));
335             factoryManager.setPasswordAuthenticator(
336                 (usr, psw, session) -> username.equals(usr) && PASSWORD.equals(psw));
337             factoryManager.setKeyPairProvider(new SimpleGeneratorHostKeyProvider());
338         };
339     }
340
341     private static ClientFactoryManagerConfigurator clientConfigurator(final String username) {
342         return new ClientFactoryManagerConfigurator() {
343             @Override
344             protected void configureClientFactoryManager(final ClientFactoryManager factoryManager)
345                     throws UnsupportedConfigurationException {
346                 factoryManager.setPasswordIdentityProvider(PasswordIdentityProvider.wrapPasswords(PASSWORD));
347                 factoryManager.setUserAuthFactories(List.of(
348                     new org.opendaylight.netconf.shaded.sshd.client.auth.password.UserAuthPasswordFactory()));
349             }
350         };
351     }
352
353     @Test
354     @DisplayName("Handle channel inactive event")
355     void handleChannelInactive() throws Exception {
356         final var username = getUsernameAndUpdate();
357         when(sshClientConfig.getClientIdentity()).thenReturn(usernameOnlyIdentity(username));
358         when(sshClientConfig.getServerAuthentication()).thenReturn(null);
359
360         // place channelInactive handlers on a server side channel when connection is established
361         final var firstHandlerFuture = SettableFuture.<Boolean>create();
362         final var lastHandlerFuture = SettableFuture.<Boolean>create();
363         final var serverTransportListener = new TransportChannelListener() {
364             @Override
365             public void onTransportChannelEstablished(final TransportChannel channel) {
366                 channel.channel().pipeline().addFirst("FIRST", new ChannelInboundHandlerAdapter() {
367                     @Override
368                     public void channelInactive(final ChannelHandlerContext ctx) throws Exception {
369                         firstHandlerFuture.set(Boolean.TRUE);
370                         ctx.fireChannelInactive();
371                     }
372                 });
373                 channel.channel().pipeline().addLast("LAST", new ChannelInboundHandlerAdapter() {
374                     @Override
375                     public void channelInactive(final ChannelHandlerContext ctx) throws Exception {
376                         lastHandlerFuture.set(Boolean.TRUE);
377                         ctx.fireChannelInactive();
378                     }
379                 });
380             }
381
382             @Override
383             public void onTransportChannelFailed(final Throwable cause) {
384                 // not used
385             }
386         };
387
388         final var server = FACTORY.listenServer(SUBSYSTEM, serverTransportListener, tcpServerConfig, null,
389                 serverConfigurator(username)).get(2, TimeUnit.SECONDS);
390         try {
391             // connect with client
392             final var client = FACTORY.connectClient(SUBSYSTEM, clientListener, tcpClientConfig, sshClientConfig,
393                 clientConfigurator(username)).get(2, TimeUnit.SECONDS);
394             try {
395                 verify(clientListener, timeout(10_000)).onTransportChannelEstablished(any(TransportChannel.class));
396             } finally {
397                 // disconnect client
398                 client.shutdown().get(2, TimeUnit.SECONDS);
399                 // validate channel closure on server side is handled properly:
400                 // both first and last handlers expected to be triggered
401                 // indicating there is no obstacles for the event in a channel pipeline
402                 assertEquals(Boolean.TRUE, firstHandlerFuture.get(1, TimeUnit.SECONDS));
403                 assertEquals(Boolean.TRUE, lastHandlerFuture.get(1, TimeUnit.SECONDS));
404             }
405         } finally {
406             server.shutdown().get(2, TimeUnit.SECONDS);
407         }
408     }
409
410     @FunctionalInterface
411     private interface Builder<T extends SSHTransportStack> {
412         ListenableFuture<T> build() throws UnsupportedConfigurationException;
413     }
414 }