Skip to content

Commit

Permalink
🐛 Fix signal handling by using process groups
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 21, 2025
1 parent db58153 commit 38af046
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 12 deletions.
12 changes: 11 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,14 @@ Enhanced signal handling in stdio server:
- Added direct signal handling in stdio server
- Fixed issue with signals not breaking scanner reads
- Added detailed debug logging for signal flow
- Improved shutdown coordination between scanner and signals
- Improved shutdown coordination between scanner and signals

# Improved Process Group Handling

Enhanced process group and signal handling in stdio transport:

- Set up command server in its own process group
- Send signals to entire process group instead of just the main process
- Added fallback to direct process signals if process group handling fails
- Improved debug logging for process and signal management
- Fixed issue with signals not being properly received by the server
56 changes: 45 additions & 11 deletions pkg/client/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"os/exec"
"sync"
"syscall"

"github.com/go-go-golems/go-go-mcp/pkg/protocol"
"github.com/rs/zerolog"
Expand Down Expand Up @@ -41,6 +42,11 @@ func NewCommandStdioTransport(command string, args ...string) (*StdioTransport,
Strs("args", args).
Msg("Creating command stdio transport")

// Set up process group
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
}

stdin, err := cmd.StdinPipe()
if err != nil {
logger.Error().Err(err).Msg("Failed to create stdin pipe")
Expand All @@ -61,7 +67,9 @@ func NewCommandStdioTransport(command string, args ...string) (*StdioTransport,
return nil, fmt.Errorf("failed to start command: %w", err)
}

logger.Debug().Msg("Command started successfully")
logger.Debug().
Int("pid", cmd.Process.Pid).
Msg("Command started successfully in new process group")

return &StdioTransport{
scanner: bufio.NewScanner(stdout),
Expand Down Expand Up @@ -118,28 +126,54 @@ func (t *StdioTransport) Close() error {
if t.cmd != nil {
t.logger.Debug().
Int("pid", t.cmd.Process.Pid).
Msg("Attempting to send interrupt signal to command")
Msg("Attempting to send interrupt signal to process group")

// First try to send an interrupt signal
if err := t.cmd.Process.Signal(os.Interrupt); err != nil {
// Send interrupt signal to the process group
pgid, err := syscall.Getpgid(t.cmd.Process.Pid)
if err != nil {
t.logger.Warn().
Err(err).
Int("pid", t.cmd.Process.Pid).
Msg("Failed to send interrupt signal, falling back to Kill")
Msg("Failed to get process group ID, falling back to direct process signal")

// If interrupt fails, try to kill the process
if err := t.cmd.Process.Kill(); err != nil {
t.logger.Error().
// Fall back to sending signal directly to process
if err := t.cmd.Process.Signal(os.Interrupt); err != nil {
t.logger.Warn().
Err(err).
Int("pid", t.cmd.Process.Pid).
Msg("Failed to kill process")
return fmt.Errorf("failed to kill process: %w", err)
Msg("Failed to send interrupt signal, falling back to Kill")

// If interrupt fails, try to kill the process
if err := t.cmd.Process.Kill(); err != nil {
t.logger.Error().
Err(err).
Int("pid", t.cmd.Process.Pid).
Msg("Failed to kill process")
return fmt.Errorf("failed to kill process: %w", err)
}
}
} else {
// Send interrupt to the process group
if err := syscall.Kill(-pgid, syscall.SIGINT); err != nil {
t.logger.Warn().
Err(err).
Int("pgid", pgid).
Msg("Failed to send interrupt signal to process group, falling back to Kill")

// If interrupt fails, try to kill the process group
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
t.logger.Error().
Err(err).
Int("pgid", pgid).
Msg("Failed to kill process group")
return fmt.Errorf("failed to kill process group: %w", err)
}
}
}

// Wait for the process to exit
t.logger.Debug().Msg("Waiting for command to exit")
err := t.cmd.Wait()
err = t.cmd.Wait()
if err != nil {
// Check if it's an expected exit error (like signal kill)
if exitErr, ok := err.(*exec.ExitError); ok {
Expand Down

0 comments on commit 38af046

Please sign in to comment.