From 8a59df248568e24cbf06140af91c14b2907194fa Mon Sep 17 00:00:00 2001 From: larabr Date: Tue, 12 Jul 2022 16:43:06 +0200 Subject: [PATCH] Add support for automatic forwarding (#54) --- openpgp/ecdh/ecdh.go | 105 +++- openpgp/ecdh/ecdh_test.go | 36 +- openpgp/errors/errors.go | 2 + openpgp/forwarding.go | 163 +++++ openpgp/forwarding_test.go | 224 +++++++ openpgp/internal/ecc/curve25519.go | 3 +- openpgp/internal/ecc/curve25519/curve25519.go | 122 ++++ .../ecc/curve25519/curve25519_test.go | 89 +++ openpgp/internal/ecc/curve25519/field/fe.go | 416 +++++++++++++ .../ecc/curve25519/field/fe_alias_test.go | 126 ++++ .../internal/ecc/curve25519/field/fe_amd64.go | 16 + .../internal/ecc/curve25519/field/fe_amd64.s | 379 ++++++++++++ .../ecc/curve25519/field/fe_amd64_noasm.go | 12 + .../internal/ecc/curve25519/field/fe_arm64.go | 16 + .../internal/ecc/curve25519/field/fe_arm64.s | 43 ++ .../ecc/curve25519/field/fe_arm64_noasm.go | 12 + .../ecc/curve25519/field/fe_bench_test.go | 36 ++ .../ecc/curve25519/field/fe_generic.go | 264 +++++++++ .../internal/ecc/curve25519/field/fe_test.go | 558 ++++++++++++++++++ openpgp/keys.go | 10 +- openpgp/packet/encrypted_key.go | 30 + openpgp/packet/forwarding.go | 36 ++ openpgp/packet/public_key.go | 53 +- openpgp/packet/signature.go | 13 +- openpgp/v2/forwarding.go | 159 +++++ openpgp/v2/forwarding_test.go | 223 +++++++ openpgp/v2/keys.go | 17 +- 27 files changed, 3132 insertions(+), 31 deletions(-) create mode 100644 openpgp/forwarding.go create mode 100644 openpgp/forwarding_test.go create mode 100644 openpgp/internal/ecc/curve25519/curve25519.go create mode 100644 openpgp/internal/ecc/curve25519/curve25519_test.go create mode 100644 openpgp/internal/ecc/curve25519/field/fe.go create mode 100644 openpgp/internal/ecc/curve25519/field/fe_alias_test.go create mode 100644 openpgp/internal/ecc/curve25519/field/fe_amd64.go create mode 100644 openpgp/internal/ecc/curve25519/field/fe_amd64.s create mode 100644 openpgp/internal/ecc/curve25519/field/fe_amd64_noasm.go create mode 100644 openpgp/internal/ecc/curve25519/field/fe_arm64.go create mode 100644 openpgp/internal/ecc/curve25519/field/fe_arm64.s create mode 100644 openpgp/internal/ecc/curve25519/field/fe_arm64_noasm.go create mode 100644 openpgp/internal/ecc/curve25519/field/fe_bench_test.go create mode 100644 openpgp/internal/ecc/curve25519/field/fe_generic.go create mode 100644 openpgp/internal/ecc/curve25519/field/fe_test.go create mode 100644 openpgp/packet/forwarding.go create mode 100644 openpgp/v2/forwarding.go create mode 100644 openpgp/v2/forwarding_test.go diff --git a/openpgp/ecdh/ecdh.go b/openpgp/ecdh/ecdh.go index db8fb163b..85a06b17c 100644 --- a/openpgp/ecdh/ecdh.go +++ b/openpgp/ecdh/ecdh.go @@ -12,13 +12,50 @@ import ( "io" "github.com/ProtonMail/go-crypto/openpgp/aes/keywrap" + pgperrors "github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/internal/ecc" + "github.com/ProtonMail/go-crypto/openpgp/internal/ecc/curve25519" +) + +const ( + KDFVersion1 = 1 + KDFVersionForwarding = 255 ) type KDF struct { - Hash algorithm.Hash - Cipher algorithm.Cipher + Version int // Defaults to v1; 255 for forwarding + Hash algorithm.Hash + Cipher algorithm.Cipher + ReplacementFingerprint []byte // (forwarding only) fingerprint to use instead of recipient's (20 octets) +} + +func (kdf *KDF) Serialize(w io.Writer) (err error) { + switch kdf.Version { + case 0, KDFVersion1: // Default to v1 if unspecified + return kdf.serializeForHash(w) + case KDFVersionForwarding: + // Length || Version || Hash || Cipher || Replacement Fingerprint + length := byte(3 + len(kdf.ReplacementFingerprint)) + if _, err := w.Write([]byte{length, KDFVersionForwarding, kdf.Hash.Id(), kdf.Cipher.Id()}); err != nil { + return err + } + if _, err := w.Write(kdf.ReplacementFingerprint); err != nil { + return err + } + + return nil + default: + return errors.New("ecdh: invalid KDF version") + } +} + +func (kdf *KDF) serializeForHash(w io.Writer) (err error) { + // Length || Version || Hash || Cipher + if _, err := w.Write([]byte{3, KDFVersion1, kdf.Hash.Id(), kdf.Cipher.Id()}); err != nil { + return err + } + return nil } type PublicKey struct { @@ -32,13 +69,10 @@ type PrivateKey struct { D []byte } -func NewPublicKey(curve ecc.ECDHCurve, kdfHash algorithm.Hash, kdfCipher algorithm.Cipher) *PublicKey { +func NewPublicKey(curve ecc.ECDHCurve, kdf KDF) *PublicKey { return &PublicKey{ curve: curve, - KDF: KDF{ - Hash: kdfHash, - Cipher: kdfCipher, - }, + KDF: kdf, } } @@ -149,21 +183,31 @@ func Decrypt(priv *PrivateKey, vsG, c, curveOID, fingerprint []byte) (msg []byte } func buildKey(pub *PublicKey, zb []byte, curveOID, fingerprint []byte, stripLeading, stripTrailing bool) ([]byte, error) { - // Param = curve_OID_len || curve_OID || public_key_alg_ID || 03 - // || 01 || KDF_hash_ID || KEK_alg_ID for AESKeyWrap + // Param = curve_OID_len || curve_OID || public_key_alg_ID + // || KDF_params for AESKeyWrap // || "Anonymous Sender " || recipient_fingerprint; param := new(bytes.Buffer) if _, err := param.Write(curveOID); err != nil { return nil, err } - algKDF := []byte{18, 3, 1, pub.KDF.Hash.Id(), pub.KDF.Cipher.Id()} - if _, err := param.Write(algKDF); err != nil { + algo := []byte{18} + if _, err := param.Write(algo); err != nil { + return nil, err + } + + if err := pub.KDF.serializeForHash(param); err != nil { return nil, err } + if _, err := param.Write([]byte("Anonymous Sender ")); err != nil { return nil, err } - if _, err := param.Write(fingerprint[:]); err != nil { + + if pub.KDF.ReplacementFingerprint != nil { + fingerprint = pub.KDF.ReplacementFingerprint + } + + if _, err := param.Write(fingerprint); err != nil { return nil, err } @@ -204,3 +248,40 @@ func buildKey(pub *PublicKey, zb []byte, curveOID, fingerprint []byte, stripLead func Validate(priv *PrivateKey) error { return priv.curve.ValidateECDH(priv.Point, priv.D) } + +func DeriveProxyParam(recipientKey, forwardeeKey *PrivateKey) (proxyParam []byte, err error) { + if recipientKey.GetCurve().GetCurveName() != "curve25519" { + return nil, pgperrors.InvalidArgumentError("recipient subkey is not curve25519") + } + + if forwardeeKey.GetCurve().GetCurveName() != "curve25519" { + return nil, pgperrors.InvalidArgumentError("forwardee subkey is not curve25519") + } + + c := ecc.NewCurve25519() + + // Clamp and reverse two secrets + proxyParam, err = curve25519.DeriveProxyParam(c.MarshalByteSecret(recipientKey.D), c.MarshalByteSecret(forwardeeKey.D)) + + return proxyParam, err +} + +func ProxyTransform(ephemeral, proxyParam []byte) ([]byte, error) { + c := ecc.NewCurve25519() + + parsedEphemeral := c.UnmarshalBytePoint(ephemeral) + if parsedEphemeral == nil { + return nil, pgperrors.InvalidArgumentError("invalid ephemeral") + } + + if len(proxyParam) != curve25519.ParamSize { + return nil, pgperrors.InvalidArgumentError("invalid proxy parameter") + } + + transformed, err := curve25519.ProxyTransform(parsedEphemeral, proxyParam) + if err != nil { + return nil, err + } + + return c.MarshalBytePoint(transformed), nil +} diff --git a/openpgp/ecdh/ecdh_test.go b/openpgp/ecdh/ecdh_test.go index 1f70b7dd0..0170d776c 100644 --- a/openpgp/ecdh/ecdh_test.go +++ b/openpgp/ecdh/ecdh_test.go @@ -88,7 +88,7 @@ func testMarshalUnmarshal(t *testing.T, priv *PrivateKey) { p := priv.MarshalPoint() d := priv.MarshalByteSecret() - parsed := NewPrivateKey(*NewPublicKey(priv.GetCurve(), priv.KDF.Hash, priv.KDF.Cipher)) + parsed := NewPrivateKey(*NewPublicKey(priv.GetCurve(), priv.KDF)) if err := parsed.UnmarshalPoint(p); err != nil { t.Fatalf("unable to unmarshal point: %s", err) @@ -112,3 +112,37 @@ func testMarshalUnmarshal(t *testing.T, priv *PrivateKey) { t.Fatal("failed to marshal/unmarshal correctly") } } + +func TestKDFParamsWrite(t *testing.T) { + kdf := KDF{ + Hash: algorithm.SHA512, + Cipher: algorithm.AES256, + } + byteBuffer := new(bytes.Buffer) + + testFingerprint := make([]byte, 20) + + expectBytesV1 := []byte{3, 1, kdf.Hash.Id(), kdf.Cipher.Id()} + kdf.Serialize(byteBuffer) + gotBytes := byteBuffer.Bytes() + if !bytes.Equal(gotBytes, expectBytesV1) { + t.Errorf("error serializing KDF params, got %x, want: %x", gotBytes, expectBytesV1) + } + byteBuffer.Reset() + + kdfV2 := KDF{ + Version: KDFVersionForwarding, + Hash: algorithm.SHA512, + Cipher: algorithm.AES256, + ReplacementFingerprint: testFingerprint, + } + expectBytesV2 := []byte{23, 0xFF, kdfV2.Hash.Id(), kdfV2.Cipher.Id()} + expectBytesV2 = append(expectBytesV2, testFingerprint...) + + kdfV2.Serialize(byteBuffer) + gotBytes = byteBuffer.Bytes() + if !bytes.Equal(gotBytes, expectBytesV2) { + t.Errorf("error serializing KDF params v2, got %x, want: %x", gotBytes, expectBytesV2) + } + byteBuffer.Reset() +} diff --git a/openpgp/errors/errors.go b/openpgp/errors/errors.go index c42b01cb0..c20292c61 100644 --- a/openpgp/errors/errors.go +++ b/openpgp/errors/errors.go @@ -33,6 +33,8 @@ func (i InvalidArgumentError) Error() string { return "openpgp: invalid argument: " + string(i) } +var InvalidForwardeeKeyError = InvalidArgumentError("invalid forwardee key") + // SignatureError indicates that a syntactically valid signature failed to // validate. type SignatureError string diff --git a/openpgp/forwarding.go b/openpgp/forwarding.go new file mode 100644 index 000000000..ae45c3c2b --- /dev/null +++ b/openpgp/forwarding.go @@ -0,0 +1,163 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package openpgp + +import ( + goerrors "errors" + + "github.com/ProtonMail/go-crypto/openpgp/ecdh" + "github.com/ProtonMail/go-crypto/openpgp/errors" + "github.com/ProtonMail/go-crypto/openpgp/packet" +) + +// NewForwardingEntity generates a new forwardee key and derives the proxy parameters from the entity e. +// If strict, it will return an error if encryption-capable non-revoked subkeys with a wrong algorithm are found, +// instead of ignoring them +func (e *Entity) NewForwardingEntity( + name, comment, email string, config *packet.Config, strict bool, +) ( + forwardeeKey *Entity, instances []packet.ForwardingInstance, err error, +) { + if e.PrimaryKey.Version != 4 { + return nil, nil, errors.InvalidArgumentError("unsupported key version") + } + + now := config.Now() + i := e.PrimaryIdentity() + if e.PrimaryKey.KeyExpired(i.SelfSignature, now) || // primary key has expired + i.SelfSignature.SigExpired(now) || // user ID self-signature has expired + e.Revoked(now) || // primary key has been revoked + i.Revoked(now) { // user ID has been revoked + return nil, nil, errors.InvalidArgumentError("primary key is expired") + } + + // Generate a new Primary key for the forwardee + config.Algorithm = packet.PubKeyAlgoEdDSA + config.Curve = packet.Curve25519 + keyLifetimeSecs := config.KeyLifetime() + + forwardeePrimaryPrivRaw, err := newSigner(config) + if err != nil { + return nil, nil, err + } + + primary := packet.NewSignerPrivateKey(now, forwardeePrimaryPrivRaw) + + forwardeeKey = &Entity{ + PrimaryKey: &primary.PublicKey, + PrivateKey: primary, + Identities: make(map[string]*Identity), + Subkeys: []Subkey{}, + } + + err = forwardeeKey.addUserId(name, comment, email, config, now, keyLifetimeSecs, true) + if err != nil { + return nil, nil, err + } + + // Init empty instances + instances = []packet.ForwardingInstance{} + + // Handle all forwarder subkeys + for _, forwarderSubKey := range e.Subkeys { + // Filter flags + if !forwarderSubKey.PublicKey.PubKeyAlgo.CanEncrypt() { + continue + } + + // Filter expiration & revokal + if forwarderSubKey.PublicKey.KeyExpired(forwarderSubKey.Sig, now) || + forwarderSubKey.Sig.SigExpired(now) || + forwarderSubKey.Revoked(now) { + continue + } + + if forwarderSubKey.PublicKey.PubKeyAlgo != packet.PubKeyAlgoECDH { + if strict { + return nil, nil, errors.InvalidArgumentError("encryption subkey is not algorithm 18 (ECDH)") + } else { + continue + } + } + + forwarderEcdhKey, ok := forwarderSubKey.PrivateKey.PrivateKey.(*ecdh.PrivateKey) + if !ok { + return nil, nil, errors.InvalidArgumentError("malformed key") + } + + err = forwardeeKey.addEncryptionSubkey(config, now, 0) + if err != nil { + return nil, nil, err + } + + forwardeeSubKey := forwardeeKey.Subkeys[len(forwardeeKey.Subkeys)-1] + + forwardeeEcdhKey, ok := forwardeeSubKey.PrivateKey.PrivateKey.(*ecdh.PrivateKey) + if !ok { + return nil, nil, goerrors.New("wrong forwarding sub key generation") + } + + instance := packet.ForwardingInstance{ + KeyVersion: 4, + ForwarderFingerprint: forwarderSubKey.PublicKey.Fingerprint, + } + + instance.ProxyParameter, err = ecdh.DeriveProxyParam(forwarderEcdhKey, forwardeeEcdhKey) + if err != nil { + return nil, nil, err + } + + kdf := ecdh.KDF{ + Version: ecdh.KDFVersionForwarding, + Hash: forwarderEcdhKey.KDF.Hash, + Cipher: forwarderEcdhKey.KDF.Cipher, + } + + // If deriving a forwarding key from a forwarding key + if forwarderSubKey.Sig.FlagForward { + if forwarderEcdhKey.KDF.Version != ecdh.KDFVersionForwarding { + return nil, nil, goerrors.New("malformed forwarder key") + } + kdf.ReplacementFingerprint = forwarderEcdhKey.KDF.ReplacementFingerprint + } else { + kdf.ReplacementFingerprint = forwarderSubKey.PublicKey.Fingerprint + } + + err = forwardeeSubKey.PublicKey.ReplaceKDF(kdf) + if err != nil { + return nil, nil, err + } + + // Extract fingerprint after changing the KDF + instance.ForwardeeFingerprint = forwardeeSubKey.PublicKey.Fingerprint + + // 0x04 - This key may be used to encrypt communications. + forwardeeSubKey.Sig.FlagEncryptCommunications = false + + // 0x08 - This key may be used to encrypt storage. + forwardeeSubKey.Sig.FlagEncryptStorage = false + + // 0x10 - The private component of this key may have been split by a secret-sharing mechanism. + forwardeeSubKey.Sig.FlagSplitKey = true + + // 0x40 - This key may be used for forwarded communications. + forwardeeSubKey.Sig.FlagForward = true + + // Re-sign subkey binding signature + err = forwardeeSubKey.Sig.SignKey(forwardeeSubKey.PublicKey, forwardeeKey.PrivateKey, config) + if err != nil { + return nil, nil, err + } + + // Append each valid instance to the list + instances = append(instances, instance) + } + + if len(instances) == 0 { + return nil, nil, errors.InvalidArgumentError("no valid subkey found") + } + + return forwardeeKey, instances, nil +} diff --git a/openpgp/forwarding_test.go b/openpgp/forwarding_test.go new file mode 100644 index 000000000..7bc167180 --- /dev/null +++ b/openpgp/forwarding_test.go @@ -0,0 +1,224 @@ +package openpgp + +import ( + "bytes" + "crypto/rand" + goerrors "errors" + "io" + "io/ioutil" + "strings" + "testing" + + "github.com/ProtonMail/go-crypto/openpgp/packet" + "golang.org/x/crypto/openpgp/armor" +) + +const forwardeeKey = `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xVgEZAdtGBYJKwYBBAHaRw8BAQdAcNgHyRGEaqGmzEqEwCobfUkyrJnY8faBvsf9 +R2c5ZzYAAP9bFL4nPBdo04ei0C2IAh5RXOpmuejGC3GAIn/UmL5cYQ+XzRtjaGFy +bGVzIDxjaGFybGVzQHByb3Rvbi5tZT7CigQTFggAPAUCZAdtGAmQFXJtmBzDhdcW +IQRl2gNflypl1XjRUV8Vcm2YHMOF1wIbAwIeAQIZAQILBwIVCAIWAAIiAQAAJKYA +/2qY16Ozyo5erNz51UrKViEoWbEpwY3XaFVNzrw+b54YAQC7zXkf/t5ieylvjmA/ +LJz3/qgH5GxZRYAH9NTpWyW1AsdxBGQHbRgSCisGAQQBl1UBBQEBB0CxmxoJsHTW +TiETWh47ot+kwNA1hCk1IYB9WwKxkXYyIBf/CgmKXzV1ODP/mRmtiBYVV+VQk5MF +EAAA/1NW8D8nMc2ky140sPhQrwkeR7rVLKP2fe5n4BEtAnVQEB3CeAQYFggAKgUC +ZAdtGAmQFXJtmBzDhdcWIQRl2gNflypl1XjRUV8Vcm2YHMOF1wIbUAAAl/8A/iIS +zWBsBR8VnoOVfEE+VQk6YAi7cTSjcMjfsIez9FYtAQDKo9aCMhUohYyqvhZjn8aS +3t9mIZPc+zRJtCHzQYmhDg== +=lESj +-----END PGP PRIVATE KEY BLOCK-----` + +const forwardedMessage = `-----BEGIN PGP MESSAGE----- + +wV4DB27Wn97eACkSAQdA62TlMU2QoGmf5iBLnIm4dlFRkLIg+6MbaatghwxK+Ccw +yGZuVVMAK/ypFfebDf4D/rlEw3cysv213m8aoK8nAUO8xQX3XQq3Sg+EGm0BNV8E +0kABEPyCWARoo5klT1rHPEhelnz8+RQXiOIX3G685XCWdCmaV+tzW082D0xGXSlC +7lM8r1DumNnO8srssko2qIja +=pVRa +-----END PGP MESSAGE-----` + +const forwardedPlaintext = "Message for Bob" + +func TestForwardingStatic(t *testing.T) { + charlesKey, err := ReadArmoredKeyRing(bytes.NewBufferString(forwardeeKey)) + if err != nil { + t.Error(err) + return + } + + ciphertext, err := armor.Decode(strings.NewReader(forwardedMessage)) + if err != nil { + t.Error(err) + return + } + + m, err := ReadMessage(ciphertext.Body, charlesKey, nil, nil) + if err != nil { + t.Fatal(err) + } + + dec, err := ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, []byte(forwardedPlaintext)) { + t.Fatal("forwarded decrypted does not match original") + } +} + +func TestForwardingFull(t *testing.T) { + keyConfig := &packet.Config{ + Algorithm: packet.PubKeyAlgoEdDSA, + Curve: packet.Curve25519, + } + + plaintext := make([]byte, 1024) + rand.Read(plaintext) + + bobEntity, err := NewEntity("bob", "", "bob@proton.me", keyConfig) + if err != nil { + t.Fatal(err) + } + + charlesEntity, instances, err := bobEntity.NewForwardingEntity("charles", "", "charles@proton.me", keyConfig, true) + if err != nil { + t.Fatal(err) + } + + charlesEntity = serializeAndParseForwardeeKey(t, charlesEntity) + + if len(instances) != 1 { + t.Fatalf("invalid number of instances, expected 1 got %d", len(instances)) + } + + if !bytes.Equal(instances[0].ForwarderFingerprint, bobEntity.Subkeys[0].PublicKey.Fingerprint) { + t.Fatalf("invalid forwarder key ID, expected: %x, got: %x", bobEntity.Subkeys[0].PublicKey.Fingerprint, instances[0].ForwarderFingerprint) + } + + if !bytes.Equal(instances[0].ForwardeeFingerprint, charlesEntity.Subkeys[0].PublicKey.Fingerprint) { + t.Fatalf("invalid forwardee key ID, expected: %x, got: %x", charlesEntity.Subkeys[0].PublicKey.Fingerprint, instances[0].ForwardeeFingerprint) + } + + // Encrypt message + buf := bytes.NewBuffer(nil) + w, err := Encrypt(buf, []*Entity{bobEntity}, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + + _, err = w.Write(plaintext) + if err != nil { + t.Fatal(err) + } + + err = w.Close() + if err != nil { + t.Fatal(err) + } + + encrypted := buf.Bytes() + + // Decrypt message for Bob + m, err := ReadMessage(bytes.NewBuffer(encrypted), EntityList([]*Entity{bobEntity}), nil, nil) + if err != nil { + t.Fatal(err) + } + dec, err := ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("decrypted does not match original") + } + + // Forward message + + transformed := transformTestMessage(t, encrypted, instances[0]) + + // Decrypt forwarded message for Charles + m, err = ReadMessage(bytes.NewBuffer(transformed), EntityList([]*Entity{charlesEntity}), nil /* no prompt */, nil) + if err != nil { + t.Fatal(err) + } + + dec, err = ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("forwarded decrypted does not match original") + } + + // Setup further forwarding + danielEntity, secondForwardInstances, err := charlesEntity.NewForwardingEntity("Daniel", "", "daniel@proton.me", keyConfig, true) + if err != nil { + t.Fatal(err) + } + + danielEntity = serializeAndParseForwardeeKey(t, danielEntity) + + secondTransformed := transformTestMessage(t, transformed, secondForwardInstances[0]) + + // Decrypt forwarded message for Charles + m, err = ReadMessage(bytes.NewBuffer(secondTransformed), EntityList([]*Entity{danielEntity}), nil /* no prompt */, nil) + if err != nil { + t.Fatal(err) + } + + dec, err = ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("forwarded decrypted does not match original") + } +} + +func transformTestMessage(t *testing.T, encrypted []byte, instance packet.ForwardingInstance) []byte { + bytesReader := bytes.NewReader(encrypted) + packets := packet.NewReader(bytesReader) + splitPoint := int64(0) + transformedEncryptedKey := bytes.NewBuffer(nil) + +Loop: + for { + p, err := packets.Next() + if goerrors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("error in parsing message: %s", err) + } + switch p := p.(type) { + case *packet.EncryptedKey: + tp, err := p.ProxyTransform(instance) + if err != nil { + t.Fatalf("error transforming PKESK: %s", err) + } + + splitPoint = bytesReader.Size() - int64(bytesReader.Len()) + + err = tp.Serialize(transformedEncryptedKey) + if err != nil { + t.Fatalf("error serializing transformed PKESK: %s", err) + } + break Loop + } + } + + transformed := transformedEncryptedKey.Bytes() + transformed = append(transformed, encrypted[splitPoint:]...) + + return transformed +} + +func serializeAndParseForwardeeKey(t *testing.T, key *Entity) *Entity { + serializedEntity := bytes.NewBuffer(nil) + err := key.SerializePrivateWithoutSigning(serializedEntity, nil) + if err != nil { + t.Fatalf("Error in serializing forwardee key: %s", err) + } + el, err := ReadKeyRing(serializedEntity) + if err != nil { + t.Fatalf("Error in reading forwardee key: %s", err) + } + + if len(el) != 1 { + t.Fatalf("Wrong number of entities in parsing, expected 1, got %d", len(el)) + } + + return el[0] +} diff --git a/openpgp/internal/ecc/curve25519.go b/openpgp/internal/ecc/curve25519.go index 888767c4e..a6721ff98 100644 --- a/openpgp/internal/ecc/curve25519.go +++ b/openpgp/internal/ecc/curve25519.go @@ -3,10 +3,9 @@ package ecc import ( "crypto/subtle" - "io" - "github.com/ProtonMail/go-crypto/openpgp/errors" x25519lib "github.com/cloudflare/circl/dh/x25519" + "io" ) type curve25519 struct{} diff --git a/openpgp/internal/ecc/curve25519/curve25519.go b/openpgp/internal/ecc/curve25519/curve25519.go new file mode 100644 index 000000000..21670a82c --- /dev/null +++ b/openpgp/internal/ecc/curve25519/curve25519.go @@ -0,0 +1,122 @@ +// Package curve25519 implements custom field operations without clamping for forwarding. +package curve25519 + +import ( + "crypto/subtle" + "github.com/ProtonMail/go-crypto/openpgp/errors" + "github.com/ProtonMail/go-crypto/openpgp/internal/ecc/curve25519/field" + x25519lib "github.com/cloudflare/circl/dh/x25519" + "math/big" +) + +var curveGroupByte = [x25519lib.Size]byte{ + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6, 0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed, +} + +const ParamSize = x25519lib.Size + +func DeriveProxyParam(recipientSecretByte, forwardeeSecretByte []byte) (proxyParam []byte, err error) { + curveGroup := new(big.Int).SetBytes(curveGroupByte[:]) + recipientSecret := new(big.Int).SetBytes(recipientSecretByte) + forwardeeSecret := new(big.Int).SetBytes(forwardeeSecretByte) + + proxyTransform := new(big.Int).Mod( + new(big.Int).Mul( + new(big.Int).ModInverse(forwardeeSecret, curveGroup), + recipientSecret, + ), + curveGroup, + ) + + rawProxyParam := proxyTransform.Bytes() + + // pad and convert to small endian + proxyParam = make([]byte, x25519lib.Size) + l := len(rawProxyParam) + for i := 0; i < l; i++ { + proxyParam[i] = rawProxyParam[l-i-1] + } + + return proxyParam, nil +} + +func ProxyTransform(ephemeral, proxyParam []byte) ([]byte, error) { + var transformed, safetyCheck [x25519lib.Size]byte + + var scalarEight = make([]byte, x25519lib.Size) + scalarEight[0] = 0x08 + err := ScalarMult(&safetyCheck, scalarEight, ephemeral) + if err != nil { + return nil, err + } + + err = ScalarMult(&transformed, proxyParam, ephemeral) + if err != nil { + return nil, err + } + + return transformed[:], nil +} + +func ScalarMult(dst *[32]byte, scalar, point []byte) error { + var in, base, zero [32]byte + copy(in[:], scalar) + copy(base[:], point) + + scalarMult(dst, &in, &base) + if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 { + return errors.InvalidArgumentError("invalid ephemeral: low order point") + } + + return nil +} + +func scalarMult(dst, scalar, point *[32]byte) { + var e [32]byte + + copy(e[:], scalar[:]) + + var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element + x1.SetBytes(point[:]) + x2.One() + x3.Set(&x1) + z3.One() + + swap := 0 + for pos := 254; pos >= 0; pos-- { + b := e[pos/8] >> uint(pos&7) + b &= 1 + swap ^= int(b) + x2.Swap(&x3, swap) + z2.Swap(&z3, swap) + swap = int(b) + + tmp0.Subtract(&x3, &z3) + tmp1.Subtract(&x2, &z2) + x2.Add(&x2, &z2) + z2.Add(&x3, &z3) + z3.Multiply(&tmp0, &x2) + z2.Multiply(&z2, &tmp1) + tmp0.Square(&tmp1) + tmp1.Square(&x2) + x3.Add(&z3, &z2) + z2.Subtract(&z3, &z2) + x2.Multiply(&tmp1, &tmp0) + tmp1.Subtract(&tmp1, &tmp0) + z2.Square(&z2) + + z3.Mult32(&tmp1, 121666) + x3.Square(&x3) + tmp0.Add(&tmp0, &z3) + z3.Multiply(&x1, &z2) + z2.Multiply(&tmp1, &tmp0) + } + + x2.Swap(&x3, swap) + z2.Swap(&z3, swap) + + z2.Invert(&z2) + x2.Multiply(&x2, &z2) + copy(dst[:], x2.Bytes()) +} diff --git a/openpgp/internal/ecc/curve25519/curve25519_test.go b/openpgp/internal/ecc/curve25519/curve25519_test.go new file mode 100644 index 000000000..bd82e03e2 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/curve25519_test.go @@ -0,0 +1,89 @@ +// Package curve25519 implements custom field operations without clamping for forwarding. +package curve25519 + +import ( + "bytes" + "encoding/hex" + "testing" +) + +const ( + hexBobSecret = "5989216365053dcf9e35a04b2a1fc19b83328426be6bb7d0a2ae78105e2e3188" + hexCharlesSecret = "684da6225bcd44d880168fc5bec7d2f746217f014c8019005f144cc148f16a00" + hexExpectedProxyParam = "e89786987c3a3ec761a679bc372cd11a425eda72bd5265d78ad0f5f32ee64f02" + + hexMessagePoint = "aaea7b3bb92f5f545d023ccb15b50f84ba1bdd53be7f5cfadcfb0106859bf77e" + hexInputProxyParam = "83c57cbe645a132477af55d5020281305860201608e81a1de43ff83f245fb302" + hexExpectedTransformedPoint = "ec31bb937d7ef08c451d516be1d7976179aa7171eea598370661d1152b85005a" + + hexSmallSubgroupPoint = "ecffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f" +) + +func TestDeriveProxyParam(t *testing.T) { + bobSecret, err := hex.DecodeString(hexBobSecret) + if err != nil { + t.Fatalf("Unexpected error in decoding recipient secret: %s", err) + } + + charlesSecret, err := hex.DecodeString(hexCharlesSecret) + if err != nil { + t.Fatalf("Unexpected error in decoding forwardee secret: %s", err) + } + + expectedProxyParam, err := hex.DecodeString(hexExpectedProxyParam) + if err != nil { + t.Fatalf("Unexpected error in parameter decoding expected proxy parameter: %s", err) + } + + proxyParam, err := DeriveProxyParam(bobSecret, charlesSecret) + if err != nil { + t.Fatalf("Unexpected error in parameter derivation: %s", err) + } + + if bytes.Compare(proxyParam, expectedProxyParam) != 0 { + t.Errorf("Computed wrong proxy parameter, expected %x got %x", expectedProxyParam, proxyParam) + } +} + +func TestTransformMessage(t *testing.T) { + proxyParam, err := hex.DecodeString(hexInputProxyParam) + if err != nil { + t.Fatalf("Unexpected error in decoding proxy parameter: %s", err) + } + + messagePoint, err := hex.DecodeString(hexMessagePoint) + if err != nil { + t.Fatalf("Unexpected error in decoding message point: %s", err) + } + + expectedTransformed, err := hex.DecodeString(hexExpectedTransformedPoint) + if err != nil { + t.Fatalf("Unexpected error in parameter decoding expected transformed point: %s", err) + } + + transformed, err := ProxyTransform(messagePoint, proxyParam) + if err != nil { + t.Fatalf("Unexpected error in parameter derivation: %s", err) + } + + if bytes.Compare(transformed, expectedTransformed) != 0 { + t.Errorf("Computed wrong proxy parameter, expected %x got %x", expectedTransformed, transformed) + } +} + +func TestTransformSmallSubgroup(t *testing.T) { + proxyParam, err := hex.DecodeString(hexInputProxyParam) + if err != nil { + t.Fatalf("Unexpected error in decoding proxy parameter: %s", err) + } + + messagePoint, err := hex.DecodeString(hexSmallSubgroupPoint) + if err != nil { + t.Fatalf("Unexpected error in decoding small sugroup point: %s", err) + } + + _, err = ProxyTransform(messagePoint, proxyParam) + if err == nil { + t.Error("Expected small subgroup error") + } +} diff --git a/openpgp/internal/ecc/curve25519/field/fe.go b/openpgp/internal/ecc/curve25519/field/fe.go new file mode 100644 index 000000000..ca841ad99 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe.go @@ -0,0 +1,416 @@ +// Copyright (c) 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package field implements fast arithmetic modulo 2^255-19. +package field + +import ( + "crypto/subtle" + "encoding/binary" + "math/bits" +) + +// Element represents an element of the field GF(2^255-19). Note that this +// is not a cryptographically secure group, and should only be used to interact +// with edwards25519.Point coordinates. +// +// This type works similarly to math/big.Int, and all arguments and receivers +// are allowed to alias. +// +// The zero value is a valid zero element. +type Element struct { + // An element t represents the integer + // t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204 + // + // Between operations, all limbs are expected to be lower than 2^52. + l0 uint64 + l1 uint64 + l2 uint64 + l3 uint64 + l4 uint64 +} + +const maskLow51Bits uint64 = (1 << 51) - 1 + +var feZero = &Element{0, 0, 0, 0, 0} + +// Zero sets v = 0, and returns v. +func (v *Element) Zero() *Element { + *v = *feZero + return v +} + +var feOne = &Element{1, 0, 0, 0, 0} + +// One sets v = 1, and returns v. +func (v *Element) One() *Element { + *v = *feOne + return v +} + +// reduce reduces v modulo 2^255 - 19 and returns it. +func (v *Element) reduce() *Element { + v.carryPropagate() + + // After the light reduction we now have a field element representation + // v < 2^255 + 2^13 * 19, but need v < 2^255 - 19. + + // If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1, + // generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise. + c := (v.l0 + 19) >> 51 + c = (v.l1 + c) >> 51 + c = (v.l2 + c) >> 51 + c = (v.l3 + c) >> 51 + c = (v.l4 + c) >> 51 + + // If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's + // effectively applying the reduction identity to the carry. + v.l0 += 19 * c + + v.l1 += v.l0 >> 51 + v.l0 = v.l0 & maskLow51Bits + v.l2 += v.l1 >> 51 + v.l1 = v.l1 & maskLow51Bits + v.l3 += v.l2 >> 51 + v.l2 = v.l2 & maskLow51Bits + v.l4 += v.l3 >> 51 + v.l3 = v.l3 & maskLow51Bits + // no additional carry + v.l4 = v.l4 & maskLow51Bits + + return v +} + +// Add sets v = a + b, and returns v. +func (v *Element) Add(a, b *Element) *Element { + v.l0 = a.l0 + b.l0 + v.l1 = a.l1 + b.l1 + v.l2 = a.l2 + b.l2 + v.l3 = a.l3 + b.l3 + v.l4 = a.l4 + b.l4 + // Using the generic implementation here is actually faster than the + // assembly. Probably because the body of this function is so simple that + // the compiler can figure out better optimizations by inlining the carry + // propagation. TODO + return v.carryPropagateGeneric() +} + +// Subtract sets v = a - b, and returns v. +func (v *Element) Subtract(a, b *Element) *Element { + // We first add 2 * p, to guarantee the subtraction won't underflow, and + // then subtract b (which can be up to 2^255 + 2^13 * 19). + v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0 + v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1 + v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2 + v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3 + v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4 + return v.carryPropagate() +} + +// Negate sets v = -a, and returns v. +func (v *Element) Negate(a *Element) *Element { + return v.Subtract(feZero, a) +} + +// Invert sets v = 1/z mod p, and returns v. +// +// If z == 0, Invert returns v = 0. +func (v *Element) Invert(z *Element) *Element { + // Inversion is implemented as exponentiation with exponent p − 2. It uses the + // same sequence of 255 squarings and 11 multiplications as [Curve25519]. + var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element + + z2.Square(z) // 2 + t.Square(&z2) // 4 + t.Square(&t) // 8 + z9.Multiply(&t, z) // 9 + z11.Multiply(&z9, &z2) // 11 + t.Square(&z11) // 22 + z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0 + + t.Square(&z2_5_0) // 2^6 - 2^1 + for i := 0; i < 4; i++ { + t.Square(&t) // 2^10 - 2^5 + } + z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0 + + t.Square(&z2_10_0) // 2^11 - 2^1 + for i := 0; i < 9; i++ { + t.Square(&t) // 2^20 - 2^10 + } + z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0 + + t.Square(&z2_20_0) // 2^21 - 2^1 + for i := 0; i < 19; i++ { + t.Square(&t) // 2^40 - 2^20 + } + t.Multiply(&t, &z2_20_0) // 2^40 - 2^0 + + t.Square(&t) // 2^41 - 2^1 + for i := 0; i < 9; i++ { + t.Square(&t) // 2^50 - 2^10 + } + z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0 + + t.Square(&z2_50_0) // 2^51 - 2^1 + for i := 0; i < 49; i++ { + t.Square(&t) // 2^100 - 2^50 + } + z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0 + + t.Square(&z2_100_0) // 2^101 - 2^1 + for i := 0; i < 99; i++ { + t.Square(&t) // 2^200 - 2^100 + } + t.Multiply(&t, &z2_100_0) // 2^200 - 2^0 + + t.Square(&t) // 2^201 - 2^1 + for i := 0; i < 49; i++ { + t.Square(&t) // 2^250 - 2^50 + } + t.Multiply(&t, &z2_50_0) // 2^250 - 2^0 + + t.Square(&t) // 2^251 - 2^1 + t.Square(&t) // 2^252 - 2^2 + t.Square(&t) // 2^253 - 2^3 + t.Square(&t) // 2^254 - 2^4 + t.Square(&t) // 2^255 - 2^5 + + return v.Multiply(&t, &z11) // 2^255 - 21 +} + +// Set sets v = a, and returns v. +func (v *Element) Set(a *Element) *Element { + *v = *a + return v +} + +// SetBytes sets v to x, which must be a 32-byte little-endian encoding. +// +// Consistent with RFC 7748, the most significant bit (the high bit of the +// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1) +// are accepted. Note that this is laxer than specified by RFC 8032. +func (v *Element) SetBytes(x []byte) *Element { + if len(x) != 32 { + panic("edwards25519: invalid field element input size") + } + + // Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51). + v.l0 = binary.LittleEndian.Uint64(x[0:8]) + v.l0 &= maskLow51Bits + // Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51). + v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3 + v.l1 &= maskLow51Bits + // Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51). + v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6 + v.l2 &= maskLow51Bits + // Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51). + v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1 + v.l3 &= maskLow51Bits + // Bits 204:251 (bytes 24:32, bits 192:256, shift 12, mask 51). + // Note: not bytes 25:33, shift 4, to avoid overread. + v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12 + v.l4 &= maskLow51Bits + + return v +} + +// Bytes returns the canonical 32-byte little-endian encoding of v. +func (v *Element) Bytes() []byte { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [32]byte + return v.bytes(&out) +} + +func (v *Element) bytes(out *[32]byte) []byte { + t := *v + t.reduce() + + var buf [8]byte + for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} { + bitsOffset := i * 51 + binary.LittleEndian.PutUint64(buf[:], l<= len(out) { + break + } + out[off] |= bb + } + } + + return out[:] +} + +// Equal returns 1 if v and u are equal, and 0 otherwise. +func (v *Element) Equal(u *Element) int { + sa, sv := u.Bytes(), v.Bytes() + return subtle.ConstantTimeCompare(sa, sv) +} + +// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise. +func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) } + +// Select sets v to a if cond == 1, and to b if cond == 0. +func (v *Element) Select(a, b *Element, cond int) *Element { + m := mask64Bits(cond) + v.l0 = (m & a.l0) | (^m & b.l0) + v.l1 = (m & a.l1) | (^m & b.l1) + v.l2 = (m & a.l2) | (^m & b.l2) + v.l3 = (m & a.l3) | (^m & b.l3) + v.l4 = (m & a.l4) | (^m & b.l4) + return v +} + +// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v. +func (v *Element) Swap(u *Element, cond int) { + m := mask64Bits(cond) + t := m & (v.l0 ^ u.l0) + v.l0 ^= t + u.l0 ^= t + t = m & (v.l1 ^ u.l1) + v.l1 ^= t + u.l1 ^= t + t = m & (v.l2 ^ u.l2) + v.l2 ^= t + u.l2 ^= t + t = m & (v.l3 ^ u.l3) + v.l3 ^= t + u.l3 ^= t + t = m & (v.l4 ^ u.l4) + v.l4 ^= t + u.l4 ^= t +} + +// IsNegative returns 1 if v is negative, and 0 otherwise. +func (v *Element) IsNegative() int { + return int(v.Bytes()[0] & 1) +} + +// Absolute sets v to |u|, and returns v. +func (v *Element) Absolute(u *Element) *Element { + return v.Select(new(Element).Negate(u), u, u.IsNegative()) +} + +// Multiply sets v = x * y, and returns v. +func (v *Element) Multiply(x, y *Element) *Element { + feMul(v, x, y) + return v +} + +// Square sets v = x * x, and returns v. +func (v *Element) Square(x *Element) *Element { + feSquare(v, x) + return v +} + +// Mult32 sets v = x * y, and returns v. +func (v *Element) Mult32(x *Element, y uint32) *Element { + x0lo, x0hi := mul51(x.l0, y) + x1lo, x1hi := mul51(x.l1, y) + x2lo, x2hi := mul51(x.l2, y) + x3lo, x3hi := mul51(x.l3, y) + x4lo, x4hi := mul51(x.l4, y) + v.l0 = x0lo + 19*x4hi // carried over per the reduction identity + v.l1 = x1lo + x0hi + v.l2 = x2lo + x1hi + v.l3 = x3lo + x2hi + v.l4 = x4lo + x3hi + // The hi portions are going to be only 32 bits, plus any previous excess, + // so we can skip the carry propagation. + return v +} + +// mul51 returns lo + hi * 2⁵¹ = a * b. +func mul51(a uint64, b uint32) (lo uint64, hi uint64) { + mh, ml := bits.Mul64(a, uint64(b)) + lo = ml & maskLow51Bits + hi = (mh << 13) | (ml >> 51) + return +} + +// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3. +func (v *Element) Pow22523(x *Element) *Element { + var t0, t1, t2 Element + + t0.Square(x) // x^2 + t1.Square(&t0) // x^4 + t1.Square(&t1) // x^8 + t1.Multiply(x, &t1) // x^9 + t0.Multiply(&t0, &t1) // x^11 + t0.Square(&t0) // x^22 + t0.Multiply(&t1, &t0) // x^31 + t1.Square(&t0) // x^62 + for i := 1; i < 5; i++ { // x^992 + t1.Square(&t1) + } + t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1 + t1.Square(&t0) // 2^11 - 2 + for i := 1; i < 10; i++ { // 2^20 - 2^10 + t1.Square(&t1) + } + t1.Multiply(&t1, &t0) // 2^20 - 1 + t2.Square(&t1) // 2^21 - 2 + for i := 1; i < 20; i++ { // 2^40 - 2^20 + t2.Square(&t2) + } + t1.Multiply(&t2, &t1) // 2^40 - 1 + t1.Square(&t1) // 2^41 - 2 + for i := 1; i < 10; i++ { // 2^50 - 2^10 + t1.Square(&t1) + } + t0.Multiply(&t1, &t0) // 2^50 - 1 + t1.Square(&t0) // 2^51 - 2 + for i := 1; i < 50; i++ { // 2^100 - 2^50 + t1.Square(&t1) + } + t1.Multiply(&t1, &t0) // 2^100 - 1 + t2.Square(&t1) // 2^101 - 2 + for i := 1; i < 100; i++ { // 2^200 - 2^100 + t2.Square(&t2) + } + t1.Multiply(&t2, &t1) // 2^200 - 1 + t1.Square(&t1) // 2^201 - 2 + for i := 1; i < 50; i++ { // 2^250 - 2^50 + t1.Square(&t1) + } + t0.Multiply(&t1, &t0) // 2^250 - 1 + t0.Square(&t0) // 2^251 - 2 + t0.Square(&t0) // 2^252 - 4 + return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3) +} + +// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion. +var sqrtM1 = &Element{1718705420411056, 234908883556509, + 2233514472574048, 2117202627021982, 765476049583133} + +// SqrtRatio sets r to the non-negative square root of the ratio of u and v. +// +// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio +// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00, +// and returns r and 0. +func (r *Element) SqrtRatio(u, v *Element) (rr *Element, wasSquare int) { + var a, b Element + + // r = (u * v3) * (u * v7)^((p-5)/8) + v2 := a.Square(v) + uv3 := b.Multiply(u, b.Multiply(v2, v)) + uv7 := a.Multiply(uv3, a.Square(v2)) + r.Multiply(uv3, r.Pow22523(uv7)) + + check := a.Multiply(v, a.Square(r)) // check = v * r^2 + + uNeg := b.Negate(u) + correctSignSqrt := check.Equal(u) + flippedSignSqrt := check.Equal(uNeg) + flippedSignSqrtI := check.Equal(uNeg.Multiply(uNeg, sqrtM1)) + + rPrime := b.Multiply(r, sqrtM1) // r_prime = SQRT_M1 * r + // r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r) + r.Select(rPrime, r, flippedSignSqrt|flippedSignSqrtI) + + r.Absolute(r) // Choose the nonnegative square root. + return r, correctSignSqrt | flippedSignSqrt +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_alias_test.go b/openpgp/internal/ecc/curve25519/field/fe_alias_test.go new file mode 100644 index 000000000..64e57c4f3 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_alias_test.go @@ -0,0 +1,126 @@ +// Copyright (c) 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import ( + "testing" + "testing/quick" +) + +func checkAliasingOneArg(f func(v, x *Element) *Element) func(v, x Element) bool { + return func(v, x Element) bool { + x1, v1 := x, x + + // Calculate a reference f(x) without aliasing. + if out := f(&v, &x); out != &v && isInBounds(out) { + return false + } + + // Test aliasing the argument and the receiver. + if out := f(&v1, &v1); out != &v1 || v1 != v { + return false + } + + // Ensure the arguments was not modified. + return x == x1 + } +} + +func checkAliasingTwoArgs(f func(v, x, y *Element) *Element) func(v, x, y Element) bool { + return func(v, x, y Element) bool { + x1, y1, v1 := x, y, Element{} + + // Calculate a reference f(x, y) without aliasing. + if out := f(&v, &x, &y); out != &v && isInBounds(out) { + return false + } + + // Test aliasing the first argument and the receiver. + v1 = x + if out := f(&v1, &v1, &y); out != &v1 || v1 != v { + return false + } + // Test aliasing the second argument and the receiver. + v1 = y + if out := f(&v1, &x, &v1); out != &v1 || v1 != v { + return false + } + + // Calculate a reference f(x, x) without aliasing. + if out := f(&v, &x, &x); out != &v { + return false + } + + // Test aliasing the first argument and the receiver. + v1 = x + if out := f(&v1, &v1, &x); out != &v1 || v1 != v { + return false + } + // Test aliasing the second argument and the receiver. + v1 = x + if out := f(&v1, &x, &v1); out != &v1 || v1 != v { + return false + } + // Test aliasing both arguments and the receiver. + v1 = x + if out := f(&v1, &v1, &v1); out != &v1 || v1 != v { + return false + } + + // Ensure the arguments were not modified. + return x == x1 && y == y1 + } +} + +// TestAliasing checks that receivers and arguments can alias each other without +// leading to incorrect results. That is, it ensures that it's safe to write +// +// v.Invert(v) +// +// or +// +// v.Add(v, v) +// +// without any of the inputs getting clobbered by the output being written. +func TestAliasing(t *testing.T) { + type target struct { + name string + oneArgF func(v, x *Element) *Element + twoArgsF func(v, x, y *Element) *Element + } + for _, tt := range []target{ + {name: "Absolute", oneArgF: (*Element).Absolute}, + {name: "Invert", oneArgF: (*Element).Invert}, + {name: "Negate", oneArgF: (*Element).Negate}, + {name: "Set", oneArgF: (*Element).Set}, + {name: "Square", oneArgF: (*Element).Square}, + {name: "Multiply", twoArgsF: (*Element).Multiply}, + {name: "Add", twoArgsF: (*Element).Add}, + {name: "Subtract", twoArgsF: (*Element).Subtract}, + { + name: "Select0", + twoArgsF: func(v, x, y *Element) *Element { + return (*Element).Select(v, x, y, 0) + }, + }, + { + name: "Select1", + twoArgsF: func(v, x, y *Element) *Element { + return (*Element).Select(v, x, y, 1) + }, + }, + } { + var err error + switch { + case tt.oneArgF != nil: + err = quick.Check(checkAliasingOneArg(tt.oneArgF), &quick.Config{MaxCountScale: 1 << 8}) + case tt.twoArgsF != nil: + err = quick.Check(checkAliasingTwoArgs(tt.twoArgsF), &quick.Config{MaxCountScale: 1 << 8}) + } + if err != nil { + t.Errorf("%v: %v", tt.name, err) + } + } +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_amd64.go b/openpgp/internal/ecc/curve25519/field/fe_amd64.go new file mode 100644 index 000000000..edcf163c4 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_amd64.go @@ -0,0 +1,16 @@ +// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT. + +//go:build amd64 && gc && !purego +// +build amd64,gc,!purego + +package field + +// feMul sets out = a * b. It works like feMulGeneric. +// +//go:noescape +func feMul(out *Element, a *Element, b *Element) + +// feSquare sets out = a * a. It works like feSquareGeneric. +// +//go:noescape +func feSquare(out *Element, a *Element) diff --git a/openpgp/internal/ecc/curve25519/field/fe_amd64.s b/openpgp/internal/ecc/curve25519/field/fe_amd64.s new file mode 100644 index 000000000..293f013c9 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_amd64.s @@ -0,0 +1,379 @@ +// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT. + +//go:build amd64 && gc && !purego +// +build amd64,gc,!purego + +#include "textflag.h" + +// func feMul(out *Element, a *Element, b *Element) +TEXT ·feMul(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), CX + MOVQ b+16(FP), BX + + // r0 = a0×b0 + MOVQ (CX), AX + MULQ (BX) + MOVQ AX, DI + MOVQ DX, SI + + // r0 += 19×a1×b4 + MOVQ 8(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r0 += 19×a2×b3 + MOVQ 16(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r0 += 19×a3×b2 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 16(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r0 += 19×a4×b1 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 8(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r1 = a0×b1 + MOVQ (CX), AX + MULQ 8(BX) + MOVQ AX, R9 + MOVQ DX, R8 + + // r1 += a1×b0 + MOVQ 8(CX), AX + MULQ (BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r1 += 19×a2×b4 + MOVQ 16(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r1 += 19×a3×b3 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r1 += 19×a4×b2 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 16(BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r2 = a0×b2 + MOVQ (CX), AX + MULQ 16(BX) + MOVQ AX, R11 + MOVQ DX, R10 + + // r2 += a1×b1 + MOVQ 8(CX), AX + MULQ 8(BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r2 += a2×b0 + MOVQ 16(CX), AX + MULQ (BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r2 += 19×a3×b4 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r2 += 19×a4×b3 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r3 = a0×b3 + MOVQ (CX), AX + MULQ 24(BX) + MOVQ AX, R13 + MOVQ DX, R12 + + // r3 += a1×b2 + MOVQ 8(CX), AX + MULQ 16(BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r3 += a2×b1 + MOVQ 16(CX), AX + MULQ 8(BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r3 += a3×b0 + MOVQ 24(CX), AX + MULQ (BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r3 += 19×a4×b4 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r4 = a0×b4 + MOVQ (CX), AX + MULQ 32(BX) + MOVQ AX, R15 + MOVQ DX, R14 + + // r4 += a1×b3 + MOVQ 8(CX), AX + MULQ 24(BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // r4 += a2×b2 + MOVQ 16(CX), AX + MULQ 16(BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // r4 += a3×b1 + MOVQ 24(CX), AX + MULQ 8(BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // r4 += a4×b0 + MOVQ 32(CX), AX + MULQ (BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // First reduction chain + MOVQ $0x0007ffffffffffff, AX + SHLQ $0x0d, DI, SI + SHLQ $0x0d, R9, R8 + SHLQ $0x0d, R11, R10 + SHLQ $0x0d, R13, R12 + SHLQ $0x0d, R15, R14 + ANDQ AX, DI + IMUL3Q $0x13, R14, R14 + ADDQ R14, DI + ANDQ AX, R9 + ADDQ SI, R9 + ANDQ AX, R11 + ADDQ R8, R11 + ANDQ AX, R13 + ADDQ R10, R13 + ANDQ AX, R15 + ADDQ R12, R15 + + // Second reduction chain (carryPropagate) + MOVQ DI, SI + SHRQ $0x33, SI + MOVQ R9, R8 + SHRQ $0x33, R8 + MOVQ R11, R10 + SHRQ $0x33, R10 + MOVQ R13, R12 + SHRQ $0x33, R12 + MOVQ R15, R14 + SHRQ $0x33, R14 + ANDQ AX, DI + IMUL3Q $0x13, R14, R14 + ADDQ R14, DI + ANDQ AX, R9 + ADDQ SI, R9 + ANDQ AX, R11 + ADDQ R8, R11 + ANDQ AX, R13 + ADDQ R10, R13 + ANDQ AX, R15 + ADDQ R12, R15 + + // Store output + MOVQ out+0(FP), AX + MOVQ DI, (AX) + MOVQ R9, 8(AX) + MOVQ R11, 16(AX) + MOVQ R13, 24(AX) + MOVQ R15, 32(AX) + RET + +// func feSquare(out *Element, a *Element) +TEXT ·feSquare(SB), NOSPLIT, $0-16 + MOVQ a+8(FP), CX + + // r0 = l0×l0 + MOVQ (CX), AX + MULQ (CX) + MOVQ AX, SI + MOVQ DX, BX + + // r0 += 38×l1×l4 + MOVQ 8(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 32(CX) + ADDQ AX, SI + ADCQ DX, BX + + // r0 += 38×l2×l3 + MOVQ 16(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 24(CX) + ADDQ AX, SI + ADCQ DX, BX + + // r1 = 2×l0×l1 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 8(CX) + MOVQ AX, R8 + MOVQ DX, DI + + // r1 += 38×l2×l4 + MOVQ 16(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 32(CX) + ADDQ AX, R8 + ADCQ DX, DI + + // r1 += 19×l3×l3 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(CX) + ADDQ AX, R8 + ADCQ DX, DI + + // r2 = 2×l0×l2 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 16(CX) + MOVQ AX, R10 + MOVQ DX, R9 + + // r2 += l1×l1 + MOVQ 8(CX), AX + MULQ 8(CX) + ADDQ AX, R10 + ADCQ DX, R9 + + // r2 += 38×l3×l4 + MOVQ 24(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 32(CX) + ADDQ AX, R10 + ADCQ DX, R9 + + // r3 = 2×l0×l3 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 24(CX) + MOVQ AX, R12 + MOVQ DX, R11 + + // r3 += 2×l1×l2 + MOVQ 8(CX), AX + IMUL3Q $0x02, AX, AX + MULQ 16(CX) + ADDQ AX, R12 + ADCQ DX, R11 + + // r3 += 19×l4×l4 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(CX) + ADDQ AX, R12 + ADCQ DX, R11 + + // r4 = 2×l0×l4 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 32(CX) + MOVQ AX, R14 + MOVQ DX, R13 + + // r4 += 2×l1×l3 + MOVQ 8(CX), AX + IMUL3Q $0x02, AX, AX + MULQ 24(CX) + ADDQ AX, R14 + ADCQ DX, R13 + + // r4 += l2×l2 + MOVQ 16(CX), AX + MULQ 16(CX) + ADDQ AX, R14 + ADCQ DX, R13 + + // First reduction chain + MOVQ $0x0007ffffffffffff, AX + SHLQ $0x0d, SI, BX + SHLQ $0x0d, R8, DI + SHLQ $0x0d, R10, R9 + SHLQ $0x0d, R12, R11 + SHLQ $0x0d, R14, R13 + ANDQ AX, SI + IMUL3Q $0x13, R13, R13 + ADDQ R13, SI + ANDQ AX, R8 + ADDQ BX, R8 + ANDQ AX, R10 + ADDQ DI, R10 + ANDQ AX, R12 + ADDQ R9, R12 + ANDQ AX, R14 + ADDQ R11, R14 + + // Second reduction chain (carryPropagate) + MOVQ SI, BX + SHRQ $0x33, BX + MOVQ R8, DI + SHRQ $0x33, DI + MOVQ R10, R9 + SHRQ $0x33, R9 + MOVQ R12, R11 + SHRQ $0x33, R11 + MOVQ R14, R13 + SHRQ $0x33, R13 + ANDQ AX, SI + IMUL3Q $0x13, R13, R13 + ADDQ R13, SI + ANDQ AX, R8 + ADDQ BX, R8 + ANDQ AX, R10 + ADDQ DI, R10 + ANDQ AX, R12 + ADDQ R9, R12 + ANDQ AX, R14 + ADDQ R11, R14 + + // Store output + MOVQ out+0(FP), AX + MOVQ SI, (AX) + MOVQ R8, 8(AX) + MOVQ R10, 16(AX) + MOVQ R12, 24(AX) + MOVQ R14, 32(AX) + RET diff --git a/openpgp/internal/ecc/curve25519/field/fe_amd64_noasm.go b/openpgp/internal/ecc/curve25519/field/fe_amd64_noasm.go new file mode 100644 index 000000000..ddb6c9b8f --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_amd64_noasm.go @@ -0,0 +1,12 @@ +// Copyright (c) 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !amd64 || !gc || purego +// +build !amd64 !gc purego + +package field + +func feMul(v, x, y *Element) { feMulGeneric(v, x, y) } + +func feSquare(v, x *Element) { feSquareGeneric(v, x) } diff --git a/openpgp/internal/ecc/curve25519/field/fe_arm64.go b/openpgp/internal/ecc/curve25519/field/fe_arm64.go new file mode 100644 index 000000000..af459ef51 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_arm64.go @@ -0,0 +1,16 @@ +// Copyright (c) 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build arm64 && gc && !purego +// +build arm64,gc,!purego + +package field + +//go:noescape +func carryPropagate(v *Element) + +func (v *Element) carryPropagate() *Element { + carryPropagate(v) + return v +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_arm64.s b/openpgp/internal/ecc/curve25519/field/fe_arm64.s new file mode 100644 index 000000000..5c91e4589 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_arm64.s @@ -0,0 +1,43 @@ +// Copyright (c) 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build arm64 && gc && !purego +// +build arm64,gc,!purego + +#include "textflag.h" + +// carryPropagate works exactly like carryPropagateGeneric and uses the +// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but +// avoids loading R0-R4 twice and uses LDP and STP. +// +// See https://golang.org/issues/43145 for the main compiler issue. +// +// func carryPropagate(v *Element) +TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8 + MOVD v+0(FP), R20 + + LDP 0(R20), (R0, R1) + LDP 16(R20), (R2, R3) + MOVD 32(R20), R4 + + AND $0x7ffffffffffff, R0, R10 + AND $0x7ffffffffffff, R1, R11 + AND $0x7ffffffffffff, R2, R12 + AND $0x7ffffffffffff, R3, R13 + AND $0x7ffffffffffff, R4, R14 + + ADD R0>>51, R11, R11 + ADD R1>>51, R12, R12 + ADD R2>>51, R13, R13 + ADD R3>>51, R14, R14 + // R4>>51 * 19 + R10 -> R10 + LSR $51, R4, R21 + MOVD $19, R22 + MADD R22, R10, R21, R10 + + STP (R10, R11), 0(R20) + STP (R12, R13), 16(R20) + MOVD R14, 32(R20) + + RET diff --git a/openpgp/internal/ecc/curve25519/field/fe_arm64_noasm.go b/openpgp/internal/ecc/curve25519/field/fe_arm64_noasm.go new file mode 100644 index 000000000..234a5b2e5 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_arm64_noasm.go @@ -0,0 +1,12 @@ +// Copyright (c) 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !arm64 || !gc || purego +// +build !arm64 !gc purego + +package field + +func (v *Element) carryPropagate() *Element { + return v.carryPropagateGeneric() +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_bench_test.go b/openpgp/internal/ecc/curve25519/field/fe_bench_test.go new file mode 100644 index 000000000..77dc06cf9 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_bench_test.go @@ -0,0 +1,36 @@ +// Copyright (c) 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import "testing" + +func BenchmarkAdd(b *testing.B) { + var x, y Element + x.One() + y.Add(feOne, feOne) + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Add(&x, &y) + } +} + +func BenchmarkMultiply(b *testing.B) { + var x, y Element + x.One() + y.Add(feOne, feOne) + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Multiply(&x, &y) + } +} + +func BenchmarkMult32(b *testing.B) { + var x Element + x.One() + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Mult32(&x, 0xaa42aa42) + } +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_generic.go b/openpgp/internal/ecc/curve25519/field/fe_generic.go new file mode 100644 index 000000000..7b5b78cbd --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_generic.go @@ -0,0 +1,264 @@ +// Copyright (c) 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import "math/bits" + +// uint128 holds a 128-bit number as two 64-bit limbs, for use with the +// bits.Mul64 and bits.Add64 intrinsics. +type uint128 struct { + lo, hi uint64 +} + +// mul64 returns a * b. +func mul64(a, b uint64) uint128 { + hi, lo := bits.Mul64(a, b) + return uint128{lo, hi} +} + +// addMul64 returns v + a * b. +func addMul64(v uint128, a, b uint64) uint128 { + hi, lo := bits.Mul64(a, b) + lo, c := bits.Add64(lo, v.lo, 0) + hi, _ = bits.Add64(hi, v.hi, c) + return uint128{lo, hi} +} + +// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits. +func shiftRightBy51(a uint128) uint64 { + return (a.hi << (64 - 51)) | (a.lo >> 51) +} + +func feMulGeneric(v, a, b *Element) { + a0 := a.l0 + a1 := a.l1 + a2 := a.l2 + a3 := a.l3 + a4 := a.l4 + + b0 := b.l0 + b1 := b.l1 + b2 := b.l2 + b3 := b.l3 + b4 := b.l4 + + // Limb multiplication works like pen-and-paper columnar multiplication, but + // with 51-bit limbs instead of digits. + // + // a4 a3 a2 a1 a0 x + // b4 b3 b2 b1 b0 = + // ------------------------ + // a4b0 a3b0 a2b0 a1b0 a0b0 + + // a4b1 a3b1 a2b1 a1b1 a0b1 + + // a4b2 a3b2 a2b2 a1b2 a0b2 + + // a4b3 a3b3 a2b3 a1b3 a0b3 + + // a4b4 a3b4 a2b4 a1b4 a0b4 = + // ---------------------------------------------- + // r8 r7 r6 r5 r4 r3 r2 r1 r0 + // + // We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to + // reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5, + // r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc. + // + // Reduction can be carried out simultaneously to multiplication. For + // example, we do not compute r5: whenever the result of a multiplication + // belongs to r5, like a1b4, we multiply it by 19 and add the result to r0. + // + // a4b0 a3b0 a2b0 a1b0 a0b0 + + // a3b1 a2b1 a1b1 a0b1 19×a4b1 + + // a2b2 a1b2 a0b2 19×a4b2 19×a3b2 + + // a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 + + // a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 = + // -------------------------------------- + // r4 r3 r2 r1 r0 + // + // Finally we add up the columns into wide, overlapping limbs. + + a1_19 := a1 * 19 + a2_19 := a2 * 19 + a3_19 := a3 * 19 + a4_19 := a4 * 19 + + // r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1) + r0 := mul64(a0, b0) + r0 = addMul64(r0, a1_19, b4) + r0 = addMul64(r0, a2_19, b3) + r0 = addMul64(r0, a3_19, b2) + r0 = addMul64(r0, a4_19, b1) + + // r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2) + r1 := mul64(a0, b1) + r1 = addMul64(r1, a1, b0) + r1 = addMul64(r1, a2_19, b4) + r1 = addMul64(r1, a3_19, b3) + r1 = addMul64(r1, a4_19, b2) + + // r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3) + r2 := mul64(a0, b2) + r2 = addMul64(r2, a1, b1) + r2 = addMul64(r2, a2, b0) + r2 = addMul64(r2, a3_19, b4) + r2 = addMul64(r2, a4_19, b3) + + // r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4 + r3 := mul64(a0, b3) + r3 = addMul64(r3, a1, b2) + r3 = addMul64(r3, a2, b1) + r3 = addMul64(r3, a3, b0) + r3 = addMul64(r3, a4_19, b4) + + // r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0 + r4 := mul64(a0, b4) + r4 = addMul64(r4, a1, b3) + r4 = addMul64(r4, a2, b2) + r4 = addMul64(r4, a3, b1) + r4 = addMul64(r4, a4, b0) + + // After the multiplication, we need to reduce (carry) the five coefficients + // to obtain a result with limbs that are at most slightly larger than 2⁵¹, + // to respect the Element invariant. + // + // Overall, the reduction works the same as carryPropagate, except with + // wider inputs: we take the carry for each coefficient by shifting it right + // by 51, and add it to the limb above it. The top carry is multiplied by 19 + // according to the reduction identity and added to the lowest limb. + // + // The largest coefficient (r0) will be at most 111 bits, which guarantees + // that all carries are at most 111 - 51 = 60 bits, which fits in a uint64. + // + // r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1) + // r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²) + // r0 < (1 + 19 × 4) × 2⁵² × 2⁵² + // r0 < 2⁷ × 2⁵² × 2⁵² + // r0 < 2¹¹¹ + // + // Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most + // 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and + // allows us to easily apply the reduction identity. + // + // r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0 + // r4 < 5 × 2⁵² × 2⁵² + // r4 < 2¹⁰⁷ + // + + c0 := shiftRightBy51(r0) + c1 := shiftRightBy51(r1) + c2 := shiftRightBy51(r2) + c3 := shiftRightBy51(r3) + c4 := shiftRightBy51(r4) + + rr0 := r0.lo&maskLow51Bits + c4*19 + rr1 := r1.lo&maskLow51Bits + c0 + rr2 := r2.lo&maskLow51Bits + c1 + rr3 := r3.lo&maskLow51Bits + c2 + rr4 := r4.lo&maskLow51Bits + c3 + + // Now all coefficients fit into 64-bit registers but are still too large to + // be passed around as a Element. We therefore do one last carry chain, + // where the carries will be small enough to fit in the wiggle room above 2⁵¹. + *v = Element{rr0, rr1, rr2, rr3, rr4} + v.carryPropagate() +} + +func feSquareGeneric(v, a *Element) { + l0 := a.l0 + l1 := a.l1 + l2 := a.l2 + l3 := a.l3 + l4 := a.l4 + + // Squaring works precisely like multiplication above, but thanks to its + // symmetry we get to group a few terms together. + // + // l4 l3 l2 l1 l0 x + // l4 l3 l2 l1 l0 = + // ------------------------ + // l4l0 l3l0 l2l0 l1l0 l0l0 + + // l4l1 l3l1 l2l1 l1l1 l0l1 + + // l4l2 l3l2 l2l2 l1l2 l0l2 + + // l4l3 l3l3 l2l3 l1l3 l0l3 + + // l4l4 l3l4 l2l4 l1l4 l0l4 = + // ---------------------------------------------- + // r8 r7 r6 r5 r4 r3 r2 r1 r0 + // + // l4l0 l3l0 l2l0 l1l0 l0l0 + + // l3l1 l2l1 l1l1 l0l1 19×l4l1 + + // l2l2 l1l2 l0l2 19×l4l2 19×l3l2 + + // l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 + + // l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 = + // -------------------------------------- + // r4 r3 r2 r1 r0 + // + // With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with + // only three Mul64 and four Add64, instead of five and eight. + + l0_2 := l0 * 2 + l1_2 := l1 * 2 + + l1_38 := l1 * 38 + l2_38 := l2 * 38 + l3_38 := l3 * 38 + + l3_19 := l3 * 19 + l4_19 := l4 * 19 + + // r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3) + r0 := mul64(l0, l0) + r0 = addMul64(r0, l1_38, l4) + r0 = addMul64(r0, l2_38, l3) + + // r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3 + r1 := mul64(l0_2, l1) + r1 = addMul64(r1, l2_38, l4) + r1 = addMul64(r1, l3_19, l3) + + // r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4 + r2 := mul64(l0_2, l2) + r2 = addMul64(r2, l1, l1) + r2 = addMul64(r2, l3_38, l4) + + // r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4 + r3 := mul64(l0_2, l3) + r3 = addMul64(r3, l1_2, l2) + r3 = addMul64(r3, l4_19, l4) + + // r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2 + r4 := mul64(l0_2, l4) + r4 = addMul64(r4, l1_2, l3) + r4 = addMul64(r4, l2, l2) + + c0 := shiftRightBy51(r0) + c1 := shiftRightBy51(r1) + c2 := shiftRightBy51(r2) + c3 := shiftRightBy51(r3) + c4 := shiftRightBy51(r4) + + rr0 := r0.lo&maskLow51Bits + c4*19 + rr1 := r1.lo&maskLow51Bits + c0 + rr2 := r2.lo&maskLow51Bits + c1 + rr3 := r3.lo&maskLow51Bits + c2 + rr4 := r4.lo&maskLow51Bits + c3 + + *v = Element{rr0, rr1, rr2, rr3, rr4} + v.carryPropagate() +} + +// carryPropagate brings the limbs below 52 bits by applying the reduction +// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry. TODO inline +func (v *Element) carryPropagateGeneric() *Element { + c0 := v.l0 >> 51 + c1 := v.l1 >> 51 + c2 := v.l2 >> 51 + c3 := v.l3 >> 51 + c4 := v.l4 >> 51 + + v.l0 = v.l0&maskLow51Bits + c4*19 + v.l1 = v.l1&maskLow51Bits + c0 + v.l2 = v.l2&maskLow51Bits + c1 + v.l3 = v.l3&maskLow51Bits + c2 + v.l4 = v.l4&maskLow51Bits + c3 + + return v +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_test.go b/openpgp/internal/ecc/curve25519/field/fe_test.go new file mode 100644 index 000000000..b484459ff --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_test.go @@ -0,0 +1,558 @@ +// Copyright (c) 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "io" + "math/big" + "math/bits" + mathrand "math/rand" + "reflect" + "testing" + "testing/quick" +) + +func (v Element) String() string { + return hex.EncodeToString(v.Bytes()) +} + +// quickCheckConfig1024 will make each quickcheck test run (1024 * -quickchecks) +// times. The default value of -quickchecks is 100. +var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10} + +func generateFieldElement(rand *mathrand.Rand) Element { + const maskLow52Bits = (1 << 52) - 1 + return Element{ + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + } +} + +// weirdLimbs can be combined to generate a range of edge-case field elements. +// 0 and -1 are intentionally more weighted, as they combine well. +var ( + weirdLimbs51 = []uint64{ + 0, 0, 0, 0, + 1, + 19 - 1, + 19, + 0x2aaaaaaaaaaaa, + 0x5555555555555, + (1 << 51) - 20, + (1 << 51) - 19, + (1 << 51) - 1, (1 << 51) - 1, + (1 << 51) - 1, (1 << 51) - 1, + } + weirdLimbs52 = []uint64{ + 0, 0, 0, 0, 0, 0, + 1, + 19 - 1, + 19, + 0x2aaaaaaaaaaaa, + 0x5555555555555, + (1 << 51) - 20, + (1 << 51) - 19, + (1 << 51) - 1, (1 << 51) - 1, + (1 << 51) - 1, (1 << 51) - 1, + (1 << 51) - 1, (1 << 51) - 1, + 1 << 51, + (1 << 51) + 1, + (1 << 52) - 19, + (1 << 52) - 1, + } +) + +func generateWeirdFieldElement(rand *mathrand.Rand) Element { + return Element{ + weirdLimbs52[rand.Intn(len(weirdLimbs52))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + } +} + +func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value { + if rand.Intn(2) == 0 { + return reflect.ValueOf(generateWeirdFieldElement(rand)) + } + return reflect.ValueOf(generateFieldElement(rand)) +} + +// isInBounds returns whether the element is within the expected bit size bounds +// after a light reduction. +func isInBounds(x *Element) bool { + return bits.Len64(x.l0) <= 52 && + bits.Len64(x.l1) <= 52 && + bits.Len64(x.l2) <= 52 && + bits.Len64(x.l3) <= 52 && + bits.Len64(x.l4) <= 52 +} + +func TestMultiplyDistributesOverAdd(t *testing.T) { + multiplyDistributesOverAdd := func(x, y, z Element) bool { + // Compute t1 = (x+y)*z + t1 := new(Element) + t1.Add(&x, &y) + t1.Multiply(t1, &z) + + // Compute t2 = x*z + y*z + t2 := new(Element) + t3 := new(Element) + t2.Multiply(&x, &z) + t3.Multiply(&y, &z) + t2.Add(t2, t3) + + return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) + } + + if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func TestMul64to128(t *testing.T) { + a := uint64(5) + b := uint64(5) + r := mul64(a, b) + if r.lo != 0x19 || r.hi != 0 { + t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi) + } + + a = uint64(18014398509481983) // 2^54 - 1 + b = uint64(18014398509481983) // 2^54 - 1 + r = mul64(a, b) + if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff { + t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi) + } + + a = uint64(1125899906842661) + b = uint64(2097155) + r = mul64(a, b) + r = addMul64(r, a, b) + r = addMul64(r, a, b) + r = addMul64(r, a, b) + r = addMul64(r, a, b) + if r.lo != 16888498990613035 || r.hi != 640 { + t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi) + } +} + +func TestSetBytesRoundTrip(t *testing.T) { + f1 := func(in [32]byte, fe Element) bool { + fe.SetBytes(in[:]) + + // Mask the most significant bit as it's ignored by SetBytes. (Now + // instead of earlier so we check the masking in SetBytes is working.) + in[len(in)-1] &= (1 << 7) - 1 + + return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe) + } + if err := quick.Check(f1, nil); err != nil { + t.Errorf("failed bytes->FE->bytes round-trip: %v", err) + } + + f2 := func(fe, r Element) bool { + r.SetBytes(fe.Bytes()) + + // Intentionally not using Equal not to go through Bytes again. + // Calling reduce because both Generate and SetBytes can produce + // non-canonical representations. + fe.reduce() + r.reduce() + return fe == r + } + if err := quick.Check(f2, nil); err != nil { + t.Errorf("failed FE->bytes->FE round-trip: %v", err) + } + + // Check some fixed vectors from dalek + type feRTTest struct { + fe Element + b []byte + } + var tests = []feRTTest{ + { + fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}, + b: []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31}, + }, + { + fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}, + b: []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122}, + }, + } + + for _, tt := range tests { + b := tt.fe.Bytes() + if !bytes.Equal(b, tt.b) || new(Element).SetBytes(tt.b).Equal(&tt.fe) != 1 { + t.Errorf("Failed fixed roundtrip: %v", tt) + } + } +} + +func swapEndianness(buf []byte) []byte { + for i := 0; i < len(buf)/2; i++ { + buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i] + } + return buf +} + +func TestBytesBigEquivalence(t *testing.T) { + f1 := func(in [32]byte, fe, fe1 Element) bool { + fe.SetBytes(in[:]) + + in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit + b := new(big.Int).SetBytes(swapEndianness(in[:])) + fe1.fromBig(b) + + if fe != fe1 { + return false + } + + buf := make([]byte, 32) // pad with zeroes + copy(buf, swapEndianness(fe1.toBig().Bytes())) + + return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1) + } + if err := quick.Check(f1, nil); err != nil { + t.Error(err) + } +} + +// fromBig sets v = n, and returns v. The bit length of n must not exceed 256. +func (v *Element) fromBig(n *big.Int) *Element { + if n.BitLen() > 32*8 { + panic("edwards25519: invalid field element input size") + } + + buf := make([]byte, 0, 32) + for _, word := range n.Bits() { + for i := 0; i < bits.UintSize; i += 8 { + if len(buf) >= cap(buf) { + break + } + buf = append(buf, byte(word)) + word >>= 8 + } + } + + return v.SetBytes(buf[:32]) +} + +func (v *Element) fromDecimal(s string) *Element { + n, ok := new(big.Int).SetString(s, 10) + if !ok { + panic("not a valid decimal: " + s) + } + return v.fromBig(n) +} + +// toBig returns v as a big.Int. +func (v *Element) toBig() *big.Int { + buf := v.Bytes() + + words := make([]big.Word, 32*8/bits.UintSize) + for n := range words { + for i := 0; i < bits.UintSize; i += 8 { + if len(buf) == 0 { + break + } + words[n] |= big.Word(buf[0]) << big.Word(i) + buf = buf[1:] + } + } + + return new(big.Int).SetBits(words) +} + +func TestDecimalConstants(t *testing.T) { + sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752" + if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 { + t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp) + } + // d is in the parent package, and we don't want to expose d or fromDecimal. + // dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555" + // if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 { + // t.Errorf("d is %v, expected %v", d, exp) + // } +} + +func TestSetBytesRoundTripEdgeCases(t *testing.T) { + // TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1, + // and between 2^255 and 2^256-1. Test both the documented SetBytes + // behavior, and that Bytes reduces them. +} + +// Tests self-consistency between Multiply and Square. +func TestConsistency(t *testing.T) { + var x Element + var x2, x2sq Element + + x = Element{1, 1, 1, 1, 1} + x2.Multiply(&x, &x) + x2sq.Square(&x) + + if x2 != x2sq { + t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) + } + + var bytes [32]byte + + _, err := io.ReadFull(rand.Reader, bytes[:]) + if err != nil { + t.Fatal(err) + } + x.SetBytes(bytes[:]) + + x2.Multiply(&x, &x) + x2sq.Square(&x) + + if x2 != x2sq { + t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) + } +} + +func TestEqual(t *testing.T) { + x := Element{1, 1, 1, 1, 1} + y := Element{5, 4, 3, 2, 1} + + eq := x.Equal(&x) + if eq != 1 { + t.Errorf("wrong about equality") + } + + eq = x.Equal(&y) + if eq != 0 { + t.Errorf("wrong about inequality") + } +} + +func TestInvert(t *testing.T) { + x := Element{1, 1, 1, 1, 1} + one := Element{1, 0, 0, 0, 0} + var xinv, r Element + + xinv.Invert(&x) + r.Multiply(&x, &xinv) + r.reduce() + + if one != r { + t.Errorf("inversion identity failed, got: %x", r) + } + + var bytes [32]byte + + _, err := io.ReadFull(rand.Reader, bytes[:]) + if err != nil { + t.Fatal(err) + } + x.SetBytes(bytes[:]) + + xinv.Invert(&x) + r.Multiply(&x, &xinv) + r.reduce() + + if one != r { + t.Errorf("random inversion identity failed, got: %x for field element %x", r, x) + } + + zero := Element{} + x.Set(&zero) + if xx := xinv.Invert(&x); xx != &xinv { + t.Errorf("inverting zero did not return the receiver") + } else if xinv.Equal(&zero) != 1 { + t.Errorf("inverting zero did not return zero") + } +} + +func TestSelectSwap(t *testing.T) { + a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676} + b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972} + + var c, d Element + + c.Select(&a, &b, 1) + d.Select(&a, &b, 0) + + if c.Equal(&a) != 1 || d.Equal(&b) != 1 { + t.Errorf("Select failed") + } + + c.Swap(&d, 0) + + if c.Equal(&a) != 1 || d.Equal(&b) != 1 { + t.Errorf("Swap failed") + } + + c.Swap(&d, 1) + + if c.Equal(&b) != 1 || d.Equal(&a) != 1 { + t.Errorf("Swap failed") + } +} + +func TestMult32(t *testing.T) { + mult32EquivalentToMul := func(x Element, y uint32) bool { + t1 := new(Element) + for i := 0; i < 100; i++ { + t1.Mult32(&x, y) + } + + ty := new(Element) + ty.l0 = uint64(y) + + t2 := new(Element) + for i := 0; i < 100; i++ { + t2.Multiply(&x, ty) + } + + return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) + } + + if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func TestSqrtRatio(t *testing.T) { + // From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4. + type test struct { + u, v string + wasSquare int + r string + } + var tests = []test{ + // If u is 0, the function is defined to return (0, TRUE), even if v + // is zero. Note that where used in this package, the denominator v + // is never zero. + { + "0000000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + 1, "0000000000000000000000000000000000000000000000000000000000000000", + }, + // 0/1 == 0² + { + "0000000000000000000000000000000000000000000000000000000000000000", + "0100000000000000000000000000000000000000000000000000000000000000", + 1, "0000000000000000000000000000000000000000000000000000000000000000", + }, + // If u is non-zero and v is zero, defined to return (0, FALSE). + { + "0100000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + 0, "0000000000000000000000000000000000000000000000000000000000000000", + }, + // 2/1 is not square in this field. + { + "0200000000000000000000000000000000000000000000000000000000000000", + "0100000000000000000000000000000000000000000000000000000000000000", + 0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54", + }, + // 4/1 == 2² + { + "0400000000000000000000000000000000000000000000000000000000000000", + "0100000000000000000000000000000000000000000000000000000000000000", + 1, "0200000000000000000000000000000000000000000000000000000000000000", + }, + // 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem + { + "0100000000000000000000000000000000000000000000000000000000000000", + "0400000000000000000000000000000000000000000000000000000000000000", + 1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f", + }, + } + + for i, tt := range tests { + u := new(Element).SetBytes(decodeHex(tt.u)) + v := new(Element).SetBytes(decodeHex(tt.v)) + want := new(Element).SetBytes(decodeHex(tt.r)) + got, wasSquare := new(Element).SqrtRatio(u, v) + if got.Equal(want) == 0 || wasSquare != tt.wasSquare { + t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare) + } + } +} + +func TestCarryPropagate(t *testing.T) { + asmLikeGeneric := func(a [5]uint64) bool { + t1 := &Element{a[0], a[1], a[2], a[3], a[4]} + t2 := &Element{a[0], a[1], a[2], a[3], a[4]} + + t1.carryPropagate() + t2.carryPropagateGeneric() + + if *t1 != *t2 { + t.Logf("got: %#v,\nexpected: %#v", t1, t2) + } + + return *t1 == *t2 && isInBounds(t2) + } + + if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { + t.Error(err) + } + + if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) { + t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}") + } +} + +func TestFeSquare(t *testing.T) { + asmLikeGeneric := func(a Element) bool { + t1 := a + t2 := a + + feSquareGeneric(&t1, &t1) + feSquare(&t2, &t2) + + if t1 != t2 { + t.Logf("got: %#v,\nexpected: %#v", t1, t2) + } + + return t1 == t2 && isInBounds(&t2) + } + + if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func TestFeMul(t *testing.T) { + asmLikeGeneric := func(a, b Element) bool { + a1 := a + a2 := a + b1 := b + b2 := b + + feMulGeneric(&a1, &a1, &b1) + feMul(&a2, &a2, &b2) + + if a1 != a2 || b1 != b2 { + t.Logf("got: %#v,\nexpected: %#v", a1, a2) + t.Logf("got: %#v,\nexpected: %#v", b1, b2) + } + + return a1 == a2 && isInBounds(&a2) && + b1 == b2 && isInBounds(&b2) + } + + if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func decodeHex(s string) []byte { + b, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + return b +} diff --git a/openpgp/keys.go b/openpgp/keys.go index da3809df3..34bffb97e 100644 --- a/openpgp/keys.go +++ b/openpgp/keys.go @@ -381,7 +381,7 @@ func (el EntityList) KeysByIdUsage(id uint64, requiredUsage byte) (keys []Key) { func (el EntityList) DecryptionKeys() (keys []Key) { for _, e := range el { for _, subKey := range e.Subkeys { - if subKey.PrivateKey != nil && subKey.Sig.FlagsValid && (subKey.Sig.FlagEncryptStorage || subKey.Sig.FlagEncryptCommunications) { + if subKey.PrivateKey != nil && subKey.Sig.FlagsValid && (subKey.Sig.FlagEncryptStorage || subKey.Sig.FlagEncryptCommunications || subKey.Sig.FlagForward) { keys = append(keys, Key{e, subKey.PublicKey, subKey.PrivateKey, subKey.Sig, subKey.Revocations}) } } @@ -772,7 +772,7 @@ func (e *Entity) serializePrivate(w io.Writer, config *packet.Config, reSign boo // signatures from other entities. No private key material will be output. func (e *Entity) Serialize(w io.Writer) error { if e.PrimaryKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoHMAC || - e.PrimaryKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoAEAD { + e.PrimaryKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoAEAD { return errors.InvalidArgumentError("Can't serialize symmetric primary key") } err := e.PrimaryKey.Serialize(w) @@ -804,14 +804,16 @@ func (e *Entity) Serialize(w io.Writer) error { } } for _, subkey := range e.Subkeys { - // The types of keys below are only useful as private keys. Thus, the // public key packets contain no meaningful information and do not need // to be serialized. + // Prevent public key export for forwarding keys, see forwarding section 4.1. if subkey.PublicKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoHMAC || - subkey.PublicKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoAEAD { + subkey.PublicKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoAEAD || + subkey.Sig.FlagForward { continue } + err = subkey.PublicKey.Serialize(w) if err != nil { return err diff --git a/openpgp/packet/encrypted_key.go b/openpgp/packet/encrypted_key.go index 4ef72a02b..5cbc966d4 100644 --- a/openpgp/packet/encrypted_key.go +++ b/openpgp/packet/encrypted_key.go @@ -481,6 +481,36 @@ func SerializeEncryptedKeyWithHiddenOption(w io.Writer, pub *PublicKey, cipherFu return SerializeEncryptedKeyAEADwithHiddenOption(w, pub, cipherFunc, config.AEAD() != nil, key, hidden, config) } +func (e *EncryptedKey) ProxyTransform(instance ForwardingInstance) (transformed *EncryptedKey, err error) { + if e.Algo != PubKeyAlgoECDH { + return nil, errors.InvalidArgumentError("invalid PKESK") + } + + if e.KeyId != 0 && e.KeyId != instance.GetForwarderKeyId() { + return nil, errors.InvalidArgumentError("invalid key id in PKESK") + } + + ephemeral := e.encryptedMPI1.Bytes() + transformedEphemeral, err := ecdh.ProxyTransform(ephemeral, instance.ProxyParameter) + if err != nil { + return nil, err + } + + wrappedKey := e.encryptedMPI2.Bytes() + copiedWrappedKey := make([]byte, len(wrappedKey)) + copy(copiedWrappedKey, wrappedKey) + + transformed = &EncryptedKey{ + Version: e.Version, + KeyId: instance.getForwardeeKeyIdOrZero(e.KeyId), + Algo: e.Algo, + encryptedMPI1: encoding.NewMPI(transformedEphemeral), + encryptedMPI2: encoding.NewOID(copiedWrappedKey), + } + + return transformed, nil +} + func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header []byte, pub *rsa.PublicKey, keyBlock []byte) error { cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock) if err != nil { diff --git a/openpgp/packet/forwarding.go b/openpgp/packet/forwarding.go new file mode 100644 index 000000000..f16a2fbdc --- /dev/null +++ b/openpgp/packet/forwarding.go @@ -0,0 +1,36 @@ +package packet + +import "encoding/binary" + +// ForwardingInstance represents a single forwarding instance (mapping IDs to a Proxy Param) +type ForwardingInstance struct { + KeyVersion int + ForwarderFingerprint []byte + ForwardeeFingerprint []byte + ProxyParameter []byte +} + +func (f *ForwardingInstance) GetForwarderKeyId() uint64 { + return computeForwardingKeyId(f.ForwarderFingerprint, f.KeyVersion) +} + +func (f *ForwardingInstance) GetForwardeeKeyId() uint64 { + return computeForwardingKeyId(f.ForwardeeFingerprint, f.KeyVersion) +} + +func (f *ForwardingInstance) getForwardeeKeyIdOrZero(originalKeyId uint64) uint64 { + if originalKeyId == 0 { + return 0 + } + + return f.GetForwardeeKeyId() +} + +func computeForwardingKeyId(fingerprint []byte, version int) uint64 { + switch version { + case 4: + return binary.BigEndian.Uint64(fingerprint[12:20]) + default: + panic("invalid pgp key version") + } +} diff --git a/openpgp/packet/public_key.go b/openpgp/packet/public_key.go index 231acd503..794cebec3 100644 --- a/openpgp/packet/public_key.go +++ b/openpgp/packet/public_key.go @@ -5,6 +5,7 @@ package packet import ( + "bytes" "crypto" "crypto/dsa" "crypto/rsa" @@ -80,6 +81,26 @@ func (pk *PublicKey) UpgradeToV6() error { return pk.checkV6Compatibility() } +// ReplaceKDF replaces the KDF instance, and updates all necessary fields. +func (pk *PublicKey) ReplaceKDF(kdf ecdh.KDF) error { + ecdhKey, ok := pk.PublicKey.(*ecdh.PublicKey) + if !ok { + return goerrors.New("wrong forwarding sub key generation") + } + + ecdhKey.KDF = kdf + byteBuffer := new(bytes.Buffer) + err := kdf.Serialize(byteBuffer) + if err != nil { + return err + } + + pk.kdf = encoding.NewOID(byteBuffer.Bytes()[1:]) + pk.setFingerprintAndKeyId() + + return nil +} + // signingKey provides a convenient abstraction over signature verification // for v3 and v4 public keys. type signingKey interface { @@ -556,11 +577,13 @@ func (pk *PublicKey) parseECDH(r io.Reader) (err error) { return errors.UnsupportedError(fmt.Sprintf("unsupported oid: %x", pk.oid)) } - if kdfLen := len(pk.kdf.Bytes()); kdfLen < 3 { + kdfLen := len(pk.kdf.Bytes()) + if kdfLen < 3 { return errors.UnsupportedError("unsupported ECDH KDF length: " + strconv.Itoa(kdfLen)) } - if reserved := pk.kdf.Bytes()[0]; reserved != 0x01 { - return errors.UnsupportedError("unsupported KDF reserved field: " + strconv.Itoa(int(reserved))) + kdfVersion := int(pk.kdf.Bytes()[0]) + if kdfVersion != ecdh.KDFVersion1 && kdfVersion != ecdh.KDFVersionForwarding { + return errors.UnsupportedError("unsupported ECDH KDF version: " + strconv.Itoa(kdfVersion)) } kdfHash, ok := algorithm.HashById[pk.kdf.Bytes()[1]] if !ok { @@ -571,10 +594,23 @@ func (pk *PublicKey) parseECDH(r io.Reader) (err error) { return errors.UnsupportedError("unsupported ECDH KDF cipher: " + strconv.Itoa(int(pk.kdf.Bytes()[2]))) } - ecdhKey := ecdh.NewPublicKey(c, kdfHash, kdfCipher) + kdf := ecdh.KDF{ + Version: kdfVersion, + Hash: kdfHash, + Cipher: kdfCipher, + } + + if kdfVersion == ecdh.KDFVersionForwarding { + if pk.Version != 4 || kdfLen != 23 { + return errors.UnsupportedError("unsupported ECDH KDF v2 length: " + strconv.Itoa(kdfLen)) + } + + kdf.ReplacementFingerprint = pk.kdf.Bytes()[3:23] + } + + ecdhKey := ecdh.NewPublicKey(c, kdf) err = ecdhKey.UnmarshalPoint(pk.p.Bytes()) pk.PublicKey = ecdhKey - return } @@ -1190,6 +1226,13 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) error } } + // Keys having this flag MUST have the forwarding KDF parameters version 2 defined in Section 5.1. + if sig.FlagForward && (signed.PubKeyAlgo != PubKeyAlgoECDH || + signed.kdf == nil || + signed.kdf.Bytes()[0] != ecdh.KDFVersionForwarding) { + return errors.StructuralError("forwarding key with wrong ecdh kdf version") + } + return nil } diff --git a/openpgp/packet/signature.go b/openpgp/packet/signature.go index 9742381e5..fe36cc4f9 100644 --- a/openpgp/packet/signature.go +++ b/openpgp/packet/signature.go @@ -38,7 +38,7 @@ const ( KeyFlagEncryptStorage KeyFlagSplitKey KeyFlagAuthenticate - _ + KeyFlagForward KeyFlagGroupKey ) @@ -134,8 +134,9 @@ type Signature struct { // FlagsValid is set if any flags were given. See RFC 9580, section // 5.2.3.29 for details. - FlagsValid bool - FlagCertify, FlagSign, FlagEncryptCommunications, FlagEncryptStorage, FlagSplitKey, FlagAuthenticate, FlagGroupKey bool + FlagsValid bool + FlagCertify, FlagSign, FlagEncryptCommunications, FlagEncryptStorage bool + FlagSplitKey, FlagAuthenticate, FlagForward, FlagGroupKey bool // RevocationReason is set if this signature has been revoked. // See RFC 9580, section 5.2.3.31 for details. @@ -617,6 +618,9 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r if subpacket[0]&KeyFlagAuthenticate != 0 { sig.FlagAuthenticate = true } + if subpacket[0]&KeyFlagForward != 0 { + sig.FlagForward = true + } if subpacket[0]&KeyFlagGroupKey != 0 { sig.FlagGroupKey = true } @@ -1421,6 +1425,9 @@ func (sig *Signature) buildSubpackets(issuer PublicKey) (subpackets []outputSubp if sig.FlagAuthenticate { flags |= KeyFlagAuthenticate } + if sig.FlagForward { + flags |= KeyFlagForward + } if sig.FlagGroupKey { flags |= KeyFlagGroupKey } diff --git a/openpgp/v2/forwarding.go b/openpgp/v2/forwarding.go new file mode 100644 index 000000000..1306c510c --- /dev/null +++ b/openpgp/v2/forwarding.go @@ -0,0 +1,159 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package v2 + +import ( + goerrors "errors" + + "github.com/ProtonMail/go-crypto/openpgp/ecdh" + "github.com/ProtonMail/go-crypto/openpgp/errors" + "github.com/ProtonMail/go-crypto/openpgp/packet" +) + +// NewForwardingEntity generates a new forwardee key and derives the proxy parameters from the entity e. +// If strict, it will return an error if encryption-capable non-revoked subkeys with a wrong algorithm are found, +// instead of ignoring them +func (e *Entity) NewForwardingEntity( + name, comment, email string, config *packet.Config, strict bool, +) ( + forwardeeKey *Entity, instances []packet.ForwardingInstance, err error, +) { + if e.PrimaryKey.Version != 4 { + return nil, nil, errors.InvalidArgumentError("unsupported key version") + } + + now := config.Now() + + if _, err = e.VerifyPrimaryKey(now, config); err != nil { + return nil, nil, err + } + + // Generate a new Primary key for the forwardee + config.Algorithm = packet.PubKeyAlgoEdDSA + config.Curve = packet.Curve25519 + + forwardeePrimaryPrivRaw, err := newSigner(config) + if err != nil { + return nil, nil, err + } + + primary := packet.NewSignerPrivateKey(now, forwardeePrimaryPrivRaw) + + forwardeeKey = &Entity{ + PrimaryKey: &primary.PublicKey, + PrivateKey: primary, + Identities: make(map[string]*Identity), + Subkeys: []Subkey{}, + } + + keyProperties := selectKeyProperties(now, config, primary) + err = forwardeeKey.addUserId(userIdData{name, comment, email}, config, keyProperties) + if err != nil { + return nil, nil, err + } + + // Init empty instances + instances = []packet.ForwardingInstance{} + + // Handle all forwarder subkeys + for _, forwarderSubKey := range e.Subkeys { + // Filter flags + if !forwarderSubKey.PublicKey.PubKeyAlgo.CanEncrypt() { + continue + } + + forwarderSubKeySelfSig, err := forwarderSubKey.Verify(now, config) + // Filter expiration & revokal + if err != nil { + continue + } + + if forwarderSubKey.PublicKey.PubKeyAlgo != packet.PubKeyAlgoECDH { + if strict { + return nil, nil, errors.InvalidArgumentError("encryption subkey is not algorithm 18 (ECDH)") + } else { + continue + } + } + + forwarderEcdhKey, ok := forwarderSubKey.PrivateKey.PrivateKey.(*ecdh.PrivateKey) + if !ok { + return nil, nil, errors.InvalidArgumentError("malformed key") + } + + err = forwardeeKey.addEncryptionSubkey(config, now, 0) + if err != nil { + return nil, nil, err + } + + forwardeeSubKey := forwardeeKey.Subkeys[len(forwardeeKey.Subkeys)-1] + forwardeeSubKeySelfSig := forwardeeSubKey.Bindings[0].Packet + + forwardeeEcdhKey, ok := forwardeeSubKey.PrivateKey.PrivateKey.(*ecdh.PrivateKey) + if !ok { + return nil, nil, goerrors.New("wrong forwarding sub key generation") + } + + instance := packet.ForwardingInstance{ + KeyVersion: 4, + ForwarderFingerprint: forwarderSubKey.PublicKey.Fingerprint, + } + + instance.ProxyParameter, err = ecdh.DeriveProxyParam(forwarderEcdhKey, forwardeeEcdhKey) + if err != nil { + return nil, nil, err + } + + kdf := ecdh.KDF{ + Version: ecdh.KDFVersionForwarding, + Hash: forwarderEcdhKey.KDF.Hash, + Cipher: forwarderEcdhKey.KDF.Cipher, + } + + // If deriving a forwarding key from a forwarding key + if forwarderSubKeySelfSig.FlagForward { + if forwarderEcdhKey.KDF.Version != ecdh.KDFVersionForwarding { + return nil, nil, goerrors.New("malformed forwarder key") + } + kdf.ReplacementFingerprint = forwarderEcdhKey.KDF.ReplacementFingerprint + } else { + kdf.ReplacementFingerprint = forwarderSubKey.PublicKey.Fingerprint + } + + err = forwardeeSubKey.PublicKey.ReplaceKDF(kdf) + if err != nil { + return nil, nil, err + } + + // Extract fingerprint after changing the KDF + instance.ForwardeeFingerprint = forwardeeSubKey.PublicKey.Fingerprint + + // 0x04 - This key may be used to encrypt communications. + forwardeeSubKeySelfSig.FlagEncryptCommunications = false + + // 0x08 - This key may be used to encrypt storage. + forwardeeSubKeySelfSig.FlagEncryptStorage = false + + // 0x10 - The private component of this key may have been split by a secret-sharing mechanism. + forwardeeSubKeySelfSig.FlagSplitKey = true + + // 0x40 - This key may be used for forwarded communications. + forwardeeSubKeySelfSig.FlagForward = true + + err = forwardeeSubKeySelfSig.SignKey(forwardeeSubKey.PublicKey, forwardeeKey.PrivateKey, config) + if err != nil { + return nil, nil, err + } + + // Append each valid instance to the list + instances = append(instances, instance) + } + + if len(instances) == 0 { + return nil, nil, errors.InvalidArgumentError("no valid subkey found") + } + + return forwardeeKey, instances, nil +} diff --git a/openpgp/v2/forwarding_test.go b/openpgp/v2/forwarding_test.go new file mode 100644 index 000000000..21c8be0e8 --- /dev/null +++ b/openpgp/v2/forwarding_test.go @@ -0,0 +1,223 @@ +package v2 + +import ( + "bytes" + "crypto/rand" + goerrors "errors" + "io" + "io/ioutil" + "strings" + "testing" + + "github.com/ProtonMail/go-crypto/openpgp/armor" + "github.com/ProtonMail/go-crypto/openpgp/packet" +) + +const forwardeeKey = `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xVgEZQRXoxYJKwYBBAHaRw8BAQdAhxdzZ8ZP1M4UcauXSGbts38KhhAZxHNRcChs +9H7danMAAQC4tHykQmFpnlvhLYJDDc4MJm68mUB9qUls34GgKkqKNw6FzRtjaGFy +bGVzIDxjaGFybGVzQHByb3Rvbi5tZT7CiwQTFggAPQUCZQRXowkQizX+kwlYIwMW +IQTYm4qmQoyzTnG0eZKLNf6TCVgjAwIbAwIeAQIZAQILBwIVCAIWAAMnBwIAAMsQ +AQD9UHMIU418Z10UQrymhbjkGq/PHCytaaneaq5oycpN/QD/UiK3aA4+HxWhX/F2 +VrvEKL5a2xyd1AKKQ2DInF3xUg3HcQRlBFejEgorBgEEAZdVAQUBAQdAep7x8ncL +ShzEgKL6h9MAJbgX2z3BBgSLeAdg/rczKngX/woJjSg9O4DzqQOtAvdhYkDoOCNf +QgUAAP9OMqK0IwNmshCtktDy1/RTeyPKT8ItHDFAZ1ReKMA5CA63wngEGBYIACoF +AmUEV6MJEIs1/pMJWCMDFiEE2JuKpkKMs05xtHmSizX+kwlYIwMCG1wAAC5EAP9s +AbYBf9NGv1NxJvU0n0K++k3UIGkw9xgGJa3VFHFKvwEAx0DZpTVpCkJmiOFAOcfu +cSvjlMyQwsC/hAAzQpcqvwE= +=8LJg +-----END PGP PRIVATE KEY BLOCK-----` + +const forwardedMessage = `-----BEGIN PGP MESSAGE----- + +wV4DKsXbtIU9/JMSAQdA/6+foCjeUhS7Xto3fimUi6pfMQ/Ft3caHkK/1i767isw +NvG8xRbjQ0sAE1IZVGE1MBcVhCIbHhqp0h2J479Zmfn/iP7hfomYxrkJ/6UMnlEo +0kABKyyfO3QVAzBBNeq6hH27uqzwLgjWVrpgY7dmWPv0goSSaqHUda0lm+8JNUuF +wssOJTwrSwQrX3ezy5D/h/E6 +=okS+ +-----END PGP MESSAGE-----` + +const forwardedPlaintext = "Message for Bob" + +func TestForwardingStatic(t *testing.T) { + charlesKey, err := ReadArmoredKeyRing(bytes.NewBufferString(forwardeeKey)) + if err != nil { + t.Error(err) + return + } + + ciphertext, err := armor.Decode(strings.NewReader(forwardedMessage)) + if err != nil { + t.Error(err) + return + } + + m, err := ReadMessage(ciphertext.Body, charlesKey, nil, nil) + if err != nil { + t.Fatal(err) + } + + dec, err := ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, []byte(forwardedPlaintext)) { + t.Fatal("forwarded decrypted does not match original") + } +} + +func TestForwardingFull(t *testing.T) { + keyConfig := &packet.Config{ + Algorithm: packet.PubKeyAlgoEdDSA, + Curve: packet.Curve25519, + } + + plaintext := make([]byte, 1024) + rand.Read(plaintext) + + bobEntity, err := NewEntity("bob", "", "bob@proton.me", keyConfig) + if err != nil { + t.Fatal(err) + } + + charlesEntity, instances, err := bobEntity.NewForwardingEntity("charles", "", "charles@proton.me", keyConfig, true) + if err != nil { + t.Fatal(err) + } + + charlesEntity = serializeAndParseForwardeeKey(t, charlesEntity) + + if len(instances) != 1 { + t.Fatalf("invalid number of instances, expected 1 got %d", len(instances)) + } + + if !bytes.Equal(instances[0].ForwarderFingerprint, bobEntity.Subkeys[0].PublicKey.Fingerprint) { + t.Fatalf("invalid forwarder key ID, expected: %x, got: %x", bobEntity.Subkeys[0].PublicKey.Fingerprint, instances[0].ForwarderFingerprint) + } + + if !bytes.Equal(instances[0].ForwardeeFingerprint, charlesEntity.Subkeys[0].PublicKey.Fingerprint) { + t.Fatalf("invalid forwardee key ID, expected: %x, got: %x", charlesEntity.Subkeys[0].PublicKey.Fingerprint, instances[0].ForwardeeFingerprint) + } + + // Encrypt message + buf := bytes.NewBuffer(nil) + w, err := Encrypt(buf, []*Entity{bobEntity}, nil, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + + _, err = w.Write(plaintext) + if err != nil { + t.Fatal(err) + } + + err = w.Close() + if err != nil { + t.Fatal(err) + } + + encrypted := buf.Bytes() + + // Decrypt message for Bob + m, err := ReadMessage(bytes.NewBuffer(encrypted), EntityList([]*Entity{bobEntity}), nil, nil) + if err != nil { + t.Fatal(err) + } + dec, err := ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("decrypted does not match original") + } + + // Forward message + transformed := transformTestMessage(t, encrypted, instances[0]) + + // Decrypt forwarded message for Charles + m, err = ReadMessage(bytes.NewBuffer(transformed), EntityList([]*Entity{charlesEntity}), nil /* no prompt */, nil) + if err != nil { + t.Fatal(err) + } + + dec, err = ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("forwarded decrypted does not match original") + } + + // Setup further forwarding + danielEntity, secondForwardInstances, err := charlesEntity.NewForwardingEntity("Daniel", "", "daniel@proton.me", keyConfig, true) + if err != nil { + t.Fatal(err) + } + + danielEntity = serializeAndParseForwardeeKey(t, danielEntity) + + secondTransformed := transformTestMessage(t, transformed, secondForwardInstances[0]) + + // Decrypt forwarded message for Charles + m, err = ReadMessage(bytes.NewBuffer(secondTransformed), EntityList([]*Entity{danielEntity}), nil /* no prompt */, nil) + if err != nil { + t.Fatal(err) + } + + dec, err = ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("forwarded decrypted does not match original") + } +} + +func transformTestMessage(t *testing.T, encrypted []byte, instance packet.ForwardingInstance) []byte { + bytesReader := bytes.NewReader(encrypted) + packets := packet.NewReader(bytesReader) + splitPoint := int64(0) + transformedEncryptedKey := bytes.NewBuffer(nil) + +Loop: + for { + p, err := packets.Next() + if goerrors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("error in parsing message: %s", err) + } + switch p := p.(type) { + case *packet.EncryptedKey: + tp, err := p.ProxyTransform(instance) + if err != nil { + t.Fatalf("error transforming PKESK: %s", err) + } + + splitPoint = bytesReader.Size() - int64(bytesReader.Len()) + + err = tp.Serialize(transformedEncryptedKey) + if err != nil { + t.Fatalf("error serializing transformed PKESK: %s", err) + } + break Loop + } + } + + transformed := transformedEncryptedKey.Bytes() + transformed = append(transformed, encrypted[splitPoint:]...) + + return transformed +} + +func serializeAndParseForwardeeKey(t *testing.T, key *Entity) *Entity { + serializedEntity := bytes.NewBuffer(nil) + err := key.SerializePrivateWithoutSigning(serializedEntity, nil) + if err != nil { + t.Fatalf("Error in serializing forwardee key: %s", err) + } + el, err := ReadKeyRing(serializedEntity) + if err != nil { + t.Fatalf("Error in reading forwardee key: %s", err) + } + + if len(el) != 1 { + t.Fatalf("Wrong number of entities in parsing, expected 1, got %d", len(el)) + } + + return el[0] +} diff --git a/openpgp/v2/keys.go b/openpgp/v2/keys.go index 954b023dc..063713f1e 100644 --- a/openpgp/v2/keys.go +++ b/openpgp/v2/keys.go @@ -166,12 +166,12 @@ func (e *Entity) DecryptionKeys(id uint64, date time.Time, config *packet.Config for _, subkey := range e.Subkeys { subkeySelfSig, err := subkey.LatestValidBindingSignature(date, config) if err == nil && - isValidEncryptionKey(subkeySelfSig, subkey.PublicKey.PubKeyAlgo) && + isValidDecryptionKey(subkeySelfSig, subkey.PublicKey.PubKeyAlgo) && (id == 0 || subkey.PublicKey.KeyId == id) { keys = append(keys, Key{subkey.Primary, primarySelfSignature, subkey.PublicKey, subkey.PrivateKey, subkeySelfSig}) } } - if isValidEncryptionKey(primarySelfSignature, e.PrimaryKey.PubKeyAlgo) { + if isValidDecryptionKey(primarySelfSignature, e.PrimaryKey.PubKeyAlgo) { keys = append(keys, Key{e, primarySelfSignature, e.PrimaryKey, e.PrivateKey, primarySelfSignature}) } return @@ -641,8 +641,11 @@ func (e *Entity) Serialize(w io.Writer) error { // The types of keys below are only useful as private keys. Thus, the // public key packets contain no meaningful information and do not need // to be serialized. + // Prevent public key export for forwarding keys, see forwarding section 4.1. + subKeySelfSig, err := subkey.LatestValidBindingSignature(time.Time{}, nil) if subkey.PublicKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoHMAC || - subkey.PublicKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoAEAD { + subkey.PublicKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoAEAD || + (err == nil && subKeySelfSig.FlagForward) { continue } if err := subkey.Serialize(w, false); err != nil { @@ -801,5 +804,11 @@ func isValidCertificationKey(signature *packet.Signature, algo packet.PublicKeyA func isValidEncryptionKey(signature *packet.Signature, algo packet.PublicKeyAlgorithm) bool { return algo.CanEncrypt() && signature.FlagsValid && - (signature.FlagEncryptCommunications || signature.FlagEncryptStorage) + (signature.FlagEncryptCommunications || signature.FlagForward || signature.FlagEncryptStorage) +} + +func isValidDecryptionKey(signature *packet.Signature, algo packet.PublicKeyAlgorithm) bool { + return algo.CanEncrypt() && + signature.FlagsValid && + (signature.FlagEncryptCommunications || signature.FlagForward || signature.FlagEncryptStorage) }