Skip to content

Commit

Permalink
✨ Try to be more lenient with notifications
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Feb 11, 2025
1 parent 7e56df0 commit 41bc93e
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 56 deletions.
20 changes: 19 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -841,4 +841,22 @@ Added CORS headers and OPTIONS request handling to the /messages endpoint to fix
Improved the SSE transport to handle notifications more efficiently by not waiting for responses when handling notification messages.

- Modified SSE transport to skip response waiting for notifications
- Added support for both empty ID and notifications/ prefix detection
- Added support for both empty ID and notifications/ prefix detection

# Improved SSE Notification Handling

Enhanced the SSE transport to handle notifications and responses separately for better efficiency and clarity:

- Added separate channels for notifications and responses
- Added notification handler support to Transport interface
- Updated SSE bridge to forward notifications to stdout
- Added default notification logging in client
- Improved notification detection and routing

# Added Notification Support to Stdio Transport

Added notification handling support to the stdio transport:
- Added notification handler to StdioTransport struct
- Implemented SetNotificationHandler method
- Added notification detection and handling in Send method
- Improved response handling to properly handle interleaved notifications
11 changes: 10 additions & 1 deletion pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ type Transport interface {
Send(ctx context.Context, request *protocol.Request) (*protocol.Response, error)
// Close closes the transport connection
Close(ctx context.Context) error
// SetNotificationHandler sets a handler for notifications
SetNotificationHandler(handler func(*protocol.Response))
}

// Client represents an MCP client that can use different transports
Expand All @@ -34,11 +36,18 @@ type Client struct {

// NewClient creates a new client instance
func NewClient(logger zerolog.Logger, transport Transport) *Client {
return &Client{
client := &Client{
logger: logger,
transport: transport,
nextID: 1,
}

// Set default notification handler to log notifications
transport.SetNotificationHandler(func(response *protocol.Response) {
logger.Debug().Interface("notification", response).Msg("Received notification")
})

return client
}

// Initialize initializes the connection with the server
Expand Down
96 changes: 72 additions & 24 deletions pkg/client/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,50 @@ import (

// SSETransport implements Transport using Server-Sent Events
type SSETransport struct {
mu sync.Mutex
baseURL string
client *http.Client
sseClient *sse.Client
events chan *sse.Event
closeOnce sync.Once
logger zerolog.Logger
initialized bool
sessionID string
endpoint string
mu sync.Mutex
baseURL string
client *http.Client
sseClient *sse.Client
events chan *sse.Event
responses chan *sse.Event
notifications chan *sse.Event
closeOnce sync.Once
logger zerolog.Logger
initialized bool
sessionID string
endpoint string
notificationHandler func(*protocol.Response)
}

// NewSSETransport creates a new SSE transport
func NewSSETransport(baseURL string, logger zerolog.Logger) *SSETransport {
return &SSETransport{
baseURL: baseURL,
client: &http.Client{},
sseClient: sse.NewClient(baseURL + "/sse"),
events: make(chan *sse.Event),
logger: logger,
baseURL: baseURL,
client: &http.Client{},
sseClient: sse.NewClient(baseURL + "/sse"),
events: make(chan *sse.Event),
responses: make(chan *sse.Event),
notifications: make(chan *sse.Event),
logger: logger,
}
}

// SetNotificationHandler sets the handler for notifications
func (t *SSETransport) SetNotificationHandler(handler func(*protocol.Response)) {
t.mu.Lock()
defer t.mu.Unlock()
t.notificationHandler = handler
}

// isNotification checks if an event is a notification
func isNotification(event *sse.Event) bool {
var response protocol.Response
if err := json.Unmarshal(event.Data, &response); err != nil {
return false
}
return len(response.ID) == 0 || string(response.ID) == "null"
}

// Send sends a request and returns the response
func (t *SSETransport) Send(ctx context.Context, request *protocol.Request) (*protocol.Response, error) {
t.mu.Lock()
Expand Down Expand Up @@ -95,17 +116,16 @@ func (t *SSETransport) Send(ctx context.Context, request *protocol.Request) (*pr
return nil, fmt.Errorf("server returned status %d: %s", resp.StatusCode, string(body))
}

// If this is a notification (empty ID or notifications/ prefix), don't wait for response
// If this is a notification, don't wait for response
if len(request.ID) == 0 || string(request.ID) == "null" || strings.HasPrefix(request.Method, "notifications/") {
t.logger.Debug().Msg("Request is a notification, not waiting for response")
return nil, nil
}

// Wait for response event with context
// XXX manuel(2025-02-10) This is overly complex and not needed, in fact I should just dump all incoming events to stdout and not care at all
t.logger.Debug().Msg("Waiting for response event")
select {
case event := <-t.events:
case event := <-t.responses:
if string(event.Event) == "error" {
t.logger.Error().
Str("error", string(event.Data)).
Expand Down Expand Up @@ -142,6 +162,25 @@ func (t *SSETransport) initializeSSE(ctx context.Context) error {
// Channel to wait for endpoint event
endpointCh := make(chan string, 1)

// Start notification handler goroutine
go func() {
for {
select {
case event := <-t.notifications:
if t.notificationHandler != nil {
var response protocol.Response
if err := json.Unmarshal(event.Data, &response); err != nil {
t.logger.Error().Err(err).Msg("Failed to parse notification")
continue
}
t.notificationHandler(&response)
}
case <-subCtx.Done():
return
}
}
}()

go func() {
defer cancel()

Expand Down Expand Up @@ -173,12 +212,21 @@ func (t *SSETransport) initializeSSE(ctx context.Context) error {
}
}

// Forward other events to the events channel
select {
case t.events <- msg:
t.logger.Debug().Msg("Forwarded event to channel")
case <-subCtx.Done():
t.logger.Debug().Msg("Context cancelled while forwarding event")
// Route event to appropriate channel
if isNotification(msg) {
select {
case t.notifications <- msg:
t.logger.Debug().Msg("Forwarded notification event")
case <-subCtx.Done():
t.logger.Debug().Msg("Context cancelled while forwarding notification")
}
} else {
select {
case t.responses <- msg:
t.logger.Debug().Msg("Forwarded response event")
case <-subCtx.Done():
t.logger.Debug().Msg("Context cancelled while forwarding response")
}
}
})

Expand Down
79 changes: 51 additions & 28 deletions pkg/client/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"os"
"os/exec"
"strings"
"sync"
"syscall"

Expand All @@ -17,11 +18,12 @@ import (

// StdioTransport implements Transport using standard input/output
type StdioTransport struct {
mu sync.Mutex
scanner *bufio.Scanner
writer *json.Encoder
cmd *exec.Cmd
logger zerolog.Logger
mu sync.Mutex
scanner *bufio.Scanner
writer *json.Encoder
cmd *exec.Cmd
logger zerolog.Logger
notificationHandler func(*protocol.Response)
}

// NewStdioTransport creates a new stdio transport
Expand All @@ -38,6 +40,15 @@ func NewStdioTransport(logger zerolog.Logger) *StdioTransport {
}
}

// SetNotificationHandler sets the handler for notifications
func (t *StdioTransport) SetNotificationHandler(handler func(*protocol.Response)) {
t.mu.Lock()
defer t.mu.Unlock()
t.notificationHandler = handler
}

var _ Transport = &StdioTransport{}

// NewCommandStdioTransport creates a new stdio transport that launches a command
func NewCommandStdioTransport(logger zerolog.Logger, command string, args ...string) (*StdioTransport, error) {
cmd := exec.Command(command, args...)
Expand Down Expand Up @@ -100,6 +111,12 @@ func (t *StdioTransport) Send(ctx context.Context, request *protocol.Request) (*
return nil, fmt.Errorf("failed to write request: %w", err)
}

// If this is a notification, don't wait for response
if len(request.ID) == 0 || string(request.ID) == "null" || strings.HasPrefix(request.Method, "notifications/") {
t.logger.Debug().Msg("Request is a notification, not waiting for response")
return nil, nil
}

t.logger.Debug().Msg("Waiting for response")

// Create a channel for the response
Expand All @@ -110,40 +127,46 @@ func (t *StdioTransport) Send(ctx context.Context, request *protocol.Request) (*

// Read response in a goroutine
go func() {
if !t.scanner.Scan() {
var err error
if scanErr := t.scanner.Err(); scanErr != nil {
err = fmt.Errorf("failed to read response: %w", scanErr)
t.logger.Error().Err(err).Msg("Failed to read response")
} else {
err = io.EOF
t.logger.Debug().Msg("EOF while reading response")
for t.scanner.Scan() {
data := t.scanner.Bytes()
t.logger.Debug().RawJSON("data", data).Msg("Received data")

var response protocol.Response
if err := json.Unmarshal(data, &response); err != nil {
t.logger.Error().Err(err).Msg("Failed to parse response")
continue
}

// If this is a notification, handle it and continue scanning
if len(response.ID) == 0 || string(response.ID) == "null" {
t.mu.Lock()
if t.notificationHandler != nil {
t.notificationHandler(&response)
}
t.mu.Unlock()
continue
}

// This is a response, send it to the channel
responseCh <- struct {
response *protocol.Response
err error
}{nil, err}
}{&response, nil}
return
}

t.logger.Debug().
RawJSON("response", t.scanner.Bytes()).
Msg("Received response")

var response protocol.Response
if err := json.Unmarshal(t.scanner.Bytes(), &response); err != nil {
t.logger.Error().Err(err).Msg("Failed to parse response")
// Handle scanner errors
if err := t.scanner.Err(); err != nil {
responseCh <- struct {
response *protocol.Response
err error
}{nil, fmt.Errorf("failed to parse response: %w", err)}
return
}{nil, fmt.Errorf("failed to read response: %w", err)}
} else {
responseCh <- struct {
response *protocol.Response
err error
}{nil, io.EOF}
}

responseCh <- struct {
response *protocol.Response
err error
}{&response, nil}
}()

// Wait for either response or context cancellation
Expand Down
17 changes: 15 additions & 2 deletions pkg/server/transports/stdio/sse_bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,26 @@ func NewSSEBridgeServer(logger zerolog.Logger, sseURL string) *SSEBridgeServer {
// Strip trailing slashes from the SSE URL
sseURL = strings.TrimRight(sseURL, "/")

return &SSEBridgeServer{
sseClient := client.NewSSETransport(sseURL, taggedLogger)

s := &SSEBridgeServer{
scanner: scanner,
writer: json.NewEncoder(os.Stdout),
logger: taggedLogger,
sseClient: client.NewSSETransport(sseURL, taggedLogger),
sseClient: sseClient,
signalChan: make(chan os.Signal, 1),
}

// Set up notification handler to write to stdout
sseClient.SetNotificationHandler(func(response *protocol.Response) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.writer.Encode(response); err != nil {
s.logger.Error().Err(err).Msg("Failed to write notification to stdout")
}
})

return s
}

// Start begins listening for and handling messages on stdio
Expand Down

0 comments on commit 41bc93e

Please sign in to comment.