Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Nov 19, 2024
1 parent ede18a5 commit af8e895
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
10 changes: 8 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,14 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
// send is a long running goroutine that sends data
func (s *Session) send() {
if err := s.sendLoop(); err != nil {
// Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
// received in a GoAway frame received just before the TCP RST that closed the sendLoop
// If we are shutting down because remote closed the connection, prefer the recvLoop error
// over the sendLoop error. The receive loop might have error code received in a GoAway frame,
// which was received just before the TCP RST that closed the sendLoop.
//
// If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop.
// We hold the shutdownLock, close the connection, and wait for the receive loop to finish and
// use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close
// but the sendLoop does.
s.shutdownLock.Lock()
if s.shutdownErr == nil {
s.conn.Close()
Expand Down
25 changes: 10 additions & 15 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,7 @@ func TestStreamResetWithError(t *testing.T) {
defer server.Close()

wc := new(sync.WaitGroup)
wc.Add(2)
wc.Add(1)
go func() {
defer wc.Done()
stream, err := server.AcceptStream()
Expand All @@ -1589,7 +1589,7 @@ func TestStreamResetWithError(t *testing.T) {
se := &StreamError{}
_, err = io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
t.Errorf("expected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: true, ErrorCode: 42}
Expand All @@ -1601,24 +1601,19 @@ func TestStreamResetWithError(t *testing.T) {
t.Error(err)
}

go func() {
defer wc.Done()

se := &StreamError{}
_, err := io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: false, ErrorCode: 42}
assert.Equal(t, se, expected)
}()

time.Sleep(1 * time.Second)
err = stream.ResetWithError(42)
if err != nil {
t.Fatal(err)
}
se := &StreamError{}
_, err = io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("expected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: false, ErrorCode: 42}
assert.Equal(t, se, expected)
wc.Wait()
}

Expand Down
6 changes: 3 additions & 3 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ func (s *Stream) cleanup() {

// processFlags is used to update the state of the stream
// based on set flags, if any. Lock must be held
func (s *Stream) processFlags(flags uint16, hdr header) {
func (s *Stream) processFlags(hdr header, flags uint16) {
// Close the stream without holding the state lock
var closeStream bool
defer func() {
Expand Down Expand Up @@ -459,15 +459,15 @@ func (s *Stream) notifyWaiting() {

// incrSendWindow updates the size of our send window
func (s *Stream) incrSendWindow(hdr header, flags uint16) {
s.processFlags(flags, hdr)
s.processFlags(hdr, flags)
// Increase window, unblock a sender
atomic.AddUint32(&s.sendWindow, hdr.Length())
asyncNotify(s.sendNotifyCh)
}

// readData is used to handle a data frame
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
s.processFlags(flags, hdr)
s.processFlags(hdr, flags)

// Check that our recv window is not exceeded
length := hdr.Length()
Expand Down

0 comments on commit af8e895

Please sign in to comment.