Skip to content

Commit

Permalink
✨ Update SSE server to fully comply with MCP protocol spec
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 21, 2025
1 parent 7b8025b commit a26a10f
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 55 deletions.
6 changes: 4 additions & 2 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
- [ ] REPL mode / TUI
- [X] Add debug logging
- [ ] make web ui to easily debug / interact
- [ ] add notification handler

### Bugs
- [ ] BUG: figure out why closing the client seems to hang
- [x] BUG: figure out why closing the client seems to hang

## MCP server

Expand All @@ -26,4 +27,5 @@
- [ ] dynamic loading / enabling / removing servers

- [X] Allow debug logging
- [ ] Implement missing SSE methods
- [x] Implement missing SSE methods
- [ ] BUG: killing server doesn't seem to kill hanging connections (when using inspector, for example)
19 changes: 15 additions & 4 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,11 @@ Enhanced error handling in SSE server to better comply with JSON-RPC 2.0 specifi

# SSE Protocol Compliance

Removed non-standard endpoint event from SSE server to better align with the official protocol specification.

- Removed endpoint event from SSE server implementation
Updated the SSE server implementation to fully comply with the MCP protocol specification:
- Added proper CORS headers for cross-origin requests
- Implemented unique session ID generation using UUID
- Added initial endpoint event with session ID
- Ensured proper SSE headers according to spec

# Async SSE Client

Expand Down Expand Up @@ -254,4 +256,13 @@ Consolidated logging setup in client transports:
- Updated SSE transport to use passed logger instead of creating its own
- Updated stdio transport to use passed logger instead of creating its own
- Modified transport constructors to accept logger parameter
- Ensured consistent logger propagation through all client components
- Ensured consistent logger propagation through all client components

# Improved SSE Server Client Management

Improved the SSE server's client management to better handle multiple clients and sessions:
- Added unique client IDs for better tracking and debugging
- Improved session management to handle multiple clients per session
- Added client metadata tracking (creation time, remote address, user agent)
- Fixed race conditions in client management
- Better error handling for invalid sessions
11 changes: 7 additions & 4 deletions pkg/client/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ func (t *SSETransport) initializeSSE(ctx context.Context) error {
// Create a new context with cancellation for the subscription
subCtx, cancel := context.WithCancel(ctx)

// Create a channel to receive initialization result
initDone := make(chan error, 1)

go func() {
defer cancel()

Expand All @@ -156,10 +153,16 @@ func (t *SSETransport) initializeSSE(ctx context.Context) error {
t.mu.Lock()
t.initialized = false
t.mu.Unlock()
initDone <- err
return
}
}()

t.mu.Lock()
t.initialized = true
t.mu.Unlock()

t.logger.Debug().Msg("SSE initialization successful")

return nil
}

Expand Down
156 changes: 111 additions & 45 deletions pkg/server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import (
"net/http"
"strings"
"sync"
"time"

"github.com/go-go-golems/go-go-mcp/pkg"
"github.com/go-go-golems/go-go-mcp/pkg/protocol"
"github.com/go-go-golems/go-go-mcp/pkg/services"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
)
Expand All @@ -21,20 +23,30 @@ type SSEServer struct {
mu sync.RWMutex
logger zerolog.Logger
registry *pkg.ProviderRegistry
clients map[string]chan *protocol.Response
clients map[string]*SSEClient
server *http.Server
port int
promptService services.PromptService
resourceService services.ResourceService
toolService services.ToolService
initializeService services.InitializeService
nextClientID int
}

type SSEClient struct {
id string
sessionID string
messageChan chan *protocol.Response
createdAt time.Time
remoteAddr string
userAgent string
}

// NewSSEServer creates a new SSE server instance
func NewSSEServer(logger zerolog.Logger, ps services.PromptService, rs services.ResourceService, ts services.ToolService, is services.InitializeService, port int) *SSEServer {
return &SSEServer{
logger: logger,
clients: make(map[string]chan *protocol.Response),
clients: make(map[string]*SSEClient),
port: port,
promptService: ps,
resourceService: rs,
Expand Down Expand Up @@ -90,7 +102,7 @@ func (s *SSEServer) Stop(ctx context.Context) error {
// Close all client connections
for sessionID, ch := range s.clients {
s.logger.Debug().Str("session_id", sessionID).Msg("Closing client connection")
close(ch)
close(ch.messageChan)
delete(s.clients, sessionID)
}
return s.server.Shutdown(ctx)
Expand All @@ -111,52 +123,71 @@ func (s *SSEServer) marshalJSON(v interface{}) (json.RawMessage, error) {
// handleSSE handles new SSE connections
func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
s.logger.Debug().
Str("remote_addr", r.RemoteAddr).
Str("user_agent", r.UserAgent()).
Msg("New SSE connection")

// Set SSE headers
// Set SSE headers according to protocol spec
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")

// Create unique session ID or use default
sessionID := "default"
if r.URL.Query().Get("session_id") != "" {
sessionID = r.URL.Query().Get("session_id")
// Create unique session ID
sessionID := r.URL.Query().Get("session_id")
if sessionID == "" {
sessionID = fmt.Sprintf("%s", uuid.New())
}
messageChan := make(chan *protocol.Response, 100)

// Register client
s.mu.Lock()
s.clients[sessionID] = messageChan
s.nextClientID++
clientID := fmt.Sprintf("client-%d", s.nextClientID)
client := &SSEClient{
id: clientID,
sessionID: sessionID,
messageChan: make(chan *protocol.Response, 100),
createdAt: time.Now(),
remoteAddr: r.RemoteAddr,
userAgent: r.UserAgent(),
}
s.clients[clientID] = client
clientCount := len(s.clients)
s.mu.Unlock()

s.logger.Debug().
Str("client_id", clientID).
Str("session_id", sessionID).
Str("remote_addr", r.RemoteAddr).
Str("user_agent", r.UserAgent()).
Int("total_clients", clientCount).
Msg("Client registered")
Msg("New SSE connection")

// Send initial endpoint event with session ID
endpoint := fmt.Sprintf("%s?sessionId=%s", "/messages", sessionID)
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint)
w.(http.Flusher).Flush()

defer func() {
s.mu.Lock()
delete(s.clients, sessionID)
close(messageChan)
s.logger.Debug().
Str("session_id", sessionID).
Int("total_clients", len(s.clients)).
Msg("Client disconnected")
if c, exists := s.clients[clientID]; exists {
close(c.messageChan)
delete(s.clients, clientID)
s.logger.Debug().
Str("client_id", clientID).
Str("session_id", sessionID).
Int("total_clients", len(s.clients)).
Dur("connection_duration", time.Since(c.createdAt)).
Msg("Client disconnected")
}
s.mu.Unlock()
}()

// Keep connection open and send messages
for {
select {
case msg := <-messageChan:
case msg := <-client.messageChan:
if msg == nil {
s.logger.Debug().
Str("client_id", clientID).
Str("session_id", sessionID).
Msg("Received nil message, closing connection")
return
Expand All @@ -166,23 +197,26 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
if err != nil {
s.logger.Error().
Err(err).
Str("client_id", clientID).
Str("session_id", sessionID).
Interface("message", msg).
Msg("Failed to marshal message")
continue
}

s.logger.Debug().
Str("client_id", clientID).
Str("session_id", sessionID).
RawJSON("message", data).
Msg("Sending message to client")

// Send message event
// Send message event according to protocol spec
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
w.(http.Flusher).Flush()

case <-ctx.Done():
s.logger.Debug().
Str("client_id", clientID).
Str("session_id", sessionID).
Msg("Context done, closing connection")
return
Expand All @@ -192,10 +226,11 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {

// handleMessages processes incoming client messages
func (s *SSEServer) handleMessages(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
sessionID := r.URL.Query().Get("session_id")
s.logger.Debug().
Str("session_id", sessionID).
Str("remote_addr", r.RemoteAddr).
Str("session_id", sessionID).
Msg("Received message request")

// Use default session if none provided
Expand All @@ -204,13 +239,21 @@ func (s *SSEServer) handleMessages(w http.ResponseWriter, r *http.Request) {
s.logger.Debug().Msg("Using default session")
}

// Find all clients for this session
s.mu.RLock()
messageChan, ok := s.clients[sessionID]
var sessionClients []*SSEClient
for _, client := range s.clients {
if client.sessionID == sessionID {
sessionClients = append(sessionClients, client)
}
}
s.mu.RUnlock()

if !ok {
s.logger.Error().Str("session_id", sessionID).Msg("Invalid session_id")
http.Error(w, "Invalid session_id", http.StatusBadRequest)
if len(sessionClients) == 0 {
s.logger.Error().
Str("session_id", sessionID).
Msg("No active clients found for session")
http.Error(w, "No active clients found for session", http.StatusBadRequest)
return
}

Expand All @@ -227,7 +270,17 @@ func (s *SSEServer) handleMessages(w http.ResponseWriter, r *http.Request) {
Message: "Parse error",
},
}
messageChan <- response
// Send error to all session clients
for _, client := range sessionClients {
select {
case client.messageChan <- response:
default:
s.logger.Error().
Str("client_id", client.id).
Str("session_id", sessionID).
Msg("Failed to send error response to client")
}
}
w.WriteHeader(http.StatusAccepted)
return
}
Expand All @@ -242,7 +295,17 @@ func (s *SSEServer) handleMessages(w http.ResponseWriter, r *http.Request) {
Data: json.RawMessage(data),
},
}
messageChan <- response
// Send error to all session clients
for _, client := range sessionClients {
select {
case client.messageChan <- response:
default:
s.logger.Error().
Str("client_id", client.id).
Str("session_id", sessionID).
Msg("Failed to send error response to client")
}
}
w.WriteHeader(http.StatusAccepted)
return
}
Expand All @@ -255,7 +318,6 @@ func (s *SSEServer) handleMessages(w http.ResponseWriter, r *http.Request) {

// Process the request based on method
var response *protocol.Response
ctx := r.Context()

switch request.Method {
case "initialize":
Expand Down Expand Up @@ -556,18 +618,22 @@ func (s *SSEServer) handleMessages(w http.ResponseWriter, r *http.Request) {
}
}

// Send response through the client's message channel
select {
case messageChan <- response:
s.logger.Debug().
Str("session_id", sessionID).
Interface("response", response).
Msg("Response sent to client")
w.WriteHeader(http.StatusAccepted)
default:
s.logger.Error().
Str("session_id", sessionID).
Msg("Failed to send response to client")
http.Error(w, "failed to send response", http.StatusInternalServerError)
// Send response to all session clients
for _, client := range sessionClients {
select {
case client.messageChan <- response:
s.logger.Debug().
Str("client_id", client.id).
Str("session_id", sessionID).
Interface("response", response).
Msg("Response sent to client")
default:
s.logger.Error().
Str("client_id", client.id).
Str("session_id", sessionID).
Msg("Failed to send response to client")
}
}

w.WriteHeader(http.StatusAccepted)
}
Loading

0 comments on commit a26a10f

Please sign in to comment.