diff --git a/changelog.md b/changelog.md index 32d0da6..2f666a8 100644 --- a/changelog.md +++ b/changelog.md @@ -206,4 +206,13 @@ Made the SSE client run asynchronously to prevent blocking on subscription: - Added initialization synchronization channel - Moved SSE subscription to a goroutine - Improved error handling and state management -- Added proper mutex protection for shared state \ No newline at end of file +- Added proper mutex protection for shared state + +# Context-Based Transport Control + +Refactored all transports to use context.Context for better control over cancellation and timeouts: +- Added context support to Transport interface methods (Send and Close) +- Updated SSE transport to use context for initialization and event handling +- Updated stdio transport to use context for command execution and response handling +- Removed channel-based cancellation in favor of context +- Added proper context propagation throughout the client API \ No newline at end of file diff --git a/cmd/mcp-client/main.go b/cmd/mcp-client/main.go index 3379546..ab9aded 100644 --- a/cmd/mcp-client/main.go +++ b/cmd/mcp-client/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "os" @@ -81,13 +82,13 @@ Supports both stdio and SSE transports for client-server communication.`, Use: "list", Short: "List available prompts", RunE: func(cmd *cobra.Command, args []string) error { - client, err := createClient() + client, err := createClient(cmd.Context()) if err != nil { return err } - defer client.Close() + defer client.Close(cmd.Context()) - prompts, cursor, err := client.ListPrompts("") + prompts, cursor, err := client.ListPrompts(cmd.Context(), "") if err != nil { return err } @@ -116,11 +117,11 @@ Supports both stdio and SSE transports for client-server communication.`, Short: "Execute a specific prompt", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - client, err := createClient() + client, err := createClient(cmd.Context()) if err != nil { return err } - defer client.Close() + defer client.Close(cmd.Context()) // Parse prompt arguments promptArgMap := make(map[string]string) @@ -130,7 +131,7 @@ Supports both stdio and SSE transports for client-server communication.`, } } - message, err := client.GetPrompt(args[0], promptArgMap) + message, err := client.GetPrompt(cmd.Context(), args[0], promptArgMap) if err != nil { return err } @@ -154,13 +155,13 @@ Supports both stdio and SSE transports for client-server communication.`, Use: "list", Short: "List available tools", RunE: func(cmd *cobra.Command, args []string) error { - client, err := createClient() + client, err := createClient(cmd.Context()) if err != nil { return err } - defer client.Close() + defer client.Close(cmd.Context()) - tools, cursor, err := client.ListTools("") + tools, cursor, err := client.ListTools(cmd.Context(), "") if err != nil { return err } @@ -184,11 +185,11 @@ Supports both stdio and SSE transports for client-server communication.`, Short: "Call a specific tool", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - client, err := createClient() + client, err := createClient(cmd.Context()) if err != nil { return err } - defer client.Close() + defer client.Close(cmd.Context()) // Parse tool arguments toolArgMap := make(map[string]interface{}) @@ -231,13 +232,13 @@ Supports both stdio and SSE transports for client-server communication.`, Use: "list", Short: "List available resources", RunE: func(cmd *cobra.Command, args []string) error { - client, err := createClient() + client, err := createClient(cmd.Context()) if err != nil { return err } - defer client.Close() + defer client.Close(cmd.Context()) - resources, cursor, err := client.ListResources("") + resources, cursor, err := client.ListResources(cmd.Context(), "") if err != nil { return err } @@ -263,13 +264,13 @@ Supports both stdio and SSE transports for client-server communication.`, Short: "Read a specific resource", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - client, err := createClient() + client, err := createClient(cmd.Context()) if err != nil { return err } - defer client.Close() + defer client.Close(cmd.Context()) - content, err := client.ReadResource(args[0]) + content, err := client.ReadResource(cmd.Context(), args[0]) if err != nil { return err } @@ -317,7 +318,7 @@ Supports both stdio and SSE transports for client-server communication.`, } } -func createClient() (*client.Client, error) { +func createClient(ctx context.Context) (*client.Client, error) { // Use ConsoleWriter for colored output consoleWriter := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339} logger := zerolog.New(consoleWriter).With().Timestamp().Logger() @@ -342,7 +343,7 @@ func createClient() (*client.Client, error) { // Create and initialize client c := client.NewClient(logger, t) - err = c.Initialize(protocol.ClientCapabilities{ + err = c.Initialize(ctx, protocol.ClientCapabilities{ Sampling: &protocol.SamplingCapability{}, }) if err != nil { diff --git a/pkg/client/client.go b/pkg/client/client.go index ab0b249..bd2975e 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -13,9 +13,9 @@ import ( // Transport represents a client transport mechanism type Transport interface { // Send sends a request and returns the response - Send(request *protocol.Request) (*protocol.Response, error) + Send(ctx context.Context, request *protocol.Request) (*protocol.Response, error) // Close closes the transport connection - Close() error + Close(ctx context.Context) error } // Client represents an MCP client that can use different transports @@ -42,7 +42,7 @@ func NewClient(logger zerolog.Logger, transport Transport) *Client { } // Initialize initializes the connection with the server -func (c *Client) Initialize(capabilities protocol.ClientCapabilities) error { +func (c *Client) Initialize(ctx context.Context, capabilities protocol.ClientCapabilities) error { c.mu.Lock() defer c.mu.Unlock() @@ -70,7 +70,7 @@ func (c *Client) Initialize(capabilities protocol.ClientCapabilities) error { } c.setRequestID(request) - response, err := c.transport.Send(request) + response, err := c.transport.Send(ctx, request) if err != nil { return fmt.Errorf("failed to send initialize request: %w", err) } @@ -96,7 +96,7 @@ func (c *Client) Initialize(capabilities protocol.ClientCapabilities) error { JSONRPC: "2.0", Method: "notifications/initialized", } - _, err = c.transport.Send(notification) + _, err = c.transport.Send(ctx, notification) if err != nil { return fmt.Errorf("failed to send initialized notification: %w", err) } @@ -105,7 +105,7 @@ func (c *Client) Initialize(capabilities protocol.ClientCapabilities) error { } // ListPrompts retrieves the list of available prompts from the server -func (c *Client) ListPrompts(cursor string) ([]protocol.Prompt, string, error) { +func (c *Client) ListPrompts(ctx context.Context, cursor string) ([]protocol.Prompt, string, error) { if !c.initialized { return nil, "", fmt.Errorf("client not initialized") } @@ -122,7 +122,7 @@ func (c *Client) ListPrompts(cursor string) ([]protocol.Prompt, string, error) { } c.setRequestID(request) - response, err := c.transport.Send(request) + response, err := c.transport.Send(ctx, request) if err != nil { return nil, "", fmt.Errorf("failed to send prompts/list request: %w", err) } @@ -143,7 +143,7 @@ func (c *Client) ListPrompts(cursor string) ([]protocol.Prompt, string, error) { } // GetPrompt retrieves a specific prompt from the server -func (c *Client) GetPrompt(name string, arguments map[string]string) (*protocol.PromptMessage, error) { +func (c *Client) GetPrompt(ctx context.Context, name string, arguments map[string]string) (*protocol.PromptMessage, error) { if !c.initialized { return nil, fmt.Errorf("client not initialized") } @@ -163,7 +163,7 @@ func (c *Client) GetPrompt(name string, arguments map[string]string) (*protocol. } c.setRequestID(request) - response, err := c.transport.Send(request) + response, err := c.transport.Send(ctx, request) if err != nil { return nil, fmt.Errorf("failed to send prompts/get request: %w", err) } @@ -187,7 +187,7 @@ func (c *Client) GetPrompt(name string, arguments map[string]string) (*protocol. } // ListResources retrieves the list of available resources from the server -func (c *Client) ListResources(cursor string) ([]protocol.Resource, string, error) { +func (c *Client) ListResources(ctx context.Context, cursor string) ([]protocol.Resource, string, error) { if !c.initialized { return nil, "", fmt.Errorf("client not initialized") } @@ -204,7 +204,7 @@ func (c *Client) ListResources(cursor string) ([]protocol.Resource, string, erro } c.setRequestID(request) - response, err := c.transport.Send(request) + response, err := c.transport.Send(ctx, request) if err != nil { return nil, "", fmt.Errorf("failed to send resources/list request: %w", err) } @@ -225,7 +225,7 @@ func (c *Client) ListResources(cursor string) ([]protocol.Resource, string, erro } // ReadResource retrieves the content of a specific resource from the server -func (c *Client) ReadResource(uri string) (*protocol.ResourceContent, error) { +func (c *Client) ReadResource(ctx context.Context, uri string) (*protocol.ResourceContent, error) { if !c.initialized { return nil, fmt.Errorf("client not initialized") } @@ -243,7 +243,7 @@ func (c *Client) ReadResource(uri string) (*protocol.ResourceContent, error) { } c.setRequestID(request) - response, err := c.transport.Send(request) + response, err := c.transport.Send(ctx, request) if err != nil { return nil, fmt.Errorf("failed to send resources/read request: %w", err) } @@ -267,7 +267,7 @@ func (c *Client) ReadResource(uri string) (*protocol.ResourceContent, error) { } // ListTools retrieves the list of available tools from the server -func (c *Client) ListTools(cursor string) ([]protocol.Tool, string, error) { +func (c *Client) ListTools(ctx context.Context, cursor string) ([]protocol.Tool, string, error) { if !c.initialized { return nil, "", fmt.Errorf("client not initialized") } @@ -284,7 +284,7 @@ func (c *Client) ListTools(cursor string) ([]protocol.Tool, string, error) { } c.setRequestID(request) - response, err := c.transport.Send(request) + response, err := c.transport.Send(ctx, request) if err != nil { return nil, "", fmt.Errorf("failed to send tools/list request: %w", err) } @@ -325,7 +325,7 @@ func (c *Client) CallTool(ctx context.Context, name string, arguments map[string } c.setRequestID(request) - response, err := c.transport.Send(request) + response, err := c.transport.Send(ctx, request) if err != nil { return nil, fmt.Errorf("failed to send tools/call request: %w", err) } @@ -343,7 +343,7 @@ func (c *Client) CallTool(ctx context.Context, name string, arguments map[string } // CreateMessage sends a sampling request to create a message -func (c *Client) CreateMessage(messages []protocol.Message, modelPreferences protocol.ModelPreferences, systemPrompt string, maxTokens int) (*protocol.Message, error) { +func (c *Client) CreateMessage(ctx context.Context, messages []protocol.Message, modelPreferences protocol.ModelPreferences, systemPrompt string, maxTokens int) (*protocol.Message, error) { if !c.initialized { return nil, fmt.Errorf("client not initialized") } @@ -367,7 +367,7 @@ func (c *Client) CreateMessage(messages []protocol.Message, modelPreferences pro } c.setRequestID(request) - response, err := c.transport.Send(request) + response, err := c.transport.Send(ctx, request) if err != nil { return nil, fmt.Errorf("failed to send sampling/createMessage request: %w", err) } @@ -385,14 +385,14 @@ func (c *Client) CreateMessage(messages []protocol.Message, modelPreferences pro } // Ping sends a ping request to the server -func (c *Client) Ping() error { +func (c *Client) Ping(ctx context.Context) error { request := &protocol.Request{ JSONRPC: "2.0", Method: "ping", } c.setRequestID(request) - response, err := c.transport.Send(request) + response, err := c.transport.Send(ctx, request) if err != nil { return fmt.Errorf("failed to send ping request: %w", err) } @@ -405,9 +405,9 @@ func (c *Client) Ping() error { } // Close closes the client connection -func (c *Client) Close() error { +func (c *Client) Close(ctx context.Context) error { c.logger.Debug().Msg("closing client") - return c.transport.Close() + return c.transport.Close(ctx) } // setRequestID sets a unique ID for the request diff --git a/pkg/client/sse.go b/pkg/client/sse.go index d040850..fc4a047 100644 --- a/pkg/client/sse.go +++ b/pkg/client/sse.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -23,10 +24,8 @@ type SSETransport struct { events chan *sse.Event sessionID string closeOnce sync.Once - closeChan chan struct{} - initialized bool - initChan chan struct{} // Channel to signal initialization completion logger zerolog.Logger + initialized bool } // NewSSETransport creates a new SSE transport @@ -36,29 +35,20 @@ func NewSSETransport(baseURL string) *SSETransport { client: &http.Client{}, sseClient: sse.NewClient(baseURL + "/sse"), events: make(chan *sse.Event), - closeChan: make(chan struct{}), - initChan: make(chan struct{}), logger: zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}).With().Timestamp().Logger(), } } // Send sends a request and returns the response -func (t *SSETransport) Send(request *protocol.Request) (*protocol.Response, error) { +func (t *SSETransport) Send(ctx context.Context, request *protocol.Request) (*protocol.Response, error) { t.mu.Lock() if !t.initialized { t.mu.Unlock() t.logger.Debug().Msg("Initializing SSE connection") - if err := t.initializeSSE(); err != nil { + if err := t.initializeSSE(ctx); err != nil { t.logger.Error().Err(err).Msg("Failed to initialize SSE") return nil, fmt.Errorf("failed to initialize SSE: %w", err) } - // Wait for initialization to complete - select { - case <-t.initChan: - t.logger.Debug().Msg("SSE initialization completed") - case <-t.closeChan: - return nil, fmt.Errorf("transport closed during initialization") - } t.mu.Lock() } defer t.mu.Unlock() @@ -80,7 +70,15 @@ func (t *SSETransport) Send(request *protocol.Request) (*protocol.Response, erro RawJSON("request", reqBody). Msg("Sending HTTP POST request") - resp, err := t.client.Post(t.baseURL+"/messages", "application/json", bytes.NewReader(reqBody)) + // Create a new request with context + req, err := http.NewRequestWithContext(ctx, "POST", t.baseURL+"/messages", bytes.NewReader(reqBody)) + if err != nil { + t.logger.Error().Err(err).Msg("Failed to create HTTP request") + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := t.client.Do(req) if err != nil { t.logger.Error().Err(err).Msg("Failed to send HTTP request") return nil, fmt.Errorf("failed to send request: %w", err) @@ -96,7 +94,7 @@ func (t *SSETransport) Send(request *protocol.Request) (*protocol.Response, erro return nil, fmt.Errorf("server returned status %d: %s", resp.StatusCode, string(body)) } - // Wait for response event + // Wait for response event with context t.logger.Debug().Msg("Waiting for response event") select { case event := <-t.events: @@ -120,21 +118,26 @@ func (t *SSETransport) Send(request *protocol.Request) (*protocol.Response, erro return &response, nil - case <-t.closeChan: - t.logger.Debug().Msg("Transport closed while waiting for response") - return nil, fmt.Errorf("transport closed") + case <-ctx.Done(): + t.logger.Debug().Msg("Context cancelled while waiting for response") + return nil, ctx.Err() } } // initializeSSE sets up the SSE connection -func (t *SSETransport) initializeSSE() error { +func (t *SSETransport) initializeSSE(ctx context.Context) error { t.logger.Debug().Str("url", t.baseURL+"/sse").Msg("Setting up SSE connection") - // Start SSE subscription in a goroutine + // Create a new context with cancellation for the subscription + subCtx, cancel := context.WithCancel(ctx) + + // Create a channel to receive initialization result + initDone := make(chan error, 1) + go func() { - defer close(t.initChan) + defer cancel() - err := t.sseClient.SubscribeRaw(func(msg *sse.Event) { + err := t.sseClient.SubscribeWithContext(subCtx, "", func(msg *sse.Event) { // Handle session ID event if string(msg.Event) == "session" { t.mu.Lock() @@ -142,6 +145,7 @@ func (t *SSETransport) initializeSSE() error { t.initialized = true t.mu.Unlock() t.logger.Debug().Str("sessionID", t.sessionID).Msg("Received session ID") + initDone <- nil return } @@ -154,29 +158,37 @@ func (t *SSETransport) initializeSSE() error { select { case t.events <- msg: t.logger.Debug().Msg("Forwarded event to channel") - case <-t.closeChan: - t.logger.Debug().Msg("Transport closed while forwarding event") + case <-subCtx.Done(): + t.logger.Debug().Msg("Context cancelled while forwarding event") } }) if err != nil { t.logger.Error().Err(err).Msg("SSE subscription failed") - // Signal initialization failure t.mu.Lock() t.initialized = false t.mu.Unlock() + initDone <- err } }() - return nil + // Wait for initialization or context cancellation + select { + case err := <-initDone: + close(initDone) + return err + case <-ctx.Done(): + cancel() + return ctx.Err() + } } // Close closes the transport -func (t *SSETransport) Close() error { +func (t *SSETransport) Close(ctx context.Context) error { t.logger.Debug().Msg("Closing transport") t.closeOnce.Do(func() { - close(t.closeChan) t.sseClient.Unsubscribe(t.events) + close(t.events) t.logger.Debug().Msg("Transport closed") }) return nil diff --git a/pkg/client/stdio.go b/pkg/client/stdio.go index 8a0f17c..be158b9 100644 --- a/pkg/client/stdio.go +++ b/pkg/client/stdio.go @@ -2,6 +2,7 @@ package client import ( "bufio" + "context" "encoding/json" "fmt" "io" @@ -80,7 +81,7 @@ func NewCommandStdioTransport(command string, args ...string) (*StdioTransport, } // Send sends a request and returns the response -func (t *StdioTransport) Send(request *protocol.Request) (*protocol.Response, error) { +func (t *StdioTransport) Send(ctx context.Context, request *protocol.Request) (*protocol.Response, error) { t.mu.Lock() defer t.mu.Unlock() @@ -97,31 +98,62 @@ func (t *StdioTransport) Send(request *protocol.Request) (*protocol.Response, er t.logger.Debug().Msg("Waiting for response") - // Read response - if !t.scanner.Scan() { - if err := t.scanner.Err(); err != nil { - t.logger.Error().Err(err).Msg("Failed to read response") - return nil, fmt.Errorf("failed to read response: %w", err) + // Create a channel for the response + responseCh := make(chan struct { + response *protocol.Response + err error + }, 1) + + // Read response in a goroutine + go func() { + if !t.scanner.Scan() { + var err error + if scanErr := t.scanner.Err(); scanErr != nil { + err = fmt.Errorf("failed to read response: %w", scanErr) + t.logger.Error().Err(err).Msg("Failed to read response") + } else { + err = io.EOF + t.logger.Debug().Msg("EOF while reading response") + } + responseCh <- struct { + response *protocol.Response + err error + }{nil, err} + return } - t.logger.Debug().Msg("EOF while reading response") - return nil, io.EOF - } - t.logger.Debug(). - RawJSON("response", t.scanner.Bytes()). - Msg("Received response") + t.logger.Debug(). + RawJSON("response", t.scanner.Bytes()). + Msg("Received response") + + var response protocol.Response + if err := json.Unmarshal(t.scanner.Bytes(), &response); err != nil { + t.logger.Error().Err(err).Msg("Failed to parse response") + responseCh <- struct { + response *protocol.Response + err error + }{nil, fmt.Errorf("failed to parse response: %w", err)} + return + } - var response protocol.Response - if err := json.Unmarshal(t.scanner.Bytes(), &response); err != nil { - t.logger.Error().Err(err).Msg("Failed to parse response") - return nil, fmt.Errorf("failed to parse response: %w", err) + responseCh <- struct { + response *protocol.Response + err error + }{&response, nil} + }() + + // Wait for either response or context cancellation + select { + case result := <-responseCh: + return result.response, result.err + case <-ctx.Done(): + t.logger.Debug().Msg("Context cancelled while waiting for response") + return nil, ctx.Err() } - - return &response, nil } // Close closes the transport -func (t *StdioTransport) Close() error { +func (t *StdioTransport) Close(ctx context.Context) error { t.logger.Debug().Msg("Closing transport") if t.cmd != nil { t.logger.Debug(). @@ -171,26 +203,35 @@ func (t *StdioTransport) Close() error { } } - // Wait for the process to exit - t.logger.Debug().Msg("Waiting for command to exit") - err = t.cmd.Wait() - if err != nil { - // Check if it's an expected exit error (like signal kill) - if exitErr, ok := err.(*exec.ExitError); ok { - t.logger.Debug(). + // Create a channel to receive the wait result + waitCh := make(chan error, 1) + go func() { + waitCh <- t.cmd.Wait() + }() + + // Wait for either the process to exit or context to be cancelled + select { + case err := <-waitCh: + if err != nil { + // Check if it's an expected exit error (like signal kill) + if exitErr, ok := err.(*exec.ExitError); ok { + t.logger.Debug(). + Err(err). + Int("exit_code", exitErr.ExitCode()). + Msg("Command exited with error (expected for signal termination)") + return nil + } + t.logger.Error(). Err(err). - Int("exit_code", exitErr.ExitCode()). - Msg("Command exited with error (expected for signal termination)") - return nil + Msg("Command exited with unexpected error") + return err } - t.logger.Error(). - Err(err). - Msg("Command exited with unexpected error") - return err + t.logger.Debug().Msg("Command exited successfully") + return nil + case <-ctx.Done(): + t.logger.Debug().Msg("Context cancelled while waiting for command to exit") + return ctx.Err() } - - t.logger.Debug().Msg("Command exited successfully") - return nil } t.logger.Debug().Msg("No command to close") return nil