Skip to content

Commit

Permalink
🔄 Improve signal handling in stdio server
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 21, 2025
1 parent 3c8caf6 commit db58153
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 29 deletions.
11 changes: 10 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,13 @@ Enhanced command server shutdown handling:
- Added fallback to Process.Kill() if interrupt signal fails
- Improved error handling for process termination
- Added detailed debug logging with process IDs
- Fixed issue with programmatic interrupt signals not working
- Fixed issue with programmatic interrupt signals not working

# Improved Signal Handling

Enhanced signal handling in stdio server:

- Added direct signal handling in stdio server
- Fixed issue with signals not breaking scanner reads
- Added detailed debug logging for signal flow
- Improved shutdown coordination between scanner and signals
97 changes: 69 additions & 28 deletions pkg/server/transports/stdio/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
"io"
"os"
"os/signal"
"syscall"
"time"

"github.com/go-go-golems/go-go-mcp/pkg/protocol"
Expand All @@ -24,6 +26,7 @@ type Server struct {
toolService services.ToolService
initializeService services.InitializeService
done chan struct{}
signalChan chan os.Signal
}

// NewServer creates a new stdio server instance
Expand Down Expand Up @@ -52,45 +55,83 @@ func NewServer(logger zerolog.Logger, ps services.PromptService, rs services.Res
toolService: ts,
initializeService: is,
done: make(chan struct{}),
signalChan: make(chan os.Signal, 1),
}
}

// Start begins listening for and handling messages on stdio
func (s *Server) Start(ctx context.Context) error {
s.logger.Info().Msg("Starting stdio server...")

// Process messages until stdin is closed, stop is called, or context is cancelled
for s.scanner.Scan() {
select {
case <-s.done:
s.logger.Debug().Msg("Received done signal, stopping stdio server")
return nil
case <-ctx.Done():
s.logger.Debug().
Err(ctx.Err()).
Msg("Context cancelled, stopping stdio server")
return ctx.Err()
default:
line := s.scanner.Text()
s.logger.Debug().
Str("line", line).
Msg("Received line")
if err := s.handleMessage(line); err != nil {
s.logger.Error().Err(err).Msg("Error handling message")
// Continue processing messages even if one fails
// Set up signal handling
signal.Notify(s.signalChan, os.Interrupt, syscall.SIGTERM)
defer signal.Stop(s.signalChan)

// Create a channel for scanner errors
scanErrChan := make(chan error, 1)

// Start scanning in a goroutine
go func() {
for s.scanner.Scan() {
select {
case <-s.done:
s.logger.Debug().Msg("Received done signal, stopping scanner")
scanErrChan <- nil
return
case <-ctx.Done():
s.logger.Debug().
Err(ctx.Err()).
Msg("Context cancelled, stopping scanner")
scanErrChan <- ctx.Err()
return
default:
line := s.scanner.Text()
s.logger.Debug().
Str("line", line).
Msg("Received line")
if err := s.handleMessage(line); err != nil {
s.logger.Error().Err(err).Msg("Error handling message")
// Continue processing messages even if one fails
}
}
}
}

if err := s.scanner.Err(); err != nil {
s.logger.Error().
Err(err).
Msg("Scanner error")
return fmt.Errorf("scanner error: %w", err)
}
if err := s.scanner.Err(); err != nil {
s.logger.Error().
Err(err).
Msg("Scanner error")
scanErrChan <- fmt.Errorf("scanner error: %w", err)
return
}

s.logger.Debug().Msg("Scanner reached EOF")
return io.EOF
s.logger.Debug().Msg("Scanner reached EOF")
scanErrChan <- io.EOF
}()

// Wait for either a signal, context cancellation, or scanner error
select {
case sig := <-s.signalChan:
s.logger.Debug().
Str("signal", sig.String()).
Msg("Received signal in stdio server")
close(s.done)
return nil
case <-ctx.Done():
s.logger.Debug().
Err(ctx.Err()).
Msg("Context cancelled in stdio server")
close(s.done)
return ctx.Err()
case err := <-scanErrChan:
if err == io.EOF {
s.logger.Debug().Msg("Scanner completed normally")
} else if err != nil {
s.logger.Error().
Err(err).
Msg("Scanner error in stdio server")
}
return err
}
}

// Stop gracefully stops the stdio server
Expand Down

0 comments on commit db58153

Please sign in to comment.