From 1da426529d45e6988b3ca03b39cab6f3d5417c8e Mon Sep 17 00:00:00 2001 From: Maxime Piraux Date: Tue, 28 Apr 2020 12:01:28 +0200 Subject: [PATCH] Adds a lock to CryptoStates, fixes #23 --- agents/parse_agent.go | 2 +- agents/send_agent.go | 6 +++--- agents/tls_agent.go | 2 ++ connection.go | 19 +++++++++++++++++-- scenarii/key_update.go | 2 ++ scenarii/padding.go | 2 +- 6 files changed, 26 insertions(+), 7 deletions(-) diff --git a/agents/parse_agent.go b/agents/parse_agent.go index 62e8ee3..5d1bee2 100644 --- a/agents/parse_agent.go +++ b/agents/parse_agent.go @@ -43,7 +43,7 @@ func (a *ParsingAgent) Run(conn *Connection) { } header := ReadHeader(bytes.NewReader(ciphertext), a.conn) - cryptoState := a.conn.CryptoStates[header.EncryptionLevel()] + cryptoState := a.conn.CryptoState(header.EncryptionLevel()) switch header.PacketType() { case Initial, Handshake, ZeroRTTProtected, ShortHeaderPacket: // Decrypt PN diff --git a/agents/send_agent.go b/agents/send_agent.go index 3139320..590359f 100644 --- a/agents/send_agent.go +++ b/agents/send_agent.go @@ -50,7 +50,7 @@ func (a *SendingAgent) Run(conn *Connection) { initialSent := false fillPacket := func(packet Framer, level EncryptionLevel) Framer { - spaceLeft := int(a.MTU) - packet.Header().HeaderLength() - conn.CryptoStates[level].Write.Overhead() + spaceLeft := int(a.MTU) - packet.Header().HeaderLength() - conn.CryptoState(level).Write.Overhead() addFrame: for i, fp := range a.FrameProducer { @@ -114,7 +114,7 @@ func (a *SendingAgent) Run(conn *Connection) { } else { initialLength = MinimumInitialLength } - initialLength -= conn.CryptoStates[EncryptionLevelInitial].Write.Overhead() + initialLength -= conn.CryptoState(EncryptionLevelInitial).Write.Overhead() p.PadTo(initialLength) initialSent = true conn.DoSendPacket(p, EncryptionLevelInitial) @@ -191,7 +191,7 @@ func (a *SendingAgent) Run(conn *Connection) { } else { initialLength = MinimumInitialLength } - initialLength -= conn.CryptoStates[EncryptionLevelInitial].Write.Overhead() + initialLength -= conn.CryptoState(EncryptionLevelInitial).Write.Overhead() initial.PadTo(initialLength) initialSent = true } diff --git a/agents/tls_agent.go b/agents/tls_agent.go index b404c85..d92b0c4 100644 --- a/agents/tls_agent.go +++ b/agents/tls_agent.go @@ -83,6 +83,7 @@ func (a *TLSAgent) Run(conn *Connection) { a.TLSStatus.Submit(TLSStatus{false, packet, err}) } + conn.CryptoStateLock.Lock() if conn.CryptoStates[EncryptionLevelHandshake] == nil { conn.CryptoStates[EncryptionLevelHandshake] = new(CryptoState) } @@ -126,6 +127,7 @@ func (a *TLSAgent) Run(conn *Connection) { conn.EncryptionLevels.Submit(*e) } } + conn.CryptoStateLock.Unlock() if !resumptionTicketSent && len(conn.Tls.ResumptionTicket()) > 0 { a.ResumptionTicket.Submit(conn.Tls.ResumptionTicket()) diff --git a/connection.go b/connection.go index 17f57f9..28d3587 100644 --- a/connection.go +++ b/connection.go @@ -31,6 +31,7 @@ type Connection struct { SpinBit SpinBit LastSpinNumber PacketNumber + CryptoStateLock sync.Locker CryptoStates map[EncryptionLevel]*CryptoState ReceivedPacketHandler func([]byte, unsafe.Pointer) @@ -90,10 +91,19 @@ func (c *Connection) nextPacketNumber(space PNSpace) PacketNumber { // TODO: Th c.PacketNumberLock.Unlock() return pn } +func (c *Connection) CryptoState(level EncryptionLevel) *CryptoState { + c.CryptoStateLock.Lock() + cs, ok := c.CryptoStates[level] + c.CryptoStateLock.Unlock() + if ok { + return cs + } + return nil +} func (c *Connection) EncodeAndEncrypt(packet Packet, level EncryptionLevel) []byte { switch packet.PNSpace() { case PNSpaceInitial, PNSpaceHandshake, PNSpaceAppData: - cryptoState := c.CryptoStates[level] + cryptoState := c.CryptoState(level) payload := packet.EncodePayload() if h, ok := packet.Header().(*LongHeader); ok { @@ -161,7 +171,9 @@ func (c *Connection) GetInitialPacket() *InitialPacket { if len(c.Tls.ZeroRTTSecret()) > 0 { c.Logger.Printf("0-RTT secret is available, installing crypto state") + c.CryptoStateLock.Lock() c.CryptoStates[EncryptionLevel0RTT] = NewProtectedCryptoState(c.Tls, nil, c.Tls.ZeroRTTSecret()) + c.CryptoStateLock.Unlock() c.EncryptionLevels.Submit(DirectionalEncryptionLevel{EncryptionLevel: EncryptionLevel0RTT, Read: false, Available: true}) } @@ -174,7 +186,7 @@ func (c *Connection) GetInitialPacket() *InitialPacket { initialPacket := NewInitialPacket(c) initialPacket.Frames = append(initialPacket.Frames, cryptoFrame) - initialPacket.PadTo(initialLength - c.CryptoStates[EncryptionLevelInitial].Write.Overhead()) + initialPacket.PadTo(initialLength - c.CryptoState(EncryptionLevelInitial).Write.Overhead()) return initialPacket } @@ -250,9 +262,12 @@ func (c *Connection) TransitionTo(version uint32, ALPN string) { c.AckQueue[space] = nil } + c.CryptoStateLock = &sync.Mutex{} + c.CryptoStateLock.Lock() c.CryptoStates = make(map[EncryptionLevel]*CryptoState) c.CryptoStreams = make(map[PNSpace]*Stream) c.CryptoStates[EncryptionLevelInitial] = NewInitialPacketProtection(c) + c.CryptoStateLock.Unlock() c.Streams = Streams{streams: make(map[uint64]*Stream), lock: &sync.Mutex{}, input: &c.StreamInput} } func (c *Connection) CloseConnection(quicLayer bool, errCode uint64, reasonPhrase string) { diff --git a/scenarii/key_update.go b/scenarii/key_update.go index 180584f..b81052d 100644 --- a/scenarii/key_update.go +++ b/scenarii/key_update.go @@ -50,12 +50,14 @@ forLoop1: readSecret := conn.Tls.HkdfExpandLabel(conn.Tls.ProtectedReadSecret(), "ku", nil, conn.Tls.HashDigestSize(), pigotls.QuicBaseLabel) writeSecret := conn.Tls.HkdfExpandLabel(conn.Tls.ProtectedWriteSecret(), "ku", nil, conn.Tls.HashDigestSize(), pigotls.QuicBaseLabel) + conn.CryptoStateLock.Lock() oldState := conn.CryptoStates[qt.EncryptionLevel1RTT] conn.CryptoStates[qt.EncryptionLevel1RTT] = qt.NewProtectedCryptoState(conn.Tls, readSecret, writeSecret) conn.CryptoStates[qt.EncryptionLevel1RTT].HeaderRead = oldState.HeaderRead conn.CryptoStates[qt.EncryptionLevel1RTT].HeaderWrite = oldState.HeaderWrite conn.KeyPhaseIndex++ + conn.CryptoStateLock.Unlock() responseChan := connAgents.AddHTTPAgent().SendRequest(preferredPath, "GET", trace.Host, nil) diff --git a/scenarii/padding.go b/scenarii/padding.go index ff701e3..f0c1215 100644 --- a/scenarii/padding.go +++ b/scenarii/padding.go @@ -32,7 +32,7 @@ func (s *PaddingScenario) Run(conn *qt.Connection, trace *qt.Trace, preferredPat initialPacket := qt.NewInitialPacket(conn) payloadLen := len(initialPacket.EncodePayload()) - paddingLength := initialLength - (len(initialPacket.Header().Encode()) + int(VarIntLen(uint64(payloadLen))) + payloadLen + conn.CryptoStates[qt.EncryptionLevelInitial].Write.Overhead()) + paddingLength := initialLength - (len(initialPacket.Header().Encode()) + int(VarIntLen(uint64(payloadLen))) + payloadLen + conn.CryptoState(qt.EncryptionLevelInitial).Write.Overhead()) for i := 0; i < paddingLength; i++ { initialPacket.Frames = append(initialPacket.Frames, new(qt.PaddingFrame)) }