diff --git a/Readme.md b/Readme.md index 4c35732..15dd0bd 100644 --- a/Readme.md +++ b/Readme.md @@ -216,48 +216,43 @@ title: "TomTP Payload Packet" packet-beta 0-3: "Ack LEN" 4: "S/R" - 5: "LRG" + 5: "N/A" 6-7: "CLOSE" - 8-23: "Opt. ACKs: 16bit LRGs headers" - 24-55: "Opt. ACKs: RCV_WND_SIZE 32bit (or 64bit if LRG)" - 56-87: "Opt. ACKs: Example ACK: StreamId 32bit" - 88-119: "Opt. ACKs: Example ACK: StreamOffset 32bit (or 64bit if LRG)" - 120-135: "Opt. ACKs: Example ACK: Len 16bit" - 136-167: "StreamId 32bit" - 168-199: "StreamOffset 32bit (or 64bit if LRG)" - 200-255: "If data at this point available: Data..." + 8-71: "Opt. ACKs: RCV_WND_SIZE 64bit" + 72-103: "Opt. ACKs: Example ACK: StreamId 32bit" + 104-167: "Opt. ACKs: Example ACK: StreamOffset 64bit" + 168-183: "Opt. ACKs: Example ACK: Len 16bit" + 184-215: "StreamId 32bit" + 216-279: "StreamOffset 64bit" + 280-287: "Data..." ``` The TomTP payload packet begins with a header byte containing several control bits: * Bits 0-3 contain the "Ack LEN" field, indicating the number of ACK entries (0-15). * Bit 4 is the "S/R" flag which distinguishes between sender and receiver roles. -* Bit 5 is the "LRG" flag which, when set, indicates 64-bit offsets are used instead of 32-bit. +* Bit 5 is not used atm. * Bits 6-7 form the "CLOSE" field for connection control (00: no close, 01: close stream, 10: close connection and all streams, 11: not used). If ACKs are present (Ack LEN > 0), the following section appears: -* Bytes 8-23 contain an 16-bit LRGs header for ACK flags -* Bytes 24-55 hold the RCV_WND_SIZE, using 32 bits (or 64 bits if LRG is set) +* Bytes 8-71 hold the RCV_WND_SIZE, using 64 bits * For each ACK entry: - * Bytes 56-87 contain the StreamId (32 bits) - * Bytes 88-119 hold the StreamOffset, using 32 bits (or 64 bits if LRG is set) - * Bytes 120-135 contain the Len field (16 bits) + * Bytes 72-103 contain the StreamId (32 bits) + * Bytes 104-167 hold the StreamOffset, using 64 bits + * Bytes 168-183 contain the Len field (16 bits) The Data section: -* Bytes 176-207 contain the StreamId (32 bits) -* Bytes 208-239 hold the StreamOffset, using 32 bits (or 64 bits if LRG is set) -* Bytes 240-255 contain the data length (16 bits) +* Bytes 184-215 contain the StreamId (32 bits) +* Bytes 216-279 hold the StreamOffset, using 64 bits Only if data length is greater than zero: -* Bytes 256-287 and beyond contain the actual data payload - -This example shows the layout with 32-bit offsets (LRG=false), one ACK entry, and a 4-byte filler. +* Bytes 280-287 and beyond contain the actual data payload ### Overhead - **Total Overhead for Data Packets:** - double encrypted sn: 48 (39+9) bytes (for a 1400-byte packet, this results in an overhead of ~3.4%). + 52 bytes (crypto header 39 bytes + payload header 13 bytes) with 0 data (for a 1400-byte packet, this results in an overhead of ~3.7%). ### Communication States @@ -284,18 +279,18 @@ Source Code LoC =============================================================================== Language Files Lines Code Comments Blanks =============================================================================== - Go 15 3013 2362 151 500 - Markdown 1 302 0 248 54 + Go 16 3033 2315 212 506 + Markdown 1 296 0 243 53 =============================================================================== - Total 16 3315 2362 399 554 + Total 17 3329 2315 455 559 =============================================================================== Test Code LoC =============================================================================== Language Files Lines Code Comments Blanks =============================================================================== - Go 11 2547 2054 149 344 + Go 15 3123 2421 251 451 =============================================================================== - Total 11 2547 2054 149 344 + Total 15 3123 2421 251 451 =============================================================================== ``` \ No newline at end of file diff --git a/codec.go b/codec.go index 6232303..a182142 100644 --- a/codec.go +++ b/codec.go @@ -8,86 +8,80 @@ import ( "net/netip" ) -func (s *Stream) encode(b []byte) (enc []byte, offset int, err error) { +func (s *Stream) Overhead(ackLen int) (overhead int) { + protoOverhead := CalcProtoOverhead(ackLen) + switch { + case s.conn.firstPaket && s.conn.sender && s.conn.snCrypto == 0 && !s.conn.isRollover: + return protoOverhead + MsgInitSndSize + case s.conn.firstPaket && !s.conn.sender && s.conn.snCrypto == 0 && !s.conn.isRollover: + return protoOverhead + MinMsgInitRcvSize + case !s.conn.firstPaket && s.conn.sender && s.conn.snCrypto == 0: //rollover + return protoOverhead + MinMsgData0Size + case !s.conn.firstPaket && !s.conn.sender && s.conn.snCrypto == 0: //rollover + return protoOverhead + MinMsgData0Size + default: + return protoOverhead + MinMsgSize + } +} + +func (s *Stream) encode(origData []byte, acks []Ack) (encData []byte, err error) { if s.state == StreamEnded || s.conn.state == ConnectionEnded { - return nil, 0, ErrStreamClosed + return nil, ErrStreamClosed } - p := &Payload{ - CloseOp: GetCloseOp(s.state == StreamEnding, s.conn.state == ConnectionEnding), - IsSender: s.conn.sender, - RcvWndSize: s.rcvWndSize - uint64(s.rbRcv.Size()), - Acks: s.rbRcv.GetAcks(), - StreamId: s.streamId, - StreamOffset: s.streamOffsetNext, - Data: []byte{}, + p := &PayloadMeta{ + CloseOp: GetCloseOp(s.state == StreamEnding, s.conn.state == ConnectionEnding), + IsSender: s.conn.sender, + RcvWndSize: s.conn.maxRcvWndSize - uint64(s.conn.rbRcv.Size()), + Acks: acks, + StreamId: s.streamId, } switch { case s.conn.firstPaket && s.conn.sender && s.conn.snCrypto == 0 && !s.conn.isRollover: - overhead := CalcOverhead(p) + MsgInitSndSize - offset = min(s.conn.mtu-overhead, len(b)) - p.Data = b[:offset] - var payRaw []byte - payRaw, _, err = EncodePayload(p) + payRaw, _, err = EncodePayload(p, origData) if err != nil { - return nil, 0, err + return nil, err } slog.Debug("EncodeWriteInitS0", debugGoroutineID(), s.debug(), slog.Int("len(payRaw)", len(payRaw))) - enc, err = EncodeWriteInitS0(s.conn.pubKeyIdRcv, s.conn.listener.prvKeyId.PublicKey(), s.conn.prvKeyEpSnd, s.conn.prvKeyEpSndRollover, payRaw) + encData, err = EncodeWriteInitS0(s.conn.pubKeyIdRcv, s.conn.listener.prvKeyId.PublicKey(), s.conn.prvKeyEpSnd, s.conn.prvKeyEpSndRollover, payRaw) case s.conn.firstPaket && !s.conn.sender && s.conn.snCrypto == 0 && !s.conn.isRollover: - overhead := CalcOverhead(p) + MinMsgInitRcvSize - offset = min(s.conn.mtu-overhead, len(b)) - p.Data = b[:offset] - var payRaw []byte - payRaw, _, err = EncodePayload(p) + payRaw, _, err = EncodePayload(p, origData) if err != nil { - return nil, 0, err + return nil, err } slog.Debug("EncodeWriteInitR0", debugGoroutineID(), s.debug(), slog.Int("len(payRaw)", len(payRaw))) - enc, err = EncodeWriteInitR0(s.conn.pubKeyIdRcv, s.conn.listener.prvKeyId.PublicKey(), s.conn.pubKeyEpRcv, s.conn.prvKeyEpSnd, s.conn.prvKeyEpSndRollover, payRaw) + encData, err = EncodeWriteInitR0(s.conn.pubKeyIdRcv, s.conn.listener.prvKeyId.PublicKey(), s.conn.pubKeyEpRcv, s.conn.prvKeyEpSnd, s.conn.prvKeyEpSndRollover, payRaw) case !s.conn.firstPaket && s.conn.sender && s.conn.snCrypto == 0: //rollover - overhead := CalcOverhead(p) + MinMsgData0Size - offset = min(s.conn.mtu-overhead, len(b)) - p.Data = b[:offset] - var payRaw []byte - payRaw, _, err = EncodePayload(p) + payRaw, _, err = EncodePayload(p, origData) if err != nil { - return nil, 0, err + return nil, err } slog.Debug("EncodeWriteData0", debugGoroutineID(), s.debug(), slog.Int("len(payRaw)", len(payRaw))) - enc, err = EncodeWriteData0(s.conn.pubKeyIdRcv, s.conn.listener.prvKeyId.PublicKey(), s.conn.sender, s.conn.pubKeyEpRcv, s.conn.prvKeyEpSndRollover, payRaw) + encData, err = EncodeWriteData0(s.conn.pubKeyIdRcv, s.conn.listener.prvKeyId.PublicKey(), s.conn.sender, s.conn.pubKeyEpRcv, s.conn.prvKeyEpSndRollover, payRaw) case !s.conn.firstPaket && !s.conn.sender && s.conn.snCrypto == 0: //rollover - overhead := CalcOverhead(p) + MinMsgData0Size - offset = min(s.conn.mtu-overhead, len(b)) - p.Data = b[:offset] - var payRaw []byte - payRaw, _, err = EncodePayload(p) + payRaw, _, err = EncodePayload(p, origData) if err != nil { - return nil, 0, err + return nil, err } slog.Debug("EncodeWriteData0", debugGoroutineID(), s.debug(), slog.Int("len(payRaw)", len(payRaw))) - enc, err = EncodeWriteData0(s.conn.pubKeyIdRcv, s.conn.listener.prvKeyId.PublicKey(), s.conn.sender, s.conn.pubKeyEpRcv, s.conn.prvKeyEpSndRollover, payRaw) + encData, err = EncodeWriteData0(s.conn.pubKeyIdRcv, s.conn.listener.prvKeyId.PublicKey(), s.conn.sender, s.conn.pubKeyEpRcv, s.conn.prvKeyEpSndRollover, payRaw) default: - overhead := CalcOverhead(p) + MinMsgSize - offset = min(s.conn.mtu-overhead, len(b)) - p.Data = b[:offset] - var payRaw []byte - payRaw, _, err = EncodePayload(p) + payRaw, _, err = EncodePayload(p, origData) if err != nil { - return nil, 0, err + return nil, err } slog.Debug("EncodeWriteData", debugGoroutineID(), s.debug(), slog.Int("len(payRaw)", len(payRaw))) - enc, err = EncodeWriteData(s.conn.pubKeyIdRcv, s.conn.listener.prvKeyId.PublicKey(), s.conn.sender, s.conn.sharedSecret, s.conn.snCrypto, payRaw) + encData, err = EncodeWriteData(s.conn.pubKeyIdRcv, s.conn.listener.prvKeyId.PublicKey(), s.conn.sender, s.conn.sharedSecret, s.conn.snCrypto, payRaw) } if err != nil { - return nil, 0, err + return nil, err } s.conn.snCrypto++ @@ -104,11 +98,7 @@ func (s *Stream) encode(b []byte) (enc []byte, offset int, err error) { s.conn.state = ConnectionEnded } - //only if we send data, increase the sequence number of the stream - if len(p.Data) > 0 { - s.streamOffsetNext += uint64(offset) - } - return enc, offset, nil + return encData, nil } func (l *Listener) decode(buffer []byte, remoteAddr netip.AddrPort) (conn *Connection, m *Message, err error) { diff --git a/codec_test.go b/codec_test.go index ad43ce4..6a2fd0e 100644 --- a/codec_test.go +++ b/codec_test.go @@ -32,7 +32,7 @@ func TestStreamEncode(t *testing.T) { setupStream func() *Stream input []byte expectedError error - validateOutput func(*testing.T, []byte, int, error) + validateOutput func(*testing.T, []byte, error) }{ { name: "Stream closed", @@ -42,7 +42,7 @@ func TestStreamEncode(t *testing.T) { } return stream }, - input: []byte("test data"), + input: []byte("test dataToSend"), expectedError: ErrStreamClosed, }, { @@ -56,7 +56,7 @@ func TestStreamEncode(t *testing.T) { } return stream }, - input: []byte("test data"), + input: []byte("test dataToSend"), expectedError: ErrStreamClosed, }, { @@ -73,19 +73,18 @@ func TestStreamEncode(t *testing.T) { listener: &Listener{ prvKeyId: prvIdAlice, }, + rbRcv: NewReceiveBuffer(1000), } stream := &Stream{ state: StreamOpen, conn: conn, - rbRcv: NewReceiveBuffer(1000), } return stream }, - input: []byte("test data"), - validateOutput: func(t *testing.T, output []byte, n int, err error) { + input: []byte("test dataToSend"), + validateOutput: func(t *testing.T, output []byte, err error) { assert.NoError(t, err) assert.NotNil(t, output) - assert.Greater(t, n, 0) }, }, } @@ -93,7 +92,7 @@ func TestStreamEncode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { stream := tt.setupStream() - output, n, err := stream.encode(tt.input) + output, err := stream.encode(tt.input, nil) if tt.expectedError != nil { assert.Equal(t, tt.expectedError, err) @@ -101,7 +100,7 @@ func TestStreamEncode(t *testing.T) { } if tt.validateOutput != nil { - tt.validateOutput(t, output, n, err) + tt.validateOutput(t, output, err) } }) } @@ -124,12 +123,12 @@ func TestEndToEndCodec(t *testing.T) { prvKeyEpSnd: prvEpAlice, prvKeyEpSndRollover: prvEpAliceRoll, listener: lAlice, + rbRcv: NewReceiveBuffer(1000), } streamAlice := &Stream{ state: StreamOpen, conn: connAlice, - rbRcv: NewReceiveBuffer(1000), } lBob := &Listener{ @@ -139,10 +138,9 @@ func TestEndToEndCodec(t *testing.T) { // Test encoding and decoding testData := []byte("test message") - encoded, n, err := streamAlice.encode(testData) + encoded, err := streamAlice.encode(testData, nil) require.NoError(t, err) require.NotNil(t, encoded) - require.Greater(t, n, 0) a, _ := netip.ParseAddr("127.0.0.1") remoteAddr := netip.AddrPortFrom(a, uint16(8080)) @@ -157,7 +155,7 @@ func TestEndToEndCodec(t *testing.T) { } func TestEndToEndCodecLargeData(t *testing.T) { - // Test with various data sizes + // Test with various dataToSend sizes dataSizes := []int{100, 1000, 2000, 10000} for _, size := range dataSizes { // Create Alice's connection and stream @@ -182,6 +180,7 @@ func TestEndToEndCodecLargeData(t *testing.T) { prvKeyEpSndRollover: prvEpAliceRoll, listener: lAlice, rbSnd: NewSendBuffer(rcvBufferCapacity), + rbRcv: NewReceiveBuffer(12000), } connId := binary.LittleEndian.Uint64(prvIdAlice.PublicKey().Bytes()) ^ binary.LittleEndian.Uint64(prvIdBob.PublicKey().Bytes()) lAlice.connMap[connId] = connAlice @@ -189,7 +188,6 @@ func TestEndToEndCodecLargeData(t *testing.T) { streamAlice := &Stream{ state: StreamOpen, conn: connAlice, - rbRcv: NewReceiveBuffer(12000), } a, _ := netip.ParseAddr("127.0.0.1") @@ -204,38 +202,28 @@ func TestEndToEndCodecLargeData(t *testing.T) { remainingData := testData var decodedData []byte - for len(remainingData) > 0 { - encoded, nA, err := streamAlice.encode(remainingData) - require.NoError(t, err) - require.NotNil(t, encoded) - require.LessOrEqual(t, nA, connAlice.mtu) - - connBob, m, err := lBob.decode(encoded, remoteAddr) - require.NoError(t, err) - s, _, err := connBob.decode(m.PayloadRaw, 0) - require.NoError(t, err) - rb, err := s.ReadBytes() - require.NoError(t, err) - decodedData = append(decodedData, rb...) - - streamBob, _ := connBob.GetOrNewStreamRcv(s.streamId) - encoded, nB, err := streamBob.encode([]byte{}) - require.NoError(t, err) - require.NotNil(t, encoded) - require.LessOrEqual(t, nB, connAlice.mtu) - - connAlice, m, err = lAlice.decode(encoded, remoteAddr) - s, _, err = connAlice.decode(m.PayloadRaw, 0) - require.NoError(t, err) - //rb, err = s.ReadBytes() - //require.NoError(t, err) - - if len(remainingData) > nA { - remainingData = remainingData[nA:] - } else { - break - } - } + encoded, err := streamAlice.encode(remainingData, nil) + require.NoError(t, err) + require.NotNil(t, encoded) + + connBob, m, err := lBob.decode(encoded, remoteAddr) + require.NoError(t, err) + s, _, err := connBob.decode(m.PayloadRaw, 0) + require.NoError(t, err) + rb, err := s.ReadBytes() + require.NoError(t, err) + decodedData = append(decodedData, rb...) + + streamBob, _ := connBob.GetOrNewStreamRcv(s.streamId) + encoded, err = streamBob.encode([]byte{}, nil) + require.NoError(t, err) + require.NotNil(t, encoded) + + connAlice, m, err = lAlice.decode(encoded, remoteAddr) + s, _, err = connAlice.decode(m.PayloadRaw, 0) + require.NoError(t, err) + //rb, err = s.ReadBytes() + //require.NoError(t, err) assert.Equal(t, testData, decodedData, "Data mismatch for Size %d", size) }) diff --git a/conn.go b/conn.go index ff04862..84d73ab 100644 --- a/conn.go +++ b/conn.go @@ -30,13 +30,19 @@ type Connection struct { sharedSecretRollover1 []byte sharedSecretRollover2 []byte nextSleepMillis uint64 - rbSnd *SendBuffer // Send buffer for outgoing data, handles the global sn + rbSnd *SendBuffer // Send buffer for outgoing dataToSend, handles the global sn + rbRcv *ReceiveBuffer bytesWritten uint64 mtu int sender bool firstPaket bool isRollover bool snCrypto uint64 //this is 48bit + + // Flow control + maxRcvWndSize uint64 // Receive window Size + maxSndWndSize uint64 // Send window Size + RTT BBR mu sync.Mutex @@ -123,14 +129,12 @@ func (c *Connection) GetOrNewStreamRcv(streamId uint32) (*Stream, bool) { if stream, ok := c.streams[streamId]; !ok { ctx, cancel := context.WithCancel(context.Background()) s := &Stream{ - streamId: streamId, - streamOffsetNext: 0, - state: StreamStarting, - conn: c, - rbRcv: NewReceiveBuffer(rcvBufferCapacity), - closeCtx: ctx, - closeCancelFn: cancel, - mu: sync.Mutex{}, + streamId: streamId, + state: StreamStarting, + conn: c, + closeCtx: ctx, + closeCancelFn: cancel, + mu: sync.Mutex{}, } c.streams[streamId] = s return s, true @@ -212,7 +216,7 @@ func (c *Connection) SetAlphaBeta(alpha, beta float64) { } func (c *Connection) decode(decryptedData []byte, nowMillis uint64) (s *Stream, isNew bool, err error) { - p, _, err := DecodePayload(decryptedData) + p, _, payloadData, err := DecodePayload(decryptedData) if err != nil { slog.Info("error in decoding payload from new connection", slog.Any("error", err)) return nil, false, err @@ -249,7 +253,7 @@ func (c *Connection) decode(decryptedData []byte, nowMillis uint64) (s *Stream, } //TODO: handle status, e.g., we may have duplicates - s.receive(p.Data, p.StreamOffset) + s.receive(p.StreamOffset, payloadData) return s, isNew, nil } diff --git a/crypto.go b/crypto.go index 4cd4038..369524e 100644 --- a/crypto.go +++ b/crypto.go @@ -25,7 +25,7 @@ const ( //MinPayloadSize is the minimum payload Size in bytes. We need at least 8 bytes as // 8 + the MAC Size (16 bytes) is 24 bytes, which is used as the input for // sealing with chacha20poly1305.NewX(). - MinPayloadSize = 9 + MinPayloadSize = 8 PubKeySize = 32 HeaderSize = 1 @@ -47,7 +47,7 @@ type Message struct { MsgType MsgType SnConn uint64 PayloadRaw []byte - Payload *Payload + Payload *PayloadMeta Fill []byte SharedSecret []byte } @@ -62,7 +62,7 @@ func EncodeWriteInitS0( rawData []byte) (encData []byte, err error) { if len(rawData) < MinPayloadSize { - return nil, errors.New("packet data too short") + return nil, errors.New("packet dataToSend too short") } // Write the public key @@ -91,14 +91,14 @@ func EncodeWriteInitS0( return nil, err } - // Encrypt and write data + // Encrypt and write dataToSend fillLen := uint16(startMtu - MsgInitSndSize - len(rawData)) // Create payload with filler length and filler if needed payloadWithFiller := make([]byte, 2+int(fillLen)+len(rawData)) // +2 for filler length // Add filler length PutUint16(payloadWithFiller, fillLen) - // After the filler, copy the data + // After the filler, copy the dataToSend copy(payloadWithFiller[2+int(fillLen):], rawData) return chainedEncrypt(0, true, noPerfectForwardSharedSecret, headerAndCryptoBuffer, payloadWithFiller) @@ -113,7 +113,7 @@ func EncodeWriteInitR0( rawData []byte) (encData []byte, err error) { if len(rawData) < MinPayloadSize { - return nil, errors.New("packet data too short") + return nil, errors.New("packet dataToSend too short") } // Write the public key @@ -139,7 +139,7 @@ func EncodeWriteInitR0( return nil, err } - // Encrypt and write data + // Encrypt and write dataToSend return chainedEncrypt(0, true, perfectForwardSharedSecret, headerAndCryptoBuffer, rawData) } @@ -152,10 +152,10 @@ func EncodeWriteData0( rawData []byte) (encData []byte, err error) { if len(rawData) < MinPayloadSize { - return nil, errors.New("packet data too short") + return nil, errors.New("packet dataToSend too short") } - // Preallocate buffer with capacity for header and crypto data + // Preallocate buffer with capacity for header and crypto dataToSend headerAndCryptoBuffer := make([]byte, MsgHeaderSize+Data0CryptoSize) // Write version @@ -174,7 +174,7 @@ func EncodeWriteData0( return nil, err } - // Encrypt and write data + // Encrypt and write dataToSend return chainedEncrypt(0, isSender, perfectForwardSharedSecret, headerAndCryptoBuffer, rawData) } @@ -187,7 +187,7 @@ func EncodeWriteData( rawData []byte) (encData []byte, err error) { if len(rawData) < MinPayloadSize { - return nil, errors.New("packet data too short") + return nil, errors.New("packet dataToSend too short") } // Preallocate buffer with capacity for header and connection ID @@ -200,13 +200,13 @@ func EncodeWriteData( connId := Uint64(pubKeyIdRcv.Bytes()) ^ Uint64(pubKeyIdSnd.Bytes()) PutUint64(headerBuffer[HeaderSize:], connId) - // Encrypt and write data + // Encrypt and write dataToSend return chainedEncrypt(sn, isSender, sharedSecret, headerBuffer, rawData) } func chainedEncrypt(snConn uint64, isSender bool, sharedSecret []byte, headerAndCrypto []byte, rawData []byte) (fullMessage []byte, err error) { if len(rawData) < 8 { - return nil, errors.New("data too short") + return nil, errors.New("dataToSend too short") } if snConn >= (1 << (SnSize * 8)) { return nil, fmt.Errorf("serial number is not a 48-bit value") @@ -323,7 +323,7 @@ func DecodeInitS0( return nil, nil, nil, nil, err } - // Extract actual data - Remove filler_length and filler + // Extract actual dataToSend - Remove filler_length and filler fillerLen := Uint16(decryptedData) actualData := decryptedData[2+int(fillerLen):] @@ -448,7 +448,7 @@ func DecodeData( func chainedDecrypt(isSender bool, sharedSecret []byte, header []byte, encData []byte) (snConn uint64, decryptedData []byte, err error) { if len(encData) < 24 { // 8 bytes for encSn + 24 bytes for nonceRand - return 0, nil, errors.New("encrypted data too short") + return 0, nil, errors.New("encrypted dataToSend too short") } snConnSer := make([]byte, SnSize) diff --git a/crypto_test.go b/crypto_test.go index e46a1c9..887816e 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -40,9 +40,9 @@ func TestDoubleEncryptDecrypt(t *testing.T) { assert.Nil(t, err) if len(buf) == 0 { - t.Fatalf("No encrypted data written") + t.Fatalf("No encrypted dataToSend written") } - t.Logf("Encrypted data: %s", hex.EncodeToString(buf)) + t.Logf("Encrypted dataToSend: %s", hex.EncodeToString(buf)) decryptedSn, decryptedData, err := chainedDecrypt(false, sharedSecret, buf[0:len(tc.additionalData)], buf[len(tc.additionalData):]) assert.Nil(t, err) @@ -75,9 +75,9 @@ func TestEncodeDecodeInitS0(t *testing.T) { payload []byte expected []byte }{ - {"Short Payload", []byte("short1234"), nil}, - {"Long Payload", randomBytes(100), nil}, - {"Max Payload", randomBytes(1400), nil}, + {"Short PayloadMeta", []byte("short1234"), nil}, + {"Long PayloadMeta", randomBytes(100), nil}, + {"Max PayloadMeta", randomBytes(1400), nil}, } for _, tc := range testCases { @@ -102,8 +102,8 @@ func TestEncodeDecodeInitR0(t *testing.T) { payload []byte expected []byte }{ - {"Short Payload", []byte("short1234"), nil}, - {"Long Payload", randomBytes(100), nil}, + {"Short PayloadMeta", []byte("short1234"), nil}, + {"Long PayloadMeta", randomBytes(100), nil}, } for _, tc := range testCases { @@ -141,9 +141,9 @@ func TestEncodeDecodeData0AndData(t *testing.T) { payload []byte expected []byte }{ - {"Short Payload", []byte("short1234"), nil}, - {"Long Payload", randomBytes(100), nil}, - {"Max Payload", randomBytes(1400), nil}, + {"Short PayloadMeta", []byte("short1234"), nil}, + {"Long PayloadMeta", randomBytes(100), nil}, + {"Max PayloadMeta", randomBytes(1400), nil}, } for _, tc := range testCases { @@ -189,7 +189,7 @@ func TestEncodeDecodeData0AndData(t *testing.T) { func FuzzEncodeDecodeCrypto(f *testing.F) { // Add seed corpus with various sizes including invalid ones seeds := [][]byte{ - []byte("initial data for fuzzer"), + []byte("initial dataToSend for fuzzer"), []byte("1234567"), // 7 bytes - should fail []byte("12345678"), // 8 bytes - minimum valid Size []byte("123456789"), // 9 bytes - valid @@ -201,17 +201,17 @@ func FuzzEncodeDecodeCrypto(f *testing.F) { } f.Fuzz(func(t *testing.T, data []byte) { - // First verify data Size requirements + // First verify dataToSend Size requirements if len(data) < MinPayloadSize { - // For data less than minimum Size, verify that we Get appropriate error + // For dataToSend less than minimum Size, verify that we Get appropriate error alicePrvKeyId, alicePrvKeyEp := generateTwoKeys(t) alicePrvKeyEpRollover := generateKeys(t) bobPrvKeyId, _ := generateTwoKeys(t) // Try InitSnd - should fail _, err := EncodeWriteInitS0(bobPrvKeyId.PublicKey(), alicePrvKeyId.PublicKey(), alicePrvKeyEp, alicePrvKeyEpRollover, data) - assert.Error(t, err, "Expected error for data Size %d < %d", len(data), MinPayloadSize) - assert.Equal(t, "packet data too short", err.Error(), "Wrong error message for small data") + assert.Error(t, err, "Expected error for dataToSend Size %d < %d", len(data), MinPayloadSize) + assert.Equal(t, "packet dataToSend too short", err.Error(), "Wrong error message for small dataToSend") return } @@ -264,7 +264,7 @@ func FuzzEncodeDecodeCrypto(f *testing.F) { }) } -// Helper function to generate random data +// Helper function to generate random dataToSend func randomBytes(n int) []byte { b := make([]byte, n) _, err := rand.Read(b) diff --git a/end2end_test.go b/end2end_test.go index 9708007..a591d2c 100644 --- a/end2end_test.go +++ b/end2end_test.go @@ -19,8 +19,8 @@ import ( type inMemoryNetworkConn struct { localAddr net.Addr remoteAddr netip.AddrPort - sendBuffer bytes.Buffer // Buffers outgoing data - recvBuffer bytes.Buffer // Buffers incoming data + sendBuffer bytes.Buffer // Buffers outgoing dataToSend + recvBuffer bytes.Buffer // Buffers incoming dataToSend mu sync.Mutex // Protects both buffers closeChan chan struct{} closed atomic.Bool @@ -94,7 +94,7 @@ func (c *inMemoryNetworkConn) LocalAddr() net.Addr { } // setupInMemoryPair creates two inMemoryNetworkConn connections that are directly linked. -// There are NO goroutines used for relaying data. The test must explicitly transfer data between the connections. +// There are NO goroutines used for relaying dataToSend. The test must explicitly transfer dataToSend between the connections. func setupInMemoryPair() (*inMemoryNetworkConn, *inMemoryNetworkConn, error) { addrA, err := net.ResolveUDPAddr("udp", "127.0.0.1:10000") if err != nil { @@ -120,7 +120,7 @@ func setupInMemoryPair() (*inMemoryNetworkConn, *inMemoryNetworkConn, error) { return nConnA, nConnB, nil } -// relayData simulates sending the data one way +// relayData simulates sending the dataToSend one way func relayData(connSrc, connDest *inMemoryNetworkConn, maxBytes int) (int, error) { connSrc.mu.Lock() defer connSrc.mu.Unlock() @@ -130,7 +130,7 @@ func relayData(connSrc, connDest *inMemoryNetworkConn, maxBytes int) (int, error // Check how many bytes are available to relay availableBytes := connSrc.sendBuffer.Len() - // Limit the relay to maxBytes if specified and if there's data available + // Limit the relay to maxBytes if specified and if there's dataToSend available if maxBytes > 0 && availableBytes > maxBytes { availableBytes = maxBytes } @@ -142,14 +142,14 @@ func relayData(connSrc, connDest *inMemoryNetworkConn, maxBytes int) (int, error // Create a limited reader to read only the availableBytes limitedReader := io.LimitReader(&connSrc.sendBuffer, int64(availableBytes)) - // Copy the limited amount of data from source's send buffer into destination's recv buffer. + // Copy the limited amount of dataToSend from source's send buffer into destination's recv buffer. bytesWritten, err := io.Copy(&connDest.recvBuffer, limitedReader) if err != nil { return 0, err } - // Reset the sendBuffer to remove the relayed data - // Create a new buffer and write remaining data into it. + // Reset the sendBuffer to remove the relayed dataToSend + // Create a new buffer and write remaining dataToSend into it. remainingData := connSrc.sendBuffer.Bytes() newBuffer := bytes.NewBuffer(remainingData) connSrc.sendBuffer = *newBuffer @@ -224,3 +224,59 @@ func TestEndToEndInMemory(t *testing.T) { assert.Nil(t, err) assert.Equal(t, a, b) } + +/*func TestSlowStart(t *testing.T) { + nConnA, nConnB, err := setupInMemoryPair() + assert.Nil(t, err) + defer nConnA.Close() + defer nConnB.Close() + + var streamB *Stream + streamA, listenerB, err := createTwoStreams(nConnA, nConnB, testPrvKey1, testPrvKey2, func(s *Stream) { streamB = s }) + assert.Nil(t, err) + + msgSize := 500 + msgA := make([]byte, msgSize) + + numPackets := 10 + for i := 0; i < numPackets; i++ { + t.Run(fmt.Sprintf("Packet %d", i+1), func(t *testing.T) { + // Send dataToSend from A to B + _, err = streamA.Write(msgA) + assert.Nil(t, err) + + err = streamA.conn.listener.Update(0) + assert.Nil(t, err) + _, err = relayData(nConnA, nConnB, startMtu) + assert.Nil(t, err) + + err = listenerB.Update(0) + assert.Nil(t, err) + _, err = relayData(nConnB, nConnA, startMtu) + assert.Nil(t, err) + + err = streamA.conn.listener.Update(0) + assert.Nil(t, err) + + //read stream + msgB := make([]byte, msgSize) + _, err := streamB.Read(msgB) + if err != nil { + if !errors.Is(err, io.EOF) { + t.Error(err) + } + } + //Assert in order to make test not crash for stream B + assert.Equal(t, msgA, msgB) + + fmt.Println("cwnd", streamA.conn.BBR.cwnd, "sthress", streamA.conn.BBR.ssthresh, "streamB-Read", streamB.bytesRead) + }) + + } + + lastRead := streamB.bytesRead + + if streamA.conn.BBR.ssthresh <= lastRead { + t.Error("Did not happen what supposed to happen") + } +}*/ diff --git a/linkedhashmap_test.go b/linkedhashmap_test.go index b216fc2..e0ed283 100644 --- a/linkedhashmap_test.go +++ b/linkedhashmap_test.go @@ -45,7 +45,7 @@ func (s *LinkedHashMapTestSuite) TestRemoveOperations() { // Test Remove on empty map s.Nil(s.lhm.Remove("nonexistent")) - // Setup test data + // Setup test dataToSend s.lhm.Put("first", 1) s.lhm.Put("second", 2) s.lhm.Put("third", 3) @@ -82,7 +82,7 @@ func (s *LinkedHashMapTestSuite) TestInsertionOrder() { {"four", 4}, } - // Insert items + // InsertBlocking items for _, item := range items { s.lhm.Put(item.key, item.val) } diff --git a/listener.go b/listener.go index a04efb6..1d9a571 100644 --- a/listener.go +++ b/listener.go @@ -232,33 +232,62 @@ func (l *Listener) UpdateRcv(nowMillis uint64) (err error) { func (l *Listener) UpdateSnd(nowMillis uint64) (err error) { //timeouts, retries, ping, sending packets for _, c := range l.connMap { - - _, _, data := c.rbSnd.ReadyToRetransmit(startMtu, uint64(c.RTT.rto.Milliseconds()), nowMillis) - - if data != nil { - slog.Debug("UpdateSnd/ReadyToRetransmit", debugGoroutineID(), slog.Any("len(data)", len(data))) - n, err := l.localConn.WriteToUDPAddrPort(data, c.remoteAddr) - if err != nil { - return c.Close() + for _, stream := range c.streams { + acks := c.rbRcv.GetAcks() + maxData := stream.calcLen(startMtu, len(acks)) + splitData := c.rbSnd.ReadyToRetransmit(stream.streamId, maxData, uint64(c.RTT.rto.Milliseconds()), nowMillis) + if splitData != nil { + encData, err := stream.encode(splitData, acks) + if err != nil { + return err + } + + slog.Debug("UpdateSnd/ReadyToRetransmit", debugGoroutineID(), slog.Any("len(dataToSend)", len(encData))) + n, err := l.localConn.WriteToUDPAddrPort(encData, c.remoteAddr) + if err != nil { + return err + } + c.bytesWritten += uint64(n) + + //we detected a packet loss, reduce ssthresh by 2 + c.BBR.ssthresh = c.BBR.cwnd / 2 + if c.BBR.ssthresh < uint64(2*c.mtu) { + c.BBR.ssthresh = uint64(2 * c.mtu) + } + continue } - c.bytesWritten += uint64(n) - //we detected a packet loss, reduce ssthresh by 2 - c.BBR.ssthresh = c.BBR.cwnd / 2 - if c.BBR.ssthresh < uint64(2*c.mtu) { - c.BBR.ssthresh = uint64(2 * c.mtu) + splitData = c.rbSnd.ReadyToSend(stream.streamId, maxData, nowMillis) + if splitData != nil { + encData, err := stream.encode(splitData, acks) + if err != nil { + return err + } + + slog.Debug("UpdateSnd/ReadyToSend", debugGoroutineID(), slog.Any("len(dataToSend)", len(encData))) + n, err := l.localConn.WriteToUDPAddrPort(encData, c.remoteAddr) + if err != nil { + return c.Close() + } + c.bytesWritten += uint64(n) + continue } - } - - _, _, data, _ = c.rbSnd.ReadyToSend(startMtu, nowMillis) - if data != nil { - slog.Debug("UpdateSnd/ReadyToSend", debugGoroutineID(), slog.Any("len(data)", len(data))) - n, err := l.localConn.WriteToUDPAddrPort(data, c.remoteAddr) - if err != nil { - return c.Close() + //here we check if we have just acks to send + if len(acks) > 0 { + encData, err := stream.encode([]byte{}, acks) + if err != nil { + return err + } + + slog.Debug("UpdateSnd/Acks", debugGoroutineID(), slog.Any("len(dataToSend)", len(encData))) + n, err := l.localConn.WriteToUDPAddrPort(encData, c.remoteAddr) + if err != nil { + return c.Close() + } + c.bytesWritten += uint64(n) + continue } - c.bytesWritten += uint64(n) } } return nil @@ -312,6 +341,7 @@ func (l *Listener) newConn( firstPaket: true, mtu: startMtu, rbSnd: NewSendBuffer(rcvBufferCapacity), + rbRcv: NewReceiveBuffer(rcvBufferCapacity), RTT: RTT{ alpha: 0.125, beta: 0.25, @@ -342,7 +372,7 @@ func (l *Listener) ReadUDP() ([]byte, netip.AddrPort, error) { if ok && netErr.Timeout() { slog.Debug("ReadUDP - net.Timeout") - return nil, netip.AddrPort{}, nil // Timeout is normal, return no data/error + return nil, netip.AddrPort{}, nil // Timeout is normal, return no dataToSend/error } else { slog.Error("ReadUDP - error during read", slog.Any("error", err)) return nil, netip.AddrPort{}, err diff --git a/listener_test.go b/listener_test.go index fdcffee..d9c0d48 100644 --- a/listener_test.go +++ b/listener_test.go @@ -73,7 +73,7 @@ func TestClose(t *testing.T) { func TestListenerUpdate_NoActivity(t *testing.T) { // 1. Arrange - // Create a listener, but don't send any data to it. + // Create a listener, but don't send any dataToSend to it. acceptCalled := false acceptFn := func(s *Stream) { acceptCalled = true @@ -169,7 +169,7 @@ func (c *ChannelNetworkConn) ReadFromUDP(p []byte) (int, net.Addr, error) { func (c *ChannelNetworkConn) WriteToUDP(p []byte, addr net.Addr) (int, error) { // Sends the message on the out channel. - //c.out <- &SendBuffer{data: p} + //c.out <- &SendBuffer{dataToSend: p} return len(p), nil } @@ -219,7 +219,7 @@ func NewTestChannel(localAddr1, localAddr2 net.Addr) (*ChannelNetworkConn, *Chan func forwardMessages(sender, receiver *ChannelNetworkConn) { /*for msg := range sender.out { select { - case receiver.in <- msg.data: + case receiver.in <- msg.dataToSend: receiver.mu.Lock() receiver.messageCounter++ receiver.cond.Broadcast() diff --git a/proto.go b/proto.go index 001b4e7..dff0bd8 100644 --- a/proto.go +++ b/proto.go @@ -2,7 +2,6 @@ package tomtp import ( "errors" - "math" ) const ( @@ -10,30 +9,27 @@ const ( CloseStream CloseConnection - uint32Max = math.MaxUint32 - FlagAckMask = 0xf // bits 0-3 for ACK count (0-15) FlagSenderShift = 4 // bit 3 for Sender/Receiver - FlagLrgShift = 5 // bit 5 for large offsets FlagCloseShift = 6 // bits 6-7 for close flags FlagCloseMask = 0x3 + + MinProtoSize = 13 ) type CloseOp uint8 var ( ErrPayloadTooSmall = errors.New("payload Size below minimum of 8 bytes") - ErrInvalidAckCount = errors.New("invalid ACK count") ) -type Payload struct { +type PayloadMeta struct { CloseOp CloseOp IsSender bool Acks []Ack RcvWndSize uint64 StreamId uint32 StreamOffset uint64 - Data []byte } type Ack struct { @@ -53,48 +49,25 @@ func GetCloseOp(streamClose bool, connClose bool) CloseOp { } } -func CalcOverhead(p *Payload) int { - size := 1 //header Size - - if p.Acks != nil { - size += 1 // IsAcksLargeOffset - - // RcvWndSize Size depends on its value - if p.RcvWndSize > uint32Max { - size += 8 // RcvWndSize (64-bit) - } else { - size += 4 // RcvWndSize (32-bit) - } - - for _, ack := range p.Acks { - size += 4 // StreamId - if ack.StreamOffset > uint32Max { - size += 8 // StreamOffset (64-bit) - } else { - size += 4 // StreamOffset (32-bit) - } - size += 2 // Len - } +func CalcProtoOverhead(ackLen int) int { + overhead := 1 //header Size + if ackLen > 0 { + overhead += 8 // RcvWndSize (64-bit) + overhead += ackLen * (4 + 8 + 2) // StreamId, StreamOffset (64-bit), Len } - - size += 4 // StreamId - if p.StreamOffset > uint32Max { - size += 8 // StreamOffset (64-bit) - } else { - size += 4 // StreamOffset (32-bit) - } - size += len(p.Data) // Data - - return size + overhead += 4 // StreamId + overhead += 8 // StreamOffset (64-bit) + // now comes the data... -> but not calculated in overhead + return overhead } -func EncodePayload(p *Payload) (encoded []byte, offset int, err error) { +func EncodePayload(p *PayloadMeta, payloadData []byte) (encoded []byte, offset int, err error) { if p.Acks != nil && len(p.Acks) > 15 { return nil, 0, errors.New("too many Acks") } // Calculate total Size - size := CalcOverhead(p) + size := CalcProtoOverhead(len(p.Acks)) + len(payloadData) // Allocate buffer encoded = make([]byte, size) @@ -108,9 +81,6 @@ func EncodePayload(p *Payload) (encoded []byte, offset int, err error) { if p.IsSender { flags |= 1 << FlagSenderShift } - if p.StreamOffset > uint32Max { - flags |= 1 << FlagLrgShift - } // Set close flags flags |= uint8(p.CloseOp) << FlagCloseShift @@ -121,42 +91,16 @@ func EncodePayload(p *Payload) (encoded []byte, offset int, err error) { // Write ACKs section if present if p.Acks != nil { - // Calculate ACK flags byte - var ackFlags uint8 - if p.RcvWndSize > uint32Max { - ackFlags |= 0x1 // Set RcvWndSize Size flag (bit 0) - } - for i, ack := range p.Acks { - if ack.StreamOffset > uint32Max { - ackFlags |= 1 << (i + 1) // Set ACK offset Size flag (bits 1-7) - } - } - - // Write ACK flags - encoded[offset] = ackFlags - offset++ - - // Write RcvWndSize based on its value - if p.RcvWndSize > uint32Max { - PutUint64(encoded[offset:], p.RcvWndSize) - offset += 8 - } else { - PutUint32(encoded[offset:], uint32(p.RcvWndSize)) - offset += 4 - } + PutUint64(encoded[offset:], p.RcvWndSize) + offset += 8 // Write ACKs for _, ack := range p.Acks { PutUint32(encoded[offset:], ack.StreamId) offset += 4 - if ack.StreamOffset > uint32Max { - PutUint64(encoded[offset:], ack.StreamOffset) - offset += 8 - } else { - PutUint32(encoded[offset:], uint32(ack.StreamOffset)) - offset += 4 - } + PutUint64(encoded[offset:], ack.StreamOffset) + offset += 8 PutUint16(encoded[offset:], ack.Len) offset += 2 @@ -167,30 +111,26 @@ func EncodePayload(p *Payload) (encoded []byte, offset int, err error) { PutUint32(encoded[offset:], p.StreamId) offset += 4 - if p.StreamOffset > uint32Max { - PutUint64(encoded[offset:], p.StreamOffset) - offset += 8 - } else { - PutUint32(encoded[offset:], uint32(p.StreamOffset)) - offset += 4 - } - dataLen := uint16(len(p.Data)) + PutUint64(encoded[offset:], p.StreamOffset) + offset += 8 + + dataLen := uint16(len(payloadData)) if dataLen > 0 { - copy(encoded[offset:], p.Data) + copy(encoded[offset:], payloadData) offset += int(dataLen) } return encoded, offset, nil } -func DecodePayload(data []byte) (payload *Payload, offset int, err error) { +func DecodePayload(data []byte) (payload *PayloadMeta, offset int, payloadData []byte, err error) { dataLen := len(data) - if dataLen < MinPayloadSize { - return nil, 0, ErrPayloadTooSmall + if dataLen < MinProtoSize { + return nil, 0, nil, ErrPayloadTooSmall } offset = 0 - payload = &Payload{} + payload = &PayloadMeta{} // Flags (8 bits) flags := data[offset] @@ -198,56 +138,29 @@ func DecodePayload(data []byte) (payload *Payload, offset int, err error) { ackCount := flags & FlagAckMask payload.IsSender = (flags & (1 << FlagSenderShift)) != 0 - isDataLargeOffset := (flags & (1 << FlagLrgShift)) != 0 - payload.CloseOp = CloseOp((flags >> FlagCloseShift) & FlagCloseMask) // Decode ACKs if present if ackCount > 0 { - // Read ACK flags - ackFlags := data[offset] - offset++ - - // Read RcvWndSize based on flag - if ackFlags&0x1 != 0 { - if offset+8 > dataLen { - return nil, 0, ErrPayloadTooSmall - } - payload.RcvWndSize = Uint64(data[offset:]) - offset += 8 - - } else { - payload.RcvWndSize = uint64(Uint32(data[offset:])) - offset += 4 + if offset+8 > dataLen { + return nil, 0, nil, ErrPayloadTooSmall } + payload.RcvWndSize = Uint64(data[offset:]) + offset += 8 // Read ACKs payload.Acks = make([]Ack, ackCount) for i := 0; i < int(ackCount); i++ { - if offset+4 > dataLen { - return nil, 0, ErrPayloadTooSmall + if offset+4+8+2 > dataLen { + return nil, 0, nil, ErrPayloadTooSmall } ack := Ack{} ack.StreamId = Uint32(data[offset:]) offset += 4 - if ackFlags&(1<<(i+1)) != 0 { - if offset+8 > dataLen { - return nil, 0, ErrPayloadTooSmall - } - ack.StreamOffset = Uint64(data[offset:]) - offset += 8 - } else { - if offset+4 > dataLen { - return nil, 0, ErrPayloadTooSmall - } - ack.StreamOffset = uint64(Uint32(data[offset:])) - offset += 4 - } + ack.StreamOffset = Uint64(data[offset:]) + offset += 8 - if offset+2 > dataLen { - return nil, 0, ErrPayloadTooSmall - } ack.Len = Uint16(data[offset:]) offset += 2 payload.Acks[i] = ack @@ -256,32 +169,24 @@ func DecodePayload(data []byte) (payload *Payload, offset int, err error) { // Decode Data if offset+4 > dataLen { - return nil, 0, ErrPayloadTooSmall + return nil, 0, nil, ErrPayloadTooSmall } payload.StreamId = Uint32(data[offset:]) offset += 4 - if isDataLargeOffset { - if offset+8 > dataLen { - return nil, 0, ErrPayloadTooSmall - } - payload.StreamOffset = Uint64(data[offset:]) - offset += 8 - } else { - if offset+4 > dataLen { - return nil, 0, ErrPayloadTooSmall - } - payload.StreamOffset = uint64(Uint32(data[offset:])) - offset += 4 + if offset+8 > dataLen { + return nil, 0, nil, ErrPayloadTooSmall } + payload.StreamOffset = Uint64(data[offset:]) + offset += 8 if dataLen > offset { - payload.Data = make([]byte, dataLen-offset) - copy(payload.Data, data[offset:dataLen]) + payloadData = make([]byte, dataLen-offset) + copy(payloadData, data[offset:dataLen]) offset += dataLen } else { - payload.Data = make([]byte, 0) + payloadData = make([]byte, 0) } - return payload, offset, nil + return payload, offset, payloadData, nil } diff --git a/proto_test.go b/proto_test.go index 1f04dd5..0621efc 100644 --- a/proto_test.go +++ b/proto_test.go @@ -1,40 +1,37 @@ package tomtp import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "math" "reflect" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEncodeDecodeMinimalPayload(t *testing.T) { - original := &Payload{ + original := &PayloadMeta{ StreamId: 12345, StreamOffset: 0, - Data: []byte{}, } - encoded, offset, err := EncodePayload(original) + encoded, _, err := EncodePayload(original, []byte{}) require.NoError(t, err, "Failed to encode minimal payload") - require.Greater(t, offset, 0) - decoded, offset, err := DecodePayload(encoded) + decoded, _, decodedData, err := DecodePayload(encoded) require.NoError(t, err, "Failed to decode minimal payload") - require.Greater(t, offset, 0) assert.Equal(t, original.StreamId, decoded.StreamId, "StreamId mismatch") assert.Equal(t, original.StreamOffset, decoded.StreamOffset, "StreamOffset mismatch") - assert.Empty(t, decoded.Data, "Data should be empty") + assert.Empty(t, decodedData, "Data should be empty") } func TestPayloadWithAllFeatures(t *testing.T) { - original := &Payload{ + original := &PayloadMeta{ CloseOp: CloseStream, IsSender: true, StreamId: 1, StreamOffset: 9999, - Data: []byte("test data"), RcvWndSize: 1000, Acks: []Ack{ {StreamId: 1, StreamOffset: 123456, Len: 10}, @@ -42,19 +39,19 @@ func TestPayloadWithAllFeatures(t *testing.T) { }, } - encoded, offset, err := EncodePayload(original) + originalData := []byte("test data") + + encoded, _, err := EncodePayload(original, originalData) require.NoError(t, err, "Failed to encode payload") - require.Greater(t, offset, 0) - decoded, offset, err := DecodePayload(encoded) + decoded, _, decodedData, err := DecodePayload(encoded) require.NoError(t, err, "Failed to decode payload") - require.Greater(t, offset, 0) assert.Equal(t, original.CloseOp, decoded.CloseOp) assert.Equal(t, original.IsSender, decoded.IsSender) assert.Equal(t, original.StreamId, decoded.StreamId) assert.Equal(t, original.StreamOffset, decoded.StreamOffset) - assert.Equal(t, original.Data, decoded.Data) + assert.Equal(t, originalData, decodedData) require.NotNil(t, decoded.Acks) assert.Equal(t, original.RcvWndSize, decoded.RcvWndSize) @@ -77,107 +74,40 @@ func TestCloseOpBehavior(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - original := &Payload{ + original := &PayloadMeta{ CloseOp: tc.closeOp, StreamId: 1, StreamOffset: 100, - Data: []byte("test"), - } - - encoded, offset, err := EncodePayload(original) - require.NoError(t, err) - require.Greater(t, offset, 0) - - decoded, offset, err := DecodePayload(encoded) - require.NoError(t, err) - require.Greater(t, offset, 0) - - assert.Equal(t, tc.closeOp, decoded.CloseOp) - }) - } -} - -func TestLargeOffsets(t *testing.T) { - testCases := []struct { - name string - streamOffset uint64 - ackOffsets []uint64 - rcvWndSize uint64 - }{ - { - name: "All 32-bit values", - streamOffset: uint32Max - 1, - ackOffsets: []uint64{uint32Max - 1, uint32Max - 2}, - rcvWndSize: uint32Max - 1, - }, - { - name: "All 64-bit values", - streamOffset: uint64(uint32Max) + 1, - ackOffsets: []uint64{uint64(uint32Max) + 1, uint64(uint32Max) + 2}, - rcvWndSize: uint64(uint32Max) + 1, - }, - { - name: "Mixed values", - streamOffset: uint64(uint32Max) + 1, - ackOffsets: []uint64{uint32Max - 1, uint64(uint32Max) + 1}, - rcvWndSize: uint32Max - 1, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - acks := make([]Ack, len(tc.ackOffsets)) - for i, offset := range tc.ackOffsets { - acks[i] = Ack{ - StreamId: uint32(i + 1), - StreamOffset: offset, - Len: uint16(i + 100), - } } - original := &Payload{ - StreamId: 1, - StreamOffset: tc.streamOffset, - Data: []byte("test"), - Acks: acks, - RcvWndSize: tc.rcvWndSize, - } + originalData := []byte("test") - encoded, offset, err := EncodePayload(original) + encoded, _, err := EncodePayload(original, originalData) require.NoError(t, err) - require.Greater(t, offset, 0) - decoded, offset, err := DecodePayload(encoded) + decoded, _, _, err := DecodePayload(encoded) require.NoError(t, err) - require.Greater(t, offset, 0) - assert.Equal(t, original.StreamOffset, decoded.StreamOffset) - assert.Equal(t, original.RcvWndSize, decoded.RcvWndSize) - for i, ack := range original.Acks { - assert.Equal(t, ack.StreamOffset, decoded.Acks[i].StreamOffset) - } + assert.Equal(t, tc.closeOp, decoded.CloseOp) }) } } func TestEmptyData(t *testing.T) { t.Run("Empty Data With Required Fields", func(t *testing.T) { - original := &Payload{ + original := &PayloadMeta{ StreamId: 1, StreamOffset: 100, - Data: []byte{}, } - encoded, offset, err := EncodePayload(original) + encoded, _, err := EncodePayload(original, []byte{}) require.NoError(t, err) - require.Greater(t, offset, 0) - decoded, offset, err := DecodePayload(encoded) + decoded, _, decodedData, err := DecodePayload(encoded) require.NoError(t, err) - require.Greater(t, offset, 0) assert.Equal(t, original.StreamId, decoded.StreamId, "StreamId should be present") assert.Equal(t, original.StreamOffset, decoded.StreamOffset, "StreamOffset should be present") - assert.Empty(t, decoded.Data, "Data should be empty") + assert.Empty(t, decodedData, "Data should be empty") }) } @@ -188,21 +118,20 @@ func TestAckHandling(t *testing.T) { acks[i] = Ack{StreamId: uint32(i), StreamOffset: uint64(i * 1000), Len: uint16(i)} } - original := &Payload{ + original := &PayloadMeta{ StreamId: 1, StreamOffset: 100, - Data: []byte("test"), Acks: acks, RcvWndSize: 1000, } - encoded, offset, err := EncodePayload(original) + originalData := []byte("test") + + encoded, _, err := EncodePayload(original, originalData) require.NoError(t, err) - require.Greater(t, offset, 0) - decoded, offset, err := DecodePayload(encoded) + decoded, _, _, err := DecodePayload(encoded) require.NoError(t, err) - require.Greater(t, offset, 0) assert.Equal(t, len(original.Acks), len(decoded.Acks)) for i := range original.Acks { @@ -212,15 +141,15 @@ func TestAckHandling(t *testing.T) { t.Run("Too Many ACKs", func(t *testing.T) { acks := make([]Ack, 16) // One more than maximum - original := &Payload{ + original := &PayloadMeta{ StreamId: 1, StreamOffset: 100, - Data: []byte("test"), Acks: acks, RcvWndSize: 1000, } + originalData := []byte("test") - _, _, err := EncodePayload(original) + _, _, err := EncodePayload(original, originalData) assert.Error(t, err, "too many Acks") }) } @@ -248,23 +177,22 @@ func TestGetCloseOp(t *testing.T) { func FuzzPayload(f *testing.F) { // Add seed corpus with valid and edge case payloads - payloads := []*Payload{ + payloads := []*PayloadMeta{ { StreamId: 1, StreamOffset: 100, - Data: []byte("test data"), RcvWndSize: 1000, Acks: []Ack{{StreamId: 1, StreamOffset: 200, Len: 10}}, }, { StreamId: math.MaxUint32, StreamOffset: math.MaxUint64, - Data: []byte{}, }, } for _, p := range payloads { - encoded, _, err := EncodePayload(p) + originalData := []byte("test data") + encoded, _, err := EncodePayload(p, originalData) if err != nil { continue } @@ -272,24 +200,23 @@ func FuzzPayload(f *testing.F) { } f.Fuzz(func(t *testing.T, data []byte) { - decoded, _, err := DecodePayload(data) + decoded, _, payloadData, err := DecodePayload(data) if err != nil { t.Skip() } - // Re-encode and decode to verify - reEncoded, _, err := EncodePayload(decoded) + reEncoded, _, err := EncodePayload(decoded, payloadData) if err != nil { t.Skip() } - reDecoded, _, err := DecodePayload(reEncoded) + reDecoded, _, reDecodedData, err := DecodePayload(reEncoded) if err != nil { t.Skip() } // Compare original decoded with re-decoded - if !reflect.DeepEqual(decoded, reDecoded) { + if !reflect.DeepEqual(decoded, reDecoded) || !reflect.DeepEqual(payloadData, reDecodedData) { t.Fatal("re-encoded/decoded payload differs from original") } }) diff --git a/rcv.go b/rcv.go index 420ccf8..95a1e14 100644 --- a/rcv.go +++ b/rcv.go @@ -13,41 +13,57 @@ const ( RcvInsertBufferFull ) -type RcvSegment struct { - offset uint64 - data []byte +type RcvBuffer struct { + segments *skipList[packetKey, []byte] + nextInOrderOffsetToWaitFor uint64 // Next expected offset } type ReceiveBuffer struct { - segments *skipList[packetKey, *RcvSegment] // Store out-of-order segments - nextOffset uint64 // Next expected offset - capacity int // Max buffer size - size int // Current size + streams *linkedHashMap[uint32, *RcvBuffer] + lastStream uint32 + + capacity int // Max buffer size + size int // Current size mu *sync.Mutex acks []Ack - dataAvailable chan struct{} // Signal that data is available + dataAvailable chan struct{} // Signal that dataToSend is available +} + +func NewRcvBuffer() *RcvBuffer { + return &RcvBuffer{ + segments: newSortedHashMap[packetKey, []byte](func(a, b packetKey) bool { return a.less(b) }), + } } func NewReceiveBuffer(capacity int) *ReceiveBuffer { return &ReceiveBuffer{ - segments: newSortedHashMap[packetKey, *RcvSegment](func(a, b packetKey) bool { return a.less(b) }), + streams: newLinkedHashMap[uint32, *RcvBuffer](), capacity: capacity, mu: &sync.Mutex{}, dataAvailable: make(chan struct{}, 1), } } -func (rb *ReceiveBuffer) Insert(segment *RcvSegment) RcvInsertStatus { +func (rb *ReceiveBuffer) Insert(streamId uint32, offset uint64, decodedData []byte) RcvInsertStatus { + dataLen := len(decodedData) + key := createPacketKey(offset, uint16(dataLen)) + rb.mu.Lock() defer rb.mu.Unlock() - dataLen := len(segment.data) - if segment.offset+uint64(dataLen) < rb.nextOffset { + // Get or create stream buffer + entry := rb.streams.Get(streamId) + if entry == nil { + stream := NewRcvBuffer() + entry = rb.streams.Put(streamId, stream) + } + stream := entry.value + + if offset+uint64(dataLen) < stream.nextInOrderOffsetToWaitFor { return RcvInsertDuplicate } - key := createPacketKey(segment.offset, uint16(dataLen)) - if rb.segments.Contains(key) { + if stream.segments.Contains(key) { return RcvInsertDuplicate } @@ -55,15 +71,15 @@ func (rb *ReceiveBuffer) Insert(segment *RcvSegment) RcvInsertStatus { return RcvInsertBufferFull } - rb.segments.Put(key, segment) + stream.segments.Put(key, decodedData) rb.acks = append(rb.acks, Ack{ - StreamOffset: segment.offset, + StreamOffset: offset, Len: uint16(dataLen), }) rb.size += dataLen - // Signal that data is available (non-blocking send) + // Signal that dataToSend is available (non-blocking send) select { case rb.dataAvailable <- struct{}{}: default: // Non-blocking to prevent deadlocks if someone is already waiting @@ -72,13 +88,24 @@ func (rb *ReceiveBuffer) Insert(segment *RcvSegment) RcvInsertStatus { return RcvInsertOk } -func (rb *ReceiveBuffer) RemoveOldestInOrder(ctx context.Context) (*RcvSegment, error) { +func (rb *ReceiveBuffer) RemoveOldestInOrderBlocking(ctx context.Context, streamId uint32) (offset uint64, data []byte, err error) { rb.mu.Lock() defer rb.mu.Unlock() + if rb.streams.Size() == 0 { + return 0, nil, nil + } + + streamPair := rb.streams.Get(streamId) + if streamPair == nil { + return 0, nil, nil + } + stream := streamPair.value + streamId = streamPair.key + for { - // Check if there is any data at all - oldest := rb.segments.Min() + // Check if there is any dataToSend at all + oldest := stream.segments.Min() if oldest == nil { // No segments available, so wait rb.mu.Unlock() @@ -88,38 +115,39 @@ func (rb *ReceiveBuffer) RemoveOldestInOrder(ctx context.Context) (*RcvSegment, continue // Recheck segments size case <-ctx.Done(): rb.mu.Lock() - return nil, ctx.Err() // Context cancelled + return 0, nil, ctx.Err() // Context cancelled } } - if oldest.value.offset == rb.nextOffset { - rb.segments.Remove(oldest.key) + if oldest.key.offset() == stream.nextInOrderOffsetToWaitFor { + stream.segments.Remove(oldest.key) rb.size -= int(oldest.key.length()) - segment := oldest.value - if segment.offset < rb.nextOffset { - diff := rb.nextOffset - segment.offset - segment.data = segment.data[diff:] - segment.offset = rb.nextOffset + segmentVal := oldest.value + segmentKey := oldest.key + off := segmentKey.offset() + if off < stream.nextInOrderOffsetToWaitFor { + diff := stream.nextInOrderOffsetToWaitFor - segmentKey.offset() + segmentVal = segmentVal[diff:] + off = stream.nextInOrderOffsetToWaitFor } - rb.nextOffset = segment.offset + uint64(len(segment.data)) - return segment, nil - } else if oldest.value.offset > rb.nextOffset { + stream.nextInOrderOffsetToWaitFor = off + uint64(len(segmentVal)) + return oldest.key.offset(), segmentVal, nil + } else if oldest.key.offset() > stream.nextInOrderOffsetToWaitFor { // Out of order; wait until segment offset available, signal that rb.mu.Unlock() select { case <-rb.dataAvailable: - rb.mu.Lock() //get new data signal, re-lock to ensure no one modifies - continue // Recheck segments size after getting the data + rb.mu.Lock() //get new dataToSend signal, re-lock to ensure no one modifies + continue // Recheck segments size after getting the dataToSend case <-ctx.Done(): rb.mu.Lock() - return nil, ctx.Err() + return 0, nil, ctx.Err() } } else { - rb.segments.Remove(oldest.key) - rb.size -= int(oldest.key.length()) - // Dupe data, loop to get more data if exist + //Dupe, overlap, do nothing. Here we could think about adding the non-overlapping part. But if + //its correctly implemented, this should not happen. } } } diff --git a/rcv_test.go b/rcv_test.go index 6cee0da..96e821a 100644 --- a/rcv_test.go +++ b/rcv_test.go @@ -2,177 +2,206 @@ package tomtp import ( "context" - "errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "testing" ) -func TestReceiveBuffer(t *testing.T) { - tests := []struct { - name string - segments []*RcvSegment - capacity int - want []*RcvSegment - wantInsertStatus []RcvInsertStatus - }{ - { - name: "Single segment", - capacity: 1000, - segments: []*RcvSegment{ - {offset: 0, data: []byte("data")}, - }, - want: []*RcvSegment{ - {offset: 0, data: []byte("data")}, - }, - wantInsertStatus: []RcvInsertStatus{RcvInsertOk}, - }, - { - name: "Duplicate exact segment", - capacity: 1000, - segments: []*RcvSegment{ - {offset: 0, data: []byte("data")}, - {offset: 0, data: []byte("data")}, - }, - want: []*RcvSegment{ - {offset: 0, data: []byte("data")}, - }, - wantInsertStatus: []RcvInsertStatus{RcvInsertOk, RcvInsertDuplicate}, - }, - { - name: "Gap between segments", - capacity: 1000, - segments: []*RcvSegment{ - {offset: 10, data: []byte("later")}, - {offset: 0, data: []byte("early")}, - }, - want: []*RcvSegment{ - {offset: 0, data: []byte("early")}, - }, - wantInsertStatus: []RcvInsertStatus{RcvInsertOk, RcvInsertOk}, - }, - { - name: "Buffer full exact", - capacity: 4, - segments: []*RcvSegment{ - {offset: 0, data: []byte("data")}, - {offset: 4, data: []byte("more")}, - }, - want: []*RcvSegment{ - {offset: 0, data: []byte("data")}, - }, - wantInsertStatus: []RcvInsertStatus{RcvInsertOk, RcvInsertBufferFull}, - }, - { - name: "Zero length segment", - capacity: 1000, - segments: []*RcvSegment{ - {offset: 0, data: []byte{}}, - }, - want: []*RcvSegment{ - {offset: 0, data: []byte{}}, - }, - wantInsertStatus: []RcvInsertStatus{RcvInsertOk}, - }, - { - name: "Consecutive segments different arrival order", - capacity: 1000, - segments: []*RcvSegment{ - {offset: 5, data: []byte("second")}, - {offset: 0, data: []byte("first")}, - {offset: 11, data: []byte("third")}, - }, - want: []*RcvSegment{ - {offset: 0, data: []byte("first")}, - {offset: 5, data: []byte("second")}, - {offset: 11, data: []byte("third")}, - }, - wantInsertStatus: []RcvInsertStatus{RcvInsertOk, RcvInsertOk, RcvInsertOk}, - }, - } +func TestReceiveBuffer_SingleSegment(t *testing.T) { + rb := NewReceiveBuffer(1000) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + status := rb.Insert(1, 0, []byte("data")) + assert.Equal(t, RcvInsertOk, status) + + offset, data, err := rb.RemoveOldestInOrderBlocking(ctx, 1) + require.NoError(t, err) + assert.Equal(t, uint64(0), offset) + assert.Equal(t, []byte("data"), data) + + // Verify empty after reading + _, data, err = rb.RemoveOldestInOrderBlocking(ctx, 1) + require.Error(t, err) + require.Empty(t, data) +} + +func TestReceiveBuffer_DuplicateSegment(t *testing.T) { + rb := NewReceiveBuffer(1000) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + status := rb.Insert(1, 0, []byte("data")) + assert.Equal(t, RcvInsertOk, status) + + status = rb.Insert(1, 0, []byte("data")) + assert.Equal(t, RcvInsertDuplicate, status) + + offset, data, err := rb.RemoveOldestInOrderBlocking(ctx, 1) + require.NoError(t, err) + assert.Equal(t, uint64(0), offset) + assert.Equal(t, []byte("data"), data) +} + +func TestReceiveBuffer_GapBetweenSegments(t *testing.T) { + rb := NewReceiveBuffer(1000) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + status := rb.Insert(1, 10, []byte("later")) + assert.Equal(t, RcvInsertOk, status) + + status = rb.Insert(1, 0, []byte("early")) + assert.Equal(t, RcvInsertOk, status) + + // Should get early segment first + offset, data, err := rb.RemoveOldestInOrderBlocking(ctx, 1) + require.NoError(t, err) + assert.Equal(t, uint64(0), offset) + assert.Equal(t, []byte("early"), data) + + // Then later segment + offset, data, err = rb.RemoveOldestInOrderBlocking(ctx, 1) + require.Error(t, err) + assert.Equal(t, uint64(0), offset) + assert.Equal(t, 0, len(data)) +} + +func TestReceiveBuffer_MultipleStreams(t *testing.T) { + rb := NewReceiveBuffer(1000) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // Insert segments from different streams + status := rb.Insert(1, 0, []byte("stream1-first")) + assert.Equal(t, RcvInsertOk, status) + + status = rb.Insert(2, 0, []byte("stream2-first")) + assert.Equal(t, RcvInsertOk, status) + + status = rb.Insert(1, 13, []byte("stream1-second")) + assert.Equal(t, RcvInsertOk, status) + + // Read from stream 1 + offset, data, err := rb.RemoveOldestInOrderBlocking(ctx, 1) + require.NoError(t, err) + assert.Equal(t, uint64(0), offset) + assert.Equal(t, []byte("stream1-first"), data) + + // Read from stream 2 + offset, data, err = rb.RemoveOldestInOrderBlocking(ctx, 2) + require.NoError(t, err) + assert.Equal(t, uint64(0), offset) + assert.Equal(t, []byte("stream2-first"), data) + + // Read second segment from stream 1 + offset, data, err = rb.RemoveOldestInOrderBlocking(ctx, 1) + require.NoError(t, err) + assert.Equal(t, uint64(13), offset) + assert.Equal(t, []byte("stream1-second"), data) +} + +func TestReceiveBuffer_BufferFullExact(t *testing.T) { + rb := NewReceiveBuffer(4) + ctx, cancel := context.WithCancel(context.Background()) + cancel() - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rb := NewReceiveBuffer(tt.capacity) - ctx, _ := context.WithTimeout(context.Background(), 0) - - for i, seg := range tt.segments { - status := rb.Insert(seg) - assert.Equal(t, status, tt.wantInsertStatus[i]) - } - - var got []*RcvSegment - for { - seg, err := rb.RemoveOldestInOrder(ctx) - if err != nil { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - break - } - t.Fatal(err) - } - if seg == nil { - break - } - got = append(got, seg) - } - - assert.Equal(t, len(tt.want), len(got)) - for i := range got { - require.Less(t, i, len(tt.want)) - assert.Equal(t, tt.want[i].offset, got[i].offset) - assert.Equal(t, tt.want[i].data, got[i].data) - } - }) + status := rb.Insert(1, 0, []byte("data")) + assert.Equal(t, RcvInsertOk, status) + + status = rb.Insert(1, 4, []byte("more")) + assert.Equal(t, RcvInsertBufferFull, status) + + offset, data, err := rb.RemoveOldestInOrderBlocking(ctx, 1) + require.NoError(t, err) + assert.Equal(t, uint64(0), offset) + assert.Equal(t, []byte("data"), data) +} + +func TestReceiveBuffer_RemoveWithHigherOffset(t *testing.T) { + rb := NewReceiveBuffer(4) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + status := rb.Insert(1, 0, []byte("12345")) + assert.Equal(t, RcvInsertBufferFull, status) + + offset, data, err := rb.RemoveOldestInOrderBlocking(ctx, 1) + require.Error(t, err) + assert.Equal(t, uint64(0), offset) + assert.Equal(t, 0, len(data)) +} + +func TestReceiveBuffer_RemoveWithHigherOffset_EmptyAfterLast(t *testing.T) { + rb := NewReceiveBuffer(4) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + status := rb.Insert(1, 0, []byte("1")) + assert.Equal(t, RcvInsertOk, status) + + offset, data, err := rb.RemoveOldestInOrderBlocking(ctx, 1) + require.NoError(t, err) + assert.Equal(t, uint64(0), offset) + assert.Equal(t, []byte("1"), data) + + // Should be empty after reading + _, data, err = rb.RemoveOldestInOrderBlocking(ctx, 1) + require.Error(t, err) + require.Empty(t, data) +} + +func TestGetAcks_NoAcks(t *testing.T) { + rb := NewReceiveBuffer(1000) + acks := rb.GetAcks() + assert.Nil(t, acks) +} + +func TestGetAcks_SingleBatch(t *testing.T) { + rb := NewReceiveBuffer(1000) + for i := 0; i < 10; i++ { + rb.Insert(1, uint64(i*10), []byte("data")) } + + acks := rb.GetAcks() + assert.Equal(t, 10, len(acks)) + + acks = rb.GetAcks() + assert.Nil(t, acks) } -func TestGetAcks(t *testing.T) { - tests := []struct { - name string - inserts int - wantLens []int - }{ - { - name: "No acks", - inserts: 0, - wantLens: []int{0}, - }, - { - name: "Single batch under limit", - inserts: 10, - wantLens: []int{10, 0}, - }, - { - name: "Multiple batches", - inserts: 35, - wantLens: []int{15, 15, 5, 0}, - }, - { - name: "Exact batch Size", - inserts: 15, - wantLens: []int{15, 0}, - }, +func TestGetAcks_MultipleBatches(t *testing.T) { + rb := NewReceiveBuffer(1000) + for i := 0; i < 35; i++ { + rb.Insert(1, uint64(i*10), []byte("data")) } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rb := NewReceiveBuffer(1000) - - for i := 0; i < tt.inserts; i++ { - rb.Insert(&RcvSegment{ - offset: uint64(i * 10), - data: []byte("data"), - }) - } - - for _, wantLen := range tt.wantLens { - acks := rb.GetAcks() - if wantLen == 0 { - assert.Nil(t, acks) - } else { - assert.Equal(t, wantLen, len(acks)) - } - } - }) + // First batch + acks := rb.GetAcks() + assert.Equal(t, 15, len(acks)) + + // Second batch + acks = rb.GetAcks() + assert.Equal(t, 15, len(acks)) + + // Third batch + acks = rb.GetAcks() + assert.Equal(t, 5, len(acks)) + + // Should be empty now + acks = rb.GetAcks() + assert.Nil(t, acks) +} + +func TestGetAcks_ExactBatchSize(t *testing.T) { + rb := NewReceiveBuffer(1000) + for i := 0; i < 15; i++ { + rb.Insert(1, uint64(i*10), []byte("data")) } + + acks := rb.GetAcks() + assert.Equal(t, 15, len(acks)) + + acks = rb.GetAcks() + assert.Nil(t, acks) } diff --git a/skiplist.go b/skiplist.go index edafdfc..1de78d1 100644 --- a/skiplist.go +++ b/skiplist.go @@ -1,4 +1,4 @@ -// Package tomtp provides data structure implementations. +// Package tomtp provides dataToSend structure implementations. // All exported methods (those starting with capital letters) are thread-safe. package tomtp diff --git a/skiplist_test.go b/skiplist_test.go index 3320076..e06caad 100644 --- a/skiplist_test.go +++ b/skiplist_test.go @@ -70,7 +70,7 @@ func (s *SortedHashMapTestSuite) TestTreeOperations() { {6, "six"}, } - // Insert values + // InsertBlocking values for _, v := range values { s.shm.Put(v.key, v.value) } diff --git a/snd.go b/snd.go index 57cc137..3d193b9 100644 --- a/snd.go +++ b/snd.go @@ -6,8 +6,6 @@ import ( "sync" ) -type nowMillis uint64 - type packetKey [10]byte func (p packetKey) offset() uint64 { @@ -37,40 +35,40 @@ func createPacketKey(offset uint64, length uint16) packetKey { return p } -// StreamBuffer represents a single stream's data and metadata +// StreamBuffer represents a single stream's dataToSend and metadata type StreamBuffer struct { - //here we append the data, after appending, we sent currentOffset. - //This is necessary, as when data gets acked, we Remove the acked data, - //which will be in front of the array. Thus, len(data) would not work. - data []byte - // based on offset, which is uint48. This is the offset of the data we did not send yet + // here we append the dataToSend, after appending, we sent currentOffset. + // This is necessary, as when dataToSend gets acked, we Remove the acked dataToSend, + // which will be in front of the array. Thus, len(dataToSend) would not work. + dataToSend []byte + // this is the offset of the dataToSend we did not send yet unsentOffset uint64 - // based on offset, which is uint48. This is the offset of the data we did send + // this is the offset of the dataToSend we did send sentOffset uint64 - // when data is acked, we Remove the data, however we don't want to update all the offsets, hence this bias + // when dataToSend is acked, we Remove the dataToSend, however we don't want to update all the offsets, hence this bias // TODO: check what happens on an 64bit rollover bias uint64 - // inflight data - key is offset, which is uint48, len in 16bit is added to a 64bit key. value is sentTime + // inflight dataToSend - key is offset, which is uint48, len in 16bit is added to a 64bit key. value is sentTime // If MTU changes for inflight packets and need to be resent. The range is split. Example: // offset: 500, len/mtu: 50 -> 1 range: 500/50,time // retransmit with mtu:20 -> 3 dataInFlightMap: 500/20,time; 520/20,time; 540/10,time - dataInFlightMap *linkedHashMap[packetKey, *node[packetKey, nowMillis]] + dataInFlightMap *linkedHashMap[packetKey, *node[packetKey, uint64]] } type SendBuffer struct { streams *linkedHashMap[uint32, *StreamBuffer] // Changed to LinkedHashMap lastReadToSendStream uint32 //for round-robin, we continue where we left lastReadToRetransmitStream uint32 - capacity int //len(data) of all streams cannot become larger than capacity - totalSize int //len(data) of all streams + capacity int //len(dataToSend) of all streams cannot become larger than capacity + totalSize int //len(dataToSend) of all streams capacityAvailable chan struct{} // Signal that capacity is now available mu *sync.Mutex } func NewStreamBuffer() *StreamBuffer { return &StreamBuffer{ - data: []byte{}, - dataInFlightMap: newLinkedHashMap[packetKey, *node[packetKey, nowMillis]](), + dataToSend: []byte{}, + dataInFlightMap: newLinkedHashMap[packetKey, *node[packetKey, uint64]](), } } @@ -83,195 +81,166 @@ func NewSendBuffer(capacity int) *SendBuffer { } } -// Insert stores the data in the dataMap -func (sb *SendBuffer) Insert(ctx context.Context, streamId uint32, data []byte) error { - dataLen := len(data) - sb.mu.Lock() +// InsertBlocking stores the dataToSend in the dataMap, does not send yet +func (sb *SendBuffer) InsertBlocking(ctx context.Context, streamId uint32, data []byte) (int, error) { + var processedBytes int + remainingData := data - //Blocking wait until totalSize < capacity - for sb.capacity < sb.totalSize+dataLen { - sb.mu.Unlock() // Unlock before waiting - select { - case <-sb.capacityAvailable: // Wait for signal - sb.mu.Lock() // Re-lock after signal - continue - case <-ctx.Done(): - sb.mu.Lock() //Re-Lock to prevent memory corruption + for len(remainingData) > 0 { + sb.mu.Lock() + + // Calculate how much dataToSend we can insert + remainingCapacity := sb.capacity - sb.totalSize + if remainingCapacity <= 0 { sb.mu.Unlock() - return ctx.Err() // Return if context is cancelled + select { + case <-sb.capacityAvailable: + continue + case <-ctx.Done(): + return processedBytes, ctx.Err() + } } - } - // Get or create stream buffer - entry := sb.streams.Get(streamId) - if entry == nil { - stream := NewStreamBuffer() - entry = sb.streams.Put(streamId, stream) - } + // Calculate chunk size + chunkSize := min(len(remainingData), remainingCapacity) + chunk := remainingData[:chunkSize] - stream := entry.value + // Get or create stream buffer + entry := sb.streams.Get(streamId) + if entry == nil { + stream := NewStreamBuffer() + entry = sb.streams.Put(streamId, stream) + } + stream := entry.value - // Store data - stream.data = append(stream.data, data...) - stream.unsentOffset = stream.unsentOffset + uint64(dataLen) - sb.totalSize += dataLen + // Store chunk + stream.dataToSend = append(stream.dataToSend, chunk...) + stream.unsentOffset = stream.unsentOffset + uint64(chunkSize) + sb.totalSize += chunkSize - sb.mu.Unlock() // Unlock after signal is received - return nil + // Update remaining dataToSend + remainingData = remainingData[chunkSize:] + processedBytes += chunkSize + + sb.mu.Unlock() + } + + return processedBytes, nil } -// ReadyToSend finds unsent data and creates a range entry for tracking -func (sb *SendBuffer) ReadyToSend(mtu uint16, nowMillis2 uint64) (streamId uint32, offset uint64, data []byte, err error) { +// ReadyToSend gets data from dataToSend and creates a entry in dataInFlightMap +func (sb *SendBuffer) ReadyToSend(streamId uint32, maxData uint16, nowMillis uint64) (splitData []byte) { sb.mu.Lock() defer sb.mu.Unlock() if sb.streams.Size() == 0 { - return 0, 0, nil, nil + return nil } - streamPair := sb.streams.Get(sb.lastReadToSendStream) + streamPair := sb.streams.Get(streamId) if streamPair == nil { - streamPair = sb.streams.Oldest() - } else { - nextStreamPair := streamPair.Next() - if nextStreamPair == nil { - streamPair = sb.streams.Oldest() - } else { - streamPair = nextStreamPair - } + return nil } + stream := streamPair.value + streamId = streamPair.key - startStreamId := streamPair.key - for { - stream := streamPair.value - streamId = streamPair.key - - // Check if there's unsent data, if true, we have unsent data - if stream.unsentOffset > stream.sentOffset { - remainingData := stream.unsentOffset - stream.sentOffset - - //the max length we can send - length := uint16(min(uint64(mtu), remainingData)) + // Check if there's unsent dataToSend, if true, we have unsent dataToSend + if stream.unsentOffset > stream.sentOffset { + remainingData := stream.unsentOffset - stream.sentOffset - // Pack offset and length into key - key := createPacketKey(stream.sentOffset, length) + //the max length we can send + length := uint16(min(uint64(maxData), remainingData)) - // Check if range is already tracked - if stream.dataInFlightMap.Get(key) == nil { - // Get data slice accounting for bias - offset = stream.sentOffset - stream.bias - data = stream.data[offset : offset+uint64(length)] + // Pack offset and length into key + key := createPacketKey(stream.sentOffset, length) - // Track range - stream.dataInFlightMap.Put(key, newNode(key, nowMillis(nowMillis2))) + // Check if range is already tracked + if stream.dataInFlightMap.Get(key) == nil { + // Get dataToSend slice accounting for bias + offset := stream.sentOffset - stream.bias + splitData = stream.dataToSend[offset : offset+uint64(length)] - // Update tracking - stream.sentOffset = stream.sentOffset + uint64(length) - sb.lastReadToSendStream = streamId + // Track range + stream.dataInFlightMap.Put(key, newNode(key, nowMillis)) - return streamId, offset, data, nil - } else { - panic(errors.New("stream range already sent? should not happen")) - } - } + // Update tracking + stream.sentOffset = stream.sentOffset + uint64(length) + sb.lastReadToSendStream = streamId - streamPair = streamPair.Next() - if streamPair == nil { - streamPair = sb.streams.Oldest() - } - if streamPair.key == startStreamId { - break + return splitData + } else { + panic(errors.New("stream range already sent? should not happen")) } } - return 0, 0, nil, nil + return nil } // ReadyToRetransmit finds expired dataInFlightMap that need to be resent -func (sb *SendBuffer) ReadyToRetransmit(mtu uint16, rto uint64, nowMillis2 uint64) (streamId uint32, offset uint64, data []byte) { +func (sb *SendBuffer) ReadyToRetransmit(streamId uint32, maxData uint16, rto uint64, nowMillis uint64) (data []byte) { sb.mu.Lock() defer sb.mu.Unlock() if sb.streams.Size() == 0 { - return 0, 0, nil + return nil } - streamPair := sb.streams.Get(sb.lastReadToSendStream) + streamPair := sb.streams.Get(streamId) if streamPair == nil { - streamPair = sb.streams.Oldest() - } else { - nextStreamPair := streamPair.Next() - if nextStreamPair == nil { - streamPair = sb.streams.Oldest() - } else { - streamPair = nextStreamPair - } - + return nil } - - startStreamId := streamPair.key - for { - stream := streamPair.value - streamId = streamPair.key - - // Check Oldest range first - dataInFlight := stream.dataInFlightMap.Oldest() - if dataInFlight != nil { - sentTime := dataInFlight.value.value - if !dataInFlight.value.IsShadow() && nowMillis2-uint64(sentTime) > rto { - // Extract offset and length from key - rangeOffset := dataInFlight.key.offset() - rangeLen := dataInFlight.key.length() - - // Get data using bias - dataOffset := rangeOffset - stream.bias - data = stream.data[dataOffset : dataOffset+uint64(rangeLen)] - - sb.lastReadToRetransmitStream = streamId - if rangeLen <= mtu { - // Remove old range - stream.dataInFlightMap.Remove(dataInFlight.key) - // Same MTU - resend entire range - stream.dataInFlightMap.Put(dataInFlight.key, newNode(dataInFlight.key, nowMillis(nowMillis2))) - return streamPair.key, dataOffset, data - } else { - // Split range due to smaller MTU - leftKey := createPacketKey(rangeOffset, mtu) - // Queue remaining data with nxt offset - remainingOffset := rangeOffset + uint64(mtu) - remainingLen := rangeLen - mtu - rightKey := createPacketKey(remainingOffset, remainingLen) - - l, r := dataInFlight.value.Split(leftKey, nowMillis(nowMillis2), rightKey, dataInFlight.value.value) - oldParentKey := dataInFlight.key - oldParentValue := dataInFlight.value.value - n := newNode(r.key, oldParentValue) - - //we return the left, thus we need to reinsert as we have a new send time - //the right we keep, and Replace it with the old value, so it keeps the send time - dataInFlight.Replace(r.key, n) - stream.dataInFlightMap.Put(l.key, newNode(l.key, nowMillis(nowMillis2))) - stream.dataInFlightMap.Put(oldParentKey, newNode(oldParentKey, nowMillis(nowMillis2))) - - return streamPair.key, dataOffset, data[:mtu] - } + stream := streamPair.value + streamId = streamPair.key + + // Check Oldest range first + dataInFlight := stream.dataInFlightMap.Oldest() + if dataInFlight != nil { + sentTime := dataInFlight.value.value + if !dataInFlight.value.IsShadow() && nowMillis-uint64(sentTime) > rto { + // Extract offset and length from key + rangeOffset := dataInFlight.key.offset() + rangeLen := dataInFlight.key.length() + + // Get dataToSend using bias + dataOffset := rangeOffset - stream.bias + data = stream.dataToSend[dataOffset : dataOffset+uint64(rangeLen)] + + sb.lastReadToRetransmitStream = streamId + if rangeLen <= maxData { + // Remove old range + stream.dataInFlightMap.Remove(dataInFlight.key) + // Same MTU - resend entire range + stream.dataInFlightMap.Put(dataInFlight.key, newNode(dataInFlight.key, nowMillis)) + return data + } else { + // Split range due to smaller MTU + leftKey := createPacketKey(rangeOffset, maxData) + // Queue remaining dataToSend with nxt offset + remainingOffset := rangeOffset + uint64(maxData) + remainingLen := rangeLen - maxData + rightKey := createPacketKey(remainingOffset, remainingLen) + + l, r := dataInFlight.value.Split(leftKey, nowMillis, rightKey, dataInFlight.value.value) + oldParentKey := dataInFlight.key + oldParentValue := dataInFlight.value.value + n := newNode(r.key, oldParentValue) + + //we return the left, thus we need to reinsert as we have a new send time + //the right we keep, and Replace it with the old value, so it keeps the send time + dataInFlight.Replace(r.key, n) + stream.dataInFlightMap.Put(l.key, newNode(l.key, nowMillis)) + stream.dataInFlightMap.Put(oldParentKey, newNode(oldParentKey, nowMillis)) + + return data[:maxData] } } - - streamPair = streamPair.Next() - if streamPair == nil { - streamPair = sb.streams.Oldest() - } - if streamPair.key == startStreamId { - break - } } - return 0, 0, nil + return nil } -// AcknowledgeRange handles acknowledgment of data -func (sb *SendBuffer) AcknowledgeRange(streamId uint32, offset uint64, length uint16) uint64 { +// AcknowledgeRange handles acknowledgment of dataToSend +func (sb *SendBuffer) AcknowledgeRange(streamId uint32, offset uint64, length uint16) (sentTimeMillis uint64) { sb.mu.Lock() streamPair := sb.streams.Get(streamId) @@ -290,31 +259,31 @@ func (sb *SendBuffer) AcknowledgeRange(streamId uint32, offset uint64, length ui return 0 } - firstSentTime := rangePair.value.value + sentTimeMillis = rangePair.value.value delKeys := rangePair.value.Remove() for _, delKey := range delKeys { deletePair := stream.dataInFlightMap.Remove(delKey) if deletePair != nil { removeSentTime := deletePair.value.value - if removeSentTime < firstSentTime { - firstSentTime = removeSentTime + if removeSentTime < sentTimeMillis { + sentTimeMillis = removeSentTime } } } - // If this range starts at our bias point, we can Remove data + // If this range starts at our bias point, we can Remove dataToSend if offset == stream.bias { // Check if we have a gap between this ack and nxt range nextRange := stream.dataInFlightMap.Oldest() if nextRange == nil { - // No gap, safe to Remove all data - stream.data = stream.data[stream.sentOffset-stream.bias:] - stream.bias += stream.sentOffset - stream.bias + // No gap, safe to Remove all dataToSend + stream.dataToSend = stream.dataToSend[stream.sentOffset-stream.bias:] sb.totalSize -= int(stream.sentOffset - stream.bias) + stream.bias += stream.sentOffset - stream.bias } else { nextOffset := nextRange.key.offset() - stream.data = stream.data[nextOffset-stream.bias:] + stream.dataToSend = stream.dataToSend[nextOffset-stream.bias:] stream.bias += nextOffset sb.totalSize -= int(nextOffset) } @@ -326,5 +295,5 @@ func (sb *SendBuffer) AcknowledgeRange(streamId uint32, offset uint64, length ui } } sb.mu.Unlock() - return uint64(firstSentTime) + return sentTimeMillis } diff --git a/snd_test.go b/snd_test.go index 5d2ac79..7a3df97 100644 --- a/snd_test.go +++ b/snd_test.go @@ -15,12 +15,14 @@ func TestInsert(t *testing.T) { ctx := context.Background() // Basic insert - err := sb.Insert(ctx, 1, []byte("test")) + _, err := sb.InsertBlocking(ctx, 1, []byte("test")) assert.Nil(err) // Verify stream created correctly - stream := sb.streams.Get(1).value - assert.Equal([]byte("test"), stream.data) + streamPair := sb.streams.Get(1) + assert.NotNil(streamPair) + stream := streamPair.value + assert.Equal([]byte("test"), stream.dataToSend) assert.Equal(uint64(4), stream.unsentOffset) assert.Equal(uint64(0), stream.sentOffset) assert.Equal(uint64(0), stream.bias) @@ -30,19 +32,24 @@ func TestInsert(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) // Timeout defer cancel() - err = sb.Insert(ctx, 1, []byte("test")) + _, err = sb.InsertBlocking(ctx, 1, []byte("test")) assert.Error(err) assert.Equal(context.DeadlineExceeded, err) - // Test 48-bit wrapping + // Test 48-bit wrapping (using MaxUint64 as uint48 in go doesn't exist) sb = NewSendBuffer(1000) stream = NewStreamBuffer() stream.unsentOffset = math.MaxUint64 - 2 sb.streams.Put(1, stream) - err = sb.Insert(context.Background(), 1, []byte("test")) + _, err = sb.InsertBlocking(context.Background(), 1, []byte("test")) assert.Nil(err) // Should succeed now - assert.Equal(uint64(1), stream.unsentOffset) + streamPair = sb.streams.Get(1) + assert.NotNil(streamPair) + stream = streamPair.value + + //assert.Equal(uint64(math.MaxUint64 + 2), stream.unsentOffset) // Rollover will occur. Because we are using unit64 + assert.Equal(uint64(1), stream.unsentOffset) // Rollover will occur. Because we are using unit64 assert.Equal(uint64(0), stream.sentOffset) } @@ -50,90 +57,90 @@ func TestReadyToSend(t *testing.T) { assert := require.New(t) sb := NewSendBuffer(1000) ctx := context.Background() + nowMillis2 := uint64(100) // Insert data - sb.Insert(ctx, 1, []byte("test1")) - sb.Insert(ctx, 2, []byte("test2")) + sb.InsertBlocking(ctx, 1, []byte("test1")) + sb.InsertBlocking(ctx, 2, []byte("test2")) // Basic send - streamId, offset, data, err := sb.ReadyToSend(10, 100) - assert.NoError(err) - assert.Equal(uint32(1), streamId) - assert.Equal(uint64(0), offset) + data := sb.ReadyToSend(1, 10, nowMillis2) assert.Equal([]byte("test1"), data) // Verify range tracking - stream := sb.streams.Get(1).value + streamPair := sb.streams.Get(1) + assert.NotNil(streamPair) + stream := streamPair.value + rangePair := stream.dataInFlightMap.Oldest() assert.NotNil(rangePair) assert.Equal(uint16(5), rangePair.key.length()) - assert.Equal(nowMillis(100), rangePair.value.value) + assert.Equal(nowMillis2, rangePair.value.value) - sb.ReadyToSend(10, 100) + sb.ReadyToSend(1, 10, nowMillis2) // Test MTU limiting - sb.Insert(ctx, 3, []byte("toolongdata")) - streamId, offset, data, err = sb.ReadyToSend(4, 100) - assert.NoError(err) - assert.Equal(uint32(3), streamId) - assert.Equal(uint64(0), offset) + sb.InsertBlocking(ctx, 3, []byte("toolongdata")) + data = sb.ReadyToSend(3, 4, nowMillis2) assert.Equal([]byte("tool"), data) - // Test round-robin - streamId, _, _, err = sb.ReadyToSend(10, 100) - assert.NoError(err) - assert.Equal(uint32(3), streamId) + // test no data available + data = sb.ReadyToSend(4, 10, nowMillis2) + assert.Nil(data) } func TestReadyToRetransmit(t *testing.T) { assert := require.New(t) sb := NewSendBuffer(1000) ctx := context.Background() + //nowMillis := uint64(200) // Setup test data - sb.Insert(ctx, 1, []byte("test1")) - sb.Insert(ctx, 2, []byte("test2")) - sb.ReadyToSend(10, 100) - sb.ReadyToSend(10, 100) + sb.InsertBlocking(ctx, 1, []byte("test1")) + sb.InsertBlocking(ctx, 2, []byte("test2")) + + sb.ReadyToSend(1, 10, 100) // Initial send at time 100 + sb.ReadyToSend(2, 10, 100) // Initial send at time 100 // Test basic retransmit - streamId, offset, data := sb.ReadyToRetransmit(10, 50, 200) - assert.Equal(uint32(1), streamId) - assert.Equal(uint64(0), offset) + data := sb.ReadyToRetransmit(1, 10, 50, 200) // RTO = 50, now = 200. 200-100 > 50 assert.Equal([]byte("test1"), data) - streamId, offset, data = sb.ReadyToRetransmit(10, 100, 200) + data = sb.ReadyToRetransmit(2, 10, 100, 200) //RTO = 100, now = 200. 200-100 = 100, thus ok assert.Nil(data) - streamId, offset, data = sb.ReadyToRetransmit(10, 99, 200) - assert.Equal(uint32(2), streamId) - assert.Equal(uint64(0), offset) - assert.Equal([]byte("test2"), data) + data = sb.ReadyToRetransmit(1, 10, 99, 300) // RTO = 99, now = 200. 200-100 > 99 + assert.Equal([]byte("test1"), data) // Test MTU split sb = NewSendBuffer(1000) - sb.Insert(ctx, 1, []byte("testdata")) - sb.ReadyToSend(8, 100) + sb.InsertBlocking(ctx, 1, []byte("testdata")) + sb.ReadyToSend(1, 100, 100) // Initial send - streamId, offset, data = sb.ReadyToRetransmit(4, 50, 200) - assert.Equal(uint32(1), streamId) - assert.Equal(uint64(0), offset) + data = sb.ReadyToRetransmit(1, 4, 99, 200) assert.Equal([]byte("test"), data) // Verify range split - stream := sb.streams.Get(1).value + streamPair := sb.streams.Get(1) + assert.NotNil(streamPair) + stream := streamPair.value assert.Equal(3, stream.dataInFlightMap.Size()) + node := stream.dataInFlightMap.Oldest() + assert.Equal(uint16(4), node.key.length()) + assert.Equal(uint64(4), node.key.offset()) } func TestAcknowledgeRangeBasic(t *testing.T) { assert := require.New(t) sb := NewSendBuffer(1000) ctx := context.Background() - sb.Insert(ctx, 1, []byte("testdata")) - sb.ReadyToSend(4, 100) + sb.InsertBlocking(ctx, 1, []byte("testdata")) + sb.ReadyToSend(1, 4, 100) + streamPair := sb.streams.Get(1) + assert.NotNil(streamPair) + stream := streamPair.value assert.Equal(uint64(100), sb.AcknowledgeRange(1, 0, 4)) - stream := sb.streams.Get(1).value - assert.Equal(4, len(stream.data)) + assert.Equal(4, len(stream.dataToSend)) assert.Equal(uint64(4), stream.bias) } @@ -150,171 +157,3 @@ func TestAcknowledgeRangeNonExistentRange(t *testing.T) { sb.streams.Put(1, stream) assert.Equal(uint64(0), sb.AcknowledgeRange(1, 0, 4)) } - -func TestSendBufferIntegration(t *testing.T) { - assert := require.New(t) - - t.Run("edge cases with varying MTU and data sizes", func(t *testing.T) { - sb := NewSendBuffer(1000) - ctx := context.Background() - - // Test case 1: Insert data near MaxUint48 boundary - stream1Data := make([]byte, 100) - sb.Insert(ctx, 1, stream1Data) - stream := sb.streams.Get(1).value - stream.unsentOffset = math.MaxUint64 - 49 - - // Insert data that will wrap around - wrapData := make([]byte, 100) - err := sb.Insert(ctx, 1, wrapData) - assert.Nil(err) - assert.Equal(uint64(50), stream.unsentOffset) - - // Test case 2: Multiple MTU sizes for ReadyToSend - streamId, _, data, err := sb.ReadyToSend(30, 100) - assert.NoError(err) - assert.Equal(uint32(1), streamId) - assert.Equal(30, len(data)) - - // Smaller MTU - streamId, _, data, err = sb.ReadyToSend(20, 200) - assert.NoError(err) - assert.Equal(20, len(data)) - - // Test case 3: Retransmission with MTU changes - // First send with large MTU - sb = NewSendBuffer(1000) - sb.Insert(ctx, 1, []byte("thisislongdataforretransmission")) - streamId, _, _, err = sb.ReadyToSend(30, 100) - assert.NoError(err) - - // Retransmit with smaller MTU - streamId, _, data = sb.ReadyToRetransmit(10, 50, 200) - assert.Equal(uint32(1), streamId) - assert.Equal(10, len(data)) - assert.Equal("thisislong", string(data)) - - // Test case 4: Out of order acknowledgments - sb = NewSendBuffer(1000) - testData := []byte("testdatafortestingacks") - sb.Insert(ctx, 1, testData) - - // Send in chunks - sb.ReadyToSend(5, 100) // "testd" - sb.ReadyToSend(5, 100) // "atafo" - sb.ReadyToSend(5, 100) // "rtest" - - // Acknowledge in reverse order - assert.Equal(uint64(100), sb.AcknowledgeRange(1, 10, 5)) // "rtest" - assert.Equal(uint64(100), sb.AcknowledgeRange(1, 5, 5)) // "atafo" - assert.Equal(uint64(100), sb.AcknowledgeRange(1, 0, 5)) // "testd" - - stream = sb.streams.Get(1).value - assert.Equal(uint64(15), stream.bias) - - // Test case 5: Mixed operations with multiple streams - sb = NewSendBuffer(1000) - - // Insert into multiple streams - sb.Insert(ctx, 1, []byte("stream1data")) - sb.Insert(ctx, 2, []byte("stream2data")) - sb.Insert(ctx, 3, []byte("stream3data")) - - // Send from different streams with varying MTUs - _, _, data1, _ := sb.ReadyToSend(5, 100) - _, _, data2, _ := sb.ReadyToSend(7, 100) - _, _, data3, _ := sb.ReadyToSend(4, 100) - - assert.Equal(5, len(data1)) - assert.Equal(7, len(data2)) - assert.Equal(4, len(data3)) - - // Retransmit with different MTUs - _, _, retrans1 := sb.ReadyToRetransmit(3, 50, 200) - assert.Equal(3, len(retrans1)) - assert.Equal("str", string(retrans1)) - - // Test case 6: Edge case - acknowledge empty range - assert.Equal(uint64(0), sb.AcknowledgeRange(1, 0, 0)) - - // Test case 7: Complex out-of-order acknowledgments with gaps - sb = NewSendBuffer(1000) - sb.Insert(ctx, 1, []byte("abcdefghijklmnopqrstuvwxyz")) - - // Send in multiple chunks - sb.ReadyToSend(5, 100) // "abcde" - sb.ReadyToSend(5, 100) // "fghij" - sb.ReadyToSend(5, 100) // "klmno" - sb.ReadyToSend(5, 100) // "pqrst" - sb.ReadyToSend(5, 100) // "uvwxy" - - // Acknowledge with gaps: 2,4,1,5,3 - assert.Equal(uint64(100), sb.AcknowledgeRange(1, 5, 5)) // "fghij" - assert.Equal(uint64(100), sb.AcknowledgeRange(1, 15, 5)) // "pqrst" - assert.Equal(uint64(100), sb.AcknowledgeRange(1, 0, 5)) // "abcde" - assert.Equal(uint64(100), sb.AcknowledgeRange(1, 20, 5)) // "uvwxy" - assert.Equal(uint64(100), sb.AcknowledgeRange(1, 10, 5)) // "klmno" - - stream = sb.streams.Get(1).value - assert.Equal(uint64(25), stream.bias) - - // Test case 8: Out-of-order retransmissions with varying MTUs - sb = NewSendBuffer(1000) - sb.Insert(ctx, 1, []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) - - // Initial send with MTU 10 - sb.ReadyToSend(10, 100) // "ABCDEFGHIJ" - sb.ReadyToSend(10, 100) // "KLMNOPQRST" - sb.ReadyToSend(6, 100) // "UVWXYZ" - - // Retransmit first chunk with larger MTU - _, _, retrans1 = sb.ReadyToRetransmit(10, 50, 200) - assert.Equal(10, len(retrans1)) - assert.Equal("ABCDEFGHIJ", string(retrans1)) - - // Retransmit middle chunk first with smaller MTU - _, _, retrans2 := sb.ReadyToRetransmit(5, 50, 200) // Should split "KLMNOPQRST" - assert.Equal(5, len(retrans2)) - assert.Equal("KLMNO", string(retrans2)) - - // Retransmit remaining part of middle chunk - _, _, retrans3 := sb.ReadyToRetransmit(5, 50, 200) - assert.Equal(5, len(retrans3)) - assert.Equal("PQRST", string(retrans3)) - - // Retransmit remaining part of middle chunk - _, _, retrans4 := sb.ReadyToRetransmit(5, 50, 300) - assert.Equal(5, len(retrans4)) - assert.Equal("UVWXY", string(retrans4)) - - // Test case 9: Edge case - maximum MTU - sb = NewSendBuffer(1000) - ctxLimited, cancelLimited := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancelLimited() - err = sb.Insert(ctxLimited, 1, make([]byte, 65535)) // This should timeout and return an error! - assert.Error(err) - - // Test case 10: Acknowledge large buffer - sb = NewSendBuffer(1000) - ctx2 := context.Background() - - longTestData := make([]byte, 500) - err = sb.Insert(ctx2, 1, longTestData) - assert.NoError(err) - sb.ReadyToSend(500, 100) - - sb.AcknowledgeRange(1, 0, 500) - - sb = NewSendBuffer(1024) - ctxEdgeCases := context.Background() - err = sb.Insert(ctxEdgeCases, 1, []byte("test")) - assert.NoError(err) - stream = sb.streams.Get(1).value - sb.AcknowledgeRange(1, 0, 0) - - }) -} - -func getOffsetAndLength(k uint64) (uint64, uint16) { - return k & ((1 << 48) - 1), uint16(k >> 48) -} diff --git a/stream.go b/stream.go index a6cd259..51dcd78 100644 --- a/stream.go +++ b/stream.go @@ -25,17 +25,12 @@ var ( type Stream struct { // Connection info - streamId uint32 - streamOffsetNext uint64 - conn *Connection - state StreamState - - // Flow control - rcvWndSize uint64 // Receive window Size - sndWndSize uint64 // Send window Size + streamId uint32 + conn *Connection + state StreamState // Reliable delivery buffers - rbRcv *ReceiveBuffer // Receive buffer for incoming data + //rbRcv *ReceiveBuffer // Receive buffer for incoming dataToSend // Statistics bytesRead uint64 @@ -60,26 +55,20 @@ func (s *Stream) Write(b []byte) (nTot int, err error) { slog.Debug("Write", debugGoroutineID(), s.debug(), slog.String("b...", string(b[:min(10, len(b))]))) for len(b) > 0 { - var enc []byte var n int - enc, n, err = s.encode(b) + n, err = s.conn.rbSnd.InsertBlocking(s.closeCtx, s.streamId, b) if err != nil { return nTot, err } nTot += n - if n == 0 { - break - } - err = s.conn.rbSnd.Insert(s.closeCtx, s.streamId, enc) + + // Signal the listener that there is dataToSend to send + + err = s.conn.listener.localConn.CancelRead() if err != nil { return nTot, err } - // Signal the listener that there is data to send - if s.conn.listener != nil { //Ensure Listener Exists - s.conn.listener.localConn.CancelRead() - } - b = b[n:] } @@ -90,20 +79,20 @@ func (s *Stream) Read(b []byte) (n int, err error) { s.mu.Lock() defer s.mu.Unlock() - slog.Debug("read data start", debugGoroutineID(), s.debug()) + slog.Debug("read dataToSend start", debugGoroutineID(), s.debug()) - segment, err := s.rbRcv.RemoveOldestInOrder(s.closeCtx) + _, data, err := s.conn.rbRcv.RemoveOldestInOrderBlocking(s.closeCtx, s.streamId) if err != nil { return 0, err } - if segment == nil { + if data == nil { if s.state >= StreamEnded { return 0, io.EOF } return 0, nil } - n = copy(b, segment.data) + n = copy(b, data) slog.Debug("read Data done", debugGoroutineID(), s.debug(), slog.String("b...", string(b[:min(10, n)]))) s.bytesRead += uint64(n) return n, nil @@ -113,15 +102,15 @@ func (s *Stream) ReadBytes() (b []byte, err error) { s.mu.Lock() defer s.mu.Unlock() - slog.Debug("read data start", debugGoroutineID(), s.debug()) + slog.Debug("read dataToSend start", debugGoroutineID(), s.debug()) - segment, err := s.rbRcv.RemoveOldestInOrder(s.closeCtx) + _, data, err := s.conn.rbRcv.RemoveOldestInOrderBlocking(s.closeCtx, s.streamId) if err != nil { return nil, err } - s.bytesRead += uint64(len(segment.data)) - return segment.data, nil + s.bytesRead += uint64(len(data)) + return data, nil } func (s *Stream) Close() error { @@ -141,15 +130,15 @@ func (s *Stream) debug() slog.Attr { return s.conn.listener.debug(s.conn.remoteAddr) } -func (s *Stream) receive(streamData []byte, streamOffset uint64) { +func (s *Stream) receive(offset uint64, decodedData []byte) { s.mu.Lock() defer s.mu.Unlock() - if len(streamData) > 0 { - r := RcvSegment{ - offset: streamOffset, - data: streamData, - } - s.rbRcv.Insert(&r) + if len(decodedData) > 0 { + s.conn.rbRcv.Insert(s.streamId, offset, decodedData) } } + +func (s *Stream) calcLen(mtu int, ackLen int) uint16 { + return uint16(mtu - s.Overhead(ackLen)) +}