Skip to content

Commit

Permalink
✨ Use context for tools
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 21, 2025
1 parent 2b61b22 commit cdbfab8
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 87 deletions.
4 changes: 2 additions & 2 deletions cmd/mcp-client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ Supports both stdio and SSE transports for client-server communication.`,
}

for _, tool := range tools {
fmt.Printf("Name: %s\n", tool.GetName())
fmt.Printf("Description: %s\n", tool.GetDescription())
fmt.Printf("Name: %s\n", tool.Name)
fmt.Printf("Description: %s\n", tool.Description)
fmt.Println()
}

Expand Down
14 changes: 8 additions & 6 deletions cmd/mcp-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,15 @@ Available transports:
}
}
}`

tool, err := tools.NewToolImpl("echo", "Echo the input arguments", json.RawMessage(schemaJson))
if err != nil {
logger.Error().Err(err).Msg("Error creating tool")
return err
}
toolRegistry.RegisterToolWithHandler(
protocol.Tool{
Name: "echo",
Description: "Echo the input arguments",
InputSchema: json.RawMessage(schemaJson),
},
func(tool protocol.Tool, arguments map[string]interface{}) (*protocol.ToolResult, error) {
tool,
func(ctx context.Context, tool tools.Tool, arguments map[string]interface{}) (*protocol.ToolResult, error) {
message, ok := arguments["message"].(string)
if !ok {
return protocol.NewToolResult(
Expand Down
74 changes: 4 additions & 70 deletions pkg/protocol/tools.go
Original file line number Diff line number Diff line change
@@ -1,80 +1,14 @@
package protocol

import (
"context"
"encoding/json"
"fmt"
)

// Tool represents a tool that can be invoked
type Tool interface {
GetName() string
GetDescription() string
GetInputSchema() json.RawMessage
Call(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error)
}

// ToolImpl is a basic implementation of the Tool interface
type ToolImpl struct {
name string
description string
inputSchema json.RawMessage
}

// NewToolImpl creates a new ToolImpl with the given parameters
func NewToolImpl(name, description string, inputSchema interface{}) (*ToolImpl, error) {
var schema json.RawMessage
switch s := inputSchema.(type) {
case json.RawMessage:
schema = s
case string:
schema = json.RawMessage(s)
default:
var err error
schema, err = json.Marshal(s)
if err != nil {
return nil, fmt.Errorf("failed to marshal input schema: %w", err)
}
}

return &ToolImpl{
name: name,
description: description,
inputSchema: schema,
}, nil
}

// GetName returns the tool's name
func (t *ToolImpl) GetName() string {
return t.name
}

// GetDescription returns the tool's description
func (t *ToolImpl) GetDescription() string {
return t.description
}

// GetInputSchema returns the tool's input schema
func (t *ToolImpl) GetInputSchema() json.RawMessage {
return t.inputSchema
}

// Call implements the Tool interface but panics as it should be overridden
func (t *ToolImpl) Call(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
panic("Call not implemented for ToolImpl - must be overridden")
}

// MarshalJSON implements json.Marshaler
func (t *ToolImpl) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema json.RawMessage `json:"inputSchema"`
}{
Name: t.name,
Description: t.description,
InputSchema: t.inputSchema,
})
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema json.RawMessage `json:"inputSchema"`
}

// ToolResult represents the result of a tool invocation
Expand Down
2 changes: 1 addition & 1 deletion pkg/services/defaults/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (s *DefaultToolService) ListTools(ctx context.Context, cursor string) ([]pr

func (s *DefaultToolService) CallTool(ctx context.Context, name string, arguments map[string]interface{}) (interface{}, error) {
for _, provider := range s.registry.GetToolProviders() {
result, err := provider.CallTool(name, arguments)
result, err := provider.CallTool(ctx, name, arguments)
if err == nil {
return result, nil
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/tools/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,30 @@ import (
// Registry provides a simple way to register individual tools
type Registry struct {
mu sync.RWMutex
tools map[string]protocol.Tool
tools map[string]Tool
handlers map[string]Handler
}

// Handler is a function that executes a tool with given arguments
type Handler func(ctx context.Context, tool protocol.Tool, arguments map[string]interface{}) (*protocol.ToolResult, error)
type Handler func(ctx context.Context, tool Tool, arguments map[string]interface{}) (*protocol.ToolResult, error)

// NewRegistry creates a new tool registry
func NewRegistry() *Registry {
return &Registry{
tools: make(map[string]protocol.Tool),
tools: make(map[string]Tool),
handlers: make(map[string]Handler),
}
}

// RegisterTool adds a tool to the registry
func (r *Registry) RegisterTool(tool protocol.Tool) {
func (r *Registry) RegisterTool(tool Tool) {
r.mu.Lock()
defer r.mu.Unlock()
r.tools[tool.GetName()] = tool
}

// RegisterToolWithHandler adds a tool with a custom handler
func (r *Registry) RegisterToolWithHandler(tool protocol.Tool, handler Handler) {
func (r *Registry) RegisterToolWithHandler(tool Tool, handler Handler) {
r.mu.Lock()
defer r.mu.Unlock()
r.tools[tool.GetName()] = tool
Expand All @@ -57,11 +57,11 @@ func (r *Registry) ListTools(cursor string) ([]protocol.Tool, string, error) {

tools := make([]protocol.Tool, 0, len(r.tools))
for _, t := range r.tools {
tools = append(tools, t)
tools = append(tools, t.GetToolDefinition())
}

sort.Slice(tools, func(i, j int) bool {
return tools[i].GetName() < tools[j].GetName()
return tools[i].Name < tools[j].Name
})

if cursor == "" {
Expand All @@ -70,7 +70,7 @@ func (r *Registry) ListTools(cursor string) ([]protocol.Tool, string, error) {

pos := -1
for i, t := range tools {
if t.GetName() == cursor {
if t.Name == cursor {
pos = i
break
}
Expand Down
89 changes: 89 additions & 0 deletions pkg/tools/tool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package tools

import (
"context"
"encoding/json"
"fmt"

"github.com/go-go-golems/go-go-mcp/pkg/protocol"
)

// Tool represents a tool that can be invoked
type Tool interface {
GetName() string
GetDescription() string
GetInputSchema() json.RawMessage
GetToolDefinition() protocol.Tool
Call(ctx context.Context, arguments map[string]interface{}) (*protocol.ToolResult, error)
}

// ToolImpl is a basic implementation of the Tool interface
type ToolImpl struct {
name string
description string
inputSchema json.RawMessage
}

// NewToolImpl creates a new ToolImpl with the given parameters
func NewToolImpl(name, description string, inputSchema interface{}) (*ToolImpl, error) {
var schema json.RawMessage
switch s := inputSchema.(type) {
case json.RawMessage:
schema = s
case string:
schema = json.RawMessage(s)
default:
var err error
schema, err = json.Marshal(s)
if err != nil {
return nil, fmt.Errorf("failed to marshal input schema: %w", err)
}
}

return &ToolImpl{
name: name,
description: description,
inputSchema: schema,
}, nil
}

// GetName returns the tool's name
func (t *ToolImpl) GetName() string {
return t.name
}

// GetDescription returns the tool's description
func (t *ToolImpl) GetDescription() string {
return t.description
}

// GetInputSchema returns the tool's input schema
func (t *ToolImpl) GetInputSchema() json.RawMessage {
return t.inputSchema
}

// GetToolDefinition returns the tool's definition
func (t *ToolImpl) GetToolDefinition() protocol.Tool {
return protocol.Tool{
Name: t.name,
Description: t.description,
InputSchema: t.inputSchema,
}
}
// Call implements the Tool interface but panics as it should be overridden
func (t *ToolImpl) Call(ctx context.Context, arguments map[string]interface{}) (*protocol.ToolResult, error) {
panic("Call not implemented for ToolImpl - must be overridden")
}

// MarshalJSON implements json.Marshaler
func (t *ToolImpl) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema json.RawMessage `json:"inputSchema"`
}{
Name: t.name,
Description: t.description,
InputSchema: t.inputSchema,
})
}

0 comments on commit cdbfab8

Please sign in to comment.