Skip to content

Commit

Permalink
🐛 Fix SSE server hanging during shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 21, 2025
1 parent cb56757 commit 7815ef4
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 6 deletions.
11 changes: 10 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,13 @@ Added proper response types for list operations to ensure consistent JSON encodi
- Fixed type mismatches between service interfaces and response types
- Improved type handling with proper interface conversions
- Ensured empty arrays are always returned instead of null
- Fixed JSON response structure to match API specification
- Fixed JSON response structure to match API specification

# Fix SSE Server Shutdown

Fixed an issue where the SSE server would hang during shutdown due to improper handling of client connections.

- Added proper context cancellation for client goroutines
- Added WaitGroup to track and wait for client goroutines to finish
- Improved shutdown coordination between HTTP server and client cleanup
- Added timeout handling for client goroutine cleanup
48 changes: 43 additions & 5 deletions pkg/server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type SSEServer struct {
toolService services.ToolService
initializeService services.InitializeService
nextClientID int
wg sync.WaitGroup
cancel context.CancelFunc
}

type SSEClient struct {
Expand Down Expand Up @@ -71,6 +73,9 @@ func NewSSEServer(logger zerolog.Logger, ps services.PromptService, rs services.

// Start begins the SSE server
func (s *SSEServer) Start(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
s.cancel = cancel

r := mux.NewRouter()

// SSE endpoint for clients to establish connection
Expand Down Expand Up @@ -102,25 +107,54 @@ func (s *SSEServer) Start(ctx context.Context) error {
case err := <-errChan:
return err
case <-ctx.Done():
return s.Stop(ctx)
return s.Stop(context.Background())
}
}

// Stop gracefully stops the SSE server
func (s *SSEServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()

if s.server != nil {
s.logger.Info().Msg("Stopping SSE server")

// Cancel all client goroutines
if s.cancel != nil {
s.cancel()
}

// Close all client connections
for sessionID, ch := range s.clients {
for sessionID, client := range s.clients {
s.logger.Debug().Str("sessionId", sessionID).Msg("Closing client connection")
close(ch.messageChan)
close(client.messageChan)
delete(s.clients, sessionID)
}
return s.server.Shutdown(ctx)

s.mu.Unlock()

// Wait for all client goroutines to finish with a timeout
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()

select {
case <-done:
s.logger.Debug().Msg("All client goroutines finished")
case <-ctx.Done():
s.logger.Warn().Msg("Timeout waiting for client goroutines")
}

// Shutdown the HTTP server
if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("error shutting down server: %w", err)
}

return nil
}

s.mu.Unlock()
return nil
}

Expand Down Expand Up @@ -180,6 +214,10 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint)
w.(http.Flusher).Flush()

// Add to waitgroup before starting goroutine
s.wg.Add(1)
defer s.wg.Done()

defer func() {
s.mu.Lock()
if c, exists := s.clients[clientID]; exists {
Expand Down

0 comments on commit 7815ef4

Please sign in to comment.