diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 211db44dd..24dcce2cb 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -122,6 +122,10 @@ path = "src/health/server.rs" name = "autoreload-server" path = "src/autoreload/server.rs" +[[bin]] +name = "optional-server" +path = "src/optional/server.rs" + [dependencies] tonic = { path = "../tonic", features = ["tls"] } prost = "0.6" diff --git a/examples/src/optional/server.rs b/examples/src/optional/server.rs new file mode 100644 index 000000000..f1d4fa685 --- /dev/null +++ b/examples/src/optional/server.rs @@ -0,0 +1,53 @@ +use std::env; +use tonic::{transport::Server, Request, Response, Status}; + +use hello_world::greeter_server::{Greeter, GreeterServer}; +use hello_world::{HelloReply, HelloRequest}; + +pub mod hello_world { + tonic::include_proto!("helloworld"); +} + +#[derive(Default)] +pub struct MyGreeter {} + +#[tonic::async_trait] +impl Greeter for MyGreeter { + async fn say_hello( + &self, + request: Request, + ) -> Result, Status> { + println!("Got a request from {:?}", request.remote_addr()); + + let reply = hello_world::HelloReply { + message: format!("Hello {}!", request.into_inner().name), + }; + Ok(Response::new(reply)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args: Vec = env::args().collect(); + let enabled = args.get(1) == Some(&"enable".to_string()); + + let addr = "[::1]:50051".parse().unwrap(); + let greeter = MyGreeter::default(); + + let optional_service = if enabled { + println!("MyGreeter enabled"); + Some(GreeterServer::new(greeter)) + } else { + println!("MyGreeter disabled"); + None + }; + + println!("GreeterServer listening on {}", addr); + + Server::builder() + .add_optional_service(optional_service) + .serve(addr) + .await?; + + Ok(()) +} diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 7751029f4..8b2043851 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -41,7 +41,8 @@ use std::{ }; use tokio::io::{AsyncRead, AsyncWrite}; use tower::{ - limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, Service, ServiceBuilder, + limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service, + ServiceBuilder, }; use tracing_futures::{Instrument, Instrumented}; @@ -86,6 +87,10 @@ pub trait NamedService { const NAME: &'static str; } +impl NamedService for Either { + const NAME: &'static str = S::NAME; +} + impl Server { /// Create a new server builder that can configure a [`Server`]. pub fn builder() -> Self { @@ -231,6 +236,34 @@ impl Server { Router::new(self.clone(), svc) } + /// Create a router with the optional `S` typed service as the first service. + /// + /// This will clone the `Server` builder and create a router that will + /// route around different services. + /// + /// # Note + /// Even when the argument given is `None` this will capture *all* requests to this service name. + /// As a result, one cannot use this to toggle between two indentically named implementations. + pub fn add_optional_service( + &mut self, + svc: Option, + ) -> Router, Unimplemented> + where + S: Service, Response = Response> + + NamedService + + Clone + + Send + + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + { + let svc = match svc { + Some(some) => Either::A(some), + None => Either::B(Unimplemented::default()), + }; + Router::new(self.clone(), svc) + } + pub(crate) async fn serve_with_shutdown( self, svc: S, @@ -342,6 +375,42 @@ where Router { server, routes } } + /// Add a new optional service to this router. + /// + /// # Note + /// Even when the argument given is `None` this will capture *all* requests to this service name. + /// As a result, one cannot use this to toggle between two indentically named implementations. + pub fn add_optional_service( + self, + svc: Option, + ) -> Router, Or>> + where + S: Service, Response = Response> + + NamedService + + Clone + + Send + + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + { + let Self { routes, server } = self; + + let svc_name = ::NAME; + let svc_route = format!("/{}", svc_name); + let pred = move |req: &Request| { + let path = req.uri().path(); + + path.starts_with(&svc_route) + }; + let svc = match svc { + Some(some) => Either::A(some), + None => Either::B(Unimplemented::default()), + }; + let routes = routes.push(pred, svc); + + Router { server, routes } + } + /// Consume this [`Server`] creating a future that will execute the server /// on [`tokio`]'s default executor. ///