Skip to content

Commit

Permalink
crypto/tls: add cipher suites TLS_ECDHE_PSK
Browse files Browse the repository at this point in the history
  • Loading branch information
joseph authored and joseph committed Jun 29, 2022
1 parent 160414c commit e70353c
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 17 deletions.
36 changes: 36 additions & 0 deletions src/crypto/tls/cipher_suites.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"crypto/rc4"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"fmt"
"hash"
"internal/cpu"
Expand Down Expand Up @@ -69,6 +70,10 @@ func CipherSuites() []*CipherSuite {
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA, "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA", supportedOnlyTLS12, false},
{TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA, "TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA", supportedOnlyTLS12, false},
{TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, "TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
}
Expand Down Expand Up @@ -128,6 +133,9 @@ const (
// suiteSHA384 indicates that the cipher suite uses SHA384 as the
// handshake hash.
suiteSHA384
// suiteNoCerts indicates that the cipher suite doesn't use certificate exchange
// (anonymous ciphersuites or pre-shared-secret)
suiteNoCerts
)

// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange
Expand Down Expand Up @@ -169,6 +177,11 @@ var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order
{TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil},

{TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdhePSKKA, suiteECDHE | suiteTLS12 | suiteNoCerts, cipherAES, macSHA1, nil},
{TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdhePSKKA, suiteECDHE | suiteTLS12 | suiteNoCerts, cipherAES, macSHA256, nil},
{TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdhePSKKA, suiteECDHE | suiteTLS12 | suiteNoCerts, cipherAES, macSHA1, nil},
{TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, 32, 48, 16, ecdhePSKKA, suiteECDHE | suiteTLS12 | suiteNoCerts, cipherAES, macSHA384, nil},
}

// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which
Expand Down Expand Up @@ -297,6 +310,10 @@ var cipherSuitesPreferenceOrder = []uint16{
// RC4
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,

// PSK
TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA,
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA,
}

var cipherSuitesPreferenceOrderNoAES = []uint16{
Expand All @@ -320,6 +337,10 @@ var cipherSuitesPreferenceOrderNoAES = []uint16{
TLS_RSA_WITH_AES_128_CBC_SHA256,
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,

// PSK
TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA,
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA,
}

// disabledCipherSuites are not used unless explicitly listed in
Expand Down Expand Up @@ -437,6 +458,11 @@ func macSHA256(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}

// macSHA384 returns a SHA-384 based MAC.
func macSHA384(key []byte) hash.Hash {
return hmac.New(sha512.New384, key)
}

type aead interface {
cipher.AEAD

Expand Down Expand Up @@ -619,6 +645,12 @@ func ecdheRSAKA(version uint16) keyAgreement {
}
}

func ecdhePSKKA(version uint16) keyAgreement {
return &ecdhePskKeyAgreement{
version: version,
}
}

// mutualCipherSuite returns a cipherSuite given a list of supported
// ciphersuites and the id requested by the peer.
func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
Expand Down Expand Up @@ -683,6 +715,10 @@ const (
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xc035
TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xc036
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xc037
TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xc038
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9

Expand Down
3 changes: 3 additions & 0 deletions src/crypto/tls/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,9 @@ type Config struct {
// used for debugging.
KeyLogWriter io.Writer

// Extra is used to hold extra configuration for external cipher-suites
Extra interface{}

// mutex protects sessionTicketKeys and autoSessionTicketKeys.
mutex sync.RWMutex
// sessionTicketKeys contains zero or more ticket keys. If set, it means the
Expand Down
36 changes: 20 additions & 16 deletions src/crypto/tls/handshake_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,14 @@ func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
if !hs.ecdheOk {
return false
}
if c.flags&suiteECSign != 0 {
if !hs.ecSignOk {
if c.flags&suiteNoCerts == 0 {
if c.flags&suiteECSign != 0 {
if !hs.ecSignOk {
return false
}
} else if !hs.rsaSignOk {
return false
}
} else if !hs.rsaSignOk {
return false
}
} else if !hs.rsaDecryptOk {
return false
Expand Down Expand Up @@ -502,20 +504,22 @@ func (hs *serverHandshakeState) doFullHandshake() error {
return err
}

certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate
hs.finishedHash.Write(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err
}

if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg)
certStatus.response = hs.cert.OCSPStaple
hs.finishedHash.Write(certStatus.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
if hs.suite.flags&suiteNoCerts == 0 { // this suite requires certificate handshake
certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate
hs.finishedHash.Write(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err
}

if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg)
certStatus.response = hs.cert.OCSPStaple
hs.finishedHash.Write(certStatus.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
return err
}
}
}

keyAgreement := hs.suite.ka(c.vers)
Expand Down
218 changes: 218 additions & 0 deletions src/crypto/tls/key_agreement.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,221 @@ func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHel

return ka.preMasterSecret, ka.ckx, nil
}

// ecdhePskKeyAgreement implements a TLS key agreement where the server
// generates an ephemeral EC public/private key pair and signs it. The
// pre-master secret is then calculated using ECDH with Pre-shared key.
type ecdhePskKeyAgreement struct {
version uint16
params ecdheParameters

// ckx and otherSecret are generated in processServerKeyExchange
// and returned in generateClientKeyExchange.
ckx *clientKeyExchangeMsg
otherSecret []byte
pskIdentity string
}

func (ka *ecdhePskKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
var curveID CurveID
for _, c := range clientHello.supportedCurves {
if config.supportsCurve(c) {
curveID = c
break
}
}

if curveID == 0 {
return nil, errors.New("tls: no supported elliptic curves offered")
}
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
return nil, errors.New("tls: CurvePreferences includes unsupported curve")
}

params, err := generateECDHEParameters(config.rand(), curveID)
if err != nil {
return nil, err
}
ka.params = params

ecdhePublic := params.PublicKey()

serverECDHEParamsSize := 1 + 2 + 1 + len(ecdhePublic)
skx := new(serverKeyExchangeMsg)
skx.key = make([]byte, 2+serverECDHEParamsSize)

// See RFC 4492, Section 5.4.
serverECDHEParams := skx.key[2:]
serverECDHEParams[0] = 3 // named curve
serverECDHEParams[1] = byte(curveID >> 8)
serverECDHEParams[2] = byte(curveID)
serverECDHEParams[3] = byte(len(ecdhePublic))
copy(serverECDHEParams[4:], ecdhePublic)

return skx, nil
}

func (ka *ecdhePskKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
pskConfig, ok := config.Extra.(PSKConfig)
if !ok {
return nil, errors.New("bad Config - Extra not of type PSKConfig")
}

if pskConfig.GetKey == nil {
return nil, errors.New("bad Config - GetKey required for PSK")
}

if len(ckx.ciphertext) < 2 {
return nil, errors.New("bad ClientKeyExchange")
}

ciphertext := ckx.ciphertext
pskIdentityLen := int(ciphertext[0])<<8 | int(ciphertext[1])
if len(ciphertext) < (pskIdentityLen + 2) {
return nil, errors.New("bad ClientKeyExchange")
}
pskIdentity := string(ciphertext[2 : 2+pskIdentityLen])
ciphertext = ciphertext[2+pskIdentityLen:]

// ciphertext is actually the pskIdentity here
psk, err := pskConfig.GetKey(pskIdentity)
if err != nil {
return nil, err
}
pskLen := len(psk)

if len(ciphertext) < 1 {
return nil, errors.New("bad ClientKeyExchange")
}

publicKeyLen := int(ciphertext[0])
if len(ciphertext) < (publicKeyLen + 1) {
return nil, errors.New("bad ClientKeyExchange")
}
publicKey := ciphertext[1 : 1+publicKeyLen]

otherSecret := ka.params.SharedKey(publicKey)
if otherSecret == nil {
return nil, errClientKeyExchange
}
otherSecretLen := len(otherSecret)

preMasterSecret := make([]byte, 4+pskLen+otherSecretLen)
preMasterSecret[0] = byte(otherSecretLen >> 8)
preMasterSecret[1] = byte(otherSecretLen)
copy(preMasterSecret[2:], otherSecret)
preMasterSecret[2+otherSecretLen] = byte(pskLen >> 8)
preMasterSecret[3+otherSecretLen] = byte(pskLen)
copy(preMasterSecret[4+otherSecretLen:], psk)

return preMasterSecret, nil
}

func (ka *ecdhePskKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
pskConfig, ok := config.Extra.(PSKConfig)
if !ok {
return errors.New("bad Config - Extra not of type PSKConfig")
}

if pskConfig.GetIdentity == nil {
return errors.New("bad PSKConfig - GetIdentity required for PSK")
}

if pskConfig.GetKey == nil {
return errors.New("bad Config - GetKey required for PSK")
}

key := skx.key

if len(key) < 2 {
return errServerKeyExchange
}
pskIdentityFromServerLen := int(key[0])<<8 | int(key[1])
if pskIdentityFromServerLen > 0 {
if len(key) < (pskIdentityFromServerLen + 2) {
return errServerKeyExchange
}
pskIdentityFromServer := string(key[2 : 2+pskIdentityFromServerLen])
key = key[2+pskIdentityFromServerLen:]
_ = pskIdentityFromServer
}

pskIdentity := pskConfig.GetIdentity()
bPskIdentity := []byte(pskIdentity)
pskIdentityLen := len(bPskIdentity)
ka.pskIdentity = pskIdentity

if len(key) < 3 {
return errServerKeyExchange
}

if key[0] != 3 { // named curve
return errors.New("tls: server selected unsupported curve")
}
curveID := CurveID(key[1])<<8 | CurveID(key[2])

publicLen := int(key[3])
if publicLen+4 > len(key) {
return errServerKeyExchange
}
serverECDHEParams := key[:4+publicLen]
publicKey := serverECDHEParams[4:]

if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
return errors.New("tls: server selected unsupported curve")
}

params, err := generateECDHEParameters(config.rand(), curveID)
if err != nil {
return err
}
ka.params = params

ka.otherSecret = params.SharedKey(publicKey)
if ka.otherSecret == nil {
return errServerKeyExchange
}

ourPublicKey := params.PublicKey()
ka.ckx = new(clientKeyExchangeMsg)
ka.ckx.ciphertext = make([]byte, 2+pskIdentityLen+1+len(ourPublicKey))
ka.ckx.ciphertext[0] = byte(pskIdentityLen >> 8)
ka.ckx.ciphertext[1] = byte(pskIdentityLen)
copy(ka.ckx.ciphertext[2:], bPskIdentity)
ka.ckx.ciphertext[2+pskIdentityLen] = byte(len(ourPublicKey))
copy(ka.ckx.ciphertext[3+pskIdentityLen:], ourPublicKey)

return nil
}

func (ka *ecdhePskKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
pskConfig, ok := config.Extra.(PSKConfig)
if !ok {
return nil, nil, errors.New("bad Config - Extra not of type PSKConfig")
}

if pskConfig.GetKey == nil {
return nil, nil, errors.New("bad Config - GetKey required for PSK")
}

if ka.ckx == nil {
return nil, nil, errors.New("tls: missing ServerKeyExchange message")
}

psk, err := pskConfig.GetKey(ka.pskIdentity)
if err != nil {
return nil, nil, err
}
pskLen := len(psk)

otherSecretLen := len(ka.otherSecret)
preMasterSecret := make([]byte, 4+pskLen+otherSecretLen)
preMasterSecret[0] = byte(otherSecretLen >> 8)
preMasterSecret[1] = byte(otherSecretLen)
copy(preMasterSecret[2:], ka.otherSecret)
preMasterSecret[2+otherSecretLen] = byte(pskLen >> 8)
preMasterSecret[3+otherSecretLen] = byte(pskLen)
copy(preMasterSecret[4+otherSecretLen:], psk)

return preMasterSecret, ka.ckx, nil
}
11 changes: 11 additions & 0 deletions src/crypto/tls/psk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package tls

// Configuration for PSK cipher-suite. The client needs to provide a GetIdentity and GetKey functions to retrieve client id and pre-shared-key
type PSKConfig struct {
// client-only - returns the client identity
GetIdentity func() string

// for server - returns the key associated to a client identity
// for client - returns the key for this client
GetKey func(identity string) ([]byte, error)
}
Loading

0 comments on commit e70353c

Please sign in to comment.