Do not use SecurityUtils in callhome-provider 49/110049/1
authorRobert Varga <robert.varga@pantheon.tech>
Sat, 27 Jan 2024 20:50:26 +0000 (21:50 +0100)
committerRobert Varga <robert.varga@pantheon.tech>
Sat, 27 Jan 2024 20:50:26 +0000 (21:50 +0100)
Use plain KeyStore access to acquire provides. This eliminate
AuthorizedKeysDecoder's dependency on sshd -- which does not make sense
in its current shape and form.

Change-Id: I95e743a34d78f7220e2edf49dbac177a132f0c3f
Signed-off-by: Robert Varga <robert.varga@pantheon.tech>
apps/callhome-provider/src/main/java/org/opendaylight/netconf/callhome/mount/AuthorizedKeysDecoder.java

index ec8593818e7ff6ff59aa9d23b8d1da070a355338..9470197d9db41c04d3796dcad8d04bb37c7d28ae 100644 (file)
@@ -7,6 +7,7 @@
  */
 package org.opendaylight.netconf.callhome.mount;
 
+import com.google.common.collect.ImmutableMap;
 import java.io.ByteArrayOutputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
@@ -14,6 +15,7 @@ import java.math.BigInteger;
 import java.nio.charset.StandardCharsets;
 import java.security.GeneralSecurityException;
 import java.security.KeyFactory;
+import java.security.NoSuchAlgorithmException;
 import java.security.PublicKey;
 import java.security.interfaces.DSAParams;
 import java.security.interfaces.DSAPublicKey;
@@ -24,31 +26,25 @@ import java.security.spec.ECPublicKeySpec;
 import java.security.spec.RSAPublicKeySpec;
 import java.util.Arrays;
 import java.util.Base64;
-import java.util.HashMap;
-import java.util.Map;
-import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
 import org.bouncycastle.jce.ECNamedCurveTable;
 import org.bouncycastle.jce.ECPointUtil;
+import org.bouncycastle.jce.interfaces.ECPublicKey;
 import org.bouncycastle.jce.spec.ECNamedCurveParameterSpec;
 import org.bouncycastle.jce.spec.ECNamedCurveSpec;
-import org.opendaylight.netconf.shaded.sshd.common.util.security.SecurityUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * FIXME: This should be probably located at AAA library.
  */
 public class AuthorizedKeysDecoder {
+    private static final Logger LOG = LoggerFactory.getLogger(AuthorizedKeysDecoder.class);
 
-    private static final String KEY_FACTORY_TYPE_RSA = "RSA";
-    private static final String KEY_FACTORY_TYPE_DSA = "DSA";
-    private static final String KEY_FACTORY_TYPE_ECDSA = "EC";
-
-    private static final Map<String, String> ECDSA_CURVES = new HashMap<>();
-
-    static {
-        ECDSA_CURVES.put("nistp256", "secp256r1");
-        ECDSA_CURVES.put("nistp384", "secp384r1");
-        ECDSA_CURVES.put("nistp512", "secp512r1");
-    }
+    private static final ImmutableMap<String, String> ECDSA_CURVES = ImmutableMap.<String, String>builder()
+        .put("nistp256", "secp256r1")
+        .put("nistp384", "secp384r1")
+        .put("nistp512", "secp512r1")
+        .build();
 
     private static final String ECDSA_SUPPORTED_CURVE_NAME = "nistp256";
     private static final String ECDSA_SUPPORTED_CURVE_NAME_SPEC = ECDSA_CURVES.get(ECDSA_SUPPORTED_CURVE_NAME);
@@ -57,6 +53,19 @@ public class AuthorizedKeysDecoder {
     private static final String KEY_TYPE_DSA = "ssh-dss";
     private static final String KEY_TYPE_ECDSA = "ecdsa-sha2-" + ECDSA_SUPPORTED_CURVE_NAME;
 
+    private static final KeyFactory RSA_FACTORY = loadOrWarn("RSA");
+    private static final KeyFactory DSA_FACTORY = loadOrWarn("DSA");
+    private static final KeyFactory EC_FACTORY = loadOrWarn("EC");
+
+    private static KeyFactory loadOrWarn(final String algorithm) {
+        try {
+            return KeyFactory.getInstance(algorithm);
+        } catch (NoSuchAlgorithmException e) {
+            LOG.warn("KeyFactory for {} not found", algorithm, e);
+            return null;
+        }
+    }
+
     private byte[] bytes = new byte[0];
     private int pos = 0;
 
@@ -71,53 +80,48 @@ public class AuthorizedKeysDecoder {
 
         pos = 0;
 
-        String type = decodeType();
-        if (type.equals(KEY_TYPE_RSA)) {
-            return decodeAsRSA();
-        }
-
-        if (type.equals(KEY_TYPE_DSA)) {
-            return decodeAsDSA();
-        }
-
-        if (type.equals(KEY_TYPE_ECDSA)) {
-            return decodeAsEcDSA();
-        }
-
-        throw new IllegalArgumentException("Unknown decode key type " + type + " in " + keyLine);
+        final var type = decodeType();
+        return switch (type) {
+            case KEY_TYPE_RSA -> decodeAsRSA();
+            case KEY_TYPE_DSA -> decodeAsDSA();
+            case KEY_TYPE_ECDSA -> decodeAsEcDSA();
+            default -> throw new IllegalArgumentException("Unknown decode key type " + type + " in " + keyLine);
+        };
     }
 
     private PublicKey decodeAsEcDSA() throws GeneralSecurityException {
-        KeyFactory ecdsaFactory = SecurityUtils.getKeyFactory(KEY_FACTORY_TYPE_ECDSA);
+        if (EC_FACTORY == null) {
+            throw new NoSuchAlgorithmException("ECDSA keys are not supported");
+        }
 
         ECNamedCurveParameterSpec spec256r1 = ECNamedCurveTable.getParameterSpec(ECDSA_SUPPORTED_CURVE_NAME_SPEC);
         ECNamedCurveSpec params256r1 = new ECNamedCurveSpec(
             ECDSA_SUPPORTED_CURVE_NAME_SPEC, spec256r1.getCurve(), spec256r1.getG(), spec256r1.getN());
         // copy last 65 bytes from ssh key.
         ECPoint point = ECPointUtil.decodePoint(params256r1.getCurve(), Arrays.copyOfRange(bytes, 39, bytes.length));
-        ECPublicKeySpec pubKeySpec = new ECPublicKeySpec(point, params256r1);
-
-        return ecdsaFactory.generatePublic(pubKeySpec);
+        return EC_FACTORY.generatePublic(new ECPublicKeySpec(point, params256r1));
     }
 
     private PublicKey decodeAsDSA() throws GeneralSecurityException {
-        KeyFactory dsaFactory = SecurityUtils.getKeyFactory(KEY_FACTORY_TYPE_DSA);
+        if (DSA_FACTORY == null) {
+            throw new NoSuchAlgorithmException("RSA keys are not supported");
+        }
+
         BigInteger prime = decodeBigInt();
         BigInteger subPrime = decodeBigInt();
         BigInteger base = decodeBigInt();
         BigInteger publicKey = decodeBigInt();
-        DSAPublicKeySpec spec = new DSAPublicKeySpec(publicKey, prime, subPrime, base);
-
-        return dsaFactory.generatePublic(spec);
+        return DSA_FACTORY.generatePublic(new DSAPublicKeySpec(publicKey, prime, subPrime, base));
     }
 
     private PublicKey decodeAsRSA() throws GeneralSecurityException {
-        KeyFactory rsaFactory = SecurityUtils.getKeyFactory(KEY_FACTORY_TYPE_RSA);
+        if (RSA_FACTORY == null) {
+            throw new NoSuchAlgorithmException("RSA keys are not supported");
+        }
+
         BigInteger exponent = decodeBigInt();
         BigInteger modulus = decodeBigInt();
-        RSAPublicKeySpec spec = new RSAPublicKeySpec(modulus, exponent);
-
-        return rsaFactory.generatePublic(spec);
+        return RSA_FACTORY.generatePublic(new RSAPublicKeySpec(modulus, exponent));
     }
 
     private String decodeType() {
@@ -142,18 +146,16 @@ public class AuthorizedKeysDecoder {
 
     public static String encodePublicKey(final PublicKey publicKey) throws IOException {
         ByteArrayOutputStream byteOs = new ByteArrayOutputStream();
-        if (publicKey.getAlgorithm().equals(KEY_FACTORY_TYPE_RSA) && publicKey instanceof RSAPublicKey) {
-            RSAPublicKey rsaPublicKey = (RSAPublicKey) publicKey;
+        if (publicKey instanceof RSAPublicKey rsa) {
             DataOutputStream dout = new DataOutputStream(byteOs);
             dout.writeInt(KEY_TYPE_RSA.getBytes(StandardCharsets.UTF_8).length);
             dout.write(KEY_TYPE_RSA.getBytes(StandardCharsets.UTF_8));
-            dout.writeInt(rsaPublicKey.getPublicExponent().toByteArray().length);
-            dout.write(rsaPublicKey.getPublicExponent().toByteArray());
-            dout.writeInt(rsaPublicKey.getModulus().toByteArray().length);
-            dout.write(rsaPublicKey.getModulus().toByteArray());
-        } else if (publicKey.getAlgorithm().equals(KEY_FACTORY_TYPE_DSA) && publicKey instanceof DSAPublicKey) {
-            DSAPublicKey dsaPublicKey = (DSAPublicKey) publicKey;
-            DSAParams dsaParams = dsaPublicKey.getParams();
+            dout.writeInt(rsa.getPublicExponent().toByteArray().length);
+            dout.write(rsa.getPublicExponent().toByteArray());
+            dout.writeInt(rsa.getModulus().toByteArray().length);
+            dout.write(rsa.getModulus().toByteArray());
+        } else if (publicKey instanceof DSAPublicKey dsa) {
+            DSAParams dsaParams = dsa.getParams();
             DataOutputStream dout = new DataOutputStream(byteOs);
             dout.writeInt(KEY_TYPE_DSA.getBytes(StandardCharsets.UTF_8).length);
             dout.write(KEY_TYPE_DSA.getBytes(StandardCharsets.UTF_8));
@@ -163,24 +165,23 @@ public class AuthorizedKeysDecoder {
             dout.write(dsaParams.getQ().toByteArray());
             dout.writeInt(dsaParams.getG().toByteArray().length);
             dout.write(dsaParams.getG().toByteArray());
-            dout.writeInt(dsaPublicKey.getY().toByteArray().length);
-            dout.write(dsaPublicKey.getY().toByteArray());
-        } else if (publicKey.getAlgorithm().equals(KEY_FACTORY_TYPE_ECDSA) && publicKey instanceof BCECPublicKey) {
-            BCECPublicKey ecPublicKey = (BCECPublicKey) publicKey;
+            dout.writeInt(dsa.getY().toByteArray().length);
+            dout.write(dsa.getY().toByteArray());
+        } else if (publicKey instanceof ECPublicKey ec) {
             DataOutputStream dout = new DataOutputStream(byteOs);
             dout.writeInt(KEY_TYPE_ECDSA.getBytes(StandardCharsets.UTF_8).length);
             dout.write(KEY_TYPE_ECDSA.getBytes(StandardCharsets.UTF_8));
             dout.writeInt(ECDSA_SUPPORTED_CURVE_NAME.getBytes(StandardCharsets.UTF_8).length);
             dout.write(ECDSA_SUPPORTED_CURVE_NAME.getBytes(StandardCharsets.UTF_8));
 
-            byte[] coordX = ecPublicKey.getQ().getAffineXCoord().getEncoded();
-            byte[] coordY = ecPublicKey.getQ().getAffineYCoord().getEncoded();
+            byte[] coordX = ec.getQ().getAffineXCoord().getEncoded();
+            byte[] coordY = ec.getQ().getAffineYCoord().getEncoded();
             dout.writeInt(coordX.length + coordY.length + 1);
             dout.writeByte(0x04);
             dout.write(coordX);
             dout.write(coordY);
         } else {
-            throw new IllegalArgumentException("Unknown public key encoding: " + publicKey.getAlgorithm());
+            throw new IllegalArgumentException("Unknown public key encoding: " + publicKey);
         }
         return Base64.getEncoder().encodeToString(byteOs.toByteArray());
     }