Use pattern match on instanceof
[aaa.git] / aaa-encrypt-service / api / src / main / java / org / opendaylight / aaa / encrypt / PKIUtil.java
1 /*
2  * Copyright (c) 2017 Brocade Communication Systems and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.aaa.encrypt;
9
10 import java.io.ByteArrayOutputStream;
11 import java.io.DataOutputStream;
12 import java.io.FileInputStream;
13 import java.io.IOException;
14 import java.io.InputStreamReader;
15 import java.io.Reader;
16 import java.io.StringReader;
17 import java.math.BigInteger;
18 import java.nio.charset.StandardCharsets;
19 import java.security.GeneralSecurityException;
20 import java.security.KeyFactory;
21 import java.security.KeyPair;
22 import java.security.NoSuchAlgorithmException;
23 import java.security.Provider;
24 import java.security.PublicKey;
25 import java.security.Security;
26 import java.security.interfaces.DSAParams;
27 import java.security.interfaces.DSAPublicKey;
28 import java.security.interfaces.RSAPublicKey;
29 import java.security.spec.DSAPublicKeySpec;
30 import java.security.spec.ECPoint;
31 import java.security.spec.ECPublicKeySpec;
32 import java.security.spec.RSAPublicKeySpec;
33 import java.util.Arrays;
34 import java.util.Base64;
35 import java.util.HashMap;
36 import java.util.Map;
37 import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
38 import org.bouncycastle.jce.ECNamedCurveTable;
39 import org.bouncycastle.jce.ECPointUtil;
40 import org.bouncycastle.jce.provider.BouncyCastleProvider;
41 import org.bouncycastle.jce.spec.ECNamedCurveParameterSpec;
42 import org.bouncycastle.jce.spec.ECNamedCurveSpec;
43 import org.bouncycastle.openssl.PEMDecryptorProvider;
44 import org.bouncycastle.openssl.PEMEncryptedKeyPair;
45 import org.bouncycastle.openssl.PEMKeyPair;
46 import org.bouncycastle.openssl.PEMParser;
47 import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
48 import org.bouncycastle.openssl.jcajce.JcePEMDecryptorProviderBuilder;
49
50 /**
51  * PKI related utilities.
52  */
53 public class PKIUtil {
54     @FunctionalInterface
55     private interface KeyFactorySupplier {
56         KeyFactory get() throws NoSuchAlgorithmException;
57     }
58
59     private static final Provider BCPROV;
60
61     static {
62         final Provider prov = Security.getProvider(BouncyCastleProvider.PROVIDER_NAME);
63         BCPROV = prov != null ? prov : new BouncyCastleProvider();
64     }
65
66     private static final String KEY_FACTORY_TYPE_RSA = "RSA";
67     private static final String KEY_FACTORY_TYPE_DSA = "DSA";
68     private static final String KEY_FACTORY_TYPE_ECDSA = "EC";
69
70     private static final KeyFactorySupplier RSA_KEY_FACTORY_SUPPLIER = resolveKeyFactory(KEY_FACTORY_TYPE_RSA);
71     private static final KeyFactorySupplier DSA_KEY_FACTORY_SUPPLIER = resolveKeyFactory(KEY_FACTORY_TYPE_DSA);
72     private static final KeyFactorySupplier ECDSA_KEY_FACTORY_SUPPLIER = resolveKeyFactory(KEY_FACTORY_TYPE_ECDSA);
73
74     private static KeyFactorySupplier resolveKeyFactory(final String algorithm) {
75         final KeyFactory factory;
76         try {
77             factory = KeyFactory.getInstance(algorithm);
78         } catch (NoSuchAlgorithmException e) {
79             return () -> {
80                 throw e;
81             };
82         }
83         return () -> factory;
84     }
85
86     private static final Map<String, String> ECDSA_CURVES = new HashMap<>();
87
88     static {
89         ECDSA_CURVES.put("nistp256", "secp256r1");
90         ECDSA_CURVES.put("nistp384", "secp384r1");
91         ECDSA_CURVES.put("nistp512", "secp512r1");
92     }
93
94     private static final String ECDSA_SUPPORTED_CURVE_NAME = "nistp256";
95     private static final String ECDSA_SUPPORTED_CURVE_NAME_SPEC = ECDSA_CURVES.get(ECDSA_SUPPORTED_CURVE_NAME);
96     private static final int ECDSA_THIRD_STR_LEN = 65;
97     private static final int ECDSA_TOTAL_STR_LEN = 104;
98
99     private static final String KEY_TYPE_RSA = "ssh-rsa";
100     private static final String KEY_TYPE_DSA = "ssh-dss";
101     private static final String KEY_TYPE_ECDSA = "ecdsa-sha2-" + ECDSA_SUPPORTED_CURVE_NAME;
102
103     private byte[] bytes = new byte[0];
104     private int pos = 0;
105
106     public PublicKey decodePublicKey(final String keyLine) throws GeneralSecurityException {
107
108         // look for the Base64 encoded part of the line to decode
109         // both ssh-rsa and ssh-dss begin with "AAAA" due to the length bytes
110         bytes = Base64.getDecoder().decode(keyLine.getBytes(StandardCharsets.UTF_8));
111         if (bytes.length == 0) {
112             throw new IllegalArgumentException("No Base64 part to decode in " + keyLine);
113         }
114         pos = 0;
115
116         String type = decodeType();
117         if (type.equals(KEY_TYPE_RSA)) {
118             return decodeAsRSA();
119         }
120
121         if (type.equals(KEY_TYPE_DSA)) {
122             return decodeAsDSA();
123         }
124
125         if (type.equals(KEY_TYPE_ECDSA)) {
126             return decodeAsECDSA();
127         }
128
129         throw new IllegalArgumentException("Unknown decode key type " + type + " in " + keyLine);
130     }
131
132     @SuppressWarnings("AbbreviationAsWordInName")
133     private PublicKey decodeAsECDSA() throws GeneralSecurityException {
134         KeyFactory ecdsaFactory = ECDSA_KEY_FACTORY_SUPPLIER.get();
135
136         ECNamedCurveParameterSpec spec256r1 = ECNamedCurveTable.getParameterSpec(ECDSA_SUPPORTED_CURVE_NAME_SPEC);
137         ECNamedCurveSpec params256r1 = new ECNamedCurveSpec(ECDSA_SUPPORTED_CURVE_NAME_SPEC, spec256r1.getCurve(),
138                 spec256r1.getG(), spec256r1.getN());
139         // The total length is 104 bytes, and the X and Y encoding uses the last 65 of these 104 bytes.
140         ECPoint point = ECPointUtil.decodePoint(params256r1.getCurve(), Arrays.copyOfRange(bytes, ECDSA_TOTAL_STR_LEN
141                 - ECDSA_THIRD_STR_LEN, ECDSA_TOTAL_STR_LEN));
142         ECPublicKeySpec pubKeySpec = new ECPublicKeySpec(point, params256r1);
143
144         return ecdsaFactory.generatePublic(pubKeySpec);
145     }
146
147     private PublicKey decodeAsDSA() throws GeneralSecurityException {
148         KeyFactory dsaFactory = DSA_KEY_FACTORY_SUPPLIER.get();
149         BigInteger var1 = decodeBigInt();
150         BigInteger var2 = decodeBigInt();
151         BigInteger var3 = decodeBigInt();
152         BigInteger var4 = decodeBigInt();
153         DSAPublicKeySpec spec = new DSAPublicKeySpec(var4, var1, var2, var3);
154
155         return dsaFactory.generatePublic(spec);
156     }
157
158     private PublicKey decodeAsRSA() throws GeneralSecurityException {
159         KeyFactory rsaFactory = RSA_KEY_FACTORY_SUPPLIER.get();
160         BigInteger exponent = decodeBigInt();
161         BigInteger modulus = decodeBigInt();
162         RSAPublicKeySpec spec = new RSAPublicKeySpec(modulus, exponent);
163
164         return rsaFactory.generatePublic(spec);
165     }
166
167     private String decodeType() {
168         int len = decodeInt();
169         String type = new String(bytes, pos, len, StandardCharsets.UTF_8);
170         pos += len;
171         return type;
172     }
173
174     private int decodeInt() {
175         return (bytes[pos++] & 0xFF) << 24 | (bytes[pos++] & 0xFF) << 16 | (bytes[pos++] & 0xFF) << 8
176                 | bytes[pos++] & 0xFF;
177     }
178
179     private BigInteger decodeBigInt() {
180         int len = decodeInt();
181         byte[] bigIntBytes = new byte[len];
182         System.arraycopy(bytes, pos, bigIntBytes, 0, len);
183         pos += len;
184         return new BigInteger(bigIntBytes);
185     }
186
187     public String encodePublicKey(final PublicKey publicKey) throws IOException {
188         ByteArrayOutputStream byteOs = new ByteArrayOutputStream();
189         if (publicKey instanceof RSAPublicKey rsaPublicKey
190             && rsaPublicKey.getAlgorithm().equals(KEY_FACTORY_TYPE_RSA)) {
191             DataOutputStream dataOutputStream = new DataOutputStream(byteOs);
192             dataOutputStream.writeInt(KEY_TYPE_RSA.getBytes(StandardCharsets.UTF_8).length);
193             dataOutputStream.write(KEY_TYPE_RSA.getBytes(StandardCharsets.UTF_8));
194             dataOutputStream.writeInt(rsaPublicKey.getPublicExponent().toByteArray().length);
195             dataOutputStream.write(rsaPublicKey.getPublicExponent().toByteArray());
196             dataOutputStream.writeInt(rsaPublicKey.getModulus().toByteArray().length);
197             dataOutputStream.write(rsaPublicKey.getModulus().toByteArray());
198         } else if (publicKey instanceof DSAPublicKey dsaPublicKey
199             && dsaPublicKey.getAlgorithm().equals(KEY_FACTORY_TYPE_DSA)) {
200             DSAParams dsaParams = dsaPublicKey.getParams();
201             DataOutputStream dataOutputStream = new DataOutputStream(byteOs);
202             dataOutputStream.writeInt(KEY_TYPE_DSA.getBytes(StandardCharsets.UTF_8).length);
203             dataOutputStream.write(KEY_TYPE_DSA.getBytes(StandardCharsets.UTF_8));
204             dataOutputStream.writeInt(dsaParams.getP().toByteArray().length);
205             dataOutputStream.write(dsaParams.getP().toByteArray());
206             dataOutputStream.writeInt(dsaParams.getQ().toByteArray().length);
207             dataOutputStream.write(dsaParams.getQ().toByteArray());
208             dataOutputStream.writeInt(dsaParams.getG().toByteArray().length);
209             dataOutputStream.write(dsaParams.getG().toByteArray());
210             dataOutputStream.writeInt(dsaPublicKey.getY().toByteArray().length);
211             dataOutputStream.write(dsaPublicKey.getY().toByteArray());
212         } else if (publicKey instanceof BCECPublicKey ecPublicKey
213             && ecPublicKey.getAlgorithm().equals(KEY_FACTORY_TYPE_ECDSA)) {
214             DataOutputStream dataOutputStream = new DataOutputStream(byteOs);
215             dataOutputStream.writeInt(KEY_TYPE_ECDSA.getBytes(StandardCharsets.UTF_8).length);
216             dataOutputStream.write(KEY_TYPE_ECDSA.getBytes(StandardCharsets.UTF_8));
217             dataOutputStream.writeInt(ECDSA_SUPPORTED_CURVE_NAME.getBytes(StandardCharsets.UTF_8).length);
218             dataOutputStream.write(ECDSA_SUPPORTED_CURVE_NAME.getBytes(StandardCharsets.UTF_8));
219             byte[] affineXCoord = ecPublicKey.getQ().getAffineXCoord().getEncoded();
220             byte[] affineYCoord = ecPublicKey.getQ().getAffineYCoord().getEncoded();
221             dataOutputStream.writeInt(affineXCoord.length + affineYCoord.length + 1);
222             dataOutputStream.writeByte(0x04);
223             dataOutputStream.write(affineXCoord);
224             dataOutputStream.write(affineYCoord);
225         } else {
226             throw new IllegalArgumentException("Unknown public key encoding: " + publicKey.getAlgorithm());
227         }
228
229         return Base64.getEncoder().encodeToString(byteOs.toByteArray());
230
231     }
232
233     public KeyPair decodePrivateKey(final StringReader reader, final String passphrase) throws IOException {
234         return doDecodePrivateKey(reader, passphrase);
235     }
236
237     public KeyPair decodePrivateKey(final String keyPath, final String passphrase) throws IOException {
238         try (Reader reader = new InputStreamReader(new FileInputStream(keyPath), StandardCharsets.UTF_8)) {
239             return doDecodePrivateKey(reader, passphrase);
240         }
241     }
242
243     private static KeyPair doDecodePrivateKey(final Reader reader, final String passphrase) throws IOException {
244         try (PEMParser keyReader = new PEMParser(reader)) {
245             JcaPEMKeyConverter converter = new JcaPEMKeyConverter();
246             PEMDecryptorProvider decryptionProv = new JcePEMDecryptorProviderBuilder().setProvider(BCPROV)
247                     .build(passphrase.toCharArray());
248
249             Object privateKey = keyReader.readObject();
250             KeyPair keyPair;
251             if (privateKey instanceof PEMEncryptedKeyPair pemPrivateKey) {
252                 keyPair = converter.getKeyPair(pemPrivateKey.decryptKeyPair(decryptionProv));
253             } else {
254                 keyPair = converter.getKeyPair((PEMKeyPair) privateKey);
255             }
256             return keyPair;
257         }
258     }
259 }