Skip to content

Commit

Permalink
Merge branch 'sukun/stream-error-code' into sukun/conn-error-2
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Sep 12, 2024
2 parents 5727def + 9190b78 commit 43cd707
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 11 deletions.
21 changes: 21 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ func (e *GoAwayError) Is(target error) bool {
return false

Check warning on line 57 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L57

Added line #L57 was not covered by tests
}

// A StreamError is used for errors returned from Read and Write calls after the stream is Reset
type StreamError struct {
ErrorCode uint32
Remote bool
}

func (s *StreamError) Error() string {
if s.Remote {
return fmt.Sprintf("stream reset by remote, error code: %d", s.ErrorCode)

Check warning on line 68 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L66-L68

Added lines #L66 - L68 were not covered by tests
}
return fmt.Sprintf("stream reset, error code: %d", s.ErrorCode)

Check warning on line 70 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L70

Added line #L70 was not covered by tests
}

func (s *StreamError) Is(target error) bool {
if target == ErrStreamReset {
return true
}
e, ok := target.(*StreamError)
return ok && *e == *s

Check warning on line 78 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L77-L78

Added lines #L77 - L78 were not covered by tests
}

var (
// ErrInvalidVersion means we received a frame with an
// invalid version
Expand Down
2 changes: 1 addition & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro
s.streamLock.Lock()
defer s.streamLock.Unlock()
for id, stream := range s.streams {
stream.forceClose()
stream.forceClose(fmt.Errorf("%w: connection closed: %w", ErrStreamReset, s.shutdownErr))
delete(s.streams, id)
stream.memorySpan.Done()
}
Expand Down
53 changes: 52 additions & 1 deletion session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -1571,6 +1572,56 @@ func TestStreamResetRead(t *testing.T) {
wc.Wait()
}

func TestStreamResetWithError(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

wc := new(sync.WaitGroup)
wc.Add(2)
go func() {
defer wc.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Error(err)
}

se := &StreamError{}
_, err = io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: true, ErrorCode: 42}
assert.Equal(t, se, expected)
}()

stream, err := client.OpenStream(context.Background())
if err != nil {
t.Error(err)
}

go func() {
defer wc.Done()

se := &StreamError{}
_, err := io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: false, ErrorCode: 42}
assert.Equal(t, se, expected)
}()

time.Sleep(1 * time.Second)
err = stream.ResetWithError(42)
if err != nil {
t.Fatal(err)
}
wc.Wait()
}

func TestLotsOfWritesWithStreamDeadline(t *testing.T) {
config := testConf()
config.EnableKeepAlive = false
Expand Down Expand Up @@ -1809,7 +1860,7 @@ func TestMaxIncomingStreams(t *testing.T) {
require.NoError(t, err)
str.SetDeadline(time.Now().Add(time.Second))
_, err = str.Read([]byte{0})
require.EqualError(t, err, "stream reset")
require.ErrorIs(t, err, ErrStreamReset)

// Now close one of the streams.
// This should then allow the client to open a new stream.
Expand Down
37 changes: 28 additions & 9 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type Stream struct {

state streamState
writeState, readState halfStreamState
writeErr, readErr error
stateLock sync.Mutex

recvBuf segmentedBuffer
Expand Down Expand Up @@ -89,6 +90,7 @@ func (s *Stream) Read(b []byte) (n int, err error) {
START:
s.stateLock.Lock()
state := s.readState
resetErr := s.readErr
s.stateLock.Unlock()

switch state {
Expand All @@ -101,7 +103,7 @@ START:
}
// Closed, but we have data pending -> read.
case halfReset:
return 0, ErrStreamReset
return 0, resetErr
default:
panic("unknown state")
}
Expand Down Expand Up @@ -147,6 +149,7 @@ func (s *Stream) write(b []byte) (n int, err error) {
START:
s.stateLock.Lock()
state := s.writeState
resetErr := s.writeErr
s.stateLock.Unlock()

switch state {
Expand All @@ -155,7 +158,7 @@ START:
case halfClosed:
return 0, ErrStreamClosed
case halfReset:
return 0, ErrStreamReset
return 0, resetErr
default:
panic("unknown state")
}
Expand Down Expand Up @@ -250,13 +253,17 @@ func (s *Stream) sendClose() error {
}

// sendReset is used to send a RST
func (s *Stream) sendReset() error {
hdr := encode(typeWindowUpdate, flagRST, s.id, 0)
func (s *Stream) sendReset(errCode uint32) error {
hdr := encode(typeWindowUpdate, flagRST, s.id, errCode)
return s.session.sendMsg(hdr, nil, nil)
}

// Reset resets the stream (forcibly closes the stream)
func (s *Stream) Reset() error {
return s.ResetWithError(0)
}

func (s *Stream) ResetWithError(errCode uint32) error {
sendReset := false
s.stateLock.Lock()
switch s.state {
Expand All @@ -276,15 +283,17 @@ func (s *Stream) Reset() error {
// If we've already sent/received an EOF, no need to reset that side.
if s.writeState == halfOpen {
s.writeState = halfReset
s.writeErr = &StreamError{Remote: false, ErrorCode: errCode}
}
if s.readState == halfOpen {
s.readState = halfReset
s.readErr = &StreamError{Remote: false, ErrorCode: errCode}
}
s.state = streamFinished
s.notifyWaiting()
s.stateLock.Unlock()
if sendReset {
_ = s.sendReset()
_ = s.sendReset(errCode)
}
s.cleanup()
return nil
Expand Down Expand Up @@ -336,6 +345,7 @@ func (s *Stream) CloseRead() error {
panic("invalid state")
}
s.readState = halfReset
s.readErr = ErrStreamReset
cleanup = s.writeState != halfOpen
if cleanup {
s.state = streamFinished
Expand All @@ -357,13 +367,15 @@ func (s *Stream) Close() error {
}

// forceClose is used for when the session is exiting
func (s *Stream) forceClose() {
func (s *Stream) forceClose(err error) {
s.stateLock.Lock()
if s.readState == halfOpen {
s.readState = halfReset
s.readErr = err
}
if s.writeState == halfOpen {
s.writeState = halfReset
s.writeErr = err
}
s.state = streamFinished
s.notifyWaiting()
Expand All @@ -382,7 +394,7 @@ func (s *Stream) cleanup() {

// processFlags is used to update the state of the stream
// based on set flags, if any. Lock must be held
func (s *Stream) processFlags(flags uint16) {
func (s *Stream) processFlags(flags uint16, hdr header) {
// Close the stream without holding the state lock
var closeStream bool
defer func() {
Expand Down Expand Up @@ -418,11 +430,18 @@ func (s *Stream) processFlags(flags uint16) {
}
if flags&flagRST == flagRST {
s.stateLock.Lock()
var resetErr error = ErrStreamReset
// Length in a window update frame with RST flag encodes an error code.
if hdr.MsgType() == typeWindowUpdate {
resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()}
}
if s.readState == halfOpen {
s.readState = halfReset
s.readErr = resetErr
}
if s.writeState == halfOpen {
s.writeState = halfReset
s.writeErr = resetErr
}
s.state = streamFinished
s.stateLock.Unlock()
Expand All @@ -439,15 +458,15 @@ func (s *Stream) notifyWaiting() {

// incrSendWindow updates the size of our send window
func (s *Stream) incrSendWindow(hdr header, flags uint16) {
s.processFlags(flags)
s.processFlags(flags, hdr)
// Increase window, unblock a sender
atomic.AddUint32(&s.sendWindow, hdr.Length())
asyncNotify(s.sendNotifyCh)
}

// readData is used to handle a data frame
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
s.processFlags(flags)
s.processFlags(flags, hdr)

// Check that our recv window is not exceeded
length := hdr.Length()
Expand Down

0 comments on commit 43cd707

Please sign in to comment.