From cdbfab8a2e5a17accafda5b1decfd7ec9d84ffcb Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Tue, 21 Jan 2025 09:40:33 -0500 Subject: [PATCH] :sparkles: Use context for tools --- cmd/mcp-client/main.go | 4 +- cmd/mcp-server/main.go | 14 +++--- pkg/protocol/tools.go | 74 ++-------------------------- pkg/services/defaults/tools.go | 2 +- pkg/tools/registry.go | 16 +++--- pkg/tools/tool.go | 89 ++++++++++++++++++++++++++++++++++ 6 files changed, 112 insertions(+), 87 deletions(-) create mode 100644 pkg/tools/tool.go diff --git a/cmd/mcp-client/main.go b/cmd/mcp-client/main.go index b457da0..3379546 100644 --- a/cmd/mcp-client/main.go +++ b/cmd/mcp-client/main.go @@ -166,8 +166,8 @@ Supports both stdio and SSE transports for client-server communication.`, } for _, tool := range tools { - fmt.Printf("Name: %s\n", tool.GetName()) - fmt.Printf("Description: %s\n", tool.GetDescription()) + fmt.Printf("Name: %s\n", tool.Name) + fmt.Printf("Description: %s\n", tool.Description) fmt.Println() } diff --git a/cmd/mcp-server/main.go b/cmd/mcp-server/main.go index d344692..e9417ee 100644 --- a/cmd/mcp-server/main.go +++ b/cmd/mcp-server/main.go @@ -118,13 +118,15 @@ Available transports: } } }` + + tool, err := tools.NewToolImpl("echo", "Echo the input arguments", json.RawMessage(schemaJson)) + if err != nil { + logger.Error().Err(err).Msg("Error creating tool") + return err + } toolRegistry.RegisterToolWithHandler( - protocol.Tool{ - Name: "echo", - Description: "Echo the input arguments", - InputSchema: json.RawMessage(schemaJson), - }, - func(tool protocol.Tool, arguments map[string]interface{}) (*protocol.ToolResult, error) { + tool, + func(ctx context.Context, tool tools.Tool, arguments map[string]interface{}) (*protocol.ToolResult, error) { message, ok := arguments["message"].(string) if !ok { return protocol.NewToolResult( diff --git a/pkg/protocol/tools.go b/pkg/protocol/tools.go index 7edb7e3..2bd0aad 100644 --- a/pkg/protocol/tools.go +++ b/pkg/protocol/tools.go @@ -1,80 +1,14 @@ package protocol import ( - "context" "encoding/json" "fmt" ) -// Tool represents a tool that can be invoked -type Tool interface { - GetName() string - GetDescription() string - GetInputSchema() json.RawMessage - Call(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) -} - -// ToolImpl is a basic implementation of the Tool interface -type ToolImpl struct { - name string - description string - inputSchema json.RawMessage -} - -// NewToolImpl creates a new ToolImpl with the given parameters -func NewToolImpl(name, description string, inputSchema interface{}) (*ToolImpl, error) { - var schema json.RawMessage - switch s := inputSchema.(type) { - case json.RawMessage: - schema = s - case string: - schema = json.RawMessage(s) - default: - var err error - schema, err = json.Marshal(s) - if err != nil { - return nil, fmt.Errorf("failed to marshal input schema: %w", err) - } - } - - return &ToolImpl{ - name: name, - description: description, - inputSchema: schema, - }, nil -} - -// GetName returns the tool's name -func (t *ToolImpl) GetName() string { - return t.name -} - -// GetDescription returns the tool's description -func (t *ToolImpl) GetDescription() string { - return t.description -} - -// GetInputSchema returns the tool's input schema -func (t *ToolImpl) GetInputSchema() json.RawMessage { - return t.inputSchema -} - -// Call implements the Tool interface but panics as it should be overridden -func (t *ToolImpl) Call(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) { - panic("Call not implemented for ToolImpl - must be overridden") -} - -// MarshalJSON implements json.Marshaler -func (t *ToolImpl) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Name string `json:"name"` - Description string `json:"description"` - InputSchema json.RawMessage `json:"inputSchema"` - }{ - Name: t.name, - Description: t.description, - InputSchema: t.inputSchema, - }) +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema json.RawMessage `json:"inputSchema"` } // ToolResult represents the result of a tool invocation diff --git a/pkg/services/defaults/tools.go b/pkg/services/defaults/tools.go index 8fd7af8..6c086d4 100644 --- a/pkg/services/defaults/tools.go +++ b/pkg/services/defaults/tools.go @@ -41,7 +41,7 @@ func (s *DefaultToolService) ListTools(ctx context.Context, cursor string) ([]pr func (s *DefaultToolService) CallTool(ctx context.Context, name string, arguments map[string]interface{}) (interface{}, error) { for _, provider := range s.registry.GetToolProviders() { - result, err := provider.CallTool(name, arguments) + result, err := provider.CallTool(ctx, name, arguments) if err == nil { return result, nil } diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 96b1dc3..7db4f7a 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -12,30 +12,30 @@ import ( // Registry provides a simple way to register individual tools type Registry struct { mu sync.RWMutex - tools map[string]protocol.Tool + tools map[string]Tool handlers map[string]Handler } // Handler is a function that executes a tool with given arguments -type Handler func(ctx context.Context, tool protocol.Tool, arguments map[string]interface{}) (*protocol.ToolResult, error) +type Handler func(ctx context.Context, tool Tool, arguments map[string]interface{}) (*protocol.ToolResult, error) // NewRegistry creates a new tool registry func NewRegistry() *Registry { return &Registry{ - tools: make(map[string]protocol.Tool), + tools: make(map[string]Tool), handlers: make(map[string]Handler), } } // RegisterTool adds a tool to the registry -func (r *Registry) RegisterTool(tool protocol.Tool) { +func (r *Registry) RegisterTool(tool Tool) { r.mu.Lock() defer r.mu.Unlock() r.tools[tool.GetName()] = tool } // RegisterToolWithHandler adds a tool with a custom handler -func (r *Registry) RegisterToolWithHandler(tool protocol.Tool, handler Handler) { +func (r *Registry) RegisterToolWithHandler(tool Tool, handler Handler) { r.mu.Lock() defer r.mu.Unlock() r.tools[tool.GetName()] = tool @@ -57,11 +57,11 @@ func (r *Registry) ListTools(cursor string) ([]protocol.Tool, string, error) { tools := make([]protocol.Tool, 0, len(r.tools)) for _, t := range r.tools { - tools = append(tools, t) + tools = append(tools, t.GetToolDefinition()) } sort.Slice(tools, func(i, j int) bool { - return tools[i].GetName() < tools[j].GetName() + return tools[i].Name < tools[j].Name }) if cursor == "" { @@ -70,7 +70,7 @@ func (r *Registry) ListTools(cursor string) ([]protocol.Tool, string, error) { pos := -1 for i, t := range tools { - if t.GetName() == cursor { + if t.Name == cursor { pos = i break } diff --git a/pkg/tools/tool.go b/pkg/tools/tool.go new file mode 100644 index 0000000..05b762b --- /dev/null +++ b/pkg/tools/tool.go @@ -0,0 +1,89 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/go-go-golems/go-go-mcp/pkg/protocol" +) + +// Tool represents a tool that can be invoked +type Tool interface { + GetName() string + GetDescription() string + GetInputSchema() json.RawMessage + GetToolDefinition() protocol.Tool + Call(ctx context.Context, arguments map[string]interface{}) (*protocol.ToolResult, error) +} + +// ToolImpl is a basic implementation of the Tool interface +type ToolImpl struct { + name string + description string + inputSchema json.RawMessage +} + +// NewToolImpl creates a new ToolImpl with the given parameters +func NewToolImpl(name, description string, inputSchema interface{}) (*ToolImpl, error) { + var schema json.RawMessage + switch s := inputSchema.(type) { + case json.RawMessage: + schema = s + case string: + schema = json.RawMessage(s) + default: + var err error + schema, err = json.Marshal(s) + if err != nil { + return nil, fmt.Errorf("failed to marshal input schema: %w", err) + } + } + + return &ToolImpl{ + name: name, + description: description, + inputSchema: schema, + }, nil +} + +// GetName returns the tool's name +func (t *ToolImpl) GetName() string { + return t.name +} + +// GetDescription returns the tool's description +func (t *ToolImpl) GetDescription() string { + return t.description +} + +// GetInputSchema returns the tool's input schema +func (t *ToolImpl) GetInputSchema() json.RawMessage { + return t.inputSchema +} + +// GetToolDefinition returns the tool's definition +func (t *ToolImpl) GetToolDefinition() protocol.Tool { + return protocol.Tool{ + Name: t.name, + Description: t.description, + InputSchema: t.inputSchema, + } +} +// Call implements the Tool interface but panics as it should be overridden +func (t *ToolImpl) Call(ctx context.Context, arguments map[string]interface{}) (*protocol.ToolResult, error) { + panic("Call not implemented for ToolImpl - must be overridden") +} + +// MarshalJSON implements json.Marshaler +func (t *ToolImpl) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema json.RawMessage `json:"inputSchema"` + }{ + Name: t.name, + Description: t.description, + InputSchema: t.inputSchema, + }) +}