/*
 * Decompiled with CFR 0.152.
 */
package org.openecard.bouncycastle.crypto.agreement;

import java.math.BigInteger;
import org.openecard.bouncycastle.crypto.CipherParameters;
import org.openecard.bouncycastle.crypto.Digest;
import org.openecard.bouncycastle.crypto.digests.SM3Digest;
import org.openecard.bouncycastle.crypto.params.ECDomainParameters;
import org.openecard.bouncycastle.crypto.params.ECPrivateKeyParameters;
import org.openecard.bouncycastle.crypto.params.ParametersWithID;
import org.openecard.bouncycastle.crypto.params.SM2KeyExchangePrivateParameters;
import org.openecard.bouncycastle.crypto.params.SM2KeyExchangePublicParameters;
import org.openecard.bouncycastle.math.ec.ECFieldElement;
import org.openecard.bouncycastle.math.ec.ECPoint;
import org.openecard.bouncycastle.util.Arrays;
import org.openecard.bouncycastle.util.BigIntegers;

public class SM2KeyExchange {
    private final Digest digest;
    private byte[] userID;
    private ECPrivateKeyParameters staticKey;
    private ECPoint staticPubPoint;
    private ECPoint ephemeralPubPoint;
    private ECDomainParameters ecParams;
    private int curveLength;
    private int w;
    private ECPrivateKeyParameters ephemeralKey;
    private boolean initiator;

    public SM2KeyExchange() {
        this(new SM3Digest());
    }

    public SM2KeyExchange(Digest digest) {
        this.digest = digest;
    }

    public void init(CipherParameters privParam) {
        SM2KeyExchangePrivateParameters baseParam;
        if (privParam instanceof ParametersWithID) {
            baseParam = (SM2KeyExchangePrivateParameters)((ParametersWithID)privParam).getParameters();
            this.userID = ((ParametersWithID)privParam).getID();
        } else {
            baseParam = (SM2KeyExchangePrivateParameters)privParam;
            this.userID = new byte[0];
        }
        this.initiator = baseParam.isInitiator();
        this.staticKey = baseParam.getStaticPrivateKey();
        this.ephemeralKey = baseParam.getEphemeralPrivateKey();
        this.ecParams = this.staticKey.getParameters();
        this.staticPubPoint = baseParam.getStaticPublicPoint();
        this.ephemeralPubPoint = baseParam.getEphemeralPublicPoint();
        this.curveLength = (this.ecParams.getCurve().getFieldSize() + 7) / 8;
        this.w = this.ecParams.getCurve().getFieldSize() / 2 - 1;
    }

    public int getFieldSize() {
        return (this.staticKey.getParameters().getCurve().getFieldSize() + 7) / 8;
    }

    public byte[] calculateKey(int kLen, CipherParameters pubParam) {
        byte[] otherUserID;
        SM2KeyExchangePublicParameters otherPub;
        if (pubParam instanceof ParametersWithID) {
            otherPub = (SM2KeyExchangePublicParameters)((ParametersWithID)pubParam).getParameters();
            otherUserID = ((ParametersWithID)pubParam).getID();
        } else {
            otherPub = (SM2KeyExchangePublicParameters)pubParam;
            otherUserID = new byte[]{};
        }
        byte[] za = this.getZ(this.digest, this.userID, this.staticPubPoint);
        byte[] zb = this.getZ(this.digest, otherUserID, otherPub.getStaticPublicKey().getQ());
        ECPoint U = this.calculateU(otherPub);
        byte[] rv = this.initiator ? this.kdf(U, za, zb, kLen) : this.kdf(U, zb, za, kLen);
        return rv;
    }

    public byte[][] calculateKeyWithConfirmation(int kLen, byte[] confirmationTag, CipherParameters pubParam) {
        byte[] otherUserID;
        SM2KeyExchangePublicParameters otherPub;
        if (pubParam instanceof ParametersWithID) {
            otherPub = (SM2KeyExchangePublicParameters)((ParametersWithID)pubParam).getParameters();
            otherUserID = ((ParametersWithID)pubParam).getID();
        } else {
            otherPub = (SM2KeyExchangePublicParameters)pubParam;
            otherUserID = new byte[]{};
        }
        if (this.initiator && confirmationTag == null) {
            throw new IllegalArgumentException("if initiating, confirmationTag must be set");
        }
        byte[] za = this.getZ(this.digest, this.userID, this.staticPubPoint);
        byte[] zb = this.getZ(this.digest, otherUserID, otherPub.getStaticPublicKey().getQ());
        ECPoint U = this.calculateU(otherPub);
        if (this.initiator) {
            byte[] rv = this.kdf(U, za, zb, kLen);
            byte[] inner = this.calculateInnerHash(this.digest, U, za, zb, this.ephemeralPubPoint, otherPub.getEphemeralPublicKey().getQ());
            byte[] s1 = this.S1(this.digest, U, inner);
            if (!Arrays.constantTimeAreEqual(s1, confirmationTag)) {
                throw new IllegalStateException("confirmation tag mismatch");
            }
            return new byte[][]{rv, this.S2(this.digest, U, inner)};
        }
        byte[] rv = this.kdf(U, zb, za, kLen);
        byte[] inner = this.calculateInnerHash(this.digest, U, zb, za, otherPub.getEphemeralPublicKey().getQ(), this.ephemeralPubPoint);
        return new byte[][]{rv, this.S1(this.digest, U, inner), this.S2(this.digest, U, inner)};
    }

    private ECPoint calculateU(SM2KeyExchangePublicParameters otherPub) {
        BigInteger x1 = this.reduce(this.ephemeralPubPoint.getAffineXCoord().toBigInteger());
        BigInteger tA = this.staticKey.getD().add(x1.multiply(this.ephemeralKey.getD())).mod(this.ecParams.getN());
        BigInteger x2 = this.reduce(otherPub.getEphemeralPublicKey().getQ().getAffineXCoord().toBigInteger());
        ECPoint B0 = otherPub.getEphemeralPublicKey().getQ().multiply(x2).normalize();
        ECPoint B1 = otherPub.getStaticPublicKey().getQ().add(B0).normalize();
        return B1.multiply(this.ecParams.getH().multiply(tA)).normalize();
    }

    private byte[] kdf(ECPoint u, byte[] za, byte[] zb, int klen) {
        int ct = 1;
        int v = this.digest.getDigestSize() * 8;
        byte[] buf = new byte[this.digest.getDigestSize()];
        byte[] rv = new byte[(klen + 7) / 8];
        int off = 0;
        for (int i = 1; i <= (klen + v - 1) / v; ++i) {
            this.addFieldElement(this.digest, u.getAffineXCoord());
            this.addFieldElement(this.digest, u.getAffineYCoord());
            this.digest.update(za, 0, za.length);
            this.digest.update(zb, 0, zb.length);
            this.digest.update((byte)(ct >> 24));
            this.digest.update((byte)(ct >> 16));
            this.digest.update((byte)(ct >> 8));
            this.digest.update((byte)ct);
            this.digest.doFinal(buf, 0);
            if (off + buf.length < rv.length) {
                System.arraycopy(buf, 0, rv, off, buf.length);
            } else {
                System.arraycopy(buf, 0, rv, off, rv.length - off);
            }
            off += buf.length;
            ++ct;
        }
        return rv;
    }

    private BigInteger reduce(BigInteger x) {
        return x.and(BigInteger.valueOf(1L).shiftLeft(this.w).subtract(BigInteger.valueOf(1L))).setBit(this.w);
    }

    private byte[] S1(Digest digest, ECPoint u, byte[] inner) {
        byte[] rv = new byte[digest.getDigestSize()];
        digest.update((byte)2);
        this.addFieldElement(digest, u.getAffineYCoord());
        digest.update(inner, 0, inner.length);
        digest.doFinal(rv, 0);
        return rv;
    }

    private byte[] calculateInnerHash(Digest digest, ECPoint u, byte[] za, byte[] zb, ECPoint p1, ECPoint p2) {
        this.addFieldElement(digest, u.getAffineXCoord());
        digest.update(za, 0, za.length);
        digest.update(zb, 0, zb.length);
        this.addFieldElement(digest, p1.getAffineXCoord());
        this.addFieldElement(digest, p1.getAffineYCoord());
        this.addFieldElement(digest, p2.getAffineXCoord());
        this.addFieldElement(digest, p2.getAffineYCoord());
        byte[] rv = new byte[digest.getDigestSize()];
        digest.doFinal(rv, 0);
        return rv;
    }

    private byte[] S2(Digest digest, ECPoint u, byte[] inner) {
        byte[] rv = new byte[digest.getDigestSize()];
        digest.update((byte)3);
        this.addFieldElement(digest, u.getAffineYCoord());
        digest.update(inner, 0, inner.length);
        digest.doFinal(rv, 0);
        return rv;
    }

    private byte[] getZ(Digest digest, byte[] userID, ECPoint pubPoint) {
        this.addUserID(digest, userID);
        this.addFieldElement(digest, this.ecParams.getCurve().getA());
        this.addFieldElement(digest, this.ecParams.getCurve().getB());
        this.addFieldElement(digest, this.ecParams.getG().getAffineXCoord());
        this.addFieldElement(digest, this.ecParams.getG().getAffineYCoord());
        this.addFieldElement(digest, pubPoint.getAffineXCoord());
        this.addFieldElement(digest, pubPoint.getAffineYCoord());
        byte[] rv = new byte[digest.getDigestSize()];
        digest.doFinal(rv, 0);
        return rv;
    }

    private void addUserID(Digest digest, byte[] userID) {
        int len = userID.length * 8;
        digest.update((byte)(len >> 8 & 0xFF));
        digest.update((byte)(len & 0xFF));
        digest.update(userID, 0, userID.length);
    }

    private void addFieldElement(Digest digest, ECFieldElement v) {
        byte[] p = BigIntegers.asUnsignedByteArray(this.curveLength, v.toBigInteger());
        digest.update(p, 0, p.length);
    }
}

