Optimize AuthorizedKeysDecoder 91/110091/1
authorRobert Varga <robert.varga@pantheon.tech>
Tue, 30 Jan 2024 06:33:03 +0000 (07:33 +0100)
committerRobert Varga <robert.varga@pantheon.tech>
Tue, 30 Jan 2024 09:44:22 +0000 (10:44 +0100)
Use pre-computed bytes for our constants and do not compute individual
components twice.

Change-Id: Iff1428f529db3c5df634cd80b8c08da6772bdbee
Signed-off-by: Robert Varga <robert.varga@pantheon.tech>
apps/callhome-provider/src/main/java/org/opendaylight/netconf/callhome/mount/AuthorizedKeysDecoder.java

index 9470197d9db41c04d3796dcad8d04bb37c7d28ae..bafb0f6aa118b58b1dbc29d7c78b845856fed695 100644 (file)
@@ -9,6 +9,7 @@ package org.opendaylight.netconf.callhome.mount;
 
 import com.google.common.collect.ImmutableMap;
 import java.io.ByteArrayOutputStream;
+import java.io.DataOutput;
 import java.io.DataOutputStream;
 import java.io.IOException;
 import java.math.BigInteger;
@@ -17,7 +18,6 @@ import java.security.GeneralSecurityException;
 import java.security.KeyFactory;
 import java.security.NoSuchAlgorithmException;
 import java.security.PublicKey;
-import java.security.interfaces.DSAParams;
 import java.security.interfaces.DSAPublicKey;
 import java.security.interfaces.RSAPublicKey;
 import java.security.spec.DSAPublicKeySpec;
@@ -47,11 +47,17 @@ public class AuthorizedKeysDecoder {
         .build();
 
     private static final String ECDSA_SUPPORTED_CURVE_NAME = "nistp256";
+    private static final byte[] ECDSA_SUPPORTED_CURVE_NAME_BYTES =
+        ECDSA_SUPPORTED_CURVE_NAME.getBytes(StandardCharsets.UTF_8);
     private static final String ECDSA_SUPPORTED_CURVE_NAME_SPEC = ECDSA_CURVES.get(ECDSA_SUPPORTED_CURVE_NAME);
 
     private static final String KEY_TYPE_RSA = "ssh-rsa";
+    private static final byte[] KEY_TYPE_RSA_BYTES = KEY_TYPE_RSA.getBytes(StandardCharsets.UTF_8);
+
     private static final String KEY_TYPE_DSA = "ssh-dss";
+    private static final byte[] KEY_TYPE_DSA_BYTES = KEY_TYPE_DSA.getBytes(StandardCharsets.UTF_8);
     private static final String KEY_TYPE_ECDSA = "ecdsa-sha2-" + ECDSA_SUPPORTED_CURVE_NAME;
+    private static final byte[] KEY_TYPE_ECDSA_BYTES = KEY_TYPE_ECDSA.getBytes(StandardCharsets.UTF_8);
 
     private static final KeyFactory RSA_FACTORY = loadOrWarn("RSA");
     private static final KeyFactory DSA_FACTORY = loadOrWarn("DSA");
@@ -107,11 +113,11 @@ public class AuthorizedKeysDecoder {
             throw new NoSuchAlgorithmException("RSA keys are not supported");
         }
 
-        BigInteger prime = decodeBigInt();
-        BigInteger subPrime = decodeBigInt();
-        BigInteger base = decodeBigInt();
-        BigInteger publicKey = decodeBigInt();
-        return DSA_FACTORY.generatePublic(new DSAPublicKeySpec(publicKey, prime, subPrime, base));
+        final var p = decodeBigInt();
+        final var q = decodeBigInt();
+        final var g = decodeBigInt();
+        final var y = decodeBigInt();
+        return DSA_FACTORY.generatePublic(new DSAPublicKeySpec(y, p, q, g));
     }
 
     private PublicKey decodeAsRSA() throws GeneralSecurityException {
@@ -119,8 +125,8 @@ public class AuthorizedKeysDecoder {
             throw new NoSuchAlgorithmException("RSA keys are not supported");
         }
 
-        BigInteger exponent = decodeBigInt();
-        BigInteger modulus = decodeBigInt();
+        final var exponent = decodeBigInt();
+        final var modulus = decodeBigInt();
         return RSA_FACTORY.generatePublic(new RSAPublicKeySpec(modulus, exponent));
     }
 
@@ -144,45 +150,46 @@ public class AuthorizedKeysDecoder {
         return new BigInteger(bigIntBytes);
     }
 
+    private static void encodeBigInt(final DataOutput out, final BigInteger value) throws IOException {
+        final var bytes = value.toByteArray();
+        out.writeInt(bytes.length);
+        out.write(bytes);
+    }
+
     public static String encodePublicKey(final PublicKey publicKey) throws IOException {
-        ByteArrayOutputStream byteOs = new ByteArrayOutputStream();
-        if (publicKey instanceof RSAPublicKey rsa) {
-            DataOutputStream dout = new DataOutputStream(byteOs);
-            dout.writeInt(KEY_TYPE_RSA.getBytes(StandardCharsets.UTF_8).length);
-            dout.write(KEY_TYPE_RSA.getBytes(StandardCharsets.UTF_8));
-            dout.writeInt(rsa.getPublicExponent().toByteArray().length);
-            dout.write(rsa.getPublicExponent().toByteArray());
-            dout.writeInt(rsa.getModulus().toByteArray().length);
-            dout.write(rsa.getModulus().toByteArray());
-        } else if (publicKey instanceof DSAPublicKey dsa) {
-            DSAParams dsaParams = dsa.getParams();
-            DataOutputStream dout = new DataOutputStream(byteOs);
-            dout.writeInt(KEY_TYPE_DSA.getBytes(StandardCharsets.UTF_8).length);
-            dout.write(KEY_TYPE_DSA.getBytes(StandardCharsets.UTF_8));
-            dout.writeInt(dsaParams.getP().toByteArray().length);
-            dout.write(dsaParams.getP().toByteArray());
-            dout.writeInt(dsaParams.getQ().toByteArray().length);
-            dout.write(dsaParams.getQ().toByteArray());
-            dout.writeInt(dsaParams.getG().toByteArray().length);
-            dout.write(dsaParams.getG().toByteArray());
-            dout.writeInt(dsa.getY().toByteArray().length);
-            dout.write(dsa.getY().toByteArray());
-        } else if (publicKey instanceof ECPublicKey ec) {
-            DataOutputStream dout = new DataOutputStream(byteOs);
-            dout.writeInt(KEY_TYPE_ECDSA.getBytes(StandardCharsets.UTF_8).length);
-            dout.write(KEY_TYPE_ECDSA.getBytes(StandardCharsets.UTF_8));
-            dout.writeInt(ECDSA_SUPPORTED_CURVE_NAME.getBytes(StandardCharsets.UTF_8).length);
-            dout.write(ECDSA_SUPPORTED_CURVE_NAME.getBytes(StandardCharsets.UTF_8));
-
-            byte[] coordX = ec.getQ().getAffineXCoord().getEncoded();
-            byte[] coordY = ec.getQ().getAffineYCoord().getEncoded();
-            dout.writeInt(coordX.length + coordY.length + 1);
-            dout.writeByte(0x04);
-            dout.write(coordX);
-            dout.write(coordY);
-        } else {
-            throw new IllegalArgumentException("Unknown public key encoding: " + publicKey);
+        final var baos = new ByteArrayOutputStream();
+
+        try (var dout = new DataOutputStream(baos)) {
+            if (publicKey instanceof RSAPublicKey rsa) {
+                dout.writeInt(KEY_TYPE_RSA_BYTES.length);
+                dout.write(KEY_TYPE_RSA_BYTES);
+                encodeBigInt(dout, rsa.getPublicExponent());
+                encodeBigInt(dout, rsa.getModulus());
+            } else if (publicKey instanceof DSAPublicKey dsa) {
+                final var dsaParams = dsa.getParams();
+                dout.writeInt(KEY_TYPE_DSA_BYTES.length);
+                dout.write(KEY_TYPE_DSA_BYTES);
+                encodeBigInt(dout, dsaParams.getP());
+                encodeBigInt(dout, dsaParams.getQ());
+                encodeBigInt(dout, dsaParams.getG());
+                encodeBigInt(dout, dsa.getY());
+            } else if (publicKey instanceof ECPublicKey ec) {
+                dout.writeInt(KEY_TYPE_ECDSA_BYTES.length);
+                dout.write(KEY_TYPE_ECDSA_BYTES);
+                dout.writeInt(ECDSA_SUPPORTED_CURVE_NAME_BYTES.length);
+                dout.write(ECDSA_SUPPORTED_CURVE_NAME_BYTES);
+
+                final var q = ec.getQ();
+                final var coordX = q.getAffineXCoord().getEncoded();
+                final var coordY = q.getAffineYCoord().getEncoded();
+                dout.writeInt(coordX.length + coordY.length + 1);
+                dout.writeByte(0x04);
+                dout.write(coordX);
+                dout.write(coordY);
+            } else {
+                throw new IllegalArgumentException("Unknown public key encoding: " + publicKey);
+            }
         }
-        return Base64.getEncoder().encodeToString(byteOs.toByteArray());
+        return Base64.getEncoder().encodeToString(baos.toByteArray());
     }
 }