Skip to content

Commit

Permalink
change CloseWithError to CloseWithErrorChan
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Nov 21, 2024
1 parent 39abe7e commit 4ad8059
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 30 deletions.
64 changes: 35 additions & 29 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,25 +281,26 @@ func (s *Session) AcceptStream() (*Stream, error) {
// Close is used to close the session and all streams. It doesn't send a GoAway before
// closing the connection.
func (s *Session) Close() error {
return s.close(ErrSessionShutdown, false, goAwayNormal)
return <-s.close(ErrSessionShutdown, false, goAwayNormal)
}

// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
// CloseWithErrorChan is used to close the session and all streams after sending a GoAway message with errCode.
// Blocks for ConnectionWriteTimeout to write the GoAway message.
//
// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
// receive buffer.
func (s *Session) CloseWithError(errCode uint32) error {
func (s *Session) CloseWithErrorChan(errCode uint32) chan error {
return s.close(&GoAwayError{Remote: false, ErrorCode: errCode}, true, errCode)
}

func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) error {
func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) chan error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()

errCh := make(chan error, 1)
if s.shutdown {
return nil
errCh <- nil
return errCh
}
s.shutdown = true
if s.shutdownErr == nil {
Expand All @@ -308,35 +309,43 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro
close(s.shutdownCh)
s.stopKeepalive()

// Only send GoAway if we have an error code.
if sendGoAway && errCode != goAwayNormal {
// wait for write loop to exit
// We need to write the current frame completely before sending a goaway.
// This will wait for at most s.config.ConnectionWriteTimeout
<-s.sendDoneCh
ga := s.goAway(errCode)
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
}
s.conn.SetWriteDeadline(time.Time{})
}

s.conn.Close()
<-s.sendDoneCh
<-s.recvDoneCh

resetErr := shutdownErr
resetErr := s.shutdownErr
if _, ok := resetErr.(*GoAwayError); !ok {
resetErr = fmt.Errorf("%w: connection closed: %w", ErrStreamReset, shutdownErr)
}

s.streamLock.Lock()
defer s.streamLock.Unlock()
for id, stream := range s.streams {
stream.forceClose(resetErr)
delete(s.streams, id)
stream.memorySpan.Done()
}
return nil
s.streamLock.Unlock()

if sendGoAway {
go func() {
// wait for write loop to exit
// We need to write the current frame completely before sending a goaway.
// This will wait for at most s.config.ConnectionWriteTimeout
<-s.sendDoneCh
ga := s.goAway(errCode)
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
}
s.conn.SetWriteDeadline(time.Time{})
s.conn.Close()
<-s.sendDoneCh
<-s.recvDoneCh
errCh <- nil
}()
return errCh
}

errCh <- nil
s.conn.Close()
<-s.sendDoneCh
<-s.recvDoneCh
return errCh
}

// GoAway can be used to prevent accepting further
Expand Down Expand Up @@ -748,12 +757,10 @@ func (s *Session) handleStreamMessage(hdr header) error {
return err
}
}

// Get the stream
s.streamLock.Lock()
stream := s.streams[id]
s.streamLock.Unlock()

// If we do not have a stream, likely we sent a RST and/or closed the stream for reading.
if stream == nil {
// Drain any data on the wire
Expand Down Expand Up @@ -850,7 +857,6 @@ func (s *Session) incomingStream(id uint32) error {
return err
}
stream := newStream(s, id, streamSYNReceived, initialStreamWindow, span)

s.streamLock.Lock()
defer s.streamLock.Unlock()

Expand Down
3 changes: 2 additions & 1 deletion session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,8 @@ func TestCloseWithError(t *testing.T) {
defer client.Close()
defer server.Close()

if err := server.CloseWithError(42); err != nil {
errCh := server.CloseWithErrorChan(42)
if err := <-errCh; err != nil {
t.Fatalf("err: %v", err)
}

Expand Down

0 comments on commit 4ad8059

Please sign in to comment.