Skip to content

Commit

Permalink
🔧 Make session ID optional and fix request ID handling
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 21, 2025
1 parent 7ea65a6 commit aefff25
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 38 deletions.
11 changes: 10 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,13 @@ Refactored all transports to use context.Context for better control over cancell
- Updated SSE transport to use context for initialization and event handling
- Updated stdio transport to use context for command execution and response handling
- Removed channel-based cancellation in favor of context
- Added proper context propagation throughout the client API
- Added proper context propagation throughout the client API

## Optional Session ID and Request ID Handling

Made session ID optional in SSE transport by using a default session when none is provided. Also ensured request ID handling follows the MCP specification.

- Made session ID optional in server and client SSE transport
- Added default session support in server
- Removed session ID requirement from client
- Updated request ID handling to follow MCP spec
9 changes: 6 additions & 3 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/go-go-golems/go-go-mcp/pkg/protocol"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)

// Transport represents a client transport mechanism
Expand Down Expand Up @@ -70,6 +71,7 @@ func (c *Client) Initialize(ctx context.Context, capabilities protocol.ClientCap
}
c.setRequestID(request)

log.Debug().Msgf("Sending initialize request")
response, err := c.transport.Send(ctx, request)
if err != nil {
return fmt.Errorf("failed to send initialize request: %w", err)
Expand Down Expand Up @@ -416,13 +418,14 @@ func (c *Client) setRequestID(request *protocol.Request) {
return // notifications don't have IDs
}

id := json.RawMessage(fmt.Sprintf("%d", c.nextID))
request.ID = id
// According to MCP spec, request IDs can be either numbers or strings
// We'll use numbers for simplicity and compatibility
request.ID = json.RawMessage(fmt.Sprintf("%d", c.nextID))
c.nextID++

c.logger.Debug().
Str("method", request.Method).
RawJSON("id", id).
RawJSON("id", request.ID).
Msg("set request ID")
}

Expand Down
26 changes: 4 additions & 22 deletions pkg/client/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/go-go-golems/go-go-mcp/pkg/protocol"
"github.com/r3labs/sse/v2"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)

// SSETransport implements Transport using Server-Sent Events
Expand All @@ -22,7 +23,6 @@ type SSETransport struct {
client *http.Client
sseClient *sse.Client
events chan *sse.Event
sessionID string
closeOnce sync.Once
logger zerolog.Logger
initialized bool
Expand Down Expand Up @@ -137,24 +137,14 @@ func (t *SSETransport) initializeSSE(ctx context.Context) error {
go func() {
defer cancel()

log.Debug().Msgf("Subscribing to SSE")
err := t.sseClient.SubscribeWithContext(subCtx, "", func(msg *sse.Event) {
// Handle session ID event
if string(msg.Event) == "session" {
t.mu.Lock()
t.sessionID = string(msg.Data)
t.initialized = true
t.mu.Unlock()
t.logger.Debug().Str("sessionID", t.sessionID).Msg("Received session ID")
initDone <- nil
return
}

t.logger.Debug().
Str("event", string(msg.Event)).
RawJSON("data", msg.Data).
Msg("Received SSE event")

// Forward other events to the events channel
// Forward events to the events channel
select {
case t.events <- msg:
t.logger.Debug().Msg("Forwarded event to channel")
Expand All @@ -172,15 +162,7 @@ func (t *SSETransport) initializeSSE(ctx context.Context) error {
}
}()

// Wait for initialization or context cancellation
select {
case err := <-initDone:
close(initDone)
return err
case <-ctx.Done():
cancel()
return ctx.Err()
}
return nil
}

// Close closes the transport
Expand Down
24 changes: 12 additions & 12 deletions pkg/server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/go-go-golems/go-go-mcp/pkg"
"github.com/go-go-golems/go-go-mcp/pkg/protocol"
"github.com/go-go-golems/go-go-mcp/pkg/services"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
)
Expand Down Expand Up @@ -123,8 +122,11 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")

// Create unique session ID
sessionID := uuid.New().String()
// Create unique session ID or use default
sessionID := "default"
if r.URL.Query().Get("session_id") != "" {
sessionID = r.URL.Query().Get("session_id")
}
messageChan := make(chan *protocol.Response, 100)

// Register client
Expand Down Expand Up @@ -196,21 +198,19 @@ func (s *SSEServer) handleMessages(w http.ResponseWriter, r *http.Request) {
Str("remote_addr", r.RemoteAddr).
Msg("Received message request")

// Use default session if none provided
if sessionID == "" {
s.logger.Error().Msg("Missing session_id in request")
http.Error(w, "session_id is required", http.StatusBadRequest)
return
sessionID = "default"
s.logger.Debug().Msg("Using default session")
}

s.mu.RLock()
messageChan, exists := s.clients[sessionID]
messageChan, ok := s.clients[sessionID]
s.mu.RUnlock()

if !exists {
s.logger.Error().
Str("session_id", sessionID).
Msg("Session not found")
http.Error(w, "session not found", http.StatusNotFound)
if !ok {
s.logger.Error().Str("session_id", sessionID).Msg("Invalid session_id")
http.Error(w, "Invalid session_id", http.StatusBadRequest)
return
}

Expand Down

0 comments on commit aefff25

Please sign in to comment.