Skip to content

Commit

Permalink
✨ Add context support to client CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 21, 2025
1 parent eeed63f commit 58da7aa
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 107 deletions.
11 changes: 10 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,13 @@ Made the SSE client run asynchronously to prevent blocking on subscription:
- Added initialization synchronization channel
- Moved SSE subscription to a goroutine
- Improved error handling and state management
- Added proper mutex protection for shared state
- Added proper mutex protection for shared state

# Context-Based Transport Control

Refactored all transports to use context.Context for better control over cancellation and timeouts:
- Added context support to Transport interface methods (Send and Close)
- Updated SSE transport to use context for initialization and event handling
- Updated stdio transport to use context for command execution and response handling
- Removed channel-based cancellation in favor of context
- Added proper context propagation throughout the client API
39 changes: 20 additions & 19 deletions cmd/mcp-client/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -81,13 +82,13 @@ Supports both stdio and SSE transports for client-server communication.`,
Use: "list",
Short: "List available prompts",
RunE: func(cmd *cobra.Command, args []string) error {
client, err := createClient()
client, err := createClient(cmd.Context())
if err != nil {
return err
}
defer client.Close()
defer client.Close(cmd.Context())

prompts, cursor, err := client.ListPrompts("")
prompts, cursor, err := client.ListPrompts(cmd.Context(), "")
if err != nil {
return err
}
Expand Down Expand Up @@ -116,11 +117,11 @@ Supports both stdio and SSE transports for client-server communication.`,
Short: "Execute a specific prompt",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
client, err := createClient()
client, err := createClient(cmd.Context())
if err != nil {
return err
}
defer client.Close()
defer client.Close(cmd.Context())

// Parse prompt arguments
promptArgMap := make(map[string]string)
Expand All @@ -130,7 +131,7 @@ Supports both stdio and SSE transports for client-server communication.`,
}
}

message, err := client.GetPrompt(args[0], promptArgMap)
message, err := client.GetPrompt(cmd.Context(), args[0], promptArgMap)
if err != nil {
return err
}
Expand All @@ -154,13 +155,13 @@ Supports both stdio and SSE transports for client-server communication.`,
Use: "list",
Short: "List available tools",
RunE: func(cmd *cobra.Command, args []string) error {
client, err := createClient()
client, err := createClient(cmd.Context())
if err != nil {
return err
}
defer client.Close()
defer client.Close(cmd.Context())

tools, cursor, err := client.ListTools("")
tools, cursor, err := client.ListTools(cmd.Context(), "")
if err != nil {
return err
}
Expand All @@ -184,11 +185,11 @@ Supports both stdio and SSE transports for client-server communication.`,
Short: "Call a specific tool",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
client, err := createClient()
client, err := createClient(cmd.Context())
if err != nil {
return err
}
defer client.Close()
defer client.Close(cmd.Context())

// Parse tool arguments
toolArgMap := make(map[string]interface{})
Expand Down Expand Up @@ -231,13 +232,13 @@ Supports both stdio and SSE transports for client-server communication.`,
Use: "list",
Short: "List available resources",
RunE: func(cmd *cobra.Command, args []string) error {
client, err := createClient()
client, err := createClient(cmd.Context())
if err != nil {
return err
}
defer client.Close()
defer client.Close(cmd.Context())

resources, cursor, err := client.ListResources("")
resources, cursor, err := client.ListResources(cmd.Context(), "")
if err != nil {
return err
}
Expand All @@ -263,13 +264,13 @@ Supports both stdio and SSE transports for client-server communication.`,
Short: "Read a specific resource",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
client, err := createClient()
client, err := createClient(cmd.Context())
if err != nil {
return err
}
defer client.Close()
defer client.Close(cmd.Context())

content, err := client.ReadResource(args[0])
content, err := client.ReadResource(cmd.Context(), args[0])
if err != nil {
return err
}
Expand Down Expand Up @@ -317,7 +318,7 @@ Supports both stdio and SSE transports for client-server communication.`,
}
}

func createClient() (*client.Client, error) {
func createClient(ctx context.Context) (*client.Client, error) {
// Use ConsoleWriter for colored output
consoleWriter := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}
logger := zerolog.New(consoleWriter).With().Timestamp().Logger()
Expand All @@ -342,7 +343,7 @@ func createClient() (*client.Client, error) {

// Create and initialize client
c := client.NewClient(logger, t)
err = c.Initialize(protocol.ClientCapabilities{
err = c.Initialize(ctx, protocol.ClientCapabilities{
Sampling: &protocol.SamplingCapability{},
})
if err != nil {
Expand Down
44 changes: 22 additions & 22 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (
// Transport represents a client transport mechanism
type Transport interface {
// Send sends a request and returns the response
Send(request *protocol.Request) (*protocol.Response, error)
Send(ctx context.Context, request *protocol.Request) (*protocol.Response, error)
// Close closes the transport connection
Close() error
Close(ctx context.Context) error
}

// Client represents an MCP client that can use different transports
Expand All @@ -42,7 +42,7 @@ func NewClient(logger zerolog.Logger, transport Transport) *Client {
}

// Initialize initializes the connection with the server
func (c *Client) Initialize(capabilities protocol.ClientCapabilities) error {
func (c *Client) Initialize(ctx context.Context, capabilities protocol.ClientCapabilities) error {
c.mu.Lock()
defer c.mu.Unlock()

Expand Down Expand Up @@ -70,7 +70,7 @@ func (c *Client) Initialize(capabilities protocol.ClientCapabilities) error {
}
c.setRequestID(request)

response, err := c.transport.Send(request)
response, err := c.transport.Send(ctx, request)
if err != nil {
return fmt.Errorf("failed to send initialize request: %w", err)
}
Expand All @@ -96,7 +96,7 @@ func (c *Client) Initialize(capabilities protocol.ClientCapabilities) error {
JSONRPC: "2.0",
Method: "notifications/initialized",
}
_, err = c.transport.Send(notification)
_, err = c.transport.Send(ctx, notification)
if err != nil {
return fmt.Errorf("failed to send initialized notification: %w", err)
}
Expand All @@ -105,7 +105,7 @@ func (c *Client) Initialize(capabilities protocol.ClientCapabilities) error {
}

// ListPrompts retrieves the list of available prompts from the server
func (c *Client) ListPrompts(cursor string) ([]protocol.Prompt, string, error) {
func (c *Client) ListPrompts(ctx context.Context, cursor string) ([]protocol.Prompt, string, error) {
if !c.initialized {
return nil, "", fmt.Errorf("client not initialized")
}
Expand All @@ -122,7 +122,7 @@ func (c *Client) ListPrompts(cursor string) ([]protocol.Prompt, string, error) {
}
c.setRequestID(request)

response, err := c.transport.Send(request)
response, err := c.transport.Send(ctx, request)
if err != nil {
return nil, "", fmt.Errorf("failed to send prompts/list request: %w", err)
}
Expand All @@ -143,7 +143,7 @@ func (c *Client) ListPrompts(cursor string) ([]protocol.Prompt, string, error) {
}

// GetPrompt retrieves a specific prompt from the server
func (c *Client) GetPrompt(name string, arguments map[string]string) (*protocol.PromptMessage, error) {
func (c *Client) GetPrompt(ctx context.Context, name string, arguments map[string]string) (*protocol.PromptMessage, error) {
if !c.initialized {
return nil, fmt.Errorf("client not initialized")
}
Expand All @@ -163,7 +163,7 @@ func (c *Client) GetPrompt(name string, arguments map[string]string) (*protocol.
}
c.setRequestID(request)

response, err := c.transport.Send(request)
response, err := c.transport.Send(ctx, request)
if err != nil {
return nil, fmt.Errorf("failed to send prompts/get request: %w", err)
}
Expand All @@ -187,7 +187,7 @@ func (c *Client) GetPrompt(name string, arguments map[string]string) (*protocol.
}

// ListResources retrieves the list of available resources from the server
func (c *Client) ListResources(cursor string) ([]protocol.Resource, string, error) {
func (c *Client) ListResources(ctx context.Context, cursor string) ([]protocol.Resource, string, error) {
if !c.initialized {
return nil, "", fmt.Errorf("client not initialized")
}
Expand All @@ -204,7 +204,7 @@ func (c *Client) ListResources(cursor string) ([]protocol.Resource, string, erro
}
c.setRequestID(request)

response, err := c.transport.Send(request)
response, err := c.transport.Send(ctx, request)
if err != nil {
return nil, "", fmt.Errorf("failed to send resources/list request: %w", err)
}
Expand All @@ -225,7 +225,7 @@ func (c *Client) ListResources(cursor string) ([]protocol.Resource, string, erro
}

// ReadResource retrieves the content of a specific resource from the server
func (c *Client) ReadResource(uri string) (*protocol.ResourceContent, error) {
func (c *Client) ReadResource(ctx context.Context, uri string) (*protocol.ResourceContent, error) {
if !c.initialized {
return nil, fmt.Errorf("client not initialized")
}
Expand All @@ -243,7 +243,7 @@ func (c *Client) ReadResource(uri string) (*protocol.ResourceContent, error) {
}
c.setRequestID(request)

response, err := c.transport.Send(request)
response, err := c.transport.Send(ctx, request)
if err != nil {
return nil, fmt.Errorf("failed to send resources/read request: %w", err)
}
Expand All @@ -267,7 +267,7 @@ func (c *Client) ReadResource(uri string) (*protocol.ResourceContent, error) {
}

// ListTools retrieves the list of available tools from the server
func (c *Client) ListTools(cursor string) ([]protocol.Tool, string, error) {
func (c *Client) ListTools(ctx context.Context, cursor string) ([]protocol.Tool, string, error) {
if !c.initialized {
return nil, "", fmt.Errorf("client not initialized")
}
Expand All @@ -284,7 +284,7 @@ func (c *Client) ListTools(cursor string) ([]protocol.Tool, string, error) {
}
c.setRequestID(request)

response, err := c.transport.Send(request)
response, err := c.transport.Send(ctx, request)
if err != nil {
return nil, "", fmt.Errorf("failed to send tools/list request: %w", err)
}
Expand Down Expand Up @@ -325,7 +325,7 @@ func (c *Client) CallTool(ctx context.Context, name string, arguments map[string
}
c.setRequestID(request)

response, err := c.transport.Send(request)
response, err := c.transport.Send(ctx, request)
if err != nil {
return nil, fmt.Errorf("failed to send tools/call request: %w", err)
}
Expand All @@ -343,7 +343,7 @@ func (c *Client) CallTool(ctx context.Context, name string, arguments map[string
}

// CreateMessage sends a sampling request to create a message
func (c *Client) CreateMessage(messages []protocol.Message, modelPreferences protocol.ModelPreferences, systemPrompt string, maxTokens int) (*protocol.Message, error) {
func (c *Client) CreateMessage(ctx context.Context, messages []protocol.Message, modelPreferences protocol.ModelPreferences, systemPrompt string, maxTokens int) (*protocol.Message, error) {
if !c.initialized {
return nil, fmt.Errorf("client not initialized")
}
Expand All @@ -367,7 +367,7 @@ func (c *Client) CreateMessage(messages []protocol.Message, modelPreferences pro
}
c.setRequestID(request)

response, err := c.transport.Send(request)
response, err := c.transport.Send(ctx, request)
if err != nil {
return nil, fmt.Errorf("failed to send sampling/createMessage request: %w", err)
}
Expand All @@ -385,14 +385,14 @@ func (c *Client) CreateMessage(messages []protocol.Message, modelPreferences pro
}

// Ping sends a ping request to the server
func (c *Client) Ping() error {
func (c *Client) Ping(ctx context.Context) error {
request := &protocol.Request{
JSONRPC: "2.0",
Method: "ping",
}
c.setRequestID(request)

response, err := c.transport.Send(request)
response, err := c.transport.Send(ctx, request)
if err != nil {
return fmt.Errorf("failed to send ping request: %w", err)
}
Expand All @@ -405,9 +405,9 @@ func (c *Client) Ping() error {
}

// Close closes the client connection
func (c *Client) Close() error {
func (c *Client) Close(ctx context.Context) error {
c.logger.Debug().Msg("closing client")
return c.transport.Close()
return c.transport.Close(ctx)
}

// setRequestID sets a unique ID for the request
Expand Down
Loading

0 comments on commit 58da7aa

Please sign in to comment.