diff --git a/session.go b/session.go index 06d20fa..8b151c9 100644 --- a/session.go +++ b/session.go @@ -535,8 +535,14 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err // send is a long running goroutine that sends data func (s *Session) send() { if err := s.sendLoop(); err != nil { - // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code - // received in a GoAway frame received just before the TCP RST that closed the sendLoop + // If we are shutting down because remote closed the connection, prefer the recvLoop error + // over the sendLoop error. The receive loop might have error code received in a GoAway frame, + // which was received just before the TCP RST that closed the sendLoop. + // + // If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop. + // We hold the shutdownLock, close the connection, and wait for the receive loop to finish and + // use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close + // but the sendLoop does. s.shutdownLock.Lock() if s.shutdownErr == nil { s.conn.Close() diff --git a/session_test.go b/session_test.go index dc6c3f0..8077e18 100644 --- a/session_test.go +++ b/session_test.go @@ -1578,7 +1578,7 @@ func TestStreamResetWithError(t *testing.T) { defer server.Close() wc := new(sync.WaitGroup) - wc.Add(2) + wc.Add(1) go func() { defer wc.Done() stream, err := server.AcceptStream() @@ -1589,7 +1589,7 @@ func TestStreamResetWithError(t *testing.T) { se := &StreamError{} _, err = io.ReadAll(stream) if !errors.As(err, &se) { - t.Errorf("exptected StreamError, got type:%T, err: %s", err, err) + t.Errorf("expected StreamError, got type:%T, err: %s", err, err) return } expected := &StreamError{Remote: true, ErrorCode: 42} @@ -1601,24 +1601,19 @@ func TestStreamResetWithError(t *testing.T) { t.Error(err) } - go func() { - defer wc.Done() - - se := &StreamError{} - _, err := io.ReadAll(stream) - if !errors.As(err, &se) { - t.Errorf("exptected StreamError, got type:%T, err: %s", err, err) - return - } - expected := &StreamError{Remote: false, ErrorCode: 42} - assert.Equal(t, se, expected) - }() - time.Sleep(1 * time.Second) err = stream.ResetWithError(42) if err != nil { t.Fatal(err) } + se := &StreamError{} + _, err = io.ReadAll(stream) + if !errors.As(err, &se) { + t.Errorf("expected StreamError, got type:%T, err: %s", err, err) + return + } + expected := &StreamError{Remote: false, ErrorCode: 42} + assert.Equal(t, se, expected) wc.Wait() } diff --git a/stream.go b/stream.go index 0835165..15a8b56 100644 --- a/stream.go +++ b/stream.go @@ -395,7 +395,7 @@ func (s *Stream) cleanup() { // processFlags is used to update the state of the stream // based on set flags, if any. Lock must be held -func (s *Stream) processFlags(flags uint16, hdr header) { +func (s *Stream) processFlags(hdr header, flags uint16) { // Close the stream without holding the state lock var closeStream bool defer func() { @@ -459,7 +459,7 @@ func (s *Stream) notifyWaiting() { // incrSendWindow updates the size of our send window func (s *Stream) incrSendWindow(hdr header, flags uint16) { - s.processFlags(flags, hdr) + s.processFlags(hdr, flags) // Increase window, unblock a sender atomic.AddUint32(&s.sendWindow, hdr.Length()) asyncNotify(s.sendNotifyCh) @@ -467,7 +467,7 @@ func (s *Stream) incrSendWindow(hdr header, flags uint16) { // readData is used to handle a data frame func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { - s.processFlags(flags, hdr) + s.processFlags(hdr, flags) // Check that our recv window is not exceeded length := hdr.Length()