diff --git a/const.go b/const.go index 93fb83f..3ecba41 100644 --- a/const.go +++ b/const.go @@ -45,10 +45,13 @@ func (e *GoAwayError) Temporary() bool { func (e *GoAwayError) Is(target error) bool { // to maintain compatibility with errors returned by previous versions - if e.Remote && target == ErrRemoteGoAway { + if e.Remote && target == ErrRemoteGoAwayNormal { return true } else if !e.Remote && target == ErrSessionShutdown { return true + } else if target == ErrStreamReset { + // A GoAway on a connection also resets all the streams. + return true } if err, ok := target.(*GoAwayError); ok { @@ -111,8 +114,8 @@ var ( // ErrUnexpectedFlag is set when we get an unexpected flag ErrUnexpectedFlag = &Error{msg: "unexpected flag"} - // ErrRemoteGoAway is used when we get a go away from the other side - ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} + // ErrRemoteGoAwayNormal is used when we get a go away from the other side + ErrRemoteGoAwayNormal = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} // ErrStreamReset is sent if a stream is reset. This can happen // if the backlog is exceeded, or if there was a remote GoAway. diff --git a/session.go b/session.go index bbecf19..06d20fa 100644 --- a/session.go +++ b/session.go @@ -46,9 +46,9 @@ var nullMemoryManager = &nullMemoryManagerImpl{} type Session struct { rtt int64 // to be accessed atomically, in nanoseconds - // remoteGoAway indicates the remote side does + // remoteGoAwayNormal indicates the remote side does // not want futher connections. Must be first for alignment. - remoteGoAway int32 + remoteGoAwayNormal int32 // localGoAway indicates that we should stop // accepting futher connections. Must be first for alignment. @@ -205,8 +205,8 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) { if s.IsClosed() { return nil, s.shutdownErr } - if atomic.LoadInt32(&s.remoteGoAway) == 1 { - return nil, ErrRemoteGoAway + if atomic.LoadInt32(&s.remoteGoAwayNormal) == 1 { + return nil, ErrRemoteGoAwayNormal } // Block if we have too many inflight SYNs @@ -285,15 +285,15 @@ func (s *Session) AcceptStream() (*Stream, error) { } } -// Close is used to close the session and all streams. -// Attempts to send a GoAway before closing the connection. The GoAway may not actually be sent depending on the -// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or -// if there's unread data in the kernel receive buffer. +// Close is used to close the session and all streams. It doesn't send a GoAway before +// closing the connection. func (s *Session) Close() error { return s.close(ErrSessionShutdown, false, goAwayNormal) } // CloseWithError is used to close the session and all streams after sending a GoAway message with errCode. +// Blocks for ConnectionWriteTimeout to write the GoAway message. +// // The GoAway may not actually be sent depending on the semantics of the underlying net.Conn. // For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel // receive buffer. @@ -315,7 +315,8 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro close(s.shutdownCh) s.stopKeepalive() - if sendGoAway { + // Only send GoAway if we have an error code. + if sendGoAway && errCode != goAwayNormal { // wait for write loop to exit // We need to write the current frame completely before sending a goaway. // This will wait for at most s.config.ConnectionWriteTimeout @@ -334,7 +335,7 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro s.streamLock.Lock() defer s.streamLock.Unlock() for id, stream := range s.streams { - stream.forceClose(fmt.Errorf("%w: connection closed: %w", ErrStreamReset, s.shutdownErr)) + stream.forceClose(s.shutdownErr) delete(s.streams, id) stream.memorySpan.Done() } @@ -814,7 +815,7 @@ func (s *Session) handleGoAway(hdr header) error { code := hdr.Length() switch code { case goAwayNormal: - atomic.SwapInt32(&s.remoteGoAway, 1) + atomic.SwapInt32(&s.remoteGoAwayNormal, 1) // Don't close connection on normal go away. Let the existing streams // complete gracefully. return nil diff --git a/session_test.go b/session_test.go index 2c06abb..dc6c3f0 100644 --- a/session_test.go +++ b/session_test.go @@ -651,7 +651,7 @@ func TestGoAway(t *testing.T) { switch err { case nil: s.Close() - case ErrRemoteGoAway: + case ErrRemoteGoAwayNormal: return default: t.Fatalf("err: %v", err) diff --git a/stream.go b/stream.go index e79562d..0835165 100644 --- a/stream.go +++ b/stream.go @@ -310,7 +310,7 @@ func (s *Stream) CloseWrite() error { return nil case halfReset: s.stateLock.Unlock() - return ErrStreamReset + return s.writeErr default: panic("invalid state") } @@ -331,7 +331,8 @@ func (s *Stream) CloseWrite() error { return err } -// CloseRead is used to close the stream for writing. +// CloseRead is used to close the stream for reading. +// Note: Remote is not notified. func (s *Stream) CloseRead() error { cleanup := false s.stateLock.Lock()