Skip to content

Commit

Permalink
✨ Add SSE mcp bridge
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 26, 2025
1 parent b571b60 commit 51bc7d5
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 1 deletion.
34 changes: 34 additions & 0 deletions cmd/mcp-server/cmds/bridge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package cmds

import (
"context"
"fmt"

"github.com/go-go-golems/go-go-mcp/pkg/server/transports/stdio"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
)

func NewBridgeCommand(logger zerolog.Logger) *cobra.Command {
var sseURL string

cmd := &cobra.Command{
Use: "bridge",
Short: "Start a stdio server that bridges to an SSE server",
Long: `Start a stdio server that forwards all requests to an SSE server.
This is useful when you want to connect a stdio client to a remote SSE server.`,
RunE: func(cmd *cobra.Command, args []string) error {
if sseURL == "" {
return fmt.Errorf("SSE URL is required")
}

server := stdio.NewSSEBridgeServer(logger, sseURL)
return server.Start(context.Background())
},
}

cmd.Flags().StringVarP(&sseURL, "sse-url", "s", "", "URL of the SSE server to connect to (required)")
_ = cmd.MarkFlagRequired("sse-url")

return cmd
}
7 changes: 6 additions & 1 deletion cmd/mcp-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package main

import (
"fmt"
"github.com/go-go-golems/go-go-mcp/pkg/cmds"
"os"

"github.com/go-go-golems/go-go-mcp/pkg/cmds"

clay "github.com/go-go-golems/clay/pkg"

"github.com/go-go-golems/glazed/pkg/cli"
Expand All @@ -13,6 +14,7 @@ import (
"github.com/go-go-golems/glazed/pkg/help"
server_cmds "github.com/go-go-golems/go-go-mcp/cmd/mcp-server/cmds"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"

glazed_cmds "github.com/go-go-golems/glazed/pkg/cmds"
Expand Down Expand Up @@ -85,6 +87,9 @@ func initRootCmd() (*help.HelpSystem, error) {
cobra.CheckErr(err)
rootCmd.AddCommand(cobraSchemaCmd)

bridgeCmd := server_cmds.NewBridgeCommand(log.Logger)
rootCmd.AddCommand(bridgeCmd)

// Add version command
versionCmd := &cobra.Command{
Use: "version",
Expand Down
209 changes: 209 additions & 0 deletions pkg/server/transports/stdio/sse_bridge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
package stdio

import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"os/signal"
"sync"
"syscall"
"time"

"github.com/go-go-golems/go-go-mcp/pkg/client"
"github.com/go-go-golems/go-go-mcp/pkg/protocol"
"github.com/rs/zerolog"
)

// SSEBridgeServer is a stdio server that forwards all requests to an SSE server
type SSEBridgeServer struct {
scanner *bufio.Scanner
writer *json.Encoder
logger zerolog.Logger
sseClient *client.SSETransport
signalChan chan os.Signal
mu sync.Mutex
}

// NewSSEBridgeServer creates a new stdio server instance that forwards to SSE
func NewSSEBridgeServer(logger zerolog.Logger, sseURL string) *SSEBridgeServer {
scanner := bufio.NewScanner(os.Stdin)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer

// Create a ConsoleWriter that writes to stderr with a SERVER tag
consoleWriter := zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: time.RFC3339,
FormatMessage: func(i interface{}) string {
return fmt.Sprintf("[STDIO-SSE-BRIDGE] %s", i)
},
}

// Create a new logger that writes to the tagged stderr
taggedLogger := logger.Output(consoleWriter)

return &SSEBridgeServer{
scanner: scanner,
writer: json.NewEncoder(os.Stdout),
logger: taggedLogger,
sseClient: client.NewSSETransport(sseURL, taggedLogger),
signalChan: make(chan os.Signal, 1),
}
}

// Start begins listening for and handling messages on stdio
func (s *SSEBridgeServer) Start(ctx context.Context) error {
s.logger.Info().Msg("Starting stdio-sse bridge server...")

// Set up signal handling
signal.Notify(s.signalChan, os.Interrupt, syscall.SIGTERM)
defer signal.Stop(s.signalChan)

// Create a channel for scanner errors
scanErrChan := make(chan error, 1)

// Create a cancellable context for the scanner
scannerCtx, cancelScanner := context.WithCancel(ctx)
defer cancelScanner()

// Start scanning in a goroutine
go func() {
for s.scanner.Scan() {
select {
case <-scannerCtx.Done():
s.logger.Debug().Msg("Context cancelled, stopping scanner")
scanErrChan <- scannerCtx.Err()
return
default:
line := s.scanner.Text()
s.logger.Debug().
Str("line", line).
Msg("Received line")
if err := s.handleMessage(line); err != nil {
s.logger.Error().Err(err).Msg("Error handling message")
// Continue processing messages even if one fails
}
}
}

if err := s.scanner.Err(); err != nil {
s.logger.Error().
Err(err).
Msg("Scanner error")
scanErrChan <- fmt.Errorf("scanner error: %w", err)
return
}

s.logger.Debug().Msg("Scanner reached EOF")
scanErrChan <- nil
}()

// Wait for either a signal, context cancellation, or scanner error
select {
case sig := <-s.signalChan:
s.logger.Debug().
Str("signal", sig.String()).
Msg("Received signal in stdio server")
cancelScanner()
return nil
case <-ctx.Done():
s.logger.Debug().
Err(ctx.Err()).
Msg("Context cancelled in stdio server")
return ctx.Err()
case err := <-scanErrChan:
if err == nil {
s.logger.Debug().Msg("Scanner completed normally")
} else {
s.logger.Error().
Err(err).
Msg("Scanner error in stdio server")
}
return err
}
}

// Stop gracefully stops the stdio server
func (s *SSEBridgeServer) Stop(ctx context.Context) error {
s.logger.Info().Msg("Stopping stdio-sse bridge server")

// Close SSE client connection
if err := s.sseClient.Close(ctx); err != nil {
s.logger.Error().Err(err).Msg("Error closing SSE client")
}

// Wait for context to be done or timeout
select {
case <-ctx.Done():
s.logger.Debug().
Err(ctx.Err()).
Msg("Stop context cancelled before clean shutdown")
return ctx.Err()
case <-time.After(100 * time.Millisecond): // Give a small grace period for cleanup
s.logger.Debug().Msg("Stop completed successfully")
return nil
}
}

// handleMessage processes a single message
func (s *SSEBridgeServer) handleMessage(message string) error {
s.mu.Lock()
defer s.mu.Unlock()

s.logger.Debug().
Str("message", message).
Msg("Processing message")

// Parse the base message structure
var request protocol.Request
if err := json.Unmarshal([]byte(message), &request); err != nil {
s.logger.Error().
Err(err).
Str("message", message).
Msg("Failed to parse message as JSON-RPC request")
return s.sendError(nil, -32700, "Parse error", err)
}

// Forward the request to the SSE server
response, err := s.sseClient.Send(context.Background(), &request)
if err != nil {
s.logger.Error().
Err(err).
Str("method", request.Method).
Msg("Error forwarding request to SSE server")
return s.sendError(&request.ID, -32603, "Internal error", err)
}

// Send the response back over stdio
return s.writer.Encode(response)
}

// sendError sends an error response
func (s *SSEBridgeServer) sendError(id *json.RawMessage, code int, message string, data interface{}) error {
var errorData json.RawMessage
if data != nil {
var err error
errorData, err = json.Marshal(data)
if err != nil {
// If we can't marshal the error data, log it and send a simpler error
s.logger.Error().Err(err).Interface("data", data).Msg("Failed to marshal error data")
return s.sendError(id, -32603, "Internal error marshaling error data", nil)
}
}

response := protocol.Response{
JSONRPC: "2.0",
Error: &protocol.Error{
Code: code,
Message: message,
Data: errorData,
},
}
if id != nil {
response.ID = *id
}

s.logger.Debug().Interface("response", response).Msg("Sending error response")
return s.writer.Encode(response)
}

0 comments on commit 51bc7d5

Please sign in to comment.