Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let Serve and Stub traits be send #475

Conversation

stevefan1999-personal
Copy link
Contributor

@stevefan1999-personal stevefan1999-personal commented Oct 7, 2024

This PR address a problem to revive the service registry code of the following:

use bytes::Bytes;
use futures::{SinkExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use std::{
    io::ErrorKind,
    sync::Arc,
};
use tarpc::{
    context,
    server::{serve, Serve},
    ClientMessage, Request, RequestName, Response, ServerError, Transport,
};

/// A request to a named service.
#[derive(Serialize, Deserialize, Clone)]
pub struct ServiceRequest {
    service_name: String,
    request_name: String,
    request: Bytes,
}

impl RequestName for ServiceRequest {
    fn name(&self) -> &str {
        &self.service_name
    }
}

/// A response from a named service.
#[derive(Serialize, Deserialize, Clone)]
pub struct ServiceResponse {
    response: Bytes,
}
/// A list of registered services.
#[derive(Clone)]
pub struct Registry<Codec, Services>
where
    Services: Clone + Send + Sync,
    Codec: Send + Sync,
{
    pub registrations: Services,
    pub codec: Arc<Codec>,
}

impl<Codec> Registry<Codec, Nil>
where
    Codec: Send + Sync,
{
    pub fn new(codec: Codec) -> Self {
        Self {
            registrations: Nil,
            codec: Arc::new(codec),
        }
    }
}

impl<Services, Codec> Registry<Codec, Services>
where
    Services: Serve<Req = ServiceRequest, Resp = ServiceResponse> + Send + Sync + Clone,
    Codec: Send + Sync,
{
    /// Returns a function that serves requests for the registered services.
    pub fn serve(&self) -> impl Serve<Req = ServiceRequest, Resp = ServiceResponse> + Clone {
        serve({
            let registrations = self.registrations.clone();
            move |ctx, req| registrations.clone().serve(ctx, req)
        })
    }
    /// Registers `serve` with the given `name` using the given serialization scheme.
    pub fn register<S, Req, Resp>(
        self,
        name: impl AsRef<str>,
        serve: S,
    ) -> Registry<
        Codec,
        Registration<
            impl Serve<Req = ServiceRequest, Resp = ServiceResponse> + Clone + Sync,
            Services,
        >,
    >
    where
        Req: Send + Sync,
        S: Serve<Req = Req, Resp = Resp> + Clone + Send + Sync,
        Codec: Serializer<Resp> + Deserializer<Req>,
    {
        Registry {
            codec: self.codec.clone(),
            registrations: Registration {
                name: name.as_ref().to_string(),
                serve: tarpc::server::serve({
                    let codec = self.codec.clone();
                    move |cx, req: ServiceRequest| async move {
                        let req = codec
                            .deserialize(req.request)
                            .map_err(|e| ServerError::new(ErrorKind::Other, e.to_string()))?;
                        let response = serve.serve(cx, req).await?;
                        let response = codec
                            .serialize(&response)
                            .map_err(|e| ServerError::new(ErrorKind::Other, e.to_string()))?;
                        Ok(ServiceResponse { response })
                    }
                }),
                rest: self.registrations,
            },
        }
    }
}
/// Creates a client that sends requests to a service
/// named `service_name`, over the given channel, using
/// the specified serialization scheme.
pub fn new_client<Req, Resp, Codec>(
    service_name: impl AsRef<str>,
    channel: impl Transport<ClientMessage<ServiceRequest>, Response<ServiceResponse>>,
    codec: Codec,
) -> impl Transport<ClientMessage<Req>, Response<Resp>>
where
    Req: Send + RequestName,
    Codec: Serializer<Req> + Deserializer<Resp> + Send + Sync,
{
    let codec = Arc::new(codec);
    channel
        .with({
            let codec = codec.clone();
            move |req: ClientMessage<Req>| {
                let codec = codec.clone();
                let service_name = service_name.as_ref().to_string();
                async move {
                    match req {
                        ClientMessage::Request(Request {
                            id,
                            context,
                            message,
                            ..
                        }) => {
                            Ok(ClientMessage::Request(Request {
                                message: ServiceRequest {
                                    service_name: service_name.clone(),
                                    request_name: message.name().to_string(),
                                    // TODO: shouldn't need to unwrap here. Maybe with_request should allow for
                                    // returning Result.
                                    request: codec.serialize(&message).unwrap(),
                                },
                                id,
                                context,
                            }))
                        }
                        ClientMessage::Cancel {
                            trace_context,
                            request_id,
                        } => Ok(ClientMessage::Cancel {
                            trace_context,
                            request_id,
                        }),
                    }
                }
            }
        })
        // TODO: same thing. Maybe this should be more like and_then rather than map.
        .and_then({
            let codec = codec.clone();
            move |resp| {
                let codec = codec.clone();
                async move {
                    let request_id = resp.request_id;
                    Ok(Response {
                        request_id,
                        message: resp.message.map(|x| codec.deserialize(x.response).unwrap()),
                    })
                }
            }
        })
}

/// A registry starting with service S, followed by Rest.
///
/// This type is mostly an implementation detail that is not used directly
/// outside of the registry internals.
#[derive(Clone)]
pub struct Registration<S, Rest>
where
    S: Clone + Send + Sync,
    Rest: Clone + Send + Sync,
{
    /// The registered service's name. Must be unique across all registered services.
    name: String,
    /// The registered service.
    serve: S,
    /// Any remaining registered services.
    rest: Rest,
}
/// An empty registry.
///
/// This type is mostly an implementation detail that is not used directly
/// outside of the registry internals.
#[derive(Clone)]
pub struct Nil;
impl Serve for Nil {
    type Req = ServiceRequest;

    type Resp = ServiceResponse;

    async fn serve(self, _: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError> {
        Err(ServerError::new(
            ErrorKind::NotFound,
            format!("Service {} not registered", req.service_name),
        ))
    }
}

impl<S, Rest> Serve for Registration<S, Rest>
where
    S: Serve<Req = ServiceRequest, Resp = ServiceResponse> + Clone + Sync,
    Rest: Serve<Req = ServiceRequest, Resp = ServiceResponse> + Clone + Sync,
{
    type Req = ServiceRequest;

    type Resp = ServiceResponse;

    async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError> {
        if self.name == req.service_name {
            self.serve.serve(ctx, req).await
        } else {
            self.rest.serve(ctx, req).await
        }
    }
}

pub trait Serializer<T> {
    type Error: std::fmt::Display + std::fmt::Debug;
    fn serialize(&self, item: &T) -> Result<Bytes, Self::Error>;
}

pub trait Deserializer<T> {
    type Error: std::fmt::Display + std::fmt::Debug;
    fn deserialize(&self, src: Bytes) -> Result<T, Self::Error>;
}

pub mod bincode;
pub mod cbor;
pub mod json;

Effectively, this registry code will serialize/deserialize any request/response code with Serde and a codec of your choice, which requires a self-descripted format and will get unintended consequences if otherwise (thus using Bincode is actually unsafe), and then just send it over the wire as usual, but the biggest problem is that Serve itself has no guarantee of Send, so I hacked the code together and eventually got it working.

This PR will make #448 obsolete with a better approach. After this PR is merged I will try to see if I can polish up the registry code.

I admit this can introduce some breaking changes -- But given the fact that nobody is using this without multithreaded async, and I assumed everyone would be very aware of async safety using standard primitives like Mutex, Arc instead of RefCell, Rc.

@tikue
Copy link
Collaborator

tikue commented Oct 7, 2024

Thank you for the PR! Unfortunately, I don't think it is true that nobody uses this without multithreading, so I can't merge a change that would make Send required for all services.

@stevefan1999-personal
Copy link
Contributor Author

Thank you for the PR! Unfortunately, I don't think it is true that nobody uses this without multithreading, so I can't merge a change that would make Send required for all services.

But the point is, since tarpc is heavily Tokio dependent, which is under the assumption of Send futures being every-where (evidence with tokio::spawn function, which itself needs a Send future), and so not having Send in the async trait bound renders the whole thing to only run in local thread only.

@stevefan1999-personal
Copy link
Contributor Author

Closing in favor of #480

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants