From bd8ac9ecf8d3c89c10b91a3b40cb5f536a99635b Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 20 Jul 2023 16:45:15 -0700 Subject: [PATCH] quic: fill out connection id handling Add support for sending and receiving NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames. Keep the peer supplied with up to 4 connection IDs. Retire connection IDs as required by the peer. Support connection IDs provided in the preferred_address transport parameter. RFC 9000, Section 5.1. For golang/go#58547 Change-Id: I015a69b94c40a6396e9f117a92c88acaf83c594e Reviewed-on: https://go-review.googlesource.com/c/net/+/513440 TryBot-Result: Gopher Robot Run-TryBot: Damien Neil Reviewed-by: Jonathan Amsterdam --- internal/quic/conn.go | 32 ++- internal/quic/conn_id.go | 238 +++++++++++++++++- internal/quic/conn_id_test.go | 422 +++++++++++++++++++++++++++++++- internal/quic/conn_loss.go | 6 + internal/quic/conn_loss_test.go | 65 +++++ internal/quic/conn_recv.go | 29 ++- internal/quic/conn_send.go | 19 +- internal/quic/conn_test.go | 74 ++++-- internal/quic/frame_debug.go | 5 +- internal/quic/packet_parser.go | 4 +- internal/quic/packet_writer.go | 7 +- internal/quic/ping_test.go | 2 +- internal/quic/quic.go | 10 + internal/quic/tls.go | 4 +- internal/quic/tls_test.go | 156 ++++++++++-- 15 files changed, 998 insertions(+), 75 deletions(-) diff --git a/internal/quic/conn.go b/internal/quic/conn.go index ff03bd7f8..77ecea0d6 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -69,6 +69,7 @@ type connListener interface { type connTestHooks interface { nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any) handleTLSEvent(tls.QUICEvent) + newConnID(seq int64) ([]byte, error) } func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l connListener, hooks connTestHooks) (*Conn, error) { @@ -90,12 +91,12 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip. c.msgc = make(chan any, 1) if c.side == clientSide { - if err := c.connIDState.initClient(newRandomConnID); err != nil { + if err := c.connIDState.initClient(c.newConnIDFunc()); err != nil { return nil, err } - initialConnID = c.connIDState.dstConnID() + initialConnID, _ = c.connIDState.dstConnID() } else { - if err := c.connIDState.initServer(newRandomConnID, initialConnID); err != nil { + if err := c.connIDState.initServer(c.newConnIDFunc(), initialConnID); err != nil { return nil, err } } @@ -154,11 +155,27 @@ func (c *Conn) discardKeys(now time.Time, space numberSpace) { } // receiveTransportParameters applies transport parameters sent by the peer. -func (c *Conn) receiveTransportParameters(p transportParameters) { +func (c *Conn) receiveTransportParameters(p transportParameters) error { c.peerAckDelayExponent = p.ackDelayExponent c.loss.setMaxAckDelay(p.maxAckDelay) + if err := c.connIDState.setPeerActiveConnIDLimit(p.activeConnIDLimit, c.newConnIDFunc()); err != nil { + return err + } + if p.preferredAddrConnID != nil { + var ( + seq int64 = 1 // sequence number of this conn id is 1 + retirePriorTo int64 = 0 // retire nothing + resetToken [16]byte + ) + copy(resetToken[:], p.preferredAddrResetToken) + if err := c.connIDState.handleNewConnID(seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil { + return err + } + } // TODO: Many more transport parameters to come. + + return nil } type timerEvent struct{} @@ -295,3 +312,10 @@ func firstTime(a, b time.Time) time.Time { return b } } + +func (c *Conn) newConnIDFunc() newConnIDFunc { + if c.testHooks != nil { + return c.testHooks.newConnID + } + return newRandomConnID +} diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go index deea70d32..561dea2c1 100644 --- a/internal/quic/conn_id.go +++ b/internal/quic/conn_id.go @@ -7,6 +7,7 @@ package quic import ( + "bytes" "crypto/rand" ) @@ -18,8 +19,16 @@ type connIDState struct { // Local IDs are usually issued by us, and remote IDs by the peer. // The exception is the transient destination connection ID sent in // a client's Initial packets, which is chosen by the client. + // + // These are []connID rather than []*connID to minimize allocations. local []connID remote []connID + + nextLocalSeq int64 + retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer + peerActiveConnIDLimit int64 // peer's active_connection_id_limit transport parameter + + needSend bool } // A connID is a connection ID and associated metadata. @@ -32,12 +41,24 @@ type connID struct { // // For the transient destination ID in a client's Initial packet, this is -1. seq int64 + + // retired is set when the connection ID is retired. + retired bool + + // send is set when the connection ID's state needs to be sent to the peer. + // + // For local IDs, this indicates a new ID that should be sent + // in a NEW_CONNECTION_ID frame. + // + // For remote IDs, this indicates a retired ID that should be sent + // in a RETIRE_CONNECTION_ID frame. + send sentVal } func (s *connIDState) initClient(newID newConnIDFunc) error { // Client chooses its initial connection ID, and sends it // in the Source Connection ID field of the first Initial packet. - locid, err := newID() + locid, err := newID(0) if err != nil { return err } @@ -45,10 +66,11 @@ func (s *connIDState) initClient(newID newConnIDFunc) error { seq: 0, cid: locid, }) + s.nextLocalSeq = 1 // Client chooses an initial, transient connection ID for the server, // and sends it in the Destination Connection ID field of the first Initial packet. - remid, err := newID() + remid, err := newID(-1) if err != nil { return err } @@ -70,7 +92,7 @@ func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error { // Server chooses a connection ID, and sends it in the Source Connection ID of // the response to the clent. - locid, err := newID() + locid, err := newID(0) if err != nil { return err } @@ -78,6 +100,7 @@ func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error { seq: 0, cid: locid, }) + s.nextLocalSeq = 1 return nil } @@ -91,8 +114,44 @@ func (s *connIDState) srcConnID() []byte { } // dstConnID is the Destination Connection ID to use in a sent packet. -func (s *connIDState) dstConnID() []byte { - return s.remote[0].cid +func (s *connIDState) dstConnID() (cid []byte, ok bool) { + for i := range s.remote { + if !s.remote[i].retired { + return s.remote[i].cid, true + } + } + return nil, false +} + +// setPeerActiveConnIDLimit sets the active_connection_id_limit +// transport parameter received from the peer. +func (s *connIDState) setPeerActiveConnIDLimit(lim int64, newID newConnIDFunc) error { + s.peerActiveConnIDLimit = lim + return s.issueLocalIDs(newID) +} + +func (s *connIDState) issueLocalIDs(newID newConnIDFunc) error { + toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit) + for i := range s.local { + if s.local[i].seq != -1 && !s.local[i].retired { + toIssue-- + } + } + for toIssue > 0 { + cid, err := newID(s.nextLocalSeq) + if err != nil { + return err + } + s.local = append(s.local, connID{ + seq: s.nextLocalSeq, + cid: cid, + }) + s.local[len(s.local)-1].send.setUnsent() + s.nextLocalSeq++ + s.needSend = true + toIssue-- + } + return nil } // handlePacket updates the connection ID state during the handshake @@ -128,19 +187,184 @@ func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID [] } } +func (s *connIDState) handleNewConnID(seq, retire int64, cid []byte, resetToken [16]byte) error { + if len(s.remote[0].cid) == 0 { + // "An endpoint that is sending packets with a zero-length + // Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID + // frame as a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6 + return localTransportError(errProtocolViolation) + } + + if retire > s.retireRemotePriorTo { + s.retireRemotePriorTo = retire + } + + have := false // do we already have this connection ID? + active := 0 + for i := range s.remote { + rcid := &s.remote[i] + if !rcid.retired && rcid.seq < s.retireRemotePriorTo { + s.retireRemote(rcid) + } + if !rcid.retired { + active++ + } + if rcid.seq == seq { + if !bytes.Equal(rcid.cid, cid) { + return localTransportError(errProtocolViolation) + } + have = true // yes, we've seen this sequence number + } + } + + if !have { + // This is a new connection ID that we have not seen before. + // + // We could take steps to keep the list of remote connection IDs + // sorted by sequence number, but there's no particular need + // so we don't bother. + s.remote = append(s.remote, connID{ + seq: seq, + cid: cloneBytes(cid), + }) + if seq < s.retireRemotePriorTo { + // This ID was already retired by a previous NEW_CONNECTION_ID frame. + s.retireRemote(&s.remote[len(s.remote)-1]) + } else { + active++ + } + } + + if active > activeConnIDLimit { + // Retired connection IDs (including newly-retired ones) do not count + // against the limit. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5 + return localTransportError(errConnectionIDLimit) + } + + // "An endpoint SHOULD limit the number of connection IDs it has retired locally + // for which RETIRE_CONNECTION_ID frames have not yet been acknowledged." + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6 + // + // Set a limit of four times the active_connection_id_limit for + // the total number of remote connection IDs we keep state for locally. + if len(s.remote) > 4*activeConnIDLimit { + return localTransportError(errConnectionIDLimit) + } + + return nil +} + +// retireRemote marks a remote connection ID as retired. +func (s *connIDState) retireRemote(rcid *connID) { + rcid.retired = true + rcid.send.setUnsent() + s.needSend = true +} + +func (s *connIDState) handleRetireConnID(seq int64, newID newConnIDFunc) error { + if seq >= s.nextLocalSeq { + return localTransportError(errProtocolViolation) + } + for i := range s.local { + if s.local[i].seq == seq { + s.local = append(s.local[:i], s.local[i+1:]...) + break + } + } + s.issueLocalIDs(newID) + return nil +} + +func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fate packetFate) { + for i := range s.local { + if s.local[i].seq != seq { + continue + } + s.local[i].send.ackOrLoss(pnum, fate) + if fate != packetAcked { + s.needSend = true + } + return + } +} + +func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) { + for i := 0; i < len(s.remote); i++ { + if s.remote[i].seq != seq { + continue + } + if fate == packetAcked { + // We have retired this connection ID, and the peer has acked. + // Discard its state completely. + s.remote = append(s.remote[:i], s.remote[i+1:]...) + } else { + // RETIRE_CONNECTION_ID frame was lost, mark for retransmission. + s.needSend = true + s.remote[i].send.ackOrLoss(pnum, fate) + } + return + } +} + +// appendFrames appends NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames +// to the current packet. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (s *connIDState) appendFrames(w *packetWriter, pnum packetNumber, pto bool) bool { + if !s.needSend && !pto { + // Fast path: We don't need to send anything. + return true + } + retireBefore := int64(0) + if s.local[0].seq != -1 { + retireBefore = s.local[0].seq + } + for i := range s.local { + if !s.local[i].send.shouldSendPTO(pto) { + continue + } + if !w.appendNewConnectionIDFrame( + s.local[i].seq, + retireBefore, + s.local[i].cid, + [16]byte{}, // TODO: stateless reset token + ) { + return false + } + s.local[i].send.setSent(pnum) + } + for i := range s.remote { + if !s.remote[i].send.shouldSendPTO(pto) { + continue + } + if !w.appendRetireConnectionIDFrame(s.remote[i].seq) { + return false + } + s.remote[i].send.setSent(pnum) + } + s.needSend = false + return true +} + func cloneBytes(b []byte) []byte { n := make([]byte, len(b)) copy(n, b) return n } -type newConnIDFunc func() ([]byte, error) +type newConnIDFunc func(seq int64) ([]byte, error) -func newRandomConnID() ([]byte, error) { +func newRandomConnID(_ int64) ([]byte, error) { // It is not necessary for connection IDs to be cryptographically secure, // but it doesn't hurt. id := make([]byte, connIDLen) if _, err := rand.Read(id); err != nil { + // TODO: Surface this error as a metric or log event or something. + // rand.Read really shouldn't ever fail, but if it does, we should + // have a way to inform the user. return nil, err } return id, nil diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go index 7c31e9d56..74905578d 100644 --- a/internal/quic/conn_id_test.go +++ b/internal/quic/conn_id_test.go @@ -7,7 +7,10 @@ package quic import ( + "bytes" + "crypto/tls" "fmt" + "net/netip" "reflect" "testing" ) @@ -22,14 +25,16 @@ func TestConnIDClientHandshake(t *testing.T) { if got, want := string(s.srcConnID()), "local-1"; got != want { t.Errorf("after initClient: srcConnID = %q, want %q", got, want) } - if got, want := string(s.dstConnID()), "local-2"; got != want { + dstConnID, _ := s.dstConnID() + if got, want := string(dstConnID), "local-2"; got != want { t.Errorf("after initClient: dstConnID = %q, want %q", got, want) } // The server's first Initial packet provides the client with a // non-transient remote connection ID. s.handlePacket(clientSide, packetTypeInitial, []byte("remote-1")) - if got, want := string(s.dstConnID()), "remote-1"; got != want { + dstConnID, _ = s.dstConnID() + if got, want := string(dstConnID), "remote-1"; got != want { t.Errorf("after receiving Initial: dstConnID = %q, want %q", got, want) } @@ -59,7 +64,8 @@ func TestConnIDServerHandshake(t *testing.T) { if got, want := string(s.srcConnID()), "local-1"; got != want { t.Errorf("after initClient: srcConnID = %q, want %q", got, want) } - if got, want := string(s.dstConnID()), "remote-1"; got != want { + dstConnID, _ := s.dstConnID() + if got, want := string(dstConnID), "remote-1"; got != want { t.Errorf("after initClient: dstConnID = %q, want %q", got, want) } @@ -95,15 +101,421 @@ func TestConnIDServerHandshake(t *testing.T) { func newConnIDSequence() newConnIDFunc { var n uint64 - return func() ([]byte, error) { + return func(_ int64) ([]byte, error) { n++ return []byte(fmt.Sprintf("local-%v", n)), nil } } func TestNewRandomConnID(t *testing.T) { - cid, err := newRandomConnID() + cid, err := newRandomConnID(0) if len(cid) != connIDLen || err != nil { t.Fatalf("newConnID() = %x, %v; want %v bytes", cid, connIDLen, err) } } + +func TestConnIDPeerRequestsManyIDs(t *testing.T) { + // "An endpoint SHOULD ensure that its peer has a sufficient number + // of available and unused connection IDs." + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-4 + // + // "An endpoint MAY limit the total number of connection IDs + // issued for each connection [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-6 + // + // Peer requests 100 connection IDs. + // We give them 4 in total. + tc := newTestConn(t, serverSide, func(p *transportParameters) { + p.activeConnIDLimit = 100 + }) + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypeCrypto) + + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.wantFrame("provide additional connection ID 1", + packetType1RTT, debugFrameNewConnectionID{ + seq: 1, + connID: testLocalConnID(1), + }) + tc.wantFrame("provide additional connection ID 2", + packetType1RTT, debugFrameNewConnectionID{ + seq: 2, + connID: testLocalConnID(2), + }) + tc.wantFrame("provide additional connection ID 3", + packetType1RTT, debugFrameNewConnectionID{ + seq: 3, + connID: testLocalConnID(3), + }) + tc.wantIdle("connection ID limit reached, no more to provide") +} + +func TestConnIDPeerProvidesTooManyIDs(t *testing.T) { + // "An endpoint MUST NOT provide more connection IDs than the peer's limit." + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-4 + tc := newTestConn(t, serverSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + connID: testLocalConnID(2), + }) + tc.wantFrame("peer provided 3 connection IDs, our limit is 2", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errConnectionIDLimit, + }) +} + +func TestConnIDPeerTemporarilyExceedsActiveConnIDLimit(t *testing.T) { + // "An endpoint MAY send connection IDs that temporarily exceed a peer's limit + // if the NEW_CONNECTION_ID frame also requires the retirement of any excess [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-4 + tc := newTestConn(t, serverSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + retirePriorTo: 2, + seq: 2, + connID: testPeerConnID(2), + }, debugFrameNewConnectionID{ + retirePriorTo: 2, + seq: 3, + connID: testPeerConnID(3), + }) + tc.wantFrame("peer requested we retire conn id 0", + packetType1RTT, debugFrameRetireConnectionID{ + seq: 0, + }) + tc.wantFrame("peer requested we retire conn id 1", + packetType1RTT, debugFrameRetireConnectionID{ + seq: 1, + }) +} + +func TestConnIDPeerRetiresConnID(t *testing.T) { + // "An endpoint SHOULD supply a new connection ID when the peer retires a connection ID." + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-6 + for _, side := range []connSide{ + clientSide, + serverSide, + } { + t.Run(side.String(), func(t *testing.T) { + tc := newTestConn(t, side) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameRetireConnectionID{ + seq: 0, + }) + tc.wantFrame("provide replacement connection ID", + packetType1RTT, debugFrameNewConnectionID{ + seq: 2, + retirePriorTo: 1, + connID: testLocalConnID(2), + }) + }) + } +} + +func TestConnIDPeerWithZeroLengthConnIDSendsNewConnectionID(t *testing.T) { + // An endpoint that selects a zero-length connection ID during the handshake + // cannot issue a new connection ID." + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-8 + tc := newTestConn(t, clientSide, func(c *tls.Config) { + c.SessionTicketsDisabled = true + }) + tc.peerConnID = []byte{} + tc.ignoreFrame(frameTypeAck) + tc.uncheckedHandshake() + + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 1, + connID: testPeerConnID(1), + }) + tc.wantFrame("invalid NEW_CONNECTION_ID: previous conn id is zero-length", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errProtocolViolation, + }) +} + +func TestConnIDPeerRequestsRetirement(t *testing.T) { + // "Upon receipt of an increased Retire Prior To field, the peer MUST + // stop using the corresponding connection IDs and retire them with + // RETIRE_CONNECTION_ID frames [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-5 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + retirePriorTo: 1, + connID: testPeerConnID(2), + }) + tc.wantFrame("peer asked for conn id 0 to be retired", + packetType1RTT, debugFrameRetireConnectionID{ + seq: 0, + }) + if got, want := tc.sentFramePacket.dstConnID, testPeerConnID(1); !bytes.Equal(got, want) { + t.Fatalf("used destination conn id {%x}, want {%x}", got, want) + } +} + +func TestConnIDPeerDoesNotAcknowledgeRetirement(t *testing.T) { + // "An endpoint SHOULD limit the number of connection IDs it has retired locally + // for which RETIRE_CONNECTION_ID frames have not yet been acknowledged." + // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypeRetireConnectionID) + + // Send a number of NEW_CONNECTION_ID frames, each retiring an old one. + for seq := int64(0); seq < 7; seq++ { + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: seq + 2, + retirePriorTo: seq + 1, + connID: testPeerConnID(seq + 2), + }) + // We're ignoring the RETIRE_CONNECTION_ID frames. + } + tc.wantFrame("number of retired, unacked conn ids is too large", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errConnectionIDLimit, + }) +} + +func TestConnIDRepeatedNewConnectionIDFrame(t *testing.T) { + // "Receipt of the same [NEW_CONNECTION_ID] frame multiple times + // MUST NOT be treated as a connection error. + // https://www.rfc-editor.org/rfc/rfc9000#section-19.15-7 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + for i := 0; i < 4; i++ { + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + retirePriorTo: 1, + connID: testPeerConnID(2), + }) + } + tc.wantFrame("peer asked for conn id to be retired", + packetType1RTT, debugFrameRetireConnectionID{ + seq: 0, + }) + tc.wantIdle("repeated NEW_CONNECTION_ID frames are not an error") +} + +func TestConnIDForSequenceNumberChanges(t *testing.T) { + // "[...] if a sequence number is used for different connection IDs, + // the endpoint MAY treat that receipt as a connection error + // of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000#section-19.15-8 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypeRetireConnectionID) + + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + retirePriorTo: 1, + connID: testPeerConnID(2), + }) + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + retirePriorTo: 1, + connID: testPeerConnID(3), + }) + tc.wantFrame("connection ID for sequence 0 has changed", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errProtocolViolation, + }) +} + +func TestConnIDRetirePriorToAfterNewConnID(t *testing.T) { + // "Receiving a value in the Retire Prior To field that is greater than + // that in the Sequence Number field MUST be treated as a connection error + // of type FRAME_ENCODING_ERROR. + // https://www.rfc-editor.org/rfc/rfc9000#section-19.15-9 + tc := newTestConn(t, serverSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + retirePriorTo: 3, + seq: 2, + connID: testPeerConnID(2), + }) + tc.wantFrame("invalid NEW_CONNECTION_ID: retired the new conn id", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errFrameEncoding, + }) +} + +func TestConnIDAlreadyRetired(t *testing.T) { + // "An endpoint that receives a NEW_CONNECTION_ID frame with a + // sequence number smaller than the Retire Prior To field of a + // previously received NEW_CONNECTION_ID frame MUST send a + // corresponding RETIRE_CONNECTION_ID frame [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-19.15-11 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 4, + retirePriorTo: 3, + connID: testPeerConnID(4), + }) + tc.wantFrame("peer asked for conn id to be retired", + packetType1RTT, debugFrameRetireConnectionID{ + seq: 0, + }) + tc.wantFrame("peer asked for conn id to be retired", + packetType1RTT, debugFrameRetireConnectionID{ + seq: 1, + }) + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + retirePriorTo: 0, + connID: testPeerConnID(2), + }) + tc.wantFrame("NEW_CONNECTION_ID was for an already-retired ID", + packetType1RTT, debugFrameRetireConnectionID{ + seq: 2, + }) +} + +func TestConnIDRepeatedRetireConnectionIDFrame(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + for i := 0; i < 4; i++ { + tc.writeFrames(packetType1RTT, + debugFrameRetireConnectionID{ + seq: 0, + }) + } + tc.wantFrame("issue new conn id after peer retires one", + packetType1RTT, debugFrameNewConnectionID{ + retirePriorTo: 1, + seq: 2, + connID: testLocalConnID(2), + }) + tc.wantIdle("repeated RETIRE_CONNECTION_ID frames are not an error") +} + +func TestConnIDRetiredUnsent(t *testing.T) { + // "Receipt of a RETIRE_CONNECTION_ID frame containing a sequence number + // greater than any previously sent to the peer MUST be treated as a + // connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000#section-19.16-7 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameRetireConnectionID{ + seq: 2, + }) + tc.wantFrame("invalid NEW_CONNECTION_ID: previous conn id is zero-length", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errProtocolViolation, + }) +} + +func TestConnIDUsePreferredAddressConnID(t *testing.T) { + // Peer gives us a connection ID in the preferred address transport parameter. + // We don't use the preferred address at this time, but we should use the + // connection ID. (It isn't tied to any specific address.) + // + // This test will probably need updating if/when we start using the preferred address. + cid := testPeerConnID(10) + tc := newTestConn(t, serverSide, func(p *transportParameters) { + p.preferredAddrV4 = netip.MustParseAddrPort("0.0.0.0:0") + p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0") + p.preferredAddrConnID = cid + p.preferredAddrResetToken = make([]byte, 16) + }) + tc.uncheckedHandshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + retirePriorTo: 1, + connID: []byte{0xff}, + }) + tc.wantFrame("peer asked for conn id 0 to be retired", + packetType1RTT, debugFrameRetireConnectionID{ + seq: 0, + }) + if got, want := tc.sentFramePacket.dstConnID, cid; !bytes.Equal(got, want) { + t.Fatalf("used destination conn id {%x}, want {%x} from preferred address transport parameter", got, want) + } +} + +func TestConnIDPeerProvidesPreferredAddrAndTooManyConnIDs(t *testing.T) { + // Peer gives us more conn ids than our advertised limit, + // including a conn id in the preferred address transport parameter. + cid := testPeerConnID(10) + tc := newTestConn(t, serverSide, func(p *transportParameters) { + p.preferredAddrV4 = netip.MustParseAddrPort("0.0.0.0:0") + p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0") + p.preferredAddrConnID = cid + p.preferredAddrResetToken = make([]byte, 16) + }) + tc.uncheckedHandshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + retirePriorTo: 0, + connID: testPeerConnID(2), + }) + tc.wantFrame("peer provided 3 connection IDs, our limit is 2", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errConnectionIDLimit, + }) +} + +func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) { + // Peer gives us more conn ids than our advertised limit, + // including a conn id in the preferred address transport parameter. + tc := newTestConn(t, serverSide, func(p *transportParameters) { + p.preferredAddrV4 = netip.MustParseAddrPort("0.0.0.0:0") + p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0") + p.preferredAddrConnID = testPeerConnID(1) + p.preferredAddrResetToken = make([]byte, 16) + }) + tc.peerConnID = []byte{} + + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.wantFrame("peer with zero-length connection ID tried to provide another in transport parameters", + packetTypeInitial, debugFrameConnectionCloseTransport{ + code: errProtocolViolation, + }) +} diff --git a/internal/quic/conn_loss.go b/internal/quic/conn_loss.go index 6cb459c33..57570d086 100644 --- a/internal/quic/conn_loss.go +++ b/internal/quic/conn_loss.go @@ -44,6 +44,12 @@ func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetF case frameTypeCrypto: start, end := sent.nextRange() c.crypto[space].ackOrLoss(start, end, fate) + case frameTypeNewConnectionID: + seq := int64(sent.nextInt()) + c.connIDState.ackOrLossNewConnectionID(sent.num, seq, fate) + case frameTypeRetireConnectionID: + seq := int64(sent.nextInt()) + c.connIDState.ackOrLossRetireConnectionID(sent.num, seq, fate) case frameTypeHandshakeDone: c.handshakeConfirmed.ackOrLoss(sent.num, fate) } diff --git a/internal/quic/conn_loss_test.go b/internal/quic/conn_loss_test.go index be4f5fb2c..021c86c87 100644 --- a/internal/quic/conn_loss_test.go +++ b/internal/quic/conn_loss_test.go @@ -93,6 +93,11 @@ func TestLostCRYPTOFrame(t *testing.T) { packetTypeHandshake, debugFrameCrypto{ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake], }) + tc.wantFrame("client provides server with an additional connection ID", + packetType1RTT, debugFrameNewConnectionID{ + seq: 1, + connID: testLocalConnID(1), + }) tc.triggerLossOrPTO(packetTypeHandshake, pto) tc.wantFrame("client resends Handshake CRYPTO frame", packetTypeHandshake, debugFrameCrypto{ @@ -101,6 +106,61 @@ func TestLostCRYPTOFrame(t *testing.T) { }) } +func TestLostNewConnectionIDFrame(t *testing.T) { + // "New connection IDs are [...] retransmitted if the packet containing them is lost." + // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.13 + lostFrameTest(t, func(t *testing.T, pto bool) { + tc := newTestConn(t, serverSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameRetireConnectionID{ + seq: 1, + }) + tc.wantFrame("provide a new connection ID after peer retires old one", + packetType1RTT, debugFrameNewConnectionID{ + seq: 2, + connID: testLocalConnID(2), + }) + + tc.triggerLossOrPTO(packetType1RTT, pto) + tc.wantFrame("resend new connection ID", + packetType1RTT, debugFrameNewConnectionID{ + seq: 2, + connID: testLocalConnID(2), + }) + }) +} + +func TestLostRetireConnectionIDFrame(t *testing.T) { + // "[...] retired connection IDs are [...] retransmitted + // if the packet containing them is lost." + // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.13 + lostFrameTest(t, func(t *testing.T, pto bool) { + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + retirePriorTo: 1, + connID: testPeerConnID(2), + }) + tc.wantFrame("peer requested connection id be retired", + packetType1RTT, debugFrameRetireConnectionID{ + seq: 0, + }) + + tc.triggerLossOrPTO(packetType1RTT, pto) + tc.wantFrame("resend RETIRE_CONNECTION_ID", + packetType1RTT, debugFrameRetireConnectionID{ + seq: 0, + }) + }) +} + func TestLostHandshakeDoneFrame(t *testing.T) { // "The HANDSHAKE_DONE frame MUST be retransmitted until it is acknowledged." // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.16 @@ -120,6 +180,11 @@ func TestLostHandshakeDoneFrame(t *testing.T) { packetTypeHandshake, debugFrameCrypto{ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake], }) + tc.wantFrame("server provides an additional connection ID", + packetType1RTT, debugFrameNewConnectionID{ + seq: 1, + connID: testLocalConnID(1), + }) tc.writeFrames(packetTypeHandshake, debugFrameCrypto{ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index 3baa79a0c..7992a619f 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -211,7 +211,12 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, if !frameOK(c, ptype, __01) { return } - _, _, _, _, n = consumeNewConnectionIDFrame(payload) + n = c.handleNewConnectionIDFrame(now, space, payload) + case frameTypeRetireConnectionID: + if !frameOK(c, ptype, __01) { + return + } + n = c.handleRetireConnectionIDFrame(now, space, payload) case frameTypeConnectionCloseTransport: // CONNECTION_CLOSE is OK in all spaces. _, _, _, n = consumeConnectionCloseTransportFrame(payload) @@ -285,6 +290,28 @@ func (c *Conn) handleCryptoFrame(now time.Time, space numberSpace, payload []byt return n } +func (c *Conn) handleNewConnectionIDFrame(now time.Time, space numberSpace, payload []byte) int { + seq, retire, connID, resetToken, n := consumeNewConnectionIDFrame(payload) + if n < 0 { + return -1 + } + if err := c.connIDState.handleNewConnID(seq, retire, connID, resetToken); err != nil { + c.abort(now, err) + } + return n +} + +func (c *Conn) handleRetireConnectionIDFrame(now time.Time, space numberSpace, payload []byte) int { + seq, n := consumeRetireConnectionIDFrame(payload) + if n < 0 { + return -1 + } + if err := c.connIDState.handleRetireConnID(seq, c.newConnIDFunc()); err != nil { + c.abort(now, err) + } + return n +} + func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payload []byte) int { if c.side == serverSide { // Clients should never send HANDSHAKE_DONE. diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index 62c9b62ec..d410548a9 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -44,6 +44,13 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // Prepare to write a datagram of at most maxSendSize bytes. c.w.reset(c.loss.maxSendSize()) + dstConnID, ok := c.connIDState.dstConnID() + if !ok { + // It is currently not possible for us to end up without a connection ID, + // but handle the case anyway. + return time.Time{} + } + // Initial packet. pad := false var sentInitial *sentPacket @@ -54,7 +61,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { ptype: packetTypeInitial, version: 1, num: pnum, - dstConnID: c.connIDState.dstConnID(), + dstConnID: dstConnID, srcConnID: c.connIDState.srcConnID(), } c.w.startProtectedLongHeaderPacket(pnumMaxAcked, p) @@ -81,7 +88,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { ptype: packetTypeHandshake, version: 1, num: pnum, - dstConnID: c.connIDState.dstConnID(), + dstConnID: dstConnID, srcConnID: c.connIDState.srcConnID(), } c.w.startProtectedLongHeaderPacket(pnumMaxAcked, p) @@ -104,7 +111,6 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { if k := c.wkeys[appDataSpace]; k.isSet() { pnumMaxAcked := c.acks[appDataSpace].largestSeen() pnum := c.loss.nextNumber(appDataSpace) - dstConnID := c.connIDState.dstConnID() c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID) c.appendFrames(now, appDataSpace, pnum, limit) if pad && len(c.w.payload()) > 0 { @@ -233,6 +239,13 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, return int64(len(b)) }) + // NEW_CONNECTION_ID, RETIRE_CONNECTION_ID + if space == appDataSpace { + if !c.connIDState.appendFrames(&c.w, pnum, pto) { + return + } + } + // Test-only PING frames. if space == c.testSendPingSpace && c.testSendPing.shouldSendPTO(pto) { if !c.w.appendPingFrame() { diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index 511fb97a0..317ca8f81 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -128,19 +128,16 @@ type testConn struct { cryptoDataIn map[tls.QUICEncryptionLevel][]byte peerTLSConn *tls.QUICConn - localConnID []byte - transientConnID []byte - // Information about the conn's (fake) peer. peerConnID []byte // source conn id of peer's packets peerNextPacketNum [numberSpaceCount]packetNumber // next packet number to use // Datagrams, packets, and frames sent by the conn, // but not yet processed by the test. - sentDatagrams [][]byte - sentPackets []*testPacket - sentFrames []debugFrame - sentFramePacketType packetType + sentDatagrams [][]byte + sentPackets []*testPacket + sentFrames []debugFrame + sentFramePacket *testPacket // Frame types to ignore in tests. ignoreFrames map[byte]bool @@ -162,7 +159,7 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { tc := &testConn{ t: t, now: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), - peerConnID: []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5}, + peerConnID: testPeerConnID(0), ignoreFrames: map[byte]bool{ frameTypePadding: true, // ignore PADDING by default }, @@ -179,6 +176,8 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { switch o := o.(type) { case func(*tls.Config): o(config.TLSConfig) + case func(p *transportParameters): + o(&peerProvidedParams) default: t.Fatalf("unknown newTestConn option %T", o) } @@ -189,7 +188,7 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { // The initial connection ID for the server is chosen by the client. // When creating a server-side connection, pick a random connection ID here. var err error - initialConnID, err = newRandomConnID() + initialConnID, err = newRandomConnID(0) if err != nil { tc.t.Fatal(err) } @@ -217,14 +216,6 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { } tc.conn = conn - if side == serverSide { - tc.transientConnID = tc.conn.connIDState.local[0].cid - tc.localConnID = tc.conn.connIDState.local[1].cid - } else if side == clientSide { - tc.transientConnID = tc.conn.connIDState.remote[0].cid - tc.localConnID = tc.conn.connIDState.local[0].cid - } - tc.wkeys[initialSpace].k = conn.wkeys[initialSpace] tc.rkeys[initialSpace].k = conn.rkeys[initialSpace] @@ -326,7 +317,11 @@ func (tc *testConn) write(d *testDatagram) { if p.num >= tc.peerNextPacketNum[space] { tc.peerNextPacketNum[space] = p.num + 1 } - buf = append(buf, tc.encodeTestPacket(p)...) + pad := 0 + if p.ptype == packetType1RTT { + pad = d.paddedSize + } + buf = append(buf, tc.encodeTestPacket(p, pad)...) } for len(buf) < d.paddedSize { buf = append(buf, 0) @@ -407,12 +402,12 @@ func (tc *testConn) readFrame() (debugFrame, packetType) { if p == nil { return nil, packetTypeInvalid } - tc.sentFramePacketType = p.ptype + tc.sentFramePacket = p tc.sentFrames = p.frames } f := tc.sentFrames[0] tc.sentFrames = tc.sentFrames[1:] - return f, tc.sentFramePacketType + return f, tc.sentFramePacket.ptype } // wantDatagram indicates that we expect the Conn to send a datagram. @@ -462,7 +457,7 @@ func (tc *testConn) wantIdle(expectation string) { } } -func (tc *testConn) encodeTestPacket(p *testPacket) []byte { +func (tc *testConn) encodeTestPacket(p *testPacket, pad int) []byte { tc.t.Helper() var w packetWriter w.reset(1200) @@ -486,6 +481,7 @@ func (tc *testConn) encodeTestPacket(p *testPacket) []byte { tc.t.Fatalf("sending packet with no %v keys available", space) return nil } + w.appendPaddingTo(pad) if p.ptype != packetType1RTT { w.finishProtectedLongHeaderPacket(pnumMaxAcked, tc.rkeys[space].k, longPacket{ ptype: p.ptype, @@ -504,6 +500,7 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram { tc.t.Helper() bufSize := len(buf) d := &testDatagram{} + size := len(buf) for len(buf) > 0 { if buf[0] == 0 { d.paddedSize = bufSize @@ -552,6 +549,20 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram { buf = buf[n:] } } + // This is rather hackish: If the last frame in the last packet + // in the datagram is PADDING, then remove it and record + // the padded size in the testDatagram.paddedSize. + // + // This makes it easier to write a test that expects a datagram + // padded to 1200 bytes. + if len(d.packets) > 0 && len(d.packets[len(d.packets)-1].frames) > 0 { + p := d.packets[len(d.packets)-1] + f := p.frames[len(p.frames)-1] + if _, ok := f.(debugFramePadding); ok { + p.frames = p.frames[:len(p.frames)-1] + d.paddedSize = size + } + } return d } @@ -686,6 +697,27 @@ func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.T return tc.now, m } +func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) { + return testLocalConnID(seq), nil +} + +// testLocalConnID returns the connection ID with a given sequence number +// used by a Conn under test. +func testLocalConnID(seq int64) []byte { + cid := make([]byte, connIDLen) + copy(cid, []byte{0xc0, 0xff, 0xee}) + cid[len(cid)-1] = byte(seq) + return cid +} + +// testPeerConnID returns the connection ID with a given sequence number +// used by the fake peer of a Conn under test. +func testPeerConnID(seq int64) []byte { + // Use a different length than we choose for our own conn ids, + // to help catch any bad assumptions. + return []byte{0xbe, 0xee, 0xff, byte(seq)} +} + // testConnListener implements connListener. type testConnListener testConn diff --git a/internal/quic/frame_debug.go b/internal/quic/frame_debug.go index 3009a0450..7a5aee57b 100644 --- a/internal/quic/frame_debug.go +++ b/internal/quic/frame_debug.go @@ -386,10 +386,7 @@ func (f debugFrameNewConnectionID) write(w *packetWriter) bool { // debugFrameRetireConnectionID is a NEW_CONNECTION_ID frame. type debugFrameRetireConnectionID struct { - seq uint64 - retirePriorTo uint64 - connID []byte - token [16]byte + seq int64 } func parseDebugFrameRetireConnectionID(b []byte) (f debugFrameRetireConnectionID, n int) { diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go index c22f03103..052007897 100644 --- a/internal/quic/packet_parser.go +++ b/internal/quic/packet_parser.go @@ -454,10 +454,10 @@ func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, re return seq, retire, connID, resetToken, n } -func consumeRetireConnectionIDFrame(b []byte) (seq uint64, n int) { +func consumeRetireConnectionIDFrame(b []byte) (seq int64, n int) { n = 1 var nn int - seq, nn = consumeVarint(b[n:]) + seq, nn = consumeVarintInt64(b[n:]) if nn < 0 { return 0, -1 } diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go index 6c4c452cd..a80b4711e 100644 --- a/internal/quic/packet_writer.go +++ b/internal/quic/packet_writer.go @@ -482,13 +482,14 @@ func (w *packetWriter) appendNewConnectionIDFrame(seq, retirePriorTo int64, conn return true } -func (w *packetWriter) appendRetireConnectionIDFrame(seq uint64) (added bool) { - if w.avail() < 1+sizeVarint(seq) { +func (w *packetWriter) appendRetireConnectionIDFrame(seq int64) (added bool) { + if w.avail() < 1+sizeVarint(uint64(seq)) { return false } w.b = append(w.b, frameTypeRetireConnectionID) - w.b = appendVarint(w.b, seq) + w.b = appendVarint(w.b, uint64(seq)) w.sent.appendAckElicitingFrame(frameTypeRetireConnectionID) + w.sent.appendInt(uint64(seq)) return true } diff --git a/internal/quic/ping_test.go b/internal/quic/ping_test.go index c370aaf1d..a8fdf2567 100644 --- a/internal/quic/ping_test.go +++ b/internal/quic/ping_test.go @@ -37,7 +37,7 @@ func TestAck(t *testing.T) { tc.wantFrame("connection should respond to ack-eliciting packet with an ACK frame", packetType1RTT, debugFrameAck{ - ranges: []i64range[packetNumber]{{0, 3}}, + ranges: []i64range[packetNumber]{{0, 4}}, }, ) } diff --git a/internal/quic/quic.go b/internal/quic/quic.go index a61c91f16..84ce2bda1 100644 --- a/internal/quic/quic.go +++ b/internal/quic/quic.go @@ -35,6 +35,16 @@ const ( ackDelayExponent = 3 // ack_delay_exponent maxAckDelay = 25 * time.Millisecond // max_ack_delay + + // The active_conn_id_limit transport parameter is the maximum + // number of connection IDs from the peer we're willing to store. + // + // maxPeerActiveConnIDLimit is the maximum number of connection IDs + // we're willing to send to the peer. + // + // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-6.2.1 + activeConnIDLimit = 2 + maxPeerActiveConnIDLimit = 4 ) // Local timer granularity. diff --git a/internal/quic/tls.go b/internal/quic/tls.go index 4306a3e46..ed848c6a1 100644 --- a/internal/quic/tls.go +++ b/internal/quic/tls.go @@ -83,7 +83,9 @@ func (c *Conn) handleTLSEvents(now time.Time) error { if err != nil { return err } - c.receiveTransportParameters(params) + if err := c.receiveTransportParameters(params); err != nil { + return err + } } } } diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go index df0782008..3768dc0c0 100644 --- a/internal/quic/tls_test.go +++ b/internal/quic/tls_test.go @@ -63,15 +63,26 @@ func (tc *testConn) handshake() { func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { var ( - clientConnID []byte - serverConnID []byte + clientConnIDs [][]byte + serverConnIDs [][]byte + transientConnID []byte ) + localConnIDs := [][]byte{ + testLocalConnID(0), + testLocalConnID(1), + } + peerConnIDs := [][]byte{ + testPeerConnID(0), + testPeerConnID(1), + } if tc.conn.side == clientSide { - clientConnID = tc.localConnID - serverConnID = tc.peerConnID + clientConnIDs = localConnIDs + serverConnIDs = peerConnIDs + transientConnID = testLocalConnID(-1) } else { - clientConnID = tc.peerConnID - serverConnID = tc.localConnID + clientConnIDs = peerConnIDs + serverConnIDs = localConnIDs + transientConnID = []byte{0xde, 0xad, 0xbe, 0xef} } return []*testDatagram{{ // Client Initial @@ -79,21 +90,21 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { ptype: packetTypeInitial, num: 0, version: 1, - srcConnID: clientConnID, - dstConnID: tc.transientConnID, + srcConnID: clientConnIDs[0], + dstConnID: transientConnID, frames: []debugFrame{ debugFrameCrypto{}, }, }}, paddedSize: 1200, }, { - // Server Initial + Handshake + // Server Initial + Handshake + 1-RTT packets: []*testPacket{{ ptype: packetTypeInitial, num: 0, version: 1, - srcConnID: serverConnID, - dstConnID: clientConnID, + srcConnID: serverConnIDs[0], + dstConnID: clientConnIDs[0], frames: []debugFrame{ debugFrameAck{ ranges: []i64range[packetNumber]{{0, 1}}, @@ -104,20 +115,30 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { ptype: packetTypeHandshake, num: 0, version: 1, - srcConnID: serverConnID, - dstConnID: clientConnID, + srcConnID: serverConnIDs[0], + dstConnID: clientConnIDs[0], frames: []debugFrame{ debugFrameCrypto{}, }, + }, { + ptype: packetType1RTT, + num: 0, + dstConnID: clientConnIDs[0], + frames: []debugFrame{ + debugFrameNewConnectionID{ + seq: 1, + connID: serverConnIDs[1], + }, + }, }}, }, { - // Client Handshake + // Client Initial + Handshake + 1-RTT packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, version: 1, - srcConnID: clientConnID, - dstConnID: serverConnID, + srcConnID: clientConnIDs[0], + dstConnID: serverConnIDs[0], frames: []debugFrame{ debugFrameAck{ ranges: []i64range[packetNumber]{{0, 1}}, @@ -127,23 +148,39 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { ptype: packetTypeHandshake, num: 0, version: 1, - srcConnID: clientConnID, - dstConnID: serverConnID, + srcConnID: clientConnIDs[0], + dstConnID: serverConnIDs[0], frames: []debugFrame{ debugFrameAck{ ranges: []i64range[packetNumber]{{0, 1}}, }, debugFrameCrypto{}, }, + }, { + ptype: packetType1RTT, + num: 0, + dstConnID: serverConnIDs[0], + frames: []debugFrame{ + debugFrameAck{ + ranges: []i64range[packetNumber]{{0, 1}}, + }, + debugFrameNewConnectionID{ + seq: 1, + connID: clientConnIDs[1], + }, + }, }}, paddedSize: 1200, }, { // Server HANDSHAKE_DONE and session ticket packets: []*testPacket{{ ptype: packetType1RTT, - num: 0, - dstConnID: clientConnID, + num: 1, + dstConnID: clientConnIDs[0], frames: []debugFrame{ + debugFrameAck{ + ranges: []i64range[packetNumber]{{0, 1}}, + }, debugFrameHandshakeDone{}, debugFrameCrypto{}, }, @@ -152,13 +189,13 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { // Client ack (after max_ack_delay) packets: []*testPacket{{ ptype: packetType1RTT, - num: 0, - dstConnID: serverConnID, + num: 1, + dstConnID: serverConnIDs[0], frames: []debugFrame{ debugFrameAck{ ackDelay: unscaledAckDelayFromDuration( maxAckDelay, ackDelayExponent), - ranges: []i64range[packetNumber]{{0, 1}}, + ranges: []i64range[packetNumber]{{0, 2}}, }, }, }}, @@ -190,6 +227,69 @@ func fillCryptoFrames(d *testDatagram, data map[tls.QUICEncryptionLevel][]byte) } } +// uncheckedHandshake executes the handshake. +// +// Unlike testConn.handshake, it sends nothing unnecessary +// (in particular, no NEW_CONNECTION_ID frames), +// and does not validate the conn's responses. +// +// Useful for testing scenarios where configuration has +// changed the handshake responses in some way. +func (tc *testConn) uncheckedHandshake() { + defer func(saved map[byte]bool) { + tc.ignoreFrames = saved + }(tc.ignoreFrames) + tc.ignoreFrames = map[byte]bool{ + frameTypeAck: true, + frameTypeCrypto: true, + frameTypeNewConnectionID: true, + } + if tc.conn.side == serverSide { + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrame("send HANDSHAKE_DONE after handshake completes", + packetType1RTT, debugFrameHandshakeDone{}) + tc.writeFrames(packetType1RTT, + debugFrameAck{ + ackDelay: unscaledAckDelayFromDuration( + maxAckDelay, ackDelayExponent), + ranges: []i64range[packetNumber]{{0, tc.sentFramePacket.num + 1}}, + }) + } else { + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantIdle("don't expect any frames we aren't ignoring") + // Send the next two frames in separate packets, so the client sends an + // ack immediately without delay. We want to consume that ack here, rather + // than returning with a delayed ack waiting to be sent. + tc.ignoreFrames = nil + tc.writeFrames(packetType1RTT, + debugFrameHandshakeDone{}) + tc.writeFrames(packetType1RTT, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrame("client ACKs server's first 1-RTT packet", + packetType1RTT, debugFrameAck{ + ranges: []i64range[packetNumber]{{0, 2}}, + }) + + } + tc.wantIdle("handshake is done") +} + func TestConnClientHandshake(t *testing.T) { tc := newTestConn(t, clientSide) tc.handshake() @@ -224,6 +324,11 @@ func TestConnKeysDiscardedClient(t *testing.T) { packetTypeHandshake, debugFrameCrypto{ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake], }) + tc.wantFrame("client provides an additional connection ID", + packetType1RTT, debugFrameNewConnectionID{ + seq: 1, + connID: testLocalConnID(1), + }) // The client discards Initial keys after sending a Handshake packet. tc.writeFrames(packetTypeInitial, @@ -273,6 +378,11 @@ func TestConnKeysDiscardedServer(t *testing.T) { }) tc.writeFrames(packetTypeInitial, debugFrameConnectionCloseTransport{code: errInternal}) + tc.wantFrame("server provides an additional connection ID", + packetType1RTT, debugFrameNewConnectionID{ + seq: 1, + connID: testLocalConnID(1), + }) tc.wantIdle("server has discarded Initial keys, cannot read CONNECTION_CLOSE") // The server discards Handshake keys after sending a HANDSHAKE_DONE frame.