Package org.bouncycastle.crypto.tls

Source Code of org.bouncycastle.crypto.tls.TlsECCUtils

package org.bouncycastle.crypto.tls;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.math.BigInteger;
import java.security.SecureRandom;
import java.util.Hashtable;

import org.bouncycastle.asn1.sec.SECNamedCurves;
import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.agreement.ECDHBasicAgreement;
import org.bouncycastle.crypto.generators.ECKeyPairGenerator;
import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.crypto.params.ECKeyGenerationParameters;
import org.bouncycastle.crypto.params.ECPrivateKeyParameters;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.math.ec.ECCurve;
import org.bouncycastle.math.ec.ECPoint;
import org.bouncycastle.util.BigIntegers;
import org.bouncycastle.util.Integers;

public class TlsECCUtils
{

    public static final Integer EXT_elliptic_curves = Integers.valueOf(ExtensionType.elliptic_curves);
    public static final Integer EXT_ec_point_formats = Integers.valueOf(ExtensionType.ec_point_formats);

    private static final String[] curveNames = new String[]{"sect163k1", "sect163r1", "sect163r2", "sect193r1",
        "sect193r2", "sect233k1", "sect233r1", "sect239k1", "sect283k1", "sect283r1", "sect409k1", "sect409r1",
        "sect571k1", "sect571r1", "secp160k1", "secp160r1", "secp160r2", "secp192k1", "secp192r1", "secp224k1",
        "secp224r1", "secp256k1", "secp256r1", "secp384r1", "secp521r1",};

    public static void addSupportedEllipticCurvesExtension(Hashtable extensions, int[] namedCurves)
        throws IOException
    {

        extensions.put(EXT_elliptic_curves, createSupportedEllipticCurvesExtension(namedCurves));
    }

    public static void addSupportedPointFormatsExtension(Hashtable extensions, short[] ecPointFormats)
        throws IOException
    {

        extensions.put(EXT_ec_point_formats, createSupportedPointFormatsExtension(ecPointFormats));
    }

    public static int[] getSupportedEllipticCurvesExtension(Hashtable extensions)
        throws IOException
    {

        if (extensions == null)
        {
            return null;
        }
        byte[] extensionValue = (byte[])extensions.get(EXT_elliptic_curves);
        if (extensionValue == null)
        {
            return null;
        }
        return readSupportedEllipticCurvesExtension(extensionValue);
    }

    public static short[] getSupportedPointFormatsExtension(Hashtable extensions)
        throws IOException
    {

        if (extensions == null)
        {
            return null;
        }
        byte[] extensionValue = (byte[])extensions.get(EXT_ec_point_formats);
        if (extensionValue == null)
        {
            return null;
        }
        return readSupportedPointFormatsExtension(extensionValue);
    }

    public static byte[] createSupportedEllipticCurvesExtension(int[] namedCurves)
        throws IOException
    {

        if (namedCurves == null || namedCurves.length < 1)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        ByteArrayOutputStream buf = new ByteArrayOutputStream();
        TlsUtils.writeUint16(2 * namedCurves.length, buf);
        TlsUtils.writeUint16Array(namedCurves, buf);
        return buf.toByteArray();
    }

    public static byte[] createSupportedPointFormatsExtension(short[] ecPointFormats)
        throws IOException
    {

        if (ecPointFormats == null)
        {
            ecPointFormats = new short[]{ECPointFormat.uncompressed};
        }
        else if (!TlsProtocol.arrayContains(ecPointFormats, ECPointFormat.uncompressed))
        {
            /*
             * RFC 4492 5.1. If the Supported Point Formats Extension is indeed sent, it MUST
             * contain the value 0 (uncompressed) as one of the items in the list of point formats.
             */

            // NOTE: We add it at the end (lowest preference)
            short[] tmp = new short[ecPointFormats.length + 1];
            System.arraycopy(ecPointFormats, 0, tmp, 0, ecPointFormats.length);
            tmp[ecPointFormats.length] = ECPointFormat.uncompressed;

            ecPointFormats = tmp;
        }

        ByteArrayOutputStream buf = new ByteArrayOutputStream();
        TlsUtils.writeUint8((short)ecPointFormats.length, buf);
        TlsUtils.writeUint8Array(ecPointFormats, buf);
        return buf.toByteArray();
    }

    public static int[] readSupportedEllipticCurvesExtension(byte[] extensionValue)
        throws IOException
    {

        if (extensionValue == null)
        {
            throw new IllegalArgumentException("'extensionValue' cannot be null");
        }

        ByteArrayInputStream buf = new ByteArrayInputStream(extensionValue);

        int length = TlsUtils.readUint16(buf);
        if (length < 2 || (length & 1) != 0)
        {
            throw new TlsFatalAlert(AlertDescription.decode_error);
        }

        int[] namedCurves = TlsUtils.readUint16Array(length / 2, buf);

        TlsProtocol.assertEmpty(buf);

        return namedCurves;
    }

    public static short[] readSupportedPointFormatsExtension(byte[] extensionValue)
        throws IOException
    {

        if (extensionValue == null)
        {
            throw new IllegalArgumentException("'extensionValue' cannot be null");
        }

        ByteArrayInputStream buf = new ByteArrayInputStream(extensionValue);

        short length = TlsUtils.readUint8(buf);
        if (length < 1)
        {
            throw new TlsFatalAlert(AlertDescription.decode_error);
        }

        short[] ecPointFormats = TlsUtils.readUint8Array(length, buf);

        TlsProtocol.assertEmpty(buf);

        if (!TlsProtocol.arrayContains(ecPointFormats, ECPointFormat.uncompressed))
        {
            /*
             * RFC 4492 5.1. If the Supported Point Formats Extension is indeed sent, it MUST
             * contain the value 0 (uncompressed) as one of the items in the list of point formats.
             */
            throw new TlsFatalAlert(AlertDescription.illegal_parameter);
        }

        return ecPointFormats;
    }

    public static String getNameOfNamedCurve(int namedCurve)
    {
        return isSupportedNamedCurve(namedCurve) ? curveNames[namedCurve - 1] : null;
    }

    public static ECDomainParameters getParametersForNamedCurve(int namedCurve)
    {
        String curveName = getNameOfNamedCurve(namedCurve);
        if (curveName == null)
        {
            return null;
        }

        // Lazily created the first time a particular curve is accessed
        X9ECParameters ecP = SECNamedCurves.getByName(curveName);

        if (ecP == null)
        {
            return null;
        }

        // It's a bit inefficient to do this conversion every time
        return new ECDomainParameters(ecP.getCurve(), ecP.getG(), ecP.getN(), ecP.getH(), ecP.getSeed());
    }

    public static boolean hasAnySupportedNamedCurves()
    {
        return curveNames.length > 0;
    }

    public static boolean containsECCCipherSuites(int[] cipherSuites)
    {
        for (int i = 0; i < cipherSuites.length; ++i)
        {
            if (isECCCipherSuite(cipherSuites[i]))
            {
                return true;
            }
        }
        return false;
    }

    public static boolean isECCCipherSuite(int cipherSuite)
    {
        switch (cipherSuite)
        {
        case CipherSuite.TLS_ECDH_ECDSA_WITH_NULL_SHA:
        case CipherSuite.TLS_ECDH_ECDSA_WITH_RC4_128_SHA:
        case CipherSuite.TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA:
        case CipherSuite.TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA:
        case CipherSuite.TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA:
        case CipherSuite.TLS_ECDHE_ECDSA_WITH_NULL_SHA:
        case CipherSuite.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA:
        case CipherSuite.TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA:
        case CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:
        case CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
        case CipherSuite.TLS_ECDH_RSA_WITH_NULL_SHA:
        case CipherSuite.TLS_ECDH_RSA_WITH_RC4_128_SHA:
        case CipherSuite.TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA:
        case CipherSuite.TLS_ECDH_RSA_WITH_AES_128_CBC_SHA:
        case CipherSuite.TLS_ECDH_RSA_WITH_AES_256_CBC_SHA:
        case CipherSuite.TLS_ECDHE_RSA_WITH_NULL_SHA:
        case CipherSuite.TLS_ECDHE_RSA_WITH_RC4_128_SHA:
        case CipherSuite.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA:
        case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
        case CipherSuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
        case CipherSuite.TLS_ECDH_anon_WITH_NULL_SHA:
        case CipherSuite.TLS_ECDH_anon_WITH_RC4_128_SHA:
        case CipherSuite.TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA:
        case CipherSuite.TLS_ECDH_anon_WITH_AES_128_CBC_SHA:
        case CipherSuite.TLS_ECDH_anon_WITH_AES_256_CBC_SHA:
        case CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256:
        case CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384:
        case CipherSuite.TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256:
        case CipherSuite.TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384:
        case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
        case CipherSuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384:
        case CipherSuite.TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256:
        case CipherSuite.TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384:
        case CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
        case CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
        case CipherSuite.TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256:
        case CipherSuite.TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384:
        case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
        case CipherSuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384:
        case CipherSuite.TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256:
        case CipherSuite.TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384:
            return true;
        default:
            return false;
        }
    }

    public static boolean areOnSameCurve(ECDomainParameters a, ECDomainParameters b)
    {
        // TODO Move to ECDomainParameters.equals() or other utility method?
        return a.getCurve().equals(b.getCurve()) && a.getG().equals(b.getG()) && a.getN().equals(b.getN())
            && a.getH().equals(b.getH());
    }

    public static boolean isSupportedNamedCurve(int namedCurve)
    {
        return (namedCurve > 0 && namedCurve <= curveNames.length);
    }

    public static boolean isCompressionPreferred(short[] ecPointFormats, short compressionFormat)
    {
        if (ecPointFormats == null)
        {
            return false;
        }
        for (int i = 0; i < ecPointFormats.length; ++i)
        {
            short ecPointFormat = ecPointFormats[i];
            if (ecPointFormat == ECPointFormat.uncompressed)
            {
                return false;
            }
            if (ecPointFormat == compressionFormat)
            {
                return true;
            }
        }
        return false;
    }

    public static byte[] serializeECFieldElement(int fieldSize, BigInteger x)
        throws IOException
    {
        int requiredLength = (fieldSize + 7) / 8;
        return BigIntegers.asUnsignedByteArray(requiredLength, x);
    }

    public static byte[] serializeECPoint(short[] ecPointFormats, ECPoint point)
        throws IOException
    {

        ECCurve curve = point.getCurve();

        /*
         * RFC 4492 5.7. ...an elliptic curve point in uncompressed or compressed format. Here, the
         * format MUST conform to what the server has requested through a Supported Point Formats
         * Extension if this extension was used, and MUST be uncompressed if this extension was not
         * used.
         */
        boolean compressed = false;
        if (curve instanceof ECCurve.F2m)
        {
            compressed = isCompressionPreferred(ecPointFormats, ECPointFormat.ansiX962_compressed_char2);
        }
        else if (curve instanceof ECCurve.Fp)
        {
            compressed = isCompressionPreferred(ecPointFormats, ECPointFormat.ansiX962_compressed_prime);
        }
        return point.getEncoded(compressed);
    }

    public static byte[] serializeECPublicKey(short[] ecPointFormats, ECPublicKeyParameters keyParameters)
        throws IOException
    {

        return serializeECPoint(ecPointFormats, keyParameters.getQ());
    }

    public static BigInteger deserializeECFieldElement(int fieldSize, byte[] encoding)
        throws IOException
    {
        int requiredLength = (fieldSize + 7) / 8;
        if (encoding.length != requiredLength)
        {
            throw new TlsFatalAlert(AlertDescription.decode_error);
        }
        return new BigInteger(1, encoding);
    }

    public static ECPoint deserializeECPoint(short[] ecPointFormats, ECCurve curve, byte[] encoding)
        throws IOException
    {
        /*
         * NOTE: Here we implicitly decode compressed or uncompressed encodings. DefaultTlsClient by
         * default is set up to advertise that we can parse any encoding so this works fine, but
         * extra checks might be needed here if that were changed.
         */
        return curve.decodePoint(encoding);
    }

    public static ECPublicKeyParameters deserializeECPublicKey(short[] ecPointFormats, ECDomainParameters curve_params,
                                                               byte[] encoding)
        throws IOException
    {

        try
        {
            ECPoint Y = deserializeECPoint(ecPointFormats, curve_params.getCurve(), encoding);
            return new ECPublicKeyParameters(Y, curve_params);
        }
        catch (RuntimeException e)
        {
            throw new TlsFatalAlert(AlertDescription.illegal_parameter);
        }
    }

    public static byte[] calculateECDHBasicAgreement(ECPublicKeyParameters publicKey, ECPrivateKeyParameters privateKey)
    {

        ECDHBasicAgreement basicAgreement = new ECDHBasicAgreement();
        basicAgreement.init(privateKey);
        BigInteger agreementValue = basicAgreement.calculateAgreement(publicKey);

        /*
         * RFC 4492 5.10. Note that this octet string (Z in IEEE 1363 terminology) as output by
         * FE2OSP, the Field Element to Octet String Conversion Primitive, has constant length for
         * any given field; leading zeros found in this octet string MUST NOT be truncated.
         */
        return BigIntegers.asUnsignedByteArray(basicAgreement.getFieldSize(), agreementValue);
    }

    public static AsymmetricCipherKeyPair generateECKeyPair(SecureRandom random, ECDomainParameters ecParams)
    {

        ECKeyPairGenerator keyPairGenerator = new ECKeyPairGenerator();
        ECKeyGenerationParameters keyGenerationParameters = new ECKeyGenerationParameters(ecParams, random);
        keyPairGenerator.init(keyGenerationParameters);
        return keyPairGenerator.generateKeyPair();
    }

    public static ECPublicKeyParameters validateECPublicKey(ECPublicKeyParameters key)
        throws IOException
    {
        // TODO Check RFC 4492 for validation
        return key;
    }

    public static int readECExponent(int fieldSize, InputStream input)
        throws IOException
    {
        BigInteger K = readECParameter(input);
        if (K.bitLength() < 32)
        {
            int k = K.intValue();
            if (k > 0 && k < fieldSize)
            {
                return k;
            }
        }
        throw new TlsFatalAlert(AlertDescription.illegal_parameter);
    }

    public static BigInteger readECFieldElement(int fieldSize, InputStream input)
        throws IOException
    {
        return deserializeECFieldElement(fieldSize, TlsUtils.readOpaque8(input));
    }

    public static BigInteger readECParameter(InputStream input)
        throws IOException
    {
        // TODO Are leading zeroes okay here?
        return new BigInteger(1, TlsUtils.readOpaque8(input));
    }

    public static ECDomainParameters readECParameters(int[] namedCurves, short[] ecPointFormats, InputStream input)
        throws IOException
    {

        try
        {
            short curveType = TlsUtils.readUint8(input);

            switch (curveType)
            {
            case ECCurveType.explicit_prime:
            {
                BigInteger prime_p = readECParameter(input);
                BigInteger a = readECFieldElement(prime_p.bitLength(), input);
                BigInteger b = readECFieldElement(prime_p.bitLength(), input);
                ECCurve curve = new ECCurve.Fp(prime_p, a, b);
                ECPoint base = deserializeECPoint(ecPointFormats, curve, TlsUtils.readOpaque8(input));
                BigInteger order = readECParameter(input);
                BigInteger cofactor = readECParameter(input);
                return new ECDomainParameters(curve, base, order, cofactor);
            }
            case ECCurveType.explicit_char2:
            {
                int m = TlsUtils.readUint16(input);
                short basis = TlsUtils.readUint8(input);
                ECCurve curve;
                switch (basis)
                {
                case ECBasisType.ec_basis_trinomial:
                {
                    int k = readECExponent(m, input);
                    BigInteger a = readECFieldElement(m, input);
                    BigInteger b = readECFieldElement(m, input);
                    curve = new ECCurve.F2m(m, k, a, b);
                    break;
                }
                case ECBasisType.ec_basis_pentanomial:
                {
                    int k1 = readECExponent(m, input);
                    int k2 = readECExponent(m, input);
                    int k3 = readECExponent(m, input);
                    BigInteger a = readECFieldElement(m, input);
                    BigInteger b = readECFieldElement(m, input);
                    curve = new ECCurve.F2m(m, k1, k2, k3, a, b);
                    break;
                }
                default:
                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
                }
                ECPoint base = deserializeECPoint(ecPointFormats, curve, TlsUtils.readOpaque8(input));
                BigInteger order = readECParameter(input);
                BigInteger cofactor = readECParameter(input);
                return new ECDomainParameters(curve, base, order, cofactor);
            }
            case ECCurveType.named_curve:
            {
                int namedCurve = TlsUtils.readUint16(input);
                if (!NamedCurve.refersToASpecificNamedCurve(namedCurve))
                {
                    /*
                     * RFC 4492 5.4. All those values of NamedCurve are allowed that refer to a
                     * specific curve. Values of NamedCurve that indicate support for a class of
                     * explicitly defined curves are not allowed here [...].
                     */
                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
                }

                if (!TlsProtocol.arrayContains(namedCurves, namedCurve))
                {
                    /*
                     * RFC 4492 4. [...] servers MUST NOT negotiate the use of an ECC cipher suite
                     * unless they can complete the handshake while respecting the choice of curves
                     * and compression techniques specified by the client.
                     */
                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
                }

                return TlsECCUtils.getParametersForNamedCurve(namedCurve);
            }
            default:
                throw new TlsFatalAlert(AlertDescription.illegal_parameter);
            }
        }
        catch (RuntimeException e)
        {
            throw new TlsFatalAlert(AlertDescription.illegal_parameter);
        }
    }

    public static void writeECExponent(int k, OutputStream output)
        throws IOException
    {
        BigInteger K = BigInteger.valueOf(k);
        writeECParameter(K, output);
    }

    public static void writeECFieldElement(int fieldSize, BigInteger x, OutputStream output)
        throws IOException
    {
        TlsUtils.writeOpaque8(serializeECFieldElement(fieldSize, x), output);
    }

    public static void writeECParameter(BigInteger x, OutputStream output)
        throws IOException
    {
        TlsUtils.writeOpaque8(BigIntegers.asUnsignedByteArray(x), output);
    }

    public static void writeExplicitECParameters(short[] ecPointFormats, ECDomainParameters ecParameters,
                                                 OutputStream output)
        throws IOException
    {

        ECCurve curve = ecParameters.getCurve();
        if (curve instanceof ECCurve.Fp)
        {

            TlsUtils.writeUint8(ECCurveType.explicit_prime, output);

            ECCurve.Fp fp = (ECCurve.Fp)curve;
            writeECParameter(fp.getQ(), output);

        }
        else if (curve instanceof ECCurve.F2m)
        {

            TlsUtils.writeUint8(ECCurveType.explicit_char2, output);

            ECCurve.F2m f2m = (ECCurve.F2m)curve;
            TlsUtils.writeUint16(f2m.getM(), output);

            if (f2m.isTrinomial())
            {
                TlsUtils.writeUint8(ECBasisType.ec_basis_trinomial, output);
                writeECExponent(f2m.getK1(), output);
            }
            else
            {
                TlsUtils.writeUint8(ECBasisType.ec_basis_pentanomial, output);
                writeECExponent(f2m.getK1(), output);
                writeECExponent(f2m.getK2(), output);
                writeECExponent(f2m.getK3(), output);
            }

        }
        else
        {
            throw new IllegalArgumentException("'ecParameters' not a known curve type");
        }

        writeECFieldElement(curve.getFieldSize(), curve.getA().toBigInteger(), output);
        writeECFieldElement(curve.getFieldSize(), curve.getB().toBigInteger(), output);
        TlsUtils.writeOpaque8(serializeECPoint(ecPointFormats, ecParameters.getG()), output);
        writeECParameter(ecParameters.getN(), output);
        writeECParameter(ecParameters.getH(), output);
    }

    public static void writeNamedECParameters(int namedCurve, OutputStream output)
        throws IOException
    {

        if (!NamedCurve.refersToASpecificNamedCurve(namedCurve))
        {
            /*
             * RFC 4492 5.4. All those values of NamedCurve are allowed that refer to a specific
             * curve. Values of NamedCurve that indicate support for a class of explicitly defined
             * curves are not allowed here [...].
             */
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        TlsUtils.writeUint8(ECCurveType.named_curve, output);
        TlsUtils.writeUint16(namedCurve, output);
    }
}
TOP

Related Classes of org.bouncycastle.crypto.tls.TlsECCUtils

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.