diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..f0c5321 --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,25 @@ +name: Go + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.23' + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -v ./... \ No newline at end of file diff --git a/codec.go b/codec.go index df2b95b..6232303 100644 --- a/codec.go +++ b/codec.go @@ -16,7 +16,7 @@ func (s *Stream) encode(b []byte) (enc []byte, offset int, err error) { p := &Payload{ CloseOp: GetCloseOp(s.state == StreamEnding, s.conn.state == ConnectionEnding), IsSender: s.conn.sender, - RcvWndSize: uint64(s.rbRcv.Size()), + RcvWndSize: s.rcvWndSize - uint64(s.rbRcv.Size()), Acks: s.rbRcv.GetAcks(), StreamId: s.streamId, StreamOffset: s.streamOffsetNext, diff --git a/conn.go b/conn.go index ad49131..ff04862 100644 --- a/conn.go +++ b/conn.go @@ -1,6 +1,7 @@ package tomtp import ( + "context" "crypto/ecdh" "log/slog" "net/netip" @@ -120,12 +121,15 @@ func (c *Connection) GetOrNewStreamRcv(streamId uint32) (*Stream, bool) { } if stream, ok := c.streams[streamId]; !ok { + ctx, cancel := context.WithCancel(context.Background()) s := &Stream{ streamId: streamId, streamOffsetNext: 0, state: StreamStarting, conn: c, rbRcv: NewReceiveBuffer(rcvBufferCapacity), + closeCtx: ctx, + closeCancelFn: cancel, mu: sync.Mutex{}, } c.streams[streamId] = s diff --git a/conn_test.go b/conn_test.go index 1fa8000..4ff2e0f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -73,25 +73,23 @@ func TestConnection_GetOrNewStreamRcv(t *testing.T) { { name: "new stream", streamID: 1, - setup: false, + setup: true, }, { name: "existing stream", - streamID: 2, - setup: true, + streamID: 1, + setup: false, }, } - + conn := &Connection{ + streams: make(map[uint32]*Stream), + } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - conn := &Connection{ - streams: make(map[uint32]*Stream), - } - stream, isNew := conn.GetOrNewStreamRcv(tt.streamID) assert.NotNil(t, stream) assert.Equal(t, tt.streamID, stream.streamId) - assert.Equal(t, !tt.setup, isNew) + assert.Equal(t, tt.setup, isNew) }) } } diff --git a/end2end_test.go b/end2end_test.go index 2471559..9708007 100644 --- a/end2end_test.go +++ b/end2end_test.go @@ -157,7 +157,7 @@ func relayData(connSrc, connDest *inMemoryNetworkConn, maxBytes int) (int, error return int(bytesWritten), nil } -func createConnectedStreams( +func createTwoStreams( nConnA *inMemoryNetworkConn, nConnB *inMemoryNetworkConn, prvKeyA *ecdh.PrivateKey, @@ -184,7 +184,7 @@ func createConnectedStreams( return nil, nil, errors.New("failed to create listener B: " + err.Error()) } - connA, err := listenerA.DialString(nConnB.LocalAddr().String(), hexPublicKey2) + connA, err := listenerA.DialString(nConnB.LocalAddr().String(), hexPubKey2) if err != nil { listenerA.Close() // clean up everything here! listenerB.Close() @@ -203,26 +203,23 @@ func createConnectedStreams( func TestEndToEndInMemory(t *testing.T) { nConnA, nConnB, err := setupInMemoryPair() - if err != nil { - t.Fatalf("failed to setup in-memory connections: %v", err) - } + assert.Nil(t, err) defer nConnA.Close() defer nConnB.Close() var streamB *Stream - acceptB := func(s *Stream) { - slog.Info("A: accept connection") - streamB = s - } - - streamA, listenerB, err := createConnectedStreams(nConnA, nConnB, testPrivateKey1, testPrivateKey2, acceptB) + streamA, listenerB, err := createTwoStreams(nConnA, nConnB, testPrvKey1, testPrvKey2, func(s *Stream) { streamB = s }) assert.Nil(t, err) a := []byte("hallo") - streamA.Write(a) - streamA.conn.listener.Update(0) - relayData(nConnA, nConnB, startMtu) - listenerB.Update(0) + _, err = streamA.Write(a) + assert.Nil(t, err) + err = streamA.conn.listener.Update(0) + assert.Nil(t, err) + _, err = relayData(nConnA, nConnB, startMtu) + assert.Nil(t, err) + err = listenerB.Update(0) + assert.Nil(t, err) b, err := streamB.ReadBytes() assert.Nil(t, err) assert.Equal(t, a, b) diff --git a/listener_test.go b/listener_test.go index 767fa6b..fdcffee 100644 --- a/listener_test.go +++ b/listener_test.go @@ -11,19 +11,19 @@ import ( ) var ( - testPrivateSeed1 = [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} - testPrivateSeed2 = [32]byte{2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} - testPrivateKey1, _ = ecdh.X25519().NewPrivateKey(testPrivateSeed1[:]) - testPrivateKey2, _ = ecdh.X25519().NewPrivateKey(testPrivateSeed2[:]) + testPrvSeed1 = [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + testPrvSeed2 = [32]byte{2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} + testPrvKey1, _ = ecdh.X25519().NewPrivateKey(testPrvSeed1[:]) + testPrvKey2, _ = ecdh.X25519().NewPrivateKey(testPrvSeed2[:]) - hexPublicKey1 = fmt.Sprintf("0x%x", testPrivateKey1.PublicKey().Bytes()) - hexPublicKey2 = fmt.Sprintf("0x%x", testPrivateKey2.PublicKey().Bytes()) + hexPubKey1 = fmt.Sprintf("0x%x", testPrvKey1.PublicKey().Bytes()) + hexPubKey2 = fmt.Sprintf("0x%x", testPrvKey2.PublicKey().Bytes()) ) func TestNewListener(t *testing.T) { // Test case 1: Create a new listener with a valid address addr := "127.0.0.1:8080" - listener, err := ListenString(addr, func(s *Stream) {}, WithSeed(testPrivateSeed1)) + listener, err := ListenString(addr, func(s *Stream) {}, WithSeed(testPrvSeed1)) defer listener.Close() if err != nil { t.Errorf("Expected no error, but got: %v", err) @@ -34,7 +34,7 @@ func TestNewListener(t *testing.T) { // Test case 2: Create a new listener with an invalid address invalidAddr := "127.0.0.1:99999" - _, err = ListenString(invalidAddr, func(s *Stream) {}, WithSeed(testPrivateSeed1)) + _, err = ListenString(invalidAddr, func(s *Stream) {}, WithSeed(testPrvSeed1)) if err == nil { t.Errorf("Expected an error, but got nil") } @@ -42,17 +42,17 @@ func TestNewListener(t *testing.T) { func TestNewStream(t *testing.T) { // Test case 1: Create a new multi-stream with a valid remote address - listener, err := ListenString("127.0.0.1:9080", func(s *Stream) {}, WithSeed(testPrivateSeed1)) + listener, err := ListenString("127.0.0.1:9080", func(s *Stream) {}, WithSeed(testPrvSeed1)) defer listener.Close() assert.Nil(t, err) - conn, err := listener.DialString("127.0.0.1:9081", hexPublicKey1) + conn, err := listener.DialString("127.0.0.1:9081", hexPubKey1) assert.Nil(t, err) if conn == nil { t.Errorf("Expected a multi-stream, but got nil") } // Test case 2: Create a new multi-stream with an invalid remote address - conn, err = listener.DialString("127.0.0.1:99999", hexPublicKey1) + conn, err = listener.DialString("127.0.0.1:99999", hexPubKey1) if conn != nil { t.Errorf("Expected nil, but got a multi-stream") } @@ -61,10 +61,10 @@ func TestNewStream(t *testing.T) { func TestClose(t *testing.T) { // Test case 1: Close a listener with no multi-streams - listener, err := ListenString("127.0.0.1:9080", func(s *Stream) {}, WithSeed(testPrivateSeed1)) + listener, err := ListenString("127.0.0.1:9080", func(s *Stream) {}, WithSeed(testPrvSeed1)) assert.NoError(t, err) // Test case 2: Close a listener with multi-streams - listener.DialString("127.0.0.1:9081", hexPublicKey1) + listener.DialString("127.0.0.1:9081", hexPubKey1) err = listener.Close() if err != nil { t.Errorf("Expected no error, but got: %v", err) @@ -78,7 +78,7 @@ func TestListenerUpdate_NoActivity(t *testing.T) { acceptFn := func(s *Stream) { acceptCalled = true } - listener, err := ListenString("127.0.0.1:9080", acceptFn, WithSeed(testPrivateSeed1)) + listener, err := ListenString("127.0.0.1:9080", acceptFn, WithSeed(testPrvSeed1)) assert.NoError(t, err) defer listener.Close() @@ -103,16 +103,16 @@ func TestListenerUpdate_ReceiveData(t *testing.T) { acceptFn := func(s *Stream) { acceptCalled = true } - listenerSnd, err := ListenString(":8881", func(stream *Stream) {}, WithSeed(testPrivateSeed1)) + listenerSnd, err := ListenString(":8881", func(stream *Stream) {}, WithSeed(testPrvSeed1)) assert.NoError(t, err) defer listenerSnd.Close() - connectionSnd, err := listenerSnd.DialString("127.0.0.1:8882", hexPublicKey2) + connectionSnd, err := listenerSnd.DialString("127.0.0.1:8882", hexPubKey2) assert.NoError(t, err) streamSnd, _ := connectionSnd.GetOrNewStreamRcv(0) - listenerRcv, err := ListenString(":8882", acceptFn, WithSeed(testPrivateSeed2)) + listenerRcv, err := ListenString(":8882", acceptFn, WithSeed(testPrvSeed2)) // Sender setup streamSnd.Write([]byte("hello")) diff --git a/rcv.go b/rcv.go index 80aeac4..420ccf8 100644 --- a/rcv.go +++ b/rcv.go @@ -1,6 +1,7 @@ package tomtp import ( + "context" "sync" ) @@ -18,19 +19,21 @@ type RcvSegment struct { } type ReceiveBuffer struct { - segments *skipList[packetKey, *RcvSegment] // Store out-of-order segments - nextOffset uint64 // Next expected offset - capacity int // Max buffer size - size int // Current size - mu *sync.Mutex - acks []Ack + segments *skipList[packetKey, *RcvSegment] // Store out-of-order segments + nextOffset uint64 // Next expected offset + capacity int // Max buffer size + size int // Current size + mu *sync.Mutex + acks []Ack + dataAvailable chan struct{} // Signal that data is available } func NewReceiveBuffer(capacity int) *ReceiveBuffer { return &ReceiveBuffer{ - segments: newSortedHashMap[packetKey, *RcvSegment](func(a, b packetKey) bool { return a.less(b) }), - capacity: capacity, - mu: &sync.Mutex{}, + segments: newSortedHashMap[packetKey, *RcvSegment](func(a, b packetKey) bool { return a.less(b) }), + capacity: capacity, + mu: &sync.Mutex{}, + dataAvailable: make(chan struct{}, 1), } } @@ -60,31 +63,65 @@ func (rb *ReceiveBuffer) Insert(segment *RcvSegment) RcvInsertStatus { rb.size += dataLen + // Signal that data is available (non-blocking send) + select { + case rb.dataAvailable <- struct{}{}: + default: // Non-blocking to prevent deadlocks if someone is already waiting + } + return RcvInsertOk } -func (rb *ReceiveBuffer) RemoveOldestInOrder() *RcvSegment { +func (rb *ReceiveBuffer) RemoveOldestInOrder(ctx context.Context) (*RcvSegment, error) { rb.mu.Lock() defer rb.mu.Unlock() - // Get the Oldest segment, check if we have data in order - oldest := rb.segments.Min() - if oldest == nil || oldest.value.offset > rb.nextOffset { - return nil + for { + // Check if there is any data at all + oldest := rb.segments.Min() + if oldest == nil { + // No segments available, so wait + rb.mu.Unlock() + select { + case <-rb.dataAvailable: // Wait for new segment signal + rb.mu.Lock() + continue // Recheck segments size + case <-ctx.Done(): + rb.mu.Lock() + return nil, ctx.Err() // Context cancelled + } + } + + if oldest.value.offset == rb.nextOffset { + rb.segments.Remove(oldest.key) + rb.size -= int(oldest.key.length()) + + segment := oldest.value + if segment.offset < rb.nextOffset { + diff := rb.nextOffset - segment.offset + segment.data = segment.data[diff:] + segment.offset = rb.nextOffset + } + + rb.nextOffset = segment.offset + uint64(len(segment.data)) + return segment, nil + } else if oldest.value.offset > rb.nextOffset { + // Out of order; wait until segment offset available, signal that + rb.mu.Unlock() + select { + case <-rb.dataAvailable: + rb.mu.Lock() //get new data signal, re-lock to ensure no one modifies + continue // Recheck segments size after getting the data + case <-ctx.Done(): + rb.mu.Lock() + return nil, ctx.Err() + } + } else { + rb.segments.Remove(oldest.key) + rb.size -= int(oldest.key.length()) + // Dupe data, loop to get more data if exist + } } - - rb.segments.Remove(oldest.key) - rb.size -= int(oldest.key.length()) - - segment := oldest.value - if segment.offset < rb.nextOffset { - diff := rb.nextOffset - segment.offset - segment.data = segment.data[diff:] - segment.offset = rb.nextOffset - } - - rb.nextOffset = segment.offset + uint64(len(segment.data)) - return segment } func (rb *ReceiveBuffer) Size() int { diff --git a/rcv_test.go b/rcv_test.go index edee5ba..6cee0da 100644 --- a/rcv_test.go +++ b/rcv_test.go @@ -1,6 +1,8 @@ package tomtp import ( + "context" + "errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "testing" @@ -92,15 +94,22 @@ func TestReceiveBuffer(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rb := NewReceiveBuffer(tt.capacity) + ctx, _ := context.WithTimeout(context.Background(), 0) for i, seg := range tt.segments { status := rb.Insert(seg) - assert.Equal(t, tt.wantInsertStatus[i], status) + assert.Equal(t, status, tt.wantInsertStatus[i]) } var got []*RcvSegment for { - seg := rb.RemoveOldestInOrder() + seg, err := rb.RemoveOldestInOrder(ctx) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + break + } + t.Fatal(err) + } if seg == nil { break } diff --git a/snd.go b/snd.go index a2dba23..57cc137 100644 --- a/snd.go +++ b/snd.go @@ -1,6 +1,7 @@ package tomtp import ( + "context" "errors" "sync" ) @@ -57,15 +58,13 @@ type StreamBuffer struct { } type SendBuffer struct { - streams *linkedHashMap[uint32, *StreamBuffer] // Changed to LinkedHashMap - //for round-robin, make sure we continue where we left - lastReadToSendStream uint32 + streams *linkedHashMap[uint32, *StreamBuffer] // Changed to LinkedHashMap + lastReadToSendStream uint32 //for round-robin, we continue where we left lastReadToRetransmitStream uint32 - //len(data) of all streams cannot become larger than capacity. With this we can throttle sending - capacity int - //len(data) of all streams - totalSize int - mu *sync.Mutex + capacity int //len(data) of all streams cannot become larger than capacity + totalSize int //len(data) of all streams + capacityAvailable chan struct{} // Signal that capacity is now available + mu *sync.Mutex } func NewStreamBuffer() *StreamBuffer { @@ -77,21 +76,30 @@ func NewStreamBuffer() *StreamBuffer { func NewSendBuffer(capacity int) *SendBuffer { return &SendBuffer{ - streams: newLinkedHashMap[uint32, *StreamBuffer](), - capacity: capacity, - mu: &sync.Mutex{}, + streams: newLinkedHashMap[uint32, *StreamBuffer](), + capacity: capacity, + capacityAvailable: make(chan struct{}, 1), // Buffered channel of size 1 + mu: &sync.Mutex{}, } } // Insert stores the data in the dataMap -func (sb *SendBuffer) Insert(streamId uint32, data []byte) bool { +func (sb *SendBuffer) Insert(ctx context.Context, streamId uint32, data []byte) error { + dataLen := len(data) sb.mu.Lock() - defer sb.mu.Unlock() - // Check capacity - dataLen := len(data) - if sb.capacity < sb.totalSize+dataLen { - return false + //Blocking wait until totalSize < capacity + for sb.capacity < sb.totalSize+dataLen { + sb.mu.Unlock() // Unlock before waiting + select { + case <-sb.capacityAvailable: // Wait for signal + sb.mu.Lock() // Re-lock after signal + continue + case <-ctx.Done(): + sb.mu.Lock() //Re-Lock to prevent memory corruption + sb.mu.Unlock() + return ctx.Err() // Return if context is cancelled + } } // Get or create stream buffer @@ -108,7 +116,8 @@ func (sb *SendBuffer) Insert(streamId uint32, data []byte) bool { stream.unsentOffset = stream.unsentOffset + uint64(dataLen) sb.totalSize += dataLen - return true + sb.mu.Unlock() // Unlock after signal is received + return nil } // ReadyToSend finds unsent data and creates a range entry for tracking @@ -264,10 +273,10 @@ func (sb *SendBuffer) ReadyToRetransmit(mtu uint16, rto uint64, nowMillis2 uint6 // AcknowledgeRange handles acknowledgment of data func (sb *SendBuffer) AcknowledgeRange(streamId uint32, offset uint64, length uint16) uint64 { sb.mu.Lock() - defer sb.mu.Unlock() streamPair := sb.streams.Get(streamId) if streamPair == nil { + sb.mu.Unlock() return 0 } stream := streamPair.value @@ -277,6 +286,7 @@ func (sb *SendBuffer) AcknowledgeRange(streamId uint32, offset uint64, length ui rangePair := stream.dataInFlightMap.Remove(key) if rangePair == nil { + sb.mu.Unlock() return 0 } @@ -300,15 +310,21 @@ func (sb *SendBuffer) AcknowledgeRange(streamId uint32, offset uint64, length ui if nextRange == nil { // No gap, safe to Remove all data stream.data = stream.data[stream.sentOffset-stream.bias:] - stream.bias += stream.sentOffset - sb.totalSize -= int(stream.sentOffset) + stream.bias += stream.sentOffset - stream.bias + sb.totalSize -= int(stream.sentOffset - stream.bias) } else { nextOffset := nextRange.key.offset() stream.data = stream.data[nextOffset-stream.bias:] stream.bias += nextOffset sb.totalSize -= int(nextOffset) } + // Broadcast capacity availability + select { + case sb.capacityAvailable <- struct{}{}: //Signal the release + default: // Non-blocking send to avoid blocking when the channel is full + // another goroutine is already aware of this, skipping + } } - + sb.mu.Unlock() return uint64(firstSentTime) } diff --git a/snd_test.go b/snd_test.go index 7e42a69..5d2ac79 100644 --- a/snd_test.go +++ b/snd_test.go @@ -1,18 +1,22 @@ package tomtp import ( - "github.com/stretchr/testify/require" + "context" "math" "testing" + "time" + + "github.com/stretchr/testify/require" ) func TestInsert(t *testing.T) { assert := require.New(t) sb := NewSendBuffer(1000) + ctx := context.Background() // Basic insert - ret := sb.Insert(1, []byte("test")) - assert.Equal(true, ret) + err := sb.Insert(ctx, 1, []byte("test")) + assert.Nil(err) // Verify stream created correctly stream := sb.streams.Get(1).value @@ -23,19 +27,21 @@ func TestInsert(t *testing.T) { // Test capacity limit sb = NewSendBuffer(3) - ret = sb.Insert(1, []byte("test")) - assert.Equal(false, ret) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) // Timeout + defer cancel() - ret = sb.Insert(1, []byte("test")) - assert.Equal(false, ret) + err = sb.Insert(ctx, 1, []byte("test")) + assert.Error(err) + assert.Equal(context.DeadlineExceeded, err) // Test 48-bit wrapping sb = NewSendBuffer(1000) stream = NewStreamBuffer() stream.unsentOffset = math.MaxUint64 - 2 sb.streams.Put(1, stream) - ret = sb.Insert(1, []byte("test")) - assert.Equal(true, ret) + err = sb.Insert(context.Background(), 1, []byte("test")) + assert.Nil(err) // Should succeed now + assert.Equal(uint64(1), stream.unsentOffset) assert.Equal(uint64(0), stream.sentOffset) } @@ -43,10 +49,11 @@ func TestInsert(t *testing.T) { func TestReadyToSend(t *testing.T) { assert := require.New(t) sb := NewSendBuffer(1000) + ctx := context.Background() // Insert data - sb.Insert(1, []byte("test1")) - sb.Insert(2, []byte("test2")) + sb.Insert(ctx, 1, []byte("test1")) + sb.Insert(ctx, 2, []byte("test2")) // Basic send streamId, offset, data, err := sb.ReadyToSend(10, 100) @@ -65,7 +72,7 @@ func TestReadyToSend(t *testing.T) { sb.ReadyToSend(10, 100) // Test MTU limiting - sb.Insert(3, []byte("toolongdata")) + sb.Insert(ctx, 3, []byte("toolongdata")) streamId, offset, data, err = sb.ReadyToSend(4, 100) assert.NoError(err) assert.Equal(uint32(3), streamId) @@ -81,10 +88,11 @@ func TestReadyToSend(t *testing.T) { func TestReadyToRetransmit(t *testing.T) { assert := require.New(t) sb := NewSendBuffer(1000) + ctx := context.Background() // Setup test data - sb.Insert(1, []byte("test1")) - sb.Insert(2, []byte("test2")) + sb.Insert(ctx, 1, []byte("test1")) + sb.Insert(ctx, 2, []byte("test2")) sb.ReadyToSend(10, 100) sb.ReadyToSend(10, 100) @@ -104,7 +112,7 @@ func TestReadyToRetransmit(t *testing.T) { // Test MTU split sb = NewSendBuffer(1000) - sb.Insert(1, []byte("testdata")) + sb.Insert(ctx, 1, []byte("testdata")) sb.ReadyToSend(8, 100) streamId, offset, data = sb.ReadyToRetransmit(4, 50, 200) @@ -120,7 +128,8 @@ func TestReadyToRetransmit(t *testing.T) { func TestAcknowledgeRangeBasic(t *testing.T) { assert := require.New(t) sb := NewSendBuffer(1000) - sb.Insert(1, []byte("testdata")) + ctx := context.Background() + sb.Insert(ctx, 1, []byte("testdata")) sb.ReadyToSend(4, 100) assert.Equal(uint64(100), sb.AcknowledgeRange(1, 0, 4)) stream := sb.streams.Get(1).value @@ -147,17 +156,18 @@ func TestSendBufferIntegration(t *testing.T) { t.Run("edge cases with varying MTU and data sizes", func(t *testing.T) { sb := NewSendBuffer(1000) + ctx := context.Background() // Test case 1: Insert data near MaxUint48 boundary stream1Data := make([]byte, 100) - sb.Insert(1, stream1Data) + sb.Insert(ctx, 1, stream1Data) stream := sb.streams.Get(1).value stream.unsentOffset = math.MaxUint64 - 49 // Insert data that will wrap around wrapData := make([]byte, 100) - ret := sb.Insert(1, wrapData) - assert.Equal(true, ret) + err := sb.Insert(ctx, 1, wrapData) + assert.Nil(err) assert.Equal(uint64(50), stream.unsentOffset) // Test case 2: Multiple MTU sizes for ReadyToSend @@ -174,7 +184,7 @@ func TestSendBufferIntegration(t *testing.T) { // Test case 3: Retransmission with MTU changes // First send with large MTU sb = NewSendBuffer(1000) - sb.Insert(1, []byte("thisislongdataforretransmission")) + sb.Insert(ctx, 1, []byte("thisislongdataforretransmission")) streamId, _, _, err = sb.ReadyToSend(30, 100) assert.NoError(err) @@ -182,11 +192,12 @@ func TestSendBufferIntegration(t *testing.T) { streamId, _, data = sb.ReadyToRetransmit(10, 50, 200) assert.Equal(uint32(1), streamId) assert.Equal(10, len(data)) + assert.Equal("thisislong", string(data)) // Test case 4: Out of order acknowledgments sb = NewSendBuffer(1000) testData := []byte("testdatafortestingacks") - sb.Insert(1, testData) + sb.Insert(ctx, 1, testData) // Send in chunks sb.ReadyToSend(5, 100) // "testd" @@ -205,9 +216,9 @@ func TestSendBufferIntegration(t *testing.T) { sb = NewSendBuffer(1000) // Insert into multiple streams - sb.Insert(1, []byte("stream1data")) - sb.Insert(2, []byte("stream2data")) - sb.Insert(3, []byte("stream3data")) + sb.Insert(ctx, 1, []byte("stream1data")) + sb.Insert(ctx, 2, []byte("stream2data")) + sb.Insert(ctx, 3, []byte("stream3data")) // Send from different streams with varying MTUs _, _, data1, _ := sb.ReadyToSend(5, 100) @@ -221,13 +232,14 @@ func TestSendBufferIntegration(t *testing.T) { // Retransmit with different MTUs _, _, retrans1 := sb.ReadyToRetransmit(3, 50, 200) assert.Equal(3, len(retrans1)) + assert.Equal("str", string(retrans1)) // Test case 6: Edge case - acknowledge empty range assert.Equal(uint64(0), sb.AcknowledgeRange(1, 0, 0)) // Test case 7: Complex out-of-order acknowledgments with gaps sb = NewSendBuffer(1000) - sb.Insert(1, []byte("abcdefghijklmnopqrstuvwxyz")) + sb.Insert(ctx, 1, []byte("abcdefghijklmnopqrstuvwxyz")) // Send in multiple chunks sb.ReadyToSend(5, 100) // "abcde" @@ -243,9 +255,12 @@ func TestSendBufferIntegration(t *testing.T) { assert.Equal(uint64(100), sb.AcknowledgeRange(1, 20, 5)) // "uvwxy" assert.Equal(uint64(100), sb.AcknowledgeRange(1, 10, 5)) // "klmno" + stream = sb.streams.Get(1).value + assert.Equal(uint64(25), stream.bias) + // Test case 8: Out-of-order retransmissions with varying MTUs sb = NewSendBuffer(1000) - sb.Insert(1, []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) + sb.Insert(ctx, 1, []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) // Initial send with MTU 10 sb.ReadyToSend(10, 100) // "ABCDEFGHIJ" @@ -274,9 +289,29 @@ func TestSendBufferIntegration(t *testing.T) { // Test case 9: Edge case - maximum MTU sb = NewSendBuffer(1000) - sb.Insert(1, make([]byte, 65535)) - _, _, maxMtuData, _ := sb.ReadyToSend(65535, 100) - assert.Equal(0, len(maxMtuData)) + ctxLimited, cancelLimited := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelLimited() + err = sb.Insert(ctxLimited, 1, make([]byte, 65535)) // This should timeout and return an error! + assert.Error(err) + + // Test case 10: Acknowledge large buffer + sb = NewSendBuffer(1000) + ctx2 := context.Background() + + longTestData := make([]byte, 500) + err = sb.Insert(ctx2, 1, longTestData) + assert.NoError(err) + sb.ReadyToSend(500, 100) + + sb.AcknowledgeRange(1, 0, 500) + + sb = NewSendBuffer(1024) + ctxEdgeCases := context.Background() + err = sb.Insert(ctxEdgeCases, 1, []byte("test")) + assert.NoError(err) + stream = sb.streams.Get(1).value + sb.AcknowledgeRange(1, 0, 0) + }) } diff --git a/stream.go b/stream.go index 48fe0e5..a6cd259 100644 --- a/stream.go +++ b/stream.go @@ -1,6 +1,7 @@ package tomtp import ( + "context" "errors" "io" "log/slog" @@ -46,9 +47,10 @@ type Stream struct { closeInitiated bool closePending bool - mu sync.Mutex - closeOnce sync.Once - cond *sync.Cond + closeCtx context.Context + closeCancelFn context.CancelFunc + + mu sync.Mutex } func (s *Stream) Write(b []byte) (nTot int, err error) { @@ -58,7 +60,9 @@ func (s *Stream) Write(b []byte) (nTot int, err error) { slog.Debug("Write", debugGoroutineID(), s.debug(), slog.String("b...", string(b[:min(10, len(b))]))) for len(b) > 0 { - enc, n, err := s.encode(b) + var enc []byte + var n int + enc, n, err = s.encode(b) if err != nil { return nTot, err } @@ -66,8 +70,9 @@ func (s *Stream) Write(b []byte) (nTot int, err error) { if n == 0 { break } - if !s.conn.rbSnd.Insert(s.streamId, enc) { - break + err = s.conn.rbSnd.Insert(s.closeCtx, s.streamId, enc) + if err != nil { + return nTot, err } // Signal the listener that there is data to send @@ -87,7 +92,10 @@ func (s *Stream) Read(b []byte) (n int, err error) { slog.Debug("read data start", debugGoroutineID(), s.debug()) - segment := s.rbRcv.RemoveOldestInOrder() + segment, err := s.rbRcv.RemoveOldestInOrder(s.closeCtx) + if err != nil { + return 0, err + } if segment == nil { if s.state >= StreamEnded { return 0, io.EOF @@ -107,12 +115,9 @@ func (s *Stream) ReadBytes() (b []byte, err error) { slog.Debug("read data start", debugGoroutineID(), s.debug()) - segment := s.rbRcv.RemoveOldestInOrder() - if segment == nil { - if s.state >= StreamEnded { - return nil, io.EOF - } - return nil, nil + segment, err := s.rbRcv.RemoveOldestInOrder(s.closeCtx) + if err != nil { + return nil, err } s.bytesRead += uint64(len(segment.data)) @@ -128,6 +133,7 @@ func (s *Stream) Close() error { } s.state = StreamEnding + s.closeCancelFn() return nil }