Cleanup switch certificate chain handling
[openflowplugin.git] / openflowjava / openflow-protocol-impl / src / main / java / org / opendaylight / openflowjava / protocol / impl / core / SslContextFactory.java
index 4ed7618cfaa951d551b84c15f376f6cd233529d0..35ba6f9a58067d13702e43bdcc34f3d37668b48e 100644 (file)
@@ -8,6 +8,8 @@
 
 package org.opendaylight.openflowjava.protocol.impl.core;
 
+import static java.util.Objects.requireNonNull;
+
 import java.io.IOException;
 import java.net.Socket;
 import java.security.KeyManagementException;
@@ -18,11 +20,15 @@ import java.security.Security;
 import java.security.UnrecoverableKeyException;
 import java.security.cert.CertificateException;
 import java.security.cert.X509Certificate;
+import java.util.List;
 import javax.net.ssl.KeyManagerFactory;
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLEngine;
+import javax.net.ssl.TrustManager;
 import javax.net.ssl.TrustManagerFactory;
 import javax.net.ssl.X509ExtendedTrustManager;
+import javax.net.ssl.X509TrustManager;
+import org.eclipse.jdt.annotation.Nullable;
 import org.opendaylight.openflowjava.protocol.api.connection.TlsConfiguration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -33,16 +39,15 @@ import org.slf4j.LoggerFactory;
  * @author michal.polkorab
  */
 public class SslContextFactory {
+    private static final Logger LOG = LoggerFactory.getLogger(SslContextFactory.class);
 
     // "TLS" - supports some version of TLS
     // Use "TLSv1", "TLSv1.1", "TLSv1.2" for specific TLS version
     private static final String PROTOCOL = "TLS";
+
     private final TlsConfiguration tlsConfig;
-    private static X509Certificate switchCertificate = null;
-    private static boolean isCustomTrustManagerEnabled;
 
-    private static final Logger LOG = LoggerFactory
-            .getLogger(SslContextFactory.class);
+    private volatile List<X509Certificate> switchCertificateChain;
 
     /**
      * Sets the TlsConfiguration.
@@ -51,29 +56,20 @@ public class SslContextFactory {
      *            TLS configuration object, contains keystore locations +
      *            keystore types
      */
-    public SslContextFactory(TlsConfiguration tlsConfig) {
-        this.tlsConfig = tlsConfig;
-    }
-
-    public X509Certificate getSwitchCertificate() {
-        return switchCertificate;
+    public SslContextFactory(final TlsConfiguration tlsConfig) {
+        this.tlsConfig = requireNonNull(tlsConfig);
     }
 
-    public boolean isCustomTrustManagerEnabled() {
-        return isCustomTrustManagerEnabled;
+    @Nullable List<X509Certificate> getSwitchCertificateChain() {
+        return switchCertificateChain;
     }
 
-    public static void setSwitchCertificate(X509Certificate certificate) {
-        switchCertificate = certificate;
-    }
-
-    public static void setIsCustomTrustManagerEnabled(boolean customTrustManagerEnabled) {
-        isCustomTrustManagerEnabled = customTrustManagerEnabled;
+    void setSwitchCertificateChain(final X509Certificate[] chain) {
+        switchCertificateChain = List.of(chain);
     }
 
     public SSLContext getServerContext() {
-        String algorithm = Security
-                .getProperty("ssl.KeyManagerFactory.algorithm");
+        String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm");
         if (algorithm == null) {
             algorithm = "SunX509";
         }
@@ -92,16 +88,27 @@ public class SslContextFactory {
             tmf.init(ts);
 
             serverContext = SSLContext.getInstance(PROTOCOL);
-            if (isCustomTrustManagerEnabled) {
-                CustomTrustManager[] customTrustManager = new CustomTrustManager[tmf.getTrustManagers().length];
-                for (int i = 0; i < tmf.getTrustManagers().length; i++) {
-                    customTrustManager[i] = new CustomTrustManager((X509ExtendedTrustManager)
-                            tmf.getTrustManagers()[i]);
+
+            // A bit ugly: intercept trust checks to establish switch certificate
+            final TrustManager[] delegates = tmf.getTrustManagers();
+            final TrustManager[] proxies;
+            if (delegates != null) {
+                proxies = new TrustManager[delegates.length];
+                for (int i = 0; i < delegates.length; i++) {
+                    final TrustManager delegate = delegates[i];
+                    if (delegate instanceof X509ExtendedTrustManager) {
+                        proxies[i] = new ProxyExtendedTrustManager((X509ExtendedTrustManager) delegate);
+                    } else if (delegate instanceof X509TrustManager) {
+                        proxies[i] = new ProxyTrustManager((X509TrustManager) delegate);
+                    } else {
+                        LOG.debug("Cannot handle trust manager {}, passing through", delegate);
+                        proxies[i] = delegate;
+                    }
                 }
-                serverContext.init(kmf.getKeyManagers(), customTrustManager, null);
             } else {
-                serverContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
+                proxies = null;
             }
+            serverContext.init(kmf.getKeyManagers(), proxies, null);
         } catch (IOException e) {
             LOG.warn("IOException - Failed to load keystore / truststore."
                     + " Failed to initialize the server-side SSLContext", e);
@@ -117,58 +124,81 @@ public class SslContextFactory {
         return serverContext;
     }
 
+    private final class ProxyTrustManager implements X509TrustManager {
+        private final X509TrustManager delegate;
+
+        ProxyTrustManager(final X509TrustManager delegate) {
+            this.delegate = requireNonNull(delegate);
+        }
+
+        @Override
+        public void checkClientTrusted(final X509Certificate[] chain, final String authType)
+                throws CertificateException {
+            setSwitchCertificateChain(chain);
+            delegate.checkClientTrusted(chain, authType);
+        }
+
+        @Override
+        public void checkServerTrusted(final X509Certificate[] chain, final String authType)
+                throws CertificateException {
+            delegate.checkServerTrusted(chain, authType);
+        }
+
+        @Override
+        public X509Certificate[] getAcceptedIssuers() {
+            return delegate.getAcceptedIssuers();
+        }
+    }
 
-    private static class CustomTrustManager extends X509ExtendedTrustManager {
-        private final X509ExtendedTrustManager trustManager;
+    private final class ProxyExtendedTrustManager extends X509ExtendedTrustManager {
+        private final X509ExtendedTrustManager delegate;
 
-        CustomTrustManager(final X509ExtendedTrustManager trustManager) {
-            this.trustManager = trustManager;
+        ProxyExtendedTrustManager(final X509ExtendedTrustManager trustManager) {
+            delegate = requireNonNull(trustManager);
         }
 
         @Override
-        public void checkClientTrusted(final X509Certificate[] x509Certificates, final String authType,
-                final Socket socket) throws CertificateException {
-            SslContextFactory.setSwitchCertificate(x509Certificates[0]);
-            trustManager.checkClientTrusted(x509Certificates, authType, socket);
+        public void checkClientTrusted(final X509Certificate[] chain, final String authType, final Socket socket)
+                throws CertificateException {
+            setSwitchCertificateChain(chain);
+            delegate.checkClientTrusted(chain, authType, socket);
         }
 
         @Override
-        public void checkClientTrusted(final X509Certificate[] x509Certificates, final String authType)
+        public void checkClientTrusted(final X509Certificate[] chain, final String authType)
                 throws CertificateException {
-            SslContextFactory.setSwitchCertificate(x509Certificates[0]);
-            trustManager.checkClientTrusted(x509Certificates, authType);
+            setSwitchCertificateChain(chain);
+            delegate.checkClientTrusted(chain, authType);
         }
 
         @Override
-        public void checkClientTrusted(final X509Certificate[] x509Certificates, final String authType,
-                final SSLEngine sslEngine) throws CertificateException {
-            SslContextFactory.setSwitchCertificate(x509Certificates[0]);
-            trustManager.checkClientTrusted(x509Certificates, authType, sslEngine);
+        public void checkClientTrusted(final X509Certificate[] chain, final String authType, final SSLEngine sslEngine)
+                throws CertificateException {
+            setSwitchCertificateChain(chain);
+            delegate.checkClientTrusted(chain, authType, sslEngine);
         }
 
         @Override
-        public void checkServerTrusted(final X509Certificate[] x509Certificates, final String authType,
-                final SSLEngine sslEngine) throws CertificateException {
-            trustManager.checkServerTrusted(x509Certificates, authType, sslEngine);
+        public void checkServerTrusted(final X509Certificate[] chain, final String authType, final SSLEngine sslEngine)
+                throws CertificateException {
+            delegate.checkServerTrusted(chain, authType, sslEngine);
         }
 
         @Override
-        public void checkServerTrusted(final X509Certificate[] x509Certificates, final String authType)
+        public void checkServerTrusted(final X509Certificate[] chain, final String authType)
                 throws CertificateException {
-            trustManager.checkServerTrusted(x509Certificates, authType);
+            delegate.checkServerTrusted(chain, authType);
         }
 
         @Override
-        public void checkServerTrusted(final X509Certificate[] x509Certificates, final String authType,
-                final Socket socket) throws CertificateException {
-            trustManager.checkServerTrusted(x509Certificates, authType, socket);
+        public void checkServerTrusted(final X509Certificate[] chain, final String authType, final Socket socket)
+                throws CertificateException {
+            delegate.checkServerTrusted(chain, authType, socket);
         }
 
         @Override
         public X509Certificate[] getAcceptedIssuers() {
-            return trustManager.getAcceptedIssuers();
+            return delegate.getAcceptedIssuers();
         }
-
     }
-
 }