Skip to content

Commit

Permalink
♻️ Refactor server to use context.Context for lifecycle control
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 21, 2025
1 parent 1e22117 commit 18cfeac
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 24 deletions.
12 changes: 11 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,14 @@ Added graceful shutdown support to handle interrupt signals (SIGINT, SIGTERM) pr
- Implemented graceful shutdown for SSE server with proper client connection cleanup
- Added graceful shutdown for stdio server
- Updated main program to handle interrupt signals and coordinate shutdown
- Added proper error handling during shutdown process
- Added proper error handling during shutdown process

# Context-Based Server Control

Refactored the server to use context.Context for better control over server lifecycle and cancellation.

- Added context support to Transport interface methods (Start and Stop)
- Updated SSE server to use context for connection handling and shutdown
- Updated stdio server to handle context cancellation
- Added context with timeout for graceful shutdown in main program
- Improved error handling with context cancellation
16 changes: 13 additions & 3 deletions cmd/mcp-server/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -140,6 +141,10 @@ Available transports:
srv.GetRegistry().RegisterResourceProvider(resourceRegistry)
srv.GetRegistry().RegisterToolProvider(toolRegistry)

// Create root context with cancellation
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
Expand All @@ -152,10 +157,10 @@ Available transports:
switch transport {
case "stdio":
logger.Info().Msg("Starting server with stdio transport")
err = srv.StartStdio()
err = srv.StartStdio(ctx)
case "sse":
logger.Info().Int("port", port).Msg("Starting server with SSE transport")
err = srv.StartSSE(port)
err = srv.StartSSE(ctx, port)
default:
err = fmt.Errorf("invalid transport type: %s", transport)
}
Expand All @@ -172,7 +177,12 @@ Available transports:
return nil
case sig := <-sigChan:
logger.Info().Str("signal", sig.String()).Msg("Received signal, initiating graceful shutdown")
if err := srv.Stop(); err != nil {
// Cancel context to initiate shutdown
cancel()
// Create a timeout context for shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
if err := srv.Stop(shutdownCtx); err != nil {
logger.Error().Err(err).Msg("Error during shutdown")
return err
}
Expand Down
21 changes: 11 additions & 10 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"sync"

"github.com/go-go-golems/go-go-mcp/pkg"
Expand All @@ -12,10 +13,10 @@ import (

// Transport represents a server transport mechanism
type Transport interface {
// Start starts the transport
Start() error
// Stop gracefully stops the transport
Stop() error
// Start starts the transport with the given context
Start(ctx context.Context) error
// Stop gracefully stops the transport with the given context
Stop(ctx context.Context) error
}

// Server represents an MCP server that can use different transports
Expand Down Expand Up @@ -58,25 +59,25 @@ func (s *Server) GetRegistry() *pkg.ProviderRegistry {
}

// StartStdio starts the server with stdio transport
func (s *Server) StartStdio() error {
func (s *Server) StartStdio(ctx context.Context) error {
s.mu.Lock()
stdioServer := stdio.NewServer(s.logger, s.promptService, s.resourceService, s.toolService, s.initializeService)
s.transport = stdioServer
s.mu.Unlock()
return stdioServer.Start()
return stdioServer.Start(ctx)
}

// StartSSE starts the server with SSE transport on the specified port
func (s *Server) StartSSE(port int) error {
func (s *Server) StartSSE(ctx context.Context, port int) error {
s.mu.Lock()
sseServer := NewSSEServer(s.logger, s.registry, port)
s.transport = sseServer
s.mu.Unlock()
return sseServer.Start()
return sseServer.Start(ctx)
}

// Stop gracefully stops the server
func (s *Server) Stop() error {
func (s *Server) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -85,5 +86,5 @@ func (s *Server) Stop() error {
}

s.logger.Info().Msg("Stopping server")
return s.transport.Stop()
return s.transport.Stop(ctx)
}
34 changes: 27 additions & 7 deletions pkg/server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"sync"

Expand Down Expand Up @@ -35,7 +36,7 @@ func NewSSEServer(logger zerolog.Logger, registry *pkg.ProviderRegistry, port in
}

// Start begins the SSE server
func (s *SSEServer) Start() error {
func (s *SSEServer) Start(ctx context.Context) error {
r := mux.NewRouter()

// SSE endpoint for clients to establish connection
Expand All @@ -47,14 +48,32 @@ func (s *SSEServer) Start() error {
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: r,
BaseContext: func(l net.Listener) context.Context {
return ctx
},
}

s.logger.Info().Int("port", s.port).Msg("Starting SSE server")
return s.server.ListenAndServe()
// Create a channel to capture server errors
errChan := make(chan error, 1)
go func() {
s.logger.Info().Int("port", s.port).Msg("Starting SSE server")
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
errChan <- err
}
close(errChan)
}()

// Wait for context cancellation or server error
select {
case err := <-errChan:
return err
case <-ctx.Done():
return s.Stop(ctx)
}
}

// Stop gracefully stops the SSE server
func (s *SSEServer) Stop() error {
func (s *SSEServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -66,7 +85,7 @@ func (s *SSEServer) Stop() error {
close(ch)
delete(s.clients, sessionID)
}
return s.server.Shutdown(context.Background())
return s.server.Shutdown(ctx)
}
return nil
}
Expand All @@ -83,6 +102,7 @@ func (s *SSEServer) marshalJSON(v interface{}) (json.RawMessage, error) {

// handleSSE handles new SSE connections
func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s.logger.Debug().
Str("remote_addr", r.RemoteAddr).
Str("user_agent", r.UserAgent()).
Expand Down Expand Up @@ -168,10 +188,10 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
w.(http.Flusher).Flush()

case <-r.Context().Done():
case <-ctx.Done():
s.logger.Debug().
Str("session_id", sessionID).
Msg("Client context done, closing connection")
Msg("Context done, closing connection")
return
}
}
Expand Down
8 changes: 5 additions & 3 deletions pkg/server/transports/stdio/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ func NewServer(logger zerolog.Logger, ps services.PromptService, rs services.Res
}

// Start begins listening for and handling messages on stdio
func (s *Server) Start() error {
func (s *Server) Start(ctx context.Context) error {
s.logger.Info().Msg("Starting stdio server...")

// Process messages until stdin is closed or stop is called
// Process messages until stdin is closed, stop is called, or context is cancelled
for s.scanner.Scan() {
select {
case <-s.done:
return nil
case <-ctx.Done():
return ctx.Err()
default:
line := s.scanner.Text()
s.logger.Debug().Str("line", line).Msg("Received line")
Expand All @@ -69,7 +71,7 @@ func (s *Server) Start() error {
}

// Stop gracefully stops the stdio server
func (s *Server) Stop() error {
func (s *Server) Stop(ctx context.Context) error {
s.logger.Info().Msg("Stopping stdio server")
close(s.done)
return nil
Expand Down

0 comments on commit 18cfeac

Please sign in to comment.