Skip to content

Commit

Permalink
initial mcp for stdio
Browse files Browse the repository at this point in the history
  • Loading branch information
hatchan committed Feb 28, 2025
1 parent 95eb7f1 commit 5244142
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 107 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion otel-worker-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ url = { version = "2.5" }


futures = "0.3"
tokio-stream = "0.1"
tokio-stream = { version = "0.1", features = ["sync"] }
axum-jrpc = "0.7"
rust-mcp-schema = { version = "0.1.0", features = ["2024_11_05"] }
async-stream = "0.3"
Expand Down
38 changes: 17 additions & 21 deletions otel-worker-cli/src/commands/mcp.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
use anyhow::{bail, Context, Result};
use axum::response::sse::Event;
use futures::StreamExt;
use otel_worker_core::api::client::{self, ApiClient};
use otel_worker_core::api::models::{ServerMessage, ServerMessageDetails};
use otel_worker_core::api::models;
use rust_mcp_schema::schema_utils::ServerMessage;
use rust_mcp_schema::{
Implementation, InitializeRequestParams, InitializeResult, ListResourcesRequestParams,
ListResourcesResult, ReadResourceRequestParams, ReadResourceResult,
ReadResourceResultContentsItem, Resource, ResourceListChangedNotification, ServerCapabilities,
ServerCapabilitiesResources, TextResourceContents,
};
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use tokio_tungstenite::tungstenite::Message;
use tracing::{debug, info, warn};
use url::Url;

mod http_sse;
mod stdio;

#[derive(clap::Args, Debug)]
pub struct Args {
Expand Down Expand Up @@ -45,15 +47,15 @@ pub async fn handle_command(args: Args) -> Result<()> {
// have to worry about error handling inside the [`tokio::task`].
let websocket_url = get_ws_url(&args.otel_worker_url)?;

let client = client::builder(args.otel_worker_url)
let api_client = client::builder(args.otel_worker_url)
.set_bearer_token(args.otel_worker_token)
.build();

// This broadcast pair is used for async communication back to the MCP
// client through SSE.
let (notifications, _) = tokio::sync::broadcast::channel(100);
let (notifications_tx, notifications_rx) = broadcast::channel(100);

let ws_sender = notifications.clone();
let ws_sender = notifications_tx.clone();
let ws_handle = tokio::spawn(async move {
info!(?websocket_url, "Connecting to websocket");

Expand All @@ -67,24 +69,15 @@ pub async fn handle_command(args: Args) -> Result<()> {
break;
};

debug!("Yay message!");

if let Message::Text(content) = message {
let msg: ServerMessage =
let msg: models::ServerMessage =
serde_json::from_str(&content).expect("Should be able to deserialize it");

match msg.details {
ServerMessageDetails::SpanAdded(_span_added) => {
models::ServerMessageDetails::SpanAdded(_span_added) => {
let data = ResourceListChangedNotification::new(None);
ws_sender
.send(
Event::default()
.event("message")
.json_data(data)
.expect("serialization should just work"),
)
.ok();
debug!("list_changed message send");
let message = ServerMessage::Notification(data.into());
ws_sender.send(message).ok();
}
_ => debug!("Irrelevant message"),
}
Expand All @@ -96,8 +89,10 @@ pub async fn handle_command(args: Args) -> Result<()> {
});

match args.transport {
Transport::Stdio => todo!("implement me!"),
Transport::HttpSse => http_sse::serve(&args.listen_address, notifications, client).await?,
Transport::Stdio => stdio::serve(notifications_tx, notifications_rx, api_client).await?,
Transport::HttpSse => {
http_sse::serve(&args.listen_address, notifications_tx, api_client).await?
}
}

ws_handle.abort();
Expand Down Expand Up @@ -137,6 +132,7 @@ async fn handle_initialize(params: InitializeRequestParams) -> Result<Initialize

// We only support one version for now
if params.protocol_version != MCP_VERSION {
debug!(?params, "unsupported version");
anyhow::bail!("unsupported version")
}

Expand Down Expand Up @@ -177,7 +173,7 @@ async fn handle_initialize(params: InitializeRequestParams) -> Result<Initialize

async fn handle_resources_list(
client: &ApiClient,
_params: ListResourcesRequestParams,
_params: Option<ListResourcesRequestParams>,
) -> Result<ListResourcesResult> {
let resources = client
.trace_list()

Check failure on line 179 in otel-worker-cli/src/commands/mcp.rs

View workflow job for this annotation

GitHub Actions / Create binary for x86_64-unknown-linux-gnu

[clippy] reported by reviewdog 🐶 error[E0061]: this method takes 2 arguments but 0 arguments were supplied --> otel-worker-cli/src/commands/mcp.rs:179:10 | 179 | .trace_list() | ^^^^^^^^^^-- two arguments of type `std::option::Option<u32>` and `std::option::Option<time::OffsetDateTime>` are missing | note: method defined here --> /home/runner/work/otel-worker/otel-worker/otel-worker-core/src/api/client.rs:230:18 | 230 | pub async fn trace_list( | ^^^^^^^^^^ help: provide the arguments | 179 | .trace_list(/* std::option::Option<u32> */, /* std::option::Option<time::OffsetDateTime> */) | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Raw Output: otel-worker-cli/src/commands/mcp.rs:179:20:e:error[E0061]: this method takes 2 arguments but 0 arguments were supplied --> otel-worker-cli/src/commands/mcp.rs:179:10 | 179 | .trace_list() | ^^^^^^^^^^-- two arguments of type `std::option::Option<u32>` and `std::option::Option<time::OffsetDateTime>` are missing | note: method defined here --> /home/runner/work/otel-worker/otel-worker/otel-worker-core/src/api/client.rs:230:18 | 230 | pub async fn trace_list( | ^^^^^^^^^^ help: provide the arguments | 179 | .trace_list(/* std::option::Option<u32> */, /* std::option::Option<time::OffsetDateTime> */) | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ __END__
Expand Down
140 changes: 55 additions & 85 deletions otel-worker-cli/src/commands/mcp/http_sse.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
use anyhow::{Context, Result};
use async_stream::try_stream;
use super::{handle_initialize, handle_resources_list, handle_resources_read};
use anyhow::{Context, Error, Result};
use axum::extract::{MatchedPath, Request, State};
use axum::middleware::{self, Next};
use axum::response::sse::Event;
use axum::response::{IntoResponse, Sse};
use axum::routing::{get, post};
use axum_jrpc::error::{JsonRpcError, JsonRpcErrorReason};
use axum_jrpc::{JsonRpcExtractor, JsonRpcResponse, Value};
use futures::Stream;
use axum_jrpc::JsonRpcExtractor;
use futures::{Stream, StreamExt};
use http::StatusCode;
use otel_worker_core::api::client::ApiClient;
use std::convert::Infallible;
use rust_mcp_schema::schema_utils::{
ResultFromServer, RpcErrorCodes, ServerJsonrpcResponse, ServerMessage,
};
use rust_mcp_schema::{JsonrpcError, RequestId};
use std::process::exit;
use std::time::{Duration, Instant};
use tokio::net::TcpListener;
use tokio::sync::broadcast::{self, Sender};
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tracing::{debug, error, info, info_span, warn, Instrument};

use super::{handle_initialize, handle_resources_list, handle_resources_read};

pub async fn serve(
pub(crate) async fn serve(
listen_address: &str,
notifications: broadcast::Sender<Event>,
notifications: broadcast::Sender<ServerMessage>,
api_client: ApiClient,
) -> Result<()> {
let listener = tokio::net::TcpListener::bind(listen_address)
let listener = TcpListener::bind(listen_address)
.await
.with_context(|| format!("Failed to bind to address: {}", listen_address))?;

Expand All @@ -41,7 +43,7 @@ pub async fn serve(
Ok(())
}

pub fn build_mcp_service(notifications: Sender<Event>, api_client: ApiClient) -> axum::Router {
fn build_mcp_service(notifications: Sender<ServerMessage>, api_client: ApiClient) -> axum::Router {
let state = McpState {
api_client,
notifications,
Expand All @@ -56,17 +58,12 @@ pub fn build_mcp_service(notifications: Sender<Event>, api_client: ApiClient) ->
#[derive(Clone)]
struct McpState {
api_client: ApiClient,
notifications: Sender<Event>,
notifications: Sender<ServerMessage>,
}

impl McpState {
fn reply(&self, response: JsonRpcResponse) {
let event = Event::default()
.event("message")
.json_data(response)
.expect("unable to serialize data");

if let Err(err) = self.notifications.send(event) {
fn reply(&self, message: ServerMessage) {
if let Err(err) = self.notifications.send(message) {
warn!(?err, "A reply was send, but client is connected");
}
}
Expand All @@ -79,67 +76,43 @@ async fn json_rpc_handler(
) -> impl IntoResponse {
tokio::spawn(async move {
let answer_id = req.get_answer_id();
let result = match req.method() {
let result: Result<ResultFromServer> = match req.method() {
"initialize" => handle_initialize(req.parse_params().unwrap())
.await
.map(|result| JsonRpcResponse::success(answer_id.clone(), result))
.unwrap_or_else(|err| {
error!(?err, "initialization failed");
JsonRpcResponse::error(
answer_id,
JsonRpcError::new(
JsonRpcErrorReason::InternalError,
"message".to_string(),
Value::Null,
),
)
}),
.map(Into::into),
"resources/list" => {
handle_resources_list(&state.api_client, req.parse_params().unwrap())
.await
.map(|result| JsonRpcResponse::success(answer_id.clone(), result))
.unwrap_or_else(|err| {
error!(?err, "handle_resources_list");
JsonRpcResponse::error(
answer_id,
JsonRpcError::new(
JsonRpcErrorReason::InternalError,
"message".to_string(),
Value::Null,
),
)
})
.map(Into::into)
}
"resources/read" => {
handle_resources_read(&state.api_client, req.parse_params().unwrap())
.await
.map(|result| JsonRpcResponse::success(answer_id.clone(), result))
.unwrap_or_else(|err| {
error!(?err, "handle_resources_read");
JsonRpcResponse::error(
answer_id,
JsonRpcError::new(
JsonRpcErrorReason::InternalError,
"message".to_string(),
Value::Null,
),
)
})
.map(Into::into)
}
method => {
error!(?method, "RPC used a unsupported method");
JsonRpcResponse::error(
answer_id,
JsonRpcError::new(
JsonRpcErrorReason::MethodNotFound,
"message".to_string(),
Value::Null,
),
)
Err(Error::msg("unknown method"))
}
};

state.reply(result);
let id = match answer_id {
axum_jrpc::Id::Num(val) => RequestId::Integer(val),
axum_jrpc::Id::Str(val) => RequestId::String(val),
axum_jrpc::Id::None(_) => panic!("id should be set"),
};

let response: ServerMessage = match result {
Ok(result) => ServerMessage::Response(ServerJsonrpcResponse::new(id, result)),
Err(_) => ServerMessage::Error(JsonrpcError::create(
id,
RpcErrorCodes::INTERNAL_ERROR,
"error_message".to_string(),
None,
)),
};

state.reply(response);
});

StatusCode::ACCEPTED
Expand All @@ -148,30 +121,27 @@ async fn json_rpc_handler(
#[tracing::instrument(skip(state))]
async fn sse_handler(
State(state): State<McpState>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
) -> Sse<impl Stream<Item = Result<Event, BroadcastStreamRecvError>>> {
debug!("MCP client connected to the SSE handler");

// We will subscribe to the global sender and convert those messages into
// a stream, which in turn gets converted into SSE events.
let mut receiver = state.notifications.subscribe();
let stream = try_stream! {
loop {
let recv = receiver.recv().await.expect("should work");
yield recv;
}
};

// Part of the SSE protocol is sending the location where the client needs
// to do its POST request. This will be sent to the channel which will be
// queued there until the client is connected to the SSE stream.
if let Err(err) = state
.notifications
.send(Event::default().event("endpoint").data("/messages"))
{
error!(?err, "unable to send initial message to MCP client");
}
let receiver = state.notifications.subscribe();

// This message needs to be send as soon as the client accesses the page.
let initial_event =
futures::stream::once(async { Ok(Event::default().event("endpoint").data("/messages")) });

let events = tokio_stream::wrappers::BroadcastStream::new(receiver).map(|message| {
message.map(|message| {
Event::default()
.event("message")
.json_data(message)
.expect("unable to serialize data")
})
});

Sse::new(stream).keep_alive(
Sse::new(initial_event.chain(events)).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(5))
.text("keep-alive-text"),
Expand Down
Loading

0 comments on commit 5244142

Please sign in to comment.