From c0637c77a381b4ce0c9f7f4c77a2496f24c94d6d Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Sun, 16 Feb 2025 11:40:43 -0500 Subject: [PATCH] :sparkles: Polish new protocol structure --- README.md | 10 ++--- changelog.md | 61 +++++++++++++++++++++++++++- cmd/go-go-mcp/cmds/server/start.go | 4 +- pkg/protocol/base.go | 9 ++++- pkg/server/handler.go | 26 ++++++------ pkg/transport/errors.go | 34 +++++++++------- pkg/transport/sse/transport.go | 9 +++-- pkg/transport/stdio/transport.go | 39 +++++++++--------- pkg/transport/transport.go | 65 +++++++++++++----------------- 9 files changed, 160 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index 1dba535..ab20585 100644 --- a/README.md +++ b/README.md @@ -87,10 +87,10 @@ Start the server with either stdio or SSE transport: ```bash # Start with stdio transport (default) -go-go-mcp start --transport stdio +go-go-mcp server start --transport stdio # Start with SSE transport -go-go-mcp start --transport sse --port 3001 +go-go-mcp server start --transport sse --port 3001 ``` The server automatically watches configured repositories and files for changes, reloading tools when: @@ -142,7 +142,7 @@ go-go-mcp server tools list --profile data Use the client subcommand to interact with an MCP server: ```bash -# List available prompts (uses default server: go-go-mcp start --transport stdio) +# List available prompts (uses default server: go-go-mcp server start --transport stdio) go-go-mcp client prompts list # List available tools @@ -171,7 +171,7 @@ go-go-mcp can be used as a bridge to expose an SSE server as a stdio server. Thi ```bash # Start an SSE server on port 3000 -go-go-mcp start --transport sse --port 3000 +go-go-mcp server start --transport sse --port 3000 # In another terminal, start the bridge to expose the SSE server as stdio go-go-mcp bridge --sse-url http://localhost:3000 --log-level debug @@ -186,7 +186,7 @@ This is particularly useful when integrating with tools that only support stdio Add the `--debug` flag to enable detailed logging: ```bash -go-go-mcp start --debug +go-go-mcp server start --debug ``` ### Version Information diff --git a/changelog.md b/changelog.md index 1ed436c..8a1b64e 100644 --- a/changelog.md +++ b/changelog.md @@ -1063,4 +1063,63 @@ Updated cobra command handling to support both full and minimal Glazed command l - Added support for GlazedMinimalCommandLayer in cobra command processing - Unified handling of common flags (print-yaml, print-parsed-parameters, etc.) between both layers - Maintained backward compatibility with full GlazedCommandLayer features -- Added placeholder for schema printing functionality \ No newline at end of file +- Added placeholder for schema printing functionality + +# Transport Layer Refactoring + +Implemented new transport layer architecture as described in RFC-01. This change: +- Creates a clean interface for different transport mechanisms +- Separates transport concerns from business logic +- Provides consistent error handling across transports +- Adds support for transport-specific options and capabilities + +- Created new transport package with core interfaces and types +- Implemented SSE transport using new architecture +- Added transport options system +- Added standardized error handling + +# Transport Layer Implementation + +Added stdio transport implementation using new transport layer architecture: +- Implemented stdio transport with proper signal handling and graceful shutdown +- Added support for configurable buffer sizes and logging +- Added proper error handling and JSON-RPC message processing +- Added context-based cancellation and cleanup + +# Server Layer Updates + +Updated server implementation to use new transport layer: +- Refactored Server struct to use transport interface +- Added RequestHandler to implement transport.RequestHandler interface +- Updated server command to support multiple transport types +- Improved error handling and logging throughout server layer + +# Enhanced SSE Transport + +Added support for integrating SSE transport with existing HTTP servers: +- Added standalone and integrated modes for SSE transport +- Added GetHandlers method to get SSE endpoint handlers +- Added RegisterHandlers method for router integration +- Added support for path prefixes and middleware +- Improved configuration options for HTTP server integration + +# Transport Interface Refactoring + +Simplified transport interface to use protocol types directly instead of custom types. +- Removed duplicate type definitions from transport package +- Use protocol.Request/Response/Notification types directly +- Improved type safety by removing interface{} usage + +# Transport Request ID Handling + +Added proper request ID handling to transport package: +- Added IsNotification helper to check for empty/null request IDs +- Improved notification detection for JSON-RPC messages +- Consistent handling of request IDs across transports + +# Transport ID Type Conversion + +Added helper functions for converting between string and JSON-RPC ID types: +- Added StringToID to convert string to json.RawMessage +- Added IDToString to convert json.RawMessage to string +- Improved type safety in ID handling across transports \ No newline at end of file diff --git a/cmd/go-go-mcp/cmds/server/start.go b/cmd/go-go-mcp/cmds/server/start.go index e66b31e..fe80ada 100644 --- a/cmd/go-go-mcp/cmds/server/start.go +++ b/cmd/go-go-mcp/cmds/server/start.go @@ -128,7 +128,9 @@ func (c *StartCommand) Run( // Start file watcher g.Go(func() error { if err := toolProvider.Watch(gctx); err != nil { - logger.Error().Err(err).Msg("failed to start file watcher") + if !errors.Is(err, context.Canceled) { + logger.Error().Err(err).Msg("failed to run file watcher") + } return err } return nil diff --git a/pkg/protocol/base.go b/pkg/protocol/base.go index e131fc1..c2b5c6e 100644 --- a/pkg/protocol/base.go +++ b/pkg/protocol/base.go @@ -1,6 +1,9 @@ package protocol -import "encoding/json" +import ( + "encoding/json" + "fmt" +) // Request represents a JSON-RPC 2.0 request. type Request struct { @@ -25,6 +28,10 @@ type Error struct { Data json.RawMessage `json:"data,omitempty"` } +func (e *Error) Error() string { + return fmt.Sprintf("code: %d, message: %s, data: %s", e.Code, e.Message, e.Data) +} + // Notification represents a JSON-RPC 2.0 notification. type Notification struct { JSONRPC string `json:"jsonrpc"` diff --git a/pkg/server/handler.go b/pkg/server/handler.go index 785918a..9151dfc 100644 --- a/pkg/server/handler.go +++ b/pkg/server/handler.go @@ -21,9 +21,9 @@ func NewRequestHandler(s *Server) *RequestHandler { } // HandleRequest processes a request and returns a response -func (h *RequestHandler) HandleRequest(ctx context.Context, req *transport.Request) (*transport.Response, error) { +func (h *RequestHandler) HandleRequest(ctx context.Context, req *protocol.Request) (*protocol.Response, error) { // Validate JSON-RPC version - if req.Headers["jsonrpc"] != "2.0" { + if req.JSONRPC != "2.0" { return nil, transport.NewInvalidRequestError("invalid JSON-RPC version") } @@ -50,7 +50,7 @@ func (h *RequestHandler) HandleRequest(ctx context.Context, req *transport.Reque } // HandleNotification processes a notification (no response expected) -func (h *RequestHandler) HandleNotification(ctx context.Context, notif *transport.Notification) error { +func (h *RequestHandler) HandleNotification(ctx context.Context, notif *protocol.Notification) error { switch notif.Method { case "notifications/initialized": h.server.logger.Info().Msg("Client initialized") @@ -62,20 +62,20 @@ func (h *RequestHandler) HandleNotification(ctx context.Context, notif *transpor } // Helper method to create success response -func (h *RequestHandler) newSuccessResponse(id string, result interface{}) (*transport.Response, error) { +func (h *RequestHandler) newSuccessResponse(id json.RawMessage, result interface{}) (*protocol.Response, error) { resultJSON, err := json.Marshal(result) if err != nil { return nil, fmt.Errorf("failed to marshal result: %w", err) } - return &transport.Response{ + return &protocol.Response{ ID: id, Result: resultJSON, }, nil } // Individual request handlers -func (h *RequestHandler) handleInitialize(ctx context.Context, req *transport.Request) (*transport.Response, error) { +func (h *RequestHandler) handleInitialize(ctx context.Context, req *protocol.Request) (*protocol.Response, error) { var params protocol.InitializeParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { return nil, transport.NewInvalidParamsError(err.Error()) @@ -120,11 +120,11 @@ func (h *RequestHandler) handleInitialize(ctx context.Context, req *transport.Re return h.newSuccessResponse(req.ID, result) } -func (h *RequestHandler) handlePing(_ context.Context, req *transport.Request) (*transport.Response, error) { +func (h *RequestHandler) handlePing(_ context.Context, req *protocol.Request) (*protocol.Response, error) { return h.newSuccessResponse(req.ID, struct{}{}) } -func (h *RequestHandler) handlePromptsList(ctx context.Context, req *transport.Request) (*transport.Response, error) { +func (h *RequestHandler) handlePromptsList(ctx context.Context, req *protocol.Request) (*protocol.Response, error) { var params struct { Cursor string `json:"cursor"` } @@ -147,7 +147,7 @@ func (h *RequestHandler) handlePromptsList(ctx context.Context, req *transport.R }) } -func (h *RequestHandler) handlePromptsGet(ctx context.Context, req *transport.Request) (*transport.Response, error) { +func (h *RequestHandler) handlePromptsGet(ctx context.Context, req *protocol.Request) (*protocol.Response, error) { var params struct { Name string `json:"name"` Arguments map[string]string `json:"arguments"` @@ -164,7 +164,7 @@ func (h *RequestHandler) handlePromptsGet(ctx context.Context, req *transport.Re return h.newSuccessResponse(req.ID, prompt) } -func (h *RequestHandler) handleResourcesList(ctx context.Context, req *transport.Request) (*transport.Response, error) { +func (h *RequestHandler) handleResourcesList(ctx context.Context, req *protocol.Request) (*protocol.Response, error) { var params struct { Cursor string `json:"cursor"` } @@ -187,7 +187,7 @@ func (h *RequestHandler) handleResourcesList(ctx context.Context, req *transport }) } -func (h *RequestHandler) handleResourcesRead(ctx context.Context, req *transport.Request) (*transport.Response, error) { +func (h *RequestHandler) handleResourcesRead(ctx context.Context, req *protocol.Request) (*protocol.Response, error) { var params struct { Name string `json:"name"` } @@ -205,7 +205,7 @@ func (h *RequestHandler) handleResourcesRead(ctx context.Context, req *transport }) } -func (h *RequestHandler) handleToolsList(ctx context.Context, req *transport.Request) (*transport.Response, error) { +func (h *RequestHandler) handleToolsList(ctx context.Context, req *protocol.Request) (*protocol.Response, error) { var params struct { Cursor string `json:"cursor"` } @@ -228,7 +228,7 @@ func (h *RequestHandler) handleToolsList(ctx context.Context, req *transport.Req }) } -func (h *RequestHandler) handleToolsCall(ctx context.Context, req *transport.Request) (*transport.Response, error) { +func (h *RequestHandler) handleToolsCall(ctx context.Context, req *protocol.Request) (*protocol.Response, error) { var params struct { Name string `json:"name"` Arguments map[string]interface{} `json:"arguments"` diff --git a/pkg/transport/errors.go b/pkg/transport/errors.go index c6da4e2..6226e36 100644 --- a/pkg/transport/errors.go +++ b/pkg/transport/errors.go @@ -1,6 +1,10 @@ package transport -import "fmt" +import ( + "fmt" + + "github.com/go-go-golems/go-go-mcp/pkg/protocol" +) // Common error codes const ( @@ -14,50 +18,50 @@ const ( ) // Error constructors -func NewParseError(msg string) *ResponseError { - return &ResponseError{ +func NewParseError(msg string) *protocol.Error { + return &protocol.Error{ Code: ErrCodeParse, Message: fmt.Sprintf("Parse error: %s", msg), } } -func NewInvalidRequestError(msg string) *ResponseError { - return &ResponseError{ +func NewInvalidRequestError(msg string) *protocol.Error { + return &protocol.Error{ Code: ErrCodeInvalidRequest, Message: fmt.Sprintf("Invalid request: %s", msg), } } -func NewMethodNotFoundError(msg string) *ResponseError { - return &ResponseError{ +func NewMethodNotFoundError(msg string) *protocol.Error { + return &protocol.Error{ Code: ErrCodeMethodNotFound, Message: fmt.Sprintf("Method not found: %s", msg), } } -func NewInvalidParamsError(msg string) *ResponseError { - return &ResponseError{ +func NewInvalidParamsError(msg string) *protocol.Error { + return &protocol.Error{ Code: ErrCodeInvalidParams, Message: fmt.Sprintf("Invalid params: %s", msg), } } -func NewInternalError(msg string) *ResponseError { - return &ResponseError{ +func NewInternalError(msg string) *protocol.Error { + return &protocol.Error{ Code: ErrCodeInternal, Message: fmt.Sprintf("Internal error: %s", msg), } } -func NewTransportError(msg string) *ResponseError { - return &ResponseError{ +func NewTransportError(msg string) *protocol.Error { + return &protocol.Error{ Code: ErrCodeTransport, Message: fmt.Sprintf("Transport error: %s", msg), } } -func NewTimeoutError(msg string) *ResponseError { - return &ResponseError{ +func NewTimeoutError(msg string) *protocol.Error { + return &protocol.Error{ Code: ErrCodeTimeout, Message: fmt.Sprintf("Timeout error: %s", msg), } diff --git a/pkg/transport/sse/transport.go b/pkg/transport/sse/transport.go index 5eda558..330bf16 100644 --- a/pkg/transport/sse/transport.go +++ b/pkg/transport/sse/transport.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/go-go-golems/go-go-mcp/pkg/protocol" "github.com/go-go-golems/go-go-mcp/pkg/transport" "github.com/google/uuid" "github.com/gorilla/mux" @@ -45,7 +46,7 @@ type SSETransport struct { type SSEClient struct { id string sessionID string - messageChan chan *transport.Response + messageChan chan *protocol.Response createdAt time.Time remoteAddr string userAgent string @@ -140,7 +141,7 @@ func (s *SSETransport) Listen(ctx context.Context, handler transport.RequestHand } } -func (s *SSETransport) Send(ctx context.Context, response *transport.Response) error { +func (s *SSETransport) Send(ctx context.Context, response *protocol.Response) error { s.mu.RLock() defer s.mu.RUnlock() @@ -255,7 +256,7 @@ func (s *SSETransport) handleSSE(w http.ResponseWriter, r *http.Request) { client := &SSEClient{ id: clientID, sessionID: sessionID, - messageChan: make(chan *transport.Response, 100), + messageChan: make(chan *protocol.Response, 100), createdAt: time.Now(), remoteAddr: r.RemoteAddr, userAgent: r.UserAgent(), @@ -318,7 +319,7 @@ func (s *SSETransport) handleMessages(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), sessionIDKey, sessionID) - var request transport.Request + var request protocol.Request if err := json.NewDecoder(r.Body).Decode(&request); err != nil { s.logger.Error().Err(err).Msg("Failed to decode request") w.WriteHeader(http.StatusBadRequest) diff --git a/pkg/transport/stdio/transport.go b/pkg/transport/stdio/transport.go index 9a2d463..9f5130c 100644 --- a/pkg/transport/stdio/transport.go +++ b/pkg/transport/stdio/transport.go @@ -12,6 +12,7 @@ import ( "syscall" "time" + "github.com/go-go-golems/go-go-mcp/pkg/protocol" "github.com/go-go-golems/go-go-mcp/pkg/transport" "github.com/rs/zerolog" ) @@ -141,7 +142,7 @@ func (s *StdioTransport) Listen(ctx context.Context, handler transport.RequestHa } } -func (s *StdioTransport) Send(ctx context.Context, response *transport.Response) error { +func (s *StdioTransport) Send(ctx context.Context, response *protocol.Response) error { s.mu.Lock() defer s.mu.Unlock() @@ -197,7 +198,7 @@ func (s *StdioTransport) handleMessage(message string) error { Msg("Processing message") // Parse the base message structure - var request transport.Request + var request protocol.Request if err := json.Unmarshal([]byte(message), &request); err != nil { s.logger.Error(). Err(err). @@ -207,9 +208,9 @@ func (s *StdioTransport) handleMessage(message string) error { } // Handle requests vs notifications based on ID presence - if request.ID != "" { + if !transport.IsNotification(&request) { s.logger.Debug(). - Str("id", request.ID). + RawJSON("id", request.ID). Str("method", request.Method). Msg("Handling request") return s.handleRequest(request) @@ -218,17 +219,20 @@ func (s *StdioTransport) handleMessage(message string) error { s.logger.Debug(). Str("method", request.Method). Msg("Handling notification") - return s.handleNotification(request) + return s.handleNotification(protocol.Notification{ + Method: request.Method, + Params: request.Params, + }) } -func (s *StdioTransport) handleRequest(request transport.Request) error { +func (s *StdioTransport) handleRequest(request protocol.Request) error { response, err := s.handler.HandleRequest(context.Background(), &request) if err != nil { s.logger.Error(). Err(err). Str("method", request.Method). Msg("Error handling request") - return s.sendError(&request.ID, transport.ErrCodeInternal, "Internal error", err) + return s.sendError(request.ID, transport.ErrCodeInternal, "Internal error", err) } if response != nil { @@ -238,17 +242,12 @@ func (s *StdioTransport) handleRequest(request transport.Request) error { return nil } -func (s *StdioTransport) handleNotification(request transport.Request) error { - notification := &transport.Notification{ - Method: request.Method, - Params: request.Params, - Headers: request.Headers, - } +func (s *StdioTransport) handleNotification(notification protocol.Notification) error { - if err := s.handler.HandleNotification(context.Background(), notification); err != nil { + if err := s.handler.HandleNotification(context.Background(), ¬ification); err != nil { s.logger.Error(). Err(err). - Str("method", request.Method). + Str("method", notification.Method). Msg("Error handling notification") // Don't send error responses for notifications } @@ -256,7 +255,7 @@ func (s *StdioTransport) handleNotification(request transport.Request) error { return nil } -func (s *StdioTransport) sendError(id *string, code int, message string, data interface{}) error { +func (s *StdioTransport) sendError(id json.RawMessage, code int, message string, data interface{}) error { var errorData json.RawMessage if data != nil { var err error @@ -268,15 +267,13 @@ func (s *StdioTransport) sendError(id *string, code int, message string, data in } } - response := &transport.Response{ - Error: &transport.ResponseError{ + response := &protocol.Response{ + Error: &protocol.Error{ Code: code, Message: message, Data: errorData, }, - } - if id != nil { - response.ID = *id + ID: id, } s.logger.Debug().Interface("response", response).Msg("Sending error response") diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index 2a74101..cf9137a 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -3,7 +3,8 @@ package transport import ( "context" "encoding/json" - "fmt" + + "github.com/go-go-golems/go-go-mcp/pkg/protocol" ) // Transport handles the low-level communication between client and server @@ -12,7 +13,7 @@ type Transport interface { Listen(ctx context.Context, handler RequestHandler) error // Send transmits a response back to the client - Send(ctx context.Context, response *Response) error + Send(ctx context.Context, response *protocol.Response) error // Close cleanly shuts down the transport Close(ctx context.Context) error @@ -29,45 +30,37 @@ type TransportInfo struct { Metadata map[string]string // Additional transport metadata } -// RequestHandler processes incoming requests and notifications -type RequestHandler interface { - // HandleRequest processes a request and returns a response - HandleRequest(ctx context.Context, req *Request) (*Response, error) - - // HandleNotification processes a notification (no response expected) - HandleNotification(ctx context.Context, notif *Notification) error +// IsNotification checks if a request is a notification (no ID) +func IsNotification(req *protocol.Request) bool { + return req.ID == nil || string(req.ID) == "null" || len(req.ID) == 0 } -// Request represents an incoming JSON-RPC request -type Request struct { - ID string - Method string - Params json.RawMessage - Headers map[string]string +// StringToID converts a string to a JSON-RPC ID (json.RawMessage) +func StringToID(s string) json.RawMessage { + if s == "" { + return nil + } + // Quote the string to make it a valid JSON string + return json.RawMessage(`"` + s + `"`) } -// Response represents an outgoing JSON-RPC response -type Response struct { - ID string - Result json.RawMessage - Error *ResponseError - Headers map[string]string +// IDToString converts a JSON-RPC ID to a string +func IDToString(id json.RawMessage) string { + if id == nil { + return "" + } + var s string + if err := json.Unmarshal(id, &s); err != nil { + return string(id) + } + return s } -// Notification represents an incoming notification -type Notification struct { - Method string - Params json.RawMessage - Headers map[string]string -} - -// ResponseError represents a JSON-RPC error response -type ResponseError struct { - Code int `json:"code"` - Message string `json:"message"` - Data json.RawMessage `json:"data,omitempty"` -} +// RequestHandler processes incoming requests and notifications +type RequestHandler interface { + // HandleRequest processes a request and returns a response + HandleRequest(ctx context.Context, req *protocol.Request) (*protocol.Response, error) -func (r *ResponseError) Error() string { - return fmt.Sprintf("code: %d, message: %s, data: %s", r.Code, r.Message, string(r.Data)) + // HandleNotification processes a notification (no response expected) + HandleNotification(ctx context.Context, notif *protocol.Notification) error }