9470197d9db41c04d3796dcad8d04bb37c7d28ae
[netconf.git] / apps / callhome-provider / src / main / java / org / opendaylight / netconf / callhome / mount / AuthorizedKeysDecoder.java
1 /*
2  * Copyright (c) 2016 Brocade Communication Systems and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8 package org.opendaylight.netconf.callhome.mount;
9
10 import com.google.common.collect.ImmutableMap;
11 import java.io.ByteArrayOutputStream;
12 import java.io.DataOutputStream;
13 import java.io.IOException;
14 import java.math.BigInteger;
15 import java.nio.charset.StandardCharsets;
16 import java.security.GeneralSecurityException;
17 import java.security.KeyFactory;
18 import java.security.NoSuchAlgorithmException;
19 import java.security.PublicKey;
20 import java.security.interfaces.DSAParams;
21 import java.security.interfaces.DSAPublicKey;
22 import java.security.interfaces.RSAPublicKey;
23 import java.security.spec.DSAPublicKeySpec;
24 import java.security.spec.ECPoint;
25 import java.security.spec.ECPublicKeySpec;
26 import java.security.spec.RSAPublicKeySpec;
27 import java.util.Arrays;
28 import java.util.Base64;
29 import org.bouncycastle.jce.ECNamedCurveTable;
30 import org.bouncycastle.jce.ECPointUtil;
31 import org.bouncycastle.jce.interfaces.ECPublicKey;
32 import org.bouncycastle.jce.spec.ECNamedCurveParameterSpec;
33 import org.bouncycastle.jce.spec.ECNamedCurveSpec;
34 import org.slf4j.Logger;
35 import org.slf4j.LoggerFactory;
36
37 /**
38  * FIXME: This should be probably located at AAA library.
39  */
40 public class AuthorizedKeysDecoder {
41     private static final Logger LOG = LoggerFactory.getLogger(AuthorizedKeysDecoder.class);
42
43     private static final ImmutableMap<String, String> ECDSA_CURVES = ImmutableMap.<String, String>builder()
44         .put("nistp256", "secp256r1")
45         .put("nistp384", "secp384r1")
46         .put("nistp512", "secp512r1")
47         .build();
48
49     private static final String ECDSA_SUPPORTED_CURVE_NAME = "nistp256";
50     private static final String ECDSA_SUPPORTED_CURVE_NAME_SPEC = ECDSA_CURVES.get(ECDSA_SUPPORTED_CURVE_NAME);
51
52     private static final String KEY_TYPE_RSA = "ssh-rsa";
53     private static final String KEY_TYPE_DSA = "ssh-dss";
54     private static final String KEY_TYPE_ECDSA = "ecdsa-sha2-" + ECDSA_SUPPORTED_CURVE_NAME;
55
56     private static final KeyFactory RSA_FACTORY = loadOrWarn("RSA");
57     private static final KeyFactory DSA_FACTORY = loadOrWarn("DSA");
58     private static final KeyFactory EC_FACTORY = loadOrWarn("EC");
59
60     private static KeyFactory loadOrWarn(final String algorithm) {
61         try {
62             return KeyFactory.getInstance(algorithm);
63         } catch (NoSuchAlgorithmException e) {
64             LOG.warn("KeyFactory for {} not found", algorithm, e);
65             return null;
66         }
67     }
68
69     private byte[] bytes = new byte[0];
70     private int pos = 0;
71
72     public PublicKey decodePublicKey(final String keyLine) throws GeneralSecurityException {
73
74         // look for the Base64 encoded part of the line to decode
75         // both ssh-rsa and ssh-dss begin with "AAAA" due to the length bytes
76         bytes = Base64.getDecoder().decode(keyLine.getBytes(StandardCharsets.UTF_8));
77         if (bytes.length == 0) {
78             throw new IllegalArgumentException("No Base64 part to decode in " + keyLine);
79         }
80
81         pos = 0;
82
83         final var type = decodeType();
84         return switch (type) {
85             case KEY_TYPE_RSA -> decodeAsRSA();
86             case KEY_TYPE_DSA -> decodeAsDSA();
87             case KEY_TYPE_ECDSA -> decodeAsEcDSA();
88             default -> throw new IllegalArgumentException("Unknown decode key type " + type + " in " + keyLine);
89         };
90     }
91
92     private PublicKey decodeAsEcDSA() throws GeneralSecurityException {
93         if (EC_FACTORY == null) {
94             throw new NoSuchAlgorithmException("ECDSA keys are not supported");
95         }
96
97         ECNamedCurveParameterSpec spec256r1 = ECNamedCurveTable.getParameterSpec(ECDSA_SUPPORTED_CURVE_NAME_SPEC);
98         ECNamedCurveSpec params256r1 = new ECNamedCurveSpec(
99             ECDSA_SUPPORTED_CURVE_NAME_SPEC, spec256r1.getCurve(), spec256r1.getG(), spec256r1.getN());
100         // copy last 65 bytes from ssh key.
101         ECPoint point = ECPointUtil.decodePoint(params256r1.getCurve(), Arrays.copyOfRange(bytes, 39, bytes.length));
102         return EC_FACTORY.generatePublic(new ECPublicKeySpec(point, params256r1));
103     }
104
105     private PublicKey decodeAsDSA() throws GeneralSecurityException {
106         if (DSA_FACTORY == null) {
107             throw new NoSuchAlgorithmException("RSA keys are not supported");
108         }
109
110         BigInteger prime = decodeBigInt();
111         BigInteger subPrime = decodeBigInt();
112         BigInteger base = decodeBigInt();
113         BigInteger publicKey = decodeBigInt();
114         return DSA_FACTORY.generatePublic(new DSAPublicKeySpec(publicKey, prime, subPrime, base));
115     }
116
117     private PublicKey decodeAsRSA() throws GeneralSecurityException {
118         if (RSA_FACTORY == null) {
119             throw new NoSuchAlgorithmException("RSA keys are not supported");
120         }
121
122         BigInteger exponent = decodeBigInt();
123         BigInteger modulus = decodeBigInt();
124         return RSA_FACTORY.generatePublic(new RSAPublicKeySpec(modulus, exponent));
125     }
126
127     private String decodeType() {
128         int len = decodeInt();
129         String type = new String(bytes, pos, len, StandardCharsets.UTF_8);
130         pos += len;
131         return type;
132     }
133
134     private int decodeInt() {
135         return (bytes[pos++] & 0xFF) << 24 | (bytes[pos++] & 0xFF) << 16
136                 | (bytes[pos++] & 0xFF) << 8 | bytes[pos++] & 0xFF;
137     }
138
139     private BigInteger decodeBigInt() {
140         int len = decodeInt();
141         byte[] bigIntBytes = new byte[len];
142         System.arraycopy(bytes, pos, bigIntBytes, 0, len);
143         pos += len;
144         return new BigInteger(bigIntBytes);
145     }
146
147     public static String encodePublicKey(final PublicKey publicKey) throws IOException {
148         ByteArrayOutputStream byteOs = new ByteArrayOutputStream();
149         if (publicKey instanceof RSAPublicKey rsa) {
150             DataOutputStream dout = new DataOutputStream(byteOs);
151             dout.writeInt(KEY_TYPE_RSA.getBytes(StandardCharsets.UTF_8).length);
152             dout.write(KEY_TYPE_RSA.getBytes(StandardCharsets.UTF_8));
153             dout.writeInt(rsa.getPublicExponent().toByteArray().length);
154             dout.write(rsa.getPublicExponent().toByteArray());
155             dout.writeInt(rsa.getModulus().toByteArray().length);
156             dout.write(rsa.getModulus().toByteArray());
157         } else if (publicKey instanceof DSAPublicKey dsa) {
158             DSAParams dsaParams = dsa.getParams();
159             DataOutputStream dout = new DataOutputStream(byteOs);
160             dout.writeInt(KEY_TYPE_DSA.getBytes(StandardCharsets.UTF_8).length);
161             dout.write(KEY_TYPE_DSA.getBytes(StandardCharsets.UTF_8));
162             dout.writeInt(dsaParams.getP().toByteArray().length);
163             dout.write(dsaParams.getP().toByteArray());
164             dout.writeInt(dsaParams.getQ().toByteArray().length);
165             dout.write(dsaParams.getQ().toByteArray());
166             dout.writeInt(dsaParams.getG().toByteArray().length);
167             dout.write(dsaParams.getG().toByteArray());
168             dout.writeInt(dsa.getY().toByteArray().length);
169             dout.write(dsa.getY().toByteArray());
170         } else if (publicKey instanceof ECPublicKey ec) {
171             DataOutputStream dout = new DataOutputStream(byteOs);
172             dout.writeInt(KEY_TYPE_ECDSA.getBytes(StandardCharsets.UTF_8).length);
173             dout.write(KEY_TYPE_ECDSA.getBytes(StandardCharsets.UTF_8));
174             dout.writeInt(ECDSA_SUPPORTED_CURVE_NAME.getBytes(StandardCharsets.UTF_8).length);
175             dout.write(ECDSA_SUPPORTED_CURVE_NAME.getBytes(StandardCharsets.UTF_8));
176
177             byte[] coordX = ec.getQ().getAffineXCoord().getEncoded();
178             byte[] coordY = ec.getQ().getAffineYCoord().getEncoded();
179             dout.writeInt(coordX.length + coordY.length + 1);
180             dout.writeByte(0x04);
181             dout.write(coordX);
182             dout.write(coordY);
183         } else {
184             throw new IllegalArgumentException("Unknown public key encoding: " + publicKey);
185         }
186         return Base64.getEncoder().encodeToString(byteOs.toByteArray());
187     }
188 }