Skip to content

Commit

Permalink
feat: rm curveHint & parrot aes-gcm for tls12
Browse files Browse the repository at this point in the history
  • Loading branch information
3andne committed Mar 9, 2023
1 parent 9785380 commit 3e17d79
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 90 deletions.
18 changes: 8 additions & 10 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -755,11 +755,10 @@ type Config struct {
// auto-rotation logic. See Config.ticketKeys.
autoSessionTicketKeys []ticketKey

RestlsSecret []byte // #RESTLS#
VersionHint versionHint // #RESTLS#
CurveIDHint atomic.Uint32 // #RESTLS#
RestlsScript []Line // #RESTLS#
ClientID *ClientHelloID // #RESTLS#
RestlsSecret []byte // #RESTLS#
VersionHint versionHint // #RESTLS#
RestlsScript []Line // #RESTLS#
ClientID atomic.Pointer[ClientHelloID] // #RESTLS#
}

const (
Expand Down Expand Up @@ -842,7 +841,6 @@ func (c *Config) Clone() *Config {
KeyLogWriter: c.KeyLogWriter,
sessionTicketKeys: c.sessionTicketKeys,
autoSessionTicketKeys: c.autoSessionTicketKeys,
CurveIDHint: c.CurveIDHint, // #RESTLS#
VersionHint: c.VersionHint, // #RESTLS#
RestlsSecret: c.RestlsSecret, // #RESTLS#
RestlsScript: c.RestlsScript, // #RESTLS#
Expand Down Expand Up @@ -1534,11 +1532,11 @@ const (
restlsMaskLength int = restlsCmdLength + 2
restlsAppDataAuthHeaderLength int = restlsAppDataMACLength +
restlsMaskLength
restls12SessionTicketMACOffset int = 16
restls12PubKeyMACOffset int = 0
restlsAppDataOffset int = 5 + restlsAppDataAuthHeaderLength
restlsAppDataLenOffset int = 5 + restlsAppDataMACLength
restlsAppDataOffset int = restlsAppDataAuthHeaderLength
restlsAppDataLenOffset int = restlsAppDataMACLength
)

// #RESTLS#
var restlsRandomResponseMagic []byte = []byte("restls-random-response")
var restls12ClientAuthLayout3 []int = []int{0, 11, 22, 32}
var restls12ClientAuthLayout4 []int = []int{0, 8, 16, 24, 32}
135 changes: 95 additions & 40 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,16 @@ type Conn struct {

tmp [16]byte

tls12PubKey []byte // #RESTLS#
eagerEcdheParameters *ecdheParameters // #RESTLS#
serverRandom []byte // #RESTLS#
restlsAuthed bool // #RESTLS#
restlsInboundCounter uint32 // #RESTLS#
restlsOutboundCounter uint32 // #RESTLS#
restlsSendBuf []byte // #RESTLS#
restlsWritePending atomic.Bool // #RESTLS#
eagerEcdheParameters []*ecdheParameters // #RESTLS#
serverRandom []byte // #RESTLS#
restlsAuthed bool // #RESTLS#
restlsToClientCounter uint64 // #RESTLS#
restlsToServerCounter uint64 // #RESTLS#
restlsSendBuf []byte // #RESTLS#
restlsWritePending atomic.Bool // #RESTLS#
restls12WithGCM bool // #RESTLS#

restls12GCMServerDisableCtr bool // #RESTLS#
}

// Access to net.Conn methods.
Expand Down Expand Up @@ -628,12 +630,27 @@ func (c *Conn) extractRestlsAppData(record []byte) ([]byte, restlsCommand, error
return nil, nil, alertBadRecordMAC
}

header := record[:recordHeaderLen]
if c.restls12WithGCM && !c.restls12GCMServerDisableCtr {
serverCounter := binary.BigEndian.Uint64(record[recordHeaderLen:])
if serverCounter != c.restlsToClientCounter+1 {
c.restls12GCMServerDisableCtr = true
record = record[recordHeaderLen:]
} else {
header = record[:recordHeaderLen+8]
record = record[recordHeaderLen+8:]
}
} else {
record = record[recordHeaderLen:]
}

hmacAuth := c.restlsAuthHeaderHash(true)
hmacAuth.Write(header)
hmacAuth.Write(record[restlsAppDataLenOffset:])
authMac := hmacAuth.Sum(nil)[:restlsAppDataMACLength]
for i, m := range authMac {
if m != record[5+i] {
// fmt.Printf("extractRestlsAppData: bad authMac, expect %v, actual %v, to_server: %d, to_client: %d\n", authMac, record[5:5+restlsAppDataMACLength], c.restlsOutboundCounter, c.restlsInboundCounter)
if m != record[i] {
// fmt.Printf("extractRestlsAppData: bad authMac, expect %v, actual %v, to_server: %d, to_client: %d\n", authMac, record[:+restlsAppDataMACLength], c.restlsToServerCounter, c.restlsToClientCounter)
return nil, nil, alertBadRecordMAC
}
}
Expand All @@ -651,7 +668,7 @@ func (c *Conn) extractRestlsAppData(record []byte) ([]byte, restlsCommand, error
return nil, nil, err
}
data := record[restlsAppDataOffset : restlsAppDataOffset+dataLen]
// fmt.Printf("extractRestlsAppData: lengthMask: %v, recordLen: %v, dataLen: %v, authMac: %v, to_server: %d, to_client: %d\n", mask, len(record), dataLen, authMac, c.restlsOutboundCounter, c.restlsInboundCounter)
// fmt.Printf("extractRestlsAppData: lengthMask: %v, recordLen: %v, dataLen: %v, authMac: %v, to_server: %d, to_client: %d\n", mask, len(record), dataLen, authMac, c.restlsToServerCounter, c.restlsToClientCounter)
return data, command, nil
}

Expand Down Expand Up @@ -765,7 +782,13 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
hmac.Write(c.serverRandom)
serverRandomMac := hmac.Sum(nil)
recordCopy := append([]byte(nil), record...)
xorWithMac(recordCopy[recordHeaderLen:], serverRandomMac[:restlsHandshakeMACLength])
if c.restls12WithGCM && binary.BigEndian.Uint64(recordCopy[recordHeaderLen:recordHeaderLen+8]) == 0 {
xorWithMac(recordCopy[recordHeaderLen+8:], serverRandomMac[:restlsHandshakeMACLength])
} else {
c.restls12GCMServerDisableCtr = true
xorWithMac(recordCopy[recordHeaderLen:], serverRandomMac[:restlsHandshakeMACLength])
}

data, typ, err = c.in.decrypt(recordCopy)
if err != nil {
c.in.cipher = backupCipher
Expand All @@ -780,13 +803,25 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
} else {
data, command, err = c.extractRestlsAppData(record)
}
n := 0
if c.restlsWritePending.Swap(false) {
// fmt.Printf("restls unblock writers, len %d\n", len(data))
n, err = c.Write([]byte{})
if err != nil {
data, typ, err = c.in.decrypt(record)
if typ == recordTypeApplicationData {
data = nil
}
c.restlsToClientCounter += 1
} else {
n := 0
if c.restlsWritePending.Swap(false) {
// fmt.Printf("restls unblock writers, len %d\n", len(data))
n, err = c.Write([]byte{})
if err != nil {
fmt.Printf("n, err = c.Write([]byte{}) %v\n", err)
err = alertInternalError
}
}
c.restlsToClientCounter += 1
c.handleRestlsCommand(command, n > 0)
}
c.restlsInboundCounter += 1
c.handleRestlsCommand(command, n > 0)
// fmt.Printf("c.handleRestlsCommand returned\n")
} else {
// #Restls# End
Expand Down Expand Up @@ -1235,37 +1270,44 @@ var (
)

// #RESTLS#
func (c *Conn) restlsAuthHeaderHash(isInbound bool) hash.Hash {
func (c *Conn) restlsAuthHeaderHash(isToClient bool) hash.Hash {
hmac := RestlsHmac(c.config.RestlsSecret)
hmac.Write(c.serverRandom)
counterBytes := make([]byte, 4)
if isInbound {
counterBytes := make([]byte, 8)
if isToClient {
hmac.Write([]byte("server-to-client"))
binary.BigEndian.PutUint32(counterBytes, c.restlsInboundCounter)
binary.BigEndian.PutUint64(counterBytes, c.restlsToClientCounter)
} else {
hmac.Write([]byte("client-to-server"))
binary.BigEndian.PutUint32(counterBytes, c.restlsOutboundCounter)
binary.BigEndian.PutUint64(counterBytes, c.restlsToServerCounter)
}
hmac.Write(counterBytes)
return hmac
}

// #RESTLS#
func (c *Conn) write0x17AuthHeader(paddingLen int, dataLen int, command restlsCommand, outBuf []byte) ([]byte, error) {
func (c *Conn) writePadding(paddingLen int, outBuf []byte) ([]byte, error) {
outBuf, padding := sliceForAppend(outBuf, paddingLen)
_, err := rand.Read(padding)
if err != nil {
return nil, err
}
return outBuf, nil
}

// #RESTLS#
func (c *Conn) write0x17AuthHeader(paddingLen int, dataLen int, command restlsCommand, outBuf []byte) error {
restlsHeaderOffset := len(outBuf) - paddingLen - dataLen - restlsAppDataAuthHeaderLength
header := outBuf[:restlsHeaderOffset]
outBuf = outBuf[restlsHeaderOffset:]
sampleSize := 32
if sampleSize > len(outBuf[restlsAppDataOffset:]) {
sampleSize = len(outBuf[restlsAppDataOffset:])
}
hmacMask := c.restlsAuthHeaderHash(false)
_, err = hmacMask.Write(outBuf[restlsAppDataOffset : restlsAppDataOffset+sampleSize])
_, err := hmacMask.Write(outBuf[restlsAppDataOffset : restlsAppDataOffset+sampleSize])
if err != nil {
return nil, err
return err
}
mask := hmacMask.Sum(nil)[:restlsMaskLength]
dataLenBytes := outBuf[restlsAppDataLenOffset : restlsAppDataLenOffset+2]
Expand All @@ -1278,23 +1320,21 @@ func (c *Conn) write0x17AuthHeader(paddingLen int, dataLen int, command restlsCo
// fmt.Printf("writing clientFinished %v\n", clientFinished)
hmacAuth.Write(clientFinished)
}
_, err = hmacAuth.Write(outBuf[restlsAppDataLenOffset:]) // data len as well as the data are protected
if err != nil {
return nil, err
}
hmacAuth.Write(header)
hmacAuth.Write(outBuf[restlsAppDataLenOffset:]) // data len as well as the data are protected
authMac := hmacAuth.Sum(nil)[:restlsAppDataMACLength]
// fmt.Printf("lengthMask: %v, authMac: %v, to_server: %d, to_client: %d\n", mask, authMac, c.restlsOutboundCounter, c.restlsInboundCounter)
copy(outBuf[5:], authMac)
return outBuf, nil
// fmt.Printf("lengthMask: %v, authMac: %v, to_server: %d, to_client: %d\n", mask, authMac, c.restlsToServerCounter, c.restlsToClientCounter)
copy(outBuf[:restlsAppDataMACLength], authMac)
return nil
}

// #RESTLS#
func (c *Conn) actAccordingToScript(data []byte) (int, int, int, restlsCommand) {
paddingLen := 0
dataLen := len(data)
var command restlsCommand = ActNoop{}
if c.restlsOutboundCounter < uint32(len(c.config.RestlsScript)) {
line := c.config.RestlsScript[c.restlsOutboundCounter]
if c.restlsToServerCounter < uint64(len(c.config.RestlsScript)) {
line := c.config.RestlsScript[c.restlsToServerCounter]
dataLen = line.targetLen.Len()
command = line.command
}
Expand All @@ -1305,11 +1345,15 @@ func (c *Conn) actAccordingToScript(data []byte) (int, int, int, restlsCommand)
paddingLen = dataLen - len(data)
dataLen = len(data)
}
payloadLen := dataLen + restlsAppDataAuthHeaderLength + paddingLen
headerLen := restlsAppDataAuthHeaderLength
if c.restls12WithGCM {
headerLen += 8
}
payloadLen := dataLen + paddingLen + headerLen
if payloadLen > maxPlaintext {
payloadLen = maxPlaintext
}
dataLen = payloadLen - restlsAppDataAuthHeaderLength - paddingLen
dataLen = payloadLen - headerLen - paddingLen
if dataLen < 0 {
panic("target length is too large")
}
Expand Down Expand Up @@ -1358,17 +1402,28 @@ func (c *Conn) writeRestlsApplicationRecord(dataNew []byte) (int, error) {
if payloadLen == 0 {
return 0, nil
}
_, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen+restlsAppDataAuthHeaderLength)
headersLength := recordHeaderLen + restlsAppDataAuthHeaderLength
if c.restls12WithGCM {
headersLength += 8
}
_, outBuf = sliceForAppend(outBuf[:0], headersLength)
outBuf[0] = byte(recordTypeApplicationData)
vers := VersionTLS12
outBuf[1] = byte(vers >> 8)
outBuf[2] = byte(vers)
outBuf[3] = byte(payloadLen >> 8)
outBuf[4] = byte(payloadLen)

if c.restls12WithGCM {
binary.BigEndian.PutUint64(outBuf[recordHeaderLen:], c.restlsToServerCounter+1)
}
outBuf = append(outBuf, data[:dataLen]...)
var err error
if outBuf, err = c.write0x17AuthHeader(paddingLen, dataLen, command, outBuf); err != nil {
if outBuf, err = c.writePadding(paddingLen, outBuf); err != nil {
// fmt.Printf("writePadding failed %v", err)
return n, err
}
if err = c.write0x17AuthHeader(paddingLen, dataLen, command, outBuf); err != nil {
// fmt.Printf("write0x17AuthHeader failed %v", err)
return n, err
}
Expand All @@ -1379,7 +1434,7 @@ func (c *Conn) writeRestlsApplicationRecord(dataNew []byte) (int, error) {
// fmt.Printf("writeRestls c.write failed %v", err)
return n, err
}
c.restlsOutboundCounter += 1
c.restlsToServerCounter += 1
n += payloadLen
data = data[dataLen:]
if command.needInterrupt() && !fakeResponse {
Expand Down
44 changes: 31 additions & 13 deletions handshake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,21 +175,31 @@ func (c *Conn) generateSessionIDForTLS12(hello *clientHelloMsg) error {
if c.config.VersionHint != TLS12Hint && hello.supportedVersions[0] != VersionTLS12 {
panic("session id should only be generated from pub key and session ticket for TLS 1.2")
}
hint := CurveID(c.config.CurveIDHint.Load())
params, err := generateECDHEParameters(c.config.rand(), hint)
if err != nil {
return fmt.Errorf("restls: CurvePreferences includes unsupported curve: %v", err)
paramsList := []*ecdheParameters{}
materials := make([][]byte, 0, 4)
for _, curve := range curveIDList {
params, err := generateECDHEParameters(c.config.rand(), curve)
if err != nil {
return fmt.Errorf("restls: CurvePreferences includes unsupported curve: %v", err)
}
paramsList = append(paramsList, &params)
materials = append(materials, params.PublicKey())
}
c.tls12PubKey = params.PublicKey()
c.eagerEcdheParameters = &params
hmac := RestlsHmac(c.config.RestlsSecret)
hmac.Write(c.tls12PubKey)
pubkeyhash := hmac.Sum(nil)[:restlsHandshakeMACLength]
copy(hello.sessionId[restls12PubKeyMACOffset:], pubkeyhash)
c.eagerEcdheParameters = paramsList

if len(hello.sessionTicket) > 0 {
materials = append(materials, hello.sessionTicket)
}

layout := restls12ClientAuthLayout3
if len(materials) == 4 {
layout = restls12ClientAuthLayout4
}

for i, material := range materials {
hmac := RestlsHmac(c.config.RestlsSecret)
hmac.Write(hello.sessionTicket)
copy(hello.sessionId[restls12SessionTicketMACOffset:], hmac.Sum(nil)[:restlsHandshakeMACLength])
hmac.Write(material)
copy(hello.sessionId[layout[i]:layout[i+1]], hmac.Sum(nil))
}
return nil
}
Expand Down Expand Up @@ -458,7 +468,15 @@ func (hs *clientHandshakeState) handshake() error {

c.buffering = true
c.didResume = isResume
c.clientFinishedIsFirst = true // #Restls#
// #Restls# Begin
c.clientFinishedIsFirst = true
for _, ci := range tls12GCMCiphers {
if ci == hs.suite.id {
c.restls12WithGCM = true
break
}
}
// #Restls# End
if isResume {
if err := hs.establishKeys(); err != nil {
return err
Expand Down
Loading

0 comments on commit 3e17d79

Please sign in to comment.