diff --git a/changelog.md b/changelog.md index 2f666a8..016e40b 100644 --- a/changelog.md +++ b/changelog.md @@ -215,4 +215,13 @@ Refactored all transports to use context.Context for better control over cancell - Updated SSE transport to use context for initialization and event handling - Updated stdio transport to use context for command execution and response handling - Removed channel-based cancellation in favor of context -- Added proper context propagation throughout the client API \ No newline at end of file +- Added proper context propagation throughout the client API + +## Optional Session ID and Request ID Handling + +Made session ID optional in SSE transport by using a default session when none is provided. Also ensured request ID handling follows the MCP specification. + +- Made session ID optional in server and client SSE transport +- Added default session support in server +- Removed session ID requirement from client +- Updated request ID handling to follow MCP spec \ No newline at end of file diff --git a/pkg/client/client.go b/pkg/client/client.go index bd2975e..5e23f25 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -8,6 +8,7 @@ import ( "github.com/go-go-golems/go-go-mcp/pkg/protocol" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" ) // Transport represents a client transport mechanism @@ -70,6 +71,7 @@ func (c *Client) Initialize(ctx context.Context, capabilities protocol.ClientCap } c.setRequestID(request) + log.Debug().Msgf("Sending initialize request") response, err := c.transport.Send(ctx, request) if err != nil { return fmt.Errorf("failed to send initialize request: %w", err) @@ -416,13 +418,14 @@ func (c *Client) setRequestID(request *protocol.Request) { return // notifications don't have IDs } - id := json.RawMessage(fmt.Sprintf("%d", c.nextID)) - request.ID = id + // According to MCP spec, request IDs can be either numbers or strings + // We'll use numbers for simplicity and compatibility + request.ID = json.RawMessage(fmt.Sprintf("%d", c.nextID)) c.nextID++ c.logger.Debug(). Str("method", request.Method). - RawJSON("id", id). + RawJSON("id", request.ID). Msg("set request ID") } diff --git a/pkg/client/sse.go b/pkg/client/sse.go index fc4a047..07526cc 100644 --- a/pkg/client/sse.go +++ b/pkg/client/sse.go @@ -13,6 +13,7 @@ import ( "github.com/go-go-golems/go-go-mcp/pkg/protocol" "github.com/r3labs/sse/v2" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" ) // SSETransport implements Transport using Server-Sent Events @@ -22,7 +23,6 @@ type SSETransport struct { client *http.Client sseClient *sse.Client events chan *sse.Event - sessionID string closeOnce sync.Once logger zerolog.Logger initialized bool @@ -137,24 +137,14 @@ func (t *SSETransport) initializeSSE(ctx context.Context) error { go func() { defer cancel() + log.Debug().Msgf("Subscribing to SSE") err := t.sseClient.SubscribeWithContext(subCtx, "", func(msg *sse.Event) { - // Handle session ID event - if string(msg.Event) == "session" { - t.mu.Lock() - t.sessionID = string(msg.Data) - t.initialized = true - t.mu.Unlock() - t.logger.Debug().Str("sessionID", t.sessionID).Msg("Received session ID") - initDone <- nil - return - } - t.logger.Debug(). Str("event", string(msg.Event)). RawJSON("data", msg.Data). Msg("Received SSE event") - // Forward other events to the events channel + // Forward events to the events channel select { case t.events <- msg: t.logger.Debug().Msg("Forwarded event to channel") @@ -172,15 +162,7 @@ func (t *SSETransport) initializeSSE(ctx context.Context) error { } }() - // Wait for initialization or context cancellation - select { - case err := <-initDone: - close(initDone) - return err - case <-ctx.Done(): - cancel() - return ctx.Err() - } + return nil } // Close closes the transport diff --git a/pkg/server/sse.go b/pkg/server/sse.go index 5c1c306..1d62f89 100644 --- a/pkg/server/sse.go +++ b/pkg/server/sse.go @@ -12,7 +12,6 @@ import ( "github.com/go-go-golems/go-go-mcp/pkg" "github.com/go-go-golems/go-go-mcp/pkg/protocol" "github.com/go-go-golems/go-go-mcp/pkg/services" - "github.com/google/uuid" "github.com/gorilla/mux" "github.com/rs/zerolog" ) @@ -123,8 +122,11 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") - // Create unique session ID - sessionID := uuid.New().String() + // Create unique session ID or use default + sessionID := "default" + if r.URL.Query().Get("session_id") != "" { + sessionID = r.URL.Query().Get("session_id") + } messageChan := make(chan *protocol.Response, 100) // Register client @@ -196,21 +198,19 @@ func (s *SSEServer) handleMessages(w http.ResponseWriter, r *http.Request) { Str("remote_addr", r.RemoteAddr). Msg("Received message request") + // Use default session if none provided if sessionID == "" { - s.logger.Error().Msg("Missing session_id in request") - http.Error(w, "session_id is required", http.StatusBadRequest) - return + sessionID = "default" + s.logger.Debug().Msg("Using default session") } s.mu.RLock() - messageChan, exists := s.clients[sessionID] + messageChan, ok := s.clients[sessionID] s.mu.RUnlock() - if !exists { - s.logger.Error(). - Str("session_id", sessionID). - Msg("Session not found") - http.Error(w, "session not found", http.StatusNotFound) + if !ok { + s.logger.Error().Str("session_id", sessionID).Msg("Invalid session_id") + http.Error(w, "Invalid session_id", http.StatusBadRequest) return }