From 325df613d44718b1f3ee6aa94f6f4a145a0ce26c Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Tue, 21 Jan 2025 08:13:55 -0500 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20command=20transport=20and=20i?= =?UTF-8?q?mprove=20argument=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/mcp-client/main.go | 14 ++++- cmd/mcp-server/main.go | 27 +++++++++ pkg/client/stdio.go | 33 +++++++++++ pkg/protocol/tools.go | 131 ++++++++++++++++++++++++++++++++++++++++- 4 files changed, 203 insertions(+), 2 deletions(-) diff --git a/cmd/mcp-client/main.go b/cmd/mcp-client/main.go index 6fa0497..14caef7 100644 --- a/cmd/mcp-client/main.go +++ b/cmd/mcp-client/main.go @@ -22,6 +22,8 @@ var ( transport string serverURL string debug bool + command string + cmdArgs []string // Operation flags promptArgs string @@ -44,8 +46,10 @@ Supports both stdio and SSE transports for client-server communication.`, // Add persistent flags rootCmd.PersistentFlags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging") - rootCmd.PersistentFlags().StringVarP(&transport, "transport", "t", "stdio", "Transport type (stdio or sse)") + rootCmd.PersistentFlags().StringVarP(&transport, "transport", "t", "stdio", "Transport type (stdio, sse, or command)") rootCmd.PersistentFlags().StringVarP(&serverURL, "server", "s", "http://localhost:8000", "Server URL for SSE transport") + rootCmd.PersistentFlags().StringVarP(&command, "command", "c", "", "Command to run for command transport") + rootCmd.PersistentFlags().StringSliceVarP(&cmdArgs, "args", "a", []string{}, "Command arguments for command transport") // Prompts command group promptsCmd := &cobra.Command{ @@ -307,6 +311,14 @@ func createClient() (*client.Client, error) { t = client.NewStdioTransport() case "sse": t = client.NewSSETransport(serverURL) + case "command": + if command == "" { + return nil, fmt.Errorf("command is required for command transport") + } + t, err = client.NewCommandStdioTransport(command, cmdArgs...) + if err != nil { + return nil, fmt.Errorf("failed to create command transport: %w", err) + } default: return nil, fmt.Errorf("invalid transport type: %s", transport) } diff --git a/cmd/mcp-server/main.go b/cmd/mcp-server/main.go index 780bd3e..5ce6971 100644 --- a/cmd/mcp-server/main.go +++ b/cmd/mcp-server/main.go @@ -1,6 +1,7 @@ package main import ( + "encoding/json" "fmt" "os" "time" @@ -85,6 +86,32 @@ Available transports: }, }) + schemaJson := `{ + "type": "object", + "properties": { + "message": { + "type": "string" + } + } + }` + toolRegistry.RegisterToolWithHandler( + protocol.Tool{ + Name: "echo", + Description: "Echo the input arguments", + InputSchema: json.RawMessage(schemaJson), + }, + func(tool protocol.Tool, arguments map[string]interface{}) (*protocol.ToolResult, error) { + message, ok := arguments["message"].(string) + if !ok { + return protocol.NewToolResult( + protocol.WithError("message argument must be a string"), + ), nil + } + return protocol.NewToolResult( + protocol.WithText(message), + ), nil + }) + // Register registries with the server srv.GetRegistry().RegisterPromptProvider(promptRegistry) srv.GetRegistry().RegisterResourceProvider(resourceRegistry) diff --git a/pkg/client/stdio.go b/pkg/client/stdio.go index 5d374a0..a6639c0 100644 --- a/pkg/client/stdio.go +++ b/pkg/client/stdio.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "os/exec" "sync" "github.com/go-go-golems/go-mcp/pkg/protocol" @@ -16,6 +17,7 @@ type StdioTransport struct { mu sync.Mutex scanner *bufio.Scanner writer *json.Encoder + cmd *exec.Cmd } // NewStdioTransport creates a new stdio transport @@ -26,6 +28,31 @@ func NewStdioTransport() *StdioTransport { } } +// NewCommandStdioTransport creates a new stdio transport that launches a command +func NewCommandStdioTransport(command string, args ...string) (*StdioTransport, error) { + cmd := exec.Command(command, args...) + + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdin pipe: %w", err) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start command: %w", err) + } + + return &StdioTransport{ + scanner: bufio.NewScanner(stdout), + writer: json.NewEncoder(stdin), + cmd: cmd, + }, nil +} + // Send sends a request and returns the response func (t *StdioTransport) Send(request *protocol.Request) (*protocol.Response, error) { t.mu.Lock() @@ -54,5 +81,11 @@ func (t *StdioTransport) Send(request *protocol.Request) (*protocol.Response, er // Close closes the transport func (t *StdioTransport) Close() error { + if t.cmd != nil { + if err := t.cmd.Process.Signal(os.Interrupt); err != nil { + return fmt.Errorf("failed to send interrupt signal: %w", err) + } + return t.cmd.Wait() + } return nil // Nothing to close for stdio } diff --git a/pkg/protocol/tools.go b/pkg/protocol/tools.go index 3b0eb60..001228d 100644 --- a/pkg/protocol/tools.go +++ b/pkg/protocol/tools.go @@ -1,6 +1,9 @@ package protocol -import "encoding/json" +import ( + "encoding/json" + "fmt" +) // Tool represents a tool that can be invoked type Tool struct { @@ -23,3 +26,129 @@ type ToolContent struct { MimeType string `json:"mimeType,omitempty"` Resource *ResourceContent `json:"resource,omitempty"` // For resource content } + +// ToolResultOption is a function that modifies a ToolResult +type ToolResultOption func(*ToolResult) + +// NewToolResult creates a new ToolResult with the given options +func NewToolResult(opts ...ToolResultOption) *ToolResult { + tr := &ToolResult{ + Content: []ToolContent{}, + IsError: false, + } + + for _, opt := range opts { + opt(tr) + } + + return tr +} + +// WithText adds a text content to the ToolResult +func WithText(text string) ToolResultOption { + return func(tr *ToolResult) { + tr.Content = append(tr.Content, NewTextContent(text)) + } +} + +// WithJSON adds JSON-serialized content to the ToolResult +// If marshaling fails, it adds an error message instead +func WithJSON(data interface{}) ToolResultOption { + return func(tr *ToolResult) { + content, err := NewJSONContent(data) + if err != nil { + tr.Content = append(tr.Content, NewTextContent(fmt.Sprintf("Error marshaling JSON: %v", err))) + tr.IsError = true + return + } + tr.Content = append(tr.Content, content) + } +} + +// WithImage adds an image content to the ToolResult +func WithImage(base64Data, mimeType string) ToolResultOption { + return func(tr *ToolResult) { + tr.Content = append(tr.Content, NewImageContent(base64Data, mimeType)) + } +} + +// WithResource adds a resource content to the ToolResult +func WithResource(resource *ResourceContent) ToolResultOption { + return func(tr *ToolResult) { + tr.Content = append(tr.Content, NewResourceContent(resource)) + } +} + +// WithError marks the ToolResult as an error and optionally adds an error message +func WithError(errorMsg string) ToolResultOption { + return func(tr *ToolResult) { + tr.IsError = true + if errorMsg != "" { + tr.Content = append(tr.Content, NewTextContent(errorMsg)) + } + } +} + +// WithContent adds raw ToolContent to the ToolResult +func WithContent(content ToolContent) ToolResultOption { + return func(tr *ToolResult) { + tr.Content = append(tr.Content, content) + } +} + +// NewErrorToolResult creates a new ToolResult marked as error with the given content +func NewErrorToolResult(content ...ToolContent) *ToolResult { + return &ToolResult{ + Content: content, + IsError: true, + } +} + +// NewTextContent creates a new ToolContent with text type +func NewTextContent(text string) ToolContent { + return ToolContent{ + Type: "text", + Text: text, + } +} + +// NewJSONContent creates a new ToolContent with JSON-serialized data +func NewJSONContent(data interface{}) (ToolContent, error) { + jsonBytes, err := json.Marshal(data) + if err != nil { + return ToolContent{}, err + } + + return ToolContent{ + Type: "text", + Text: string(jsonBytes), + MimeType: "application/json", + }, nil +} + +// MustNewJSONContent creates a new ToolContent with JSON-serialized data +// Panics if marshaling fails +func MustNewJSONContent(data interface{}) ToolContent { + content, err := NewJSONContent(data) + if err != nil { + panic(err) + } + return content +} + +// NewImageContent creates a new ToolContent with base64-encoded image data +func NewImageContent(base64Data, mimeType string) ToolContent { + return ToolContent{ + Type: "image", + Data: base64Data, + MimeType: mimeType, + } +} + +// NewResourceContent creates a new ToolContent with resource data +func NewResourceContent(resource *ResourceContent) ToolContent { + return ToolContent{ + Type: "resource", + Resource: resource, + } +}