From db5815302f5549f1130cb888ed8583e22e3b1ff0 Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Tue, 21 Jan 2025 09:15:26 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=84=20Improve=20signal=20handling=20in?= =?UTF-8?q?=20stdio=20server?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- changelog.md | 11 +++- pkg/server/transports/stdio/stdio.go | 97 ++++++++++++++++++++-------- 2 files changed, 79 insertions(+), 29 deletions(-) diff --git a/changelog.md b/changelog.md index 195e8e8..fb07d85 100644 --- a/changelog.md +++ b/changelog.md @@ -130,4 +130,13 @@ Enhanced command server shutdown handling: - Added fallback to Process.Kill() if interrupt signal fails - Improved error handling for process termination - Added detailed debug logging with process IDs -- Fixed issue with programmatic interrupt signals not working \ No newline at end of file +- Fixed issue with programmatic interrupt signals not working + +# Improved Signal Handling + +Enhanced signal handling in stdio server: + +- Added direct signal handling in stdio server +- Fixed issue with signals not breaking scanner reads +- Added detailed debug logging for signal flow +- Improved shutdown coordination between scanner and signals \ No newline at end of file diff --git a/pkg/server/transports/stdio/stdio.go b/pkg/server/transports/stdio/stdio.go index e30d8d7..909ce95 100644 --- a/pkg/server/transports/stdio/stdio.go +++ b/pkg/server/transports/stdio/stdio.go @@ -7,6 +7,8 @@ import ( "fmt" "io" "os" + "os/signal" + "syscall" "time" "github.com/go-go-golems/go-go-mcp/pkg/protocol" @@ -24,6 +26,7 @@ type Server struct { toolService services.ToolService initializeService services.InitializeService done chan struct{} + signalChan chan os.Signal } // NewServer creates a new stdio server instance @@ -52,6 +55,7 @@ func NewServer(logger zerolog.Logger, ps services.PromptService, rs services.Res toolService: ts, initializeService: is, done: make(chan struct{}), + signalChan: make(chan os.Signal, 1), } } @@ -59,38 +63,75 @@ func NewServer(logger zerolog.Logger, ps services.PromptService, rs services.Res func (s *Server) Start(ctx context.Context) error { s.logger.Info().Msg("Starting stdio server...") - // Process messages until stdin is closed, stop is called, or context is cancelled - for s.scanner.Scan() { - select { - case <-s.done: - s.logger.Debug().Msg("Received done signal, stopping stdio server") - return nil - case <-ctx.Done(): - s.logger.Debug(). - Err(ctx.Err()). - Msg("Context cancelled, stopping stdio server") - return ctx.Err() - default: - line := s.scanner.Text() - s.logger.Debug(). - Str("line", line). - Msg("Received line") - if err := s.handleMessage(line); err != nil { - s.logger.Error().Err(err).Msg("Error handling message") - // Continue processing messages even if one fails + // Set up signal handling + signal.Notify(s.signalChan, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(s.signalChan) + + // Create a channel for scanner errors + scanErrChan := make(chan error, 1) + + // Start scanning in a goroutine + go func() { + for s.scanner.Scan() { + select { + case <-s.done: + s.logger.Debug().Msg("Received done signal, stopping scanner") + scanErrChan <- nil + return + case <-ctx.Done(): + s.logger.Debug(). + Err(ctx.Err()). + Msg("Context cancelled, stopping scanner") + scanErrChan <- ctx.Err() + return + default: + line := s.scanner.Text() + s.logger.Debug(). + Str("line", line). + Msg("Received line") + if err := s.handleMessage(line); err != nil { + s.logger.Error().Err(err).Msg("Error handling message") + // Continue processing messages even if one fails + } } } - } - if err := s.scanner.Err(); err != nil { - s.logger.Error(). - Err(err). - Msg("Scanner error") - return fmt.Errorf("scanner error: %w", err) - } + if err := s.scanner.Err(); err != nil { + s.logger.Error(). + Err(err). + Msg("Scanner error") + scanErrChan <- fmt.Errorf("scanner error: %w", err) + return + } - s.logger.Debug().Msg("Scanner reached EOF") - return io.EOF + s.logger.Debug().Msg("Scanner reached EOF") + scanErrChan <- io.EOF + }() + + // Wait for either a signal, context cancellation, or scanner error + select { + case sig := <-s.signalChan: + s.logger.Debug(). + Str("signal", sig.String()). + Msg("Received signal in stdio server") + close(s.done) + return nil + case <-ctx.Done(): + s.logger.Debug(). + Err(ctx.Err()). + Msg("Context cancelled in stdio server") + close(s.done) + return ctx.Err() + case err := <-scanErrChan: + if err == io.EOF { + s.logger.Debug().Msg("Scanner completed normally") + } else if err != nil { + s.logger.Error(). + Err(err). + Msg("Scanner error in stdio server") + } + return err + } } // Stop gracefully stops the stdio server