Skip to content

Commit

Permalink
🚀 Make SSE client run asynchronously
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 21, 2025
1 parent d6cec0a commit eeed63f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 35 deletions.
10 changes: 9 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,12 @@ Enhanced error handling in SSE server to better comply with JSON-RPC 2.0 specifi

Removed non-standard endpoint event from SSE server to better align with the official protocol specification.

- Removed endpoint event from SSE server implementation
- Removed endpoint event from SSE server implementation

# Async SSE Client

Made the SSE client run asynchronously to prevent blocking on subscription:
- Added initialization synchronization channel
- Moved SSE subscription to a goroutine
- Improved error handling and state management
- Added proper mutex protection for shared state
87 changes: 53 additions & 34 deletions pkg/client/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type SSETransport struct {
closeOnce sync.Once
closeChan chan struct{}
initialized bool
initChan chan struct{} // Channel to signal initialization completion
logger zerolog.Logger
}

Expand All @@ -36,28 +37,36 @@ func NewSSETransport(baseURL string) *SSETransport {
sseClient: sse.NewClient(baseURL + "/sse"),
events: make(chan *sse.Event),
closeChan: make(chan struct{}),
initChan: make(chan struct{}),
logger: zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}).With().Timestamp().Logger(),
}
}

// Send sends a request and returns the response
func (t *SSETransport) Send(request *protocol.Request) (*protocol.Response, error) {
t.mu.Lock()
defer t.mu.Unlock()

t.logger.Debug().
Str("method", request.Method).
Interface("params", request.Params).
Msg("Sending request")

// Initialize SSE connection if not already done
if !t.initialized {
t.mu.Unlock()
t.logger.Debug().Msg("Initializing SSE connection")
if err := t.initializeSSE(); err != nil {
t.logger.Error().Err(err).Msg("Failed to initialize SSE")
return nil, fmt.Errorf("failed to initialize SSE: %w", err)
}
// Wait for initialization to complete
select {
case <-t.initChan:
t.logger.Debug().Msg("SSE initialization completed")
case <-t.closeChan:
return nil, fmt.Errorf("transport closed during initialization")
}
t.mu.Lock()
}
defer t.mu.Unlock()

t.logger.Debug().
Str("method", request.Method).
Interface("params", request.Params).
Msg("Sending request")

// Send request via HTTP POST
reqBody, err := json.Marshal(request)
Expand Down Expand Up @@ -121,34 +130,44 @@ func (t *SSETransport) Send(request *protocol.Request) (*protocol.Response, erro
func (t *SSETransport) initializeSSE() error {
t.logger.Debug().Str("url", t.baseURL+"/sse").Msg("Setting up SSE connection")

// Subscribe to SSE events
if err := t.sseClient.SubscribeRaw(func(msg *sse.Event) {
// Handle session ID event
if string(msg.Event) == "session" {
t.sessionID = string(msg.Data)
t.logger.Debug().Str("sessionID", t.sessionID).Msg("Received session ID")
return
// Start SSE subscription in a goroutine
go func() {
defer close(t.initChan)

err := t.sseClient.SubscribeRaw(func(msg *sse.Event) {
// Handle session ID event
if string(msg.Event) == "session" {
t.mu.Lock()
t.sessionID = string(msg.Data)
t.initialized = true
t.mu.Unlock()
t.logger.Debug().Str("sessionID", t.sessionID).Msg("Received session ID")
return
}

t.logger.Debug().
Str("event", string(msg.Event)).
RawJSON("data", msg.Data).
Msg("Received SSE event")

// Forward other events to the events channel
select {
case t.events <- msg:
t.logger.Debug().Msg("Forwarded event to channel")
case <-t.closeChan:
t.logger.Debug().Msg("Transport closed while forwarding event")
}
})

if err != nil {
t.logger.Error().Err(err).Msg("SSE subscription failed")
// Signal initialization failure
t.mu.Lock()
t.initialized = false
t.mu.Unlock()
}
}()

t.logger.Debug().
Str("event", string(msg.Event)).
RawJSON("data", msg.Data).
Msg("Received SSE event")

// Forward other events to the events channel
select {
case t.events <- msg:
t.logger.Debug().Msg("Forwarded event to channel")
case <-t.closeChan:
t.logger.Debug().Msg("Transport closed while forwarding event")
}
}); err != nil {
t.logger.Error().Err(err).Msg("Failed to subscribe to SSE")
return fmt.Errorf("failed to subscribe to SSE: %w", err)
}

t.initialized = true
t.logger.Debug().Msg("SSE connection initialized")
return nil
}

Expand Down

0 comments on commit eeed63f

Please sign in to comment.