import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;
+import com.google.common.annotations.VisibleForTesting;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.netty.channel.EventLoopGroup;
import io.netty.util.concurrent.GlobalEventExecutor;
class CallHomeSessionContext implements CallHomeProtocolSessionContext {
private static final Logger LOG = LoggerFactory.getLogger(CallHomeSessionContext.class);
- static final Session.AttributeKey<CallHomeSessionContext> SESSION_KEY = new Session.AttributeKey<>();
-
private static final String NETCONF = "netconf";
+ @VisibleForTesting
+ static final Session.AttributeKey<CallHomeSessionContext> SESSION_KEY = new Session.AttributeKey<>();
+
private final ClientSession sshSession;
private final CallHomeAuthorization authorization;
private final Factory factory;
checkArgument(this.authorization.isServerAllowed(), "Server was not allowed.");
this.factory = requireNonNull(factory);
this.sshSession = requireNonNull(sshSession);
- this.sshSession.setAttribute(SESSION_KEY, this);
this.remoteAddress = (InetSocketAddress) this.sshSession.getIoSession().getRemoteAddress();
serverKey = this.sshSession.getServerKey();
}
+ final void associate() {
+ sshSession.setAttribute(SESSION_KEY, this);
+ }
+
static CallHomeSessionContext getFrom(final ClientSession sshSession) {
return sshSession.getAttribute(SESSION_KEY);
}
}
static class Factory {
-
+ private final ConcurrentMap<String, CallHomeSessionContext> sessions = new ConcurrentHashMap<>();
private final EventLoopGroup nettyGroup;
private final NetconfClientSessionNegotiatorFactory negotiatorFactory;
private final CallHomeNetconfSubsystemListener subsystemListener;
- private final ConcurrentMap<String, CallHomeSessionContext> sessions = new ConcurrentHashMap<>();
Factory(final EventLoopGroup nettyGroup, final NetconfClientSessionNegotiatorFactory negotiatorFactory,
final CallHomeNetconfSubsystemListener subsystemListener) {
- this.nettyGroup = requireNonNull(nettyGroup, "nettyGroup");
- this.negotiatorFactory = requireNonNull(negotiatorFactory, "negotiatorFactory");
+ this.nettyGroup = requireNonNull(nettyGroup);
+ this.negotiatorFactory = requireNonNull(negotiatorFactory);
this.subsystemListener = requireNonNull(subsystemListener);
}
- void remove(final CallHomeSessionContext session) {
- sessions.remove(session.getSessionId(), session);
- }
-
ReverseSshChannelInitializer getChannelInitializer(final NetconfClientSessionListener listener) {
return ReverseSshChannelInitializer.create(negotiatorFactory, listener);
}
return subsystemListener;
}
+ EventLoopGroup getNettyGroup() {
+ return nettyGroup;
+ }
+
@Nullable CallHomeSessionContext createIfNotExists(final ClientSession sshSession,
final CallHomeAuthorization authorization, final SocketAddress remoteAddress) {
- CallHomeSessionContext session = new CallHomeSessionContext(sshSession, authorization,
- remoteAddress, this);
- CallHomeSessionContext preexisting = sessions.putIfAbsent(session.getSessionId(), session);
- // If preexisting is null - session does not exist, so we can safely create new one, otherwise we return
- // null and incoming connection will be rejected.
- return preexisting == null ? session : null;
+ final var newSession = new CallHomeSessionContext(sshSession, authorization, remoteAddress, this);
+ final var existing = sessions.putIfAbsent(newSession.getSessionId(), newSession);
+ if (existing == null) {
+ // There was no mapping, but now there is. Associate the the context with the session.
+ newSession.associate();
+ return newSession;
+ }
+
+ // We already have a mapping, do not create a new one. But also check if the current session matches
+ // the one stored in the session. This can happen during rekeying.
+ return existing == CallHomeSessionContext.getFrom(sshSession) ? existing : null;
}
- EventLoopGroup getNettyGroup() {
- return nettyGroup;
+ void remove(final CallHomeSessionContext session) {
+ sessions.remove(session.getSessionId(), session);
}
}
}