Skip to content

Commit

Permalink
Relax kem constaint from kem.AuthScheme to kem.Scheme.
Browse files Browse the repository at this point in the history
  • Loading branch information
armfazh committed Jan 21, 2025
1 parent 964fefa commit 9340445
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 79 deletions.
4 changes: 2 additions & 2 deletions hpke/algs.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (k KEM) IsValid() bool {

// Scheme returns an instance of a KEM that supports authentication. Panics if
// the KEM identifier is invalid.
func (k KEM) Scheme() kem.AuthScheme {
func (k KEM) Scheme() kem.Scheme {
switch k {
case KEM_P256_HKDF_SHA256:
return dhkemp256hkdfsha256
Expand Down Expand Up @@ -283,6 +283,6 @@ func init() {
hybridkemX25519Kyber768.kemA = dhkemx25519hkdfsha256
hybridkemX25519Kyber768.kemB = kyber768.Scheme()

kemXwing.kem = xwing.Scheme()
kemXwing.Scheme = xwing.Scheme()
kemXwing.name = "HPKE_KEM_XWING"
}
59 changes: 4 additions & 55 deletions hpke/genericnoauthkem.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,70 +9,19 @@ import (

// genericNoAuthKEM wraps a generic KEM (kem.Scheme) to be used as a HPKE KEM.
type genericNoAuthKEM struct {
kem kem.Scheme
kem.Scheme
name string
}

func (h genericNoAuthKEM) PrivateKeySize() int { return h.kem.PrivateKeySize() }
func (h genericNoAuthKEM) SeedSize() int { return h.kem.SeedSize() }
func (h genericNoAuthKEM) CiphertextSize() int { return h.kem.CiphertextSize() }
func (h genericNoAuthKEM) PublicKeySize() int { return h.kem.PublicKeySize() }
func (h genericNoAuthKEM) EncapsulationSeedSize() int { return h.kem.EncapsulationSeedSize() }
func (h genericNoAuthKEM) SharedKeySize() int { return h.kem.SharedKeySize() }
func (h genericNoAuthKEM) Name() string { return h.name }

func (h genericNoAuthKEM) AuthDecapsulate(skR kem.PrivateKey,
ct []byte,
pkS kem.PublicKey,
) ([]byte, error) {
panic("AuthDecapsulate is not supported for this KEM")
}

func (h genericNoAuthKEM) AuthEncapsulate(pkr kem.PublicKey, sks kem.PrivateKey) (
ct []byte, ss []byte, err error,
) {
panic("AuthEncapsulate is not supported for this KEM")
}

func (h genericNoAuthKEM) AuthEncapsulateDeterministically(pkr kem.PublicKey, sks kem.PrivateKey, seed []byte) (ct, ss []byte, err error) {
panic("AuthEncapsulateDeterministically is not supported for this KEM")
}

func (h genericNoAuthKEM) Encapsulate(pkr kem.PublicKey) (
ct []byte, ss []byte, err error,
) {
return h.kem.Encapsulate(pkr)
}

func (h genericNoAuthKEM) Decapsulate(skr kem.PrivateKey, ct []byte) ([]byte, error) {
return h.kem.Decapsulate(skr, ct)
}

func (h genericNoAuthKEM) EncapsulateDeterministically(
pkr kem.PublicKey, seed []byte,
) (ct, ss []byte, err error) {
return h.kem.EncapsulateDeterministically(pkr, seed)
}
func (h genericNoAuthKEM) Name() string { return h.name }

// HPKE requires DeriveKeyPair() to take any seed larger than the private key
// size, whereas typical KEMs expect a specific seed size. We'll just use
// SHAKE256 to hash it to the right size as in X-Wing.
func (h genericNoAuthKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
seed2 := make([]byte, h.kem.SeedSize())
seed2 := make([]byte, h.Scheme.SeedSize())
hh := sha3.NewShake256()
_, _ = hh.Write(seed)
_, _ = hh.Read(seed2)
return h.kem.DeriveKeyPair(seed2)
}

func (h genericNoAuthKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
return h.kem.GenerateKeyPair()
}

func (h genericNoAuthKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
return h.kem.UnmarshalBinaryPrivateKey(data)
}

func (h genericNoAuthKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
return h.kem.UnmarshalBinaryPublicKey(data)
return h.Scheme.DeriveKeyPair(seed2)
}
15 changes: 13 additions & 2 deletions hpke/hpke.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,12 @@ func (s *Sender) allSetup(rnd io.Reader) ([]byte, Sealer, error) {
case modeBase, modePSK:
enc, ss, err = scheme.EncapsulateDeterministically(s.pkR, seed)
case modeAuth, modeAuthPSK:
enc, ss, err = scheme.AuthEncapsulateDeterministically(s.pkR, s.skS, seed)
authScheme, ok := scheme.(kem.AuthScheme)
if !ok {
return nil, nil, ErrInvalidAuthKEM
}

enc, ss, err = authScheme.AuthEncapsulateDeterministically(s.pkR, s.skS, seed)
}
if err != nil {
return nil, nil, err
Expand All @@ -246,7 +251,12 @@ func (r *Receiver) allSetup() (Opener, error) {
case modeBase, modePSK:
ss, err = scheme.Decapsulate(r.skR, r.enc)
case modeAuth, modeAuthPSK:
ss, err = scheme.AuthDecapsulate(r.skR, r.enc, r.pkS)
authScheme, ok := scheme.(kem.AuthScheme)
if !ok {
return nil, ErrInvalidAuthKEM
}

ss, err = authScheme.AuthDecapsulate(r.skR, r.enc, r.pkS)
}
if err != nil {
return nil, err
Expand All @@ -263,6 +273,7 @@ var (
ErrInvalidHPKESuite = errors.New("hpke: invalid HPKE suite")
ErrInvalidKDF = errors.New("hpke: invalid KDF identifier")
ErrInvalidKEM = errors.New("hpke: invalid KEM identifier")
ErrInvalidAuthKEM = errors.New("hpke: KEM does not support Auth mode")
ErrInvalidAEAD = errors.New("hpke: invalid AEAD identifier")
ErrInvalidKEMPublicKey = errors.New("hpke: invalid KEM public key")
ErrInvalidKEMPrivateKey = errors.New("hpke: invalid KEM private key")
Expand Down
38 changes: 18 additions & 20 deletions kem/xwing/scheme.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,21 @@ import (
// generic KEM API.

// Returns the generic KEM interface for X-Wing PQ/T hybrid KEM.
func Scheme() kem.Scheme { return &xwing }
func Scheme() kem.Scheme { return scheme{} }

type scheme struct{}

var xwing scheme

func (*scheme) Name() string { return "X-Wing" }
func (*scheme) PublicKeySize() int { return PublicKeySize }
func (*scheme) PrivateKeySize() int { return PrivateKeySize }
func (*scheme) SeedSize() int { return SeedSize }
func (*scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize }
func (*scheme) SharedKeySize() int { return SharedKeySize }
func (*scheme) CiphertextSize() int { return CiphertextSize }
func (*PrivateKey) Scheme() kem.Scheme { return &xwing }
func (*PublicKey) Scheme() kem.Scheme { return &xwing }

func (sch *scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) {
func (scheme) Name() string { return "X-Wing" }
func (scheme) PublicKeySize() int { return PublicKeySize }
func (scheme) PrivateKeySize() int { return PrivateKeySize }
func (scheme) SeedSize() int { return SeedSize }
func (scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize }
func (scheme) SharedKeySize() int { return SharedKeySize }
func (scheme) CiphertextSize() int { return CiphertextSize }
func (*PrivateKey) Scheme() kem.Scheme { return scheme{} }
func (*PublicKey) Scheme() kem.Scheme { return scheme{} }

func (sch scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) {
var seed [EncapsulationSeedSize]byte
_, err = cryptoRand.Read(seed[:])
if err != nil {
Expand All @@ -38,7 +36,7 @@ func (sch *scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) {
return sch.EncapsulateDeterministically(pk, seed[:])
}

func (sch *scheme) EncapsulateDeterministically(
func (scheme) EncapsulateDeterministically(
pk kem.PublicKey, seed []byte,
) ([]byte, []byte, error) {
if len(seed) != EncapsulationSeedSize {
Expand All @@ -56,7 +54,7 @@ func (sch *scheme) EncapsulateDeterministically(
return ct[:], ss[:], nil
}

func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) {
func (scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) {
var pk PublicKey
if len(buf) != PublicKeySize {
return nil, kem.ErrPubKeySize
Expand All @@ -68,7 +66,7 @@ func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) {
return &pk, nil
}

func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) {
func (scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) {
var sk PrivateKey
if len(buf) != PrivateKeySize {
return nil, kem.ErrPrivKeySize
Expand Down Expand Up @@ -114,17 +112,17 @@ func (pk *PublicKey) MarshalBinary() ([]byte, error) {
return ret[:], nil
}

func (*scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
func (scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
sk, pk := DeriveKeyPair(seed)
return pk, sk
}

func (sch *scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
func (scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
sk, pk, err := GenerateKeyPair(nil)
return pk, sk, err
}

func (*scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) {
func (scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) {
if len(ct) != CiphertextSize {
return nil, kem.ErrCiphertextSize
}
Expand Down

0 comments on commit 9340445

Please sign in to comment.