Skip to content

Commit

Permalink
🔧 refactor: make Tool an interface with context support
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 21, 2025
1 parent 16e523a commit df35321
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 21 deletions.
11 changes: 10 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,13 @@ Simplified stdio server shutdown by using only context-based coordination:
- Removed redundant done channel in favor of context cancellation
- Added dedicated scanner context for cleaner shutdown
- Simplified shutdown logic and error handling
- Improved logging messages for shutdown events
- Improved logging messages for shutdown events

# Tool Interface Improvements

Made Tool an interface with accessors and context-aware Call method for better extensibility and context propagation.

- Changed Tool from struct to interface with GetName, GetDescription, GetInputSchema and Call methods
- Added ToolImpl as a basic implementation of the Tool interface
- Updated Registry, Client and CLI to use the new Tool interface
- Added context support throughout the tool invocation chain
6 changes: 3 additions & 3 deletions cmd/mcp-client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ Supports both stdio and SSE transports for client-server communication.`,
}

for _, tool := range tools {
fmt.Printf("Name: %s\n", tool.Name)
fmt.Printf("Description: %s\n", tool.Description)
fmt.Printf("Name: %s\n", tool.GetName())
fmt.Printf("Description: %s\n", tool.GetDescription())
fmt.Println()
}

Expand Down Expand Up @@ -198,7 +198,7 @@ Supports both stdio and SSE transports for client-server communication.`,
}
}

result, err := client.CallTool(args[0], toolArgMap)
result, err := client.CallTool(cmd.Context(), args[0], toolArgMap)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"encoding/json"
"fmt"
"sync"
Expand Down Expand Up @@ -304,7 +305,7 @@ func (c *Client) ListTools(cursor string) ([]protocol.Tool, string, error) {
}

// CallTool calls a specific tool on the server
func (c *Client) CallTool(name string, arguments map[string]interface{}) (*protocol.ToolResult, error) {
func (c *Client) CallTool(ctx context.Context, name string, arguments map[string]interface{}) (*protocol.ToolResult, error) {
if !c.initialized {
return nil, fmt.Errorf("client not initialized")
}
Expand Down
73 changes: 69 additions & 4 deletions pkg/protocol/tools.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,80 @@
package protocol

import (
"context"
"encoding/json"
"fmt"
)

// Tool represents a tool that can be invoked
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema json.RawMessage `json:"inputSchema"` // JSON Schema
type Tool interface {
GetName() string
GetDescription() string
GetInputSchema() json.RawMessage
Call(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error)
}

// ToolImpl is a basic implementation of the Tool interface
type ToolImpl struct {
name string
description string
inputSchema json.RawMessage
}

// NewToolImpl creates a new ToolImpl with the given parameters
func NewToolImpl(name, description string, inputSchema interface{}) (*ToolImpl, error) {
var schema json.RawMessage
switch s := inputSchema.(type) {
case json.RawMessage:
schema = s
case string:
schema = json.RawMessage(s)
default:
var err error
schema, err = json.Marshal(s)
if err != nil {
return nil, fmt.Errorf("failed to marshal input schema: %w", err)
}
}

return &ToolImpl{
name: name,
description: description,
inputSchema: schema,
}, nil
}

// GetName returns the tool's name
func (t *ToolImpl) GetName() string {
return t.name
}

// GetDescription returns the tool's description
func (t *ToolImpl) GetDescription() string {
return t.description
}

// GetInputSchema returns the tool's input schema
func (t *ToolImpl) GetInputSchema() json.RawMessage {
return t.inputSchema
}

// Call implements the Tool interface but panics as it should be overridden
func (t *ToolImpl) Call(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
panic("Call not implemented for ToolImpl - must be overridden")
}

// MarshalJSON implements json.Marshaler
func (t *ToolImpl) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema json.RawMessage `json:"inputSchema"`
}{
Name: t.name,
Description: t.description,
InputSchema: t.inputSchema,
})
}

// ToolResult represents the result of a tool invocation
Expand Down
8 changes: 6 additions & 2 deletions pkg/providers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package pkg

import "github.com/go-go-golems/go-go-mcp/pkg/protocol"
import (
"context"

"github.com/go-go-golems/go-go-mcp/pkg/protocol"
)

// PromptProvider defines the interface for serving prompts
type PromptProvider interface {
Expand Down Expand Up @@ -33,7 +37,7 @@ type ToolProvider interface {
ListTools(cursor string) ([]protocol.Tool, string, error)

// CallTool invokes a specific tool with the given arguments
CallTool(name string, arguments map[string]interface{}) (*protocol.ToolResult, error)
CallTool(ctx context.Context, name string, arguments map[string]interface{}) (*protocol.ToolResult, error)
}

// Provider combines all provider interfaces
Expand Down
21 changes: 11 additions & 10 deletions pkg/tools/registry.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tools

import (
"context"
"sort"
"sync"

Expand All @@ -16,7 +17,7 @@ type Registry struct {
}

// Handler is a function that executes a tool with given arguments
type Handler func(tool protocol.Tool, arguments map[string]interface{}) (*protocol.ToolResult, error)
type Handler func(ctx context.Context, tool protocol.Tool, arguments map[string]interface{}) (*protocol.ToolResult, error)

// NewRegistry creates a new tool registry
func NewRegistry() *Registry {
Expand All @@ -30,15 +31,15 @@ func NewRegistry() *Registry {
func (r *Registry) RegisterTool(tool protocol.Tool) {
r.mu.Lock()
defer r.mu.Unlock()
r.tools[tool.Name] = tool
r.tools[tool.GetName()] = tool
}

// RegisterToolWithHandler adds a tool with a custom handler
func (r *Registry) RegisterToolWithHandler(tool protocol.Tool, handler Handler) {
r.mu.Lock()
defer r.mu.Unlock()
r.tools[tool.Name] = tool
r.handlers[tool.Name] = handler
r.tools[tool.GetName()] = tool
r.handlers[tool.GetName()] = handler
}

// UnregisterTool removes a tool from the registry
Expand All @@ -60,7 +61,7 @@ func (r *Registry) ListTools(cursor string) ([]protocol.Tool, string, error) {
}

sort.Slice(tools, func(i, j int) bool {
return tools[i].Name < tools[j].Name
return tools[i].GetName() < tools[j].GetName()
})

if cursor == "" {
Expand All @@ -69,7 +70,7 @@ func (r *Registry) ListTools(cursor string) ([]protocol.Tool, string, error) {

pos := -1
for i, t := range tools {
if t.Name == cursor {
if t.GetName() == cursor {
pos = i
break
}
Expand All @@ -83,7 +84,7 @@ func (r *Registry) ListTools(cursor string) ([]protocol.Tool, string, error) {
}

// CallTool implements ToolProvider interface
func (r *Registry) CallTool(name string, arguments map[string]interface{}) (*protocol.ToolResult, error) {
func (r *Registry) CallTool(ctx context.Context, name string, arguments map[string]interface{}) (*protocol.ToolResult, error) {
r.mu.RLock()
defer r.mu.RUnlock()

Expand All @@ -93,9 +94,9 @@ func (r *Registry) CallTool(name string, arguments map[string]interface{}) (*pro
}

if handler, ok := r.handlers[name]; ok {
return handler(tool, arguments)
return handler(ctx, tool, arguments)
}

// Return empty result if no handler is registered
return &protocol.ToolResult{}, nil
// If no handler is registered, use the tool's Call method
return tool.Call(ctx, arguments)
}

0 comments on commit df35321

Please sign in to comment.