From 352b0f584be33bc49ca266698c9224d16a6825ff Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 13 May 2021 15:54:20 +0200 Subject: [PATCH] feat(tonic): add `Request` and `Response` extensions (#642) Adds `tonic::Extensions` which is a newtype around `http::Extensions`. Request extensions can be set by interceptors with `Request::extensions_mut` and retrieved from RPCs with `Request::extensions`. Extensions can also be set in tower middleware and will be carried through to the RPC. Since response extensions cannot be set by interceptors the main use case is to set them in RPCs and retrieve them in tower middlewares. Figured that might be useful. Fixes https://github.com/hyperium/tonic/issues/255 --- examples/src/interceptor/server.rs | 15 +- tests/integration_tests/Cargo.toml | 3 + tests/integration_tests/tests/extensions.rs | 144 ++++++++++++++++++++ tonic/src/client/grpc.rs | 5 +- tonic/src/extensions.rs | 71 ++++++++++ tonic/src/lib.rs | 3 + tonic/src/request.rs | 64 ++++++++- tonic/src/response.rs | 29 +++- 8 files changed, 321 insertions(+), 13 deletions(-) create mode 100644 tests/integration_tests/tests/extensions.rs create mode 100644 tonic/src/extensions.rs diff --git a/examples/src/interceptor/server.rs b/examples/src/interceptor/server.rs index b73a15d9f..a79d98e2e 100644 --- a/examples/src/interceptor/server.rs +++ b/examples/src/interceptor/server.rs @@ -16,6 +16,9 @@ impl Greeter for MyGreeter { &self, request: Request, ) -> Result, Status> { + let extension = request.extensions().get::().unwrap(); + println!("extension data = {}", extension.some_piece_of_data); + let reply = hello_world::HelloReply { message: format!("Hello {}!", request.into_inner().name), }; @@ -40,7 +43,17 @@ async fn main() -> Result<(), Box> { /// This function will get called on each inbound request, if a `Status` /// is returned, it will cancel the request and return that status to the /// client. -fn intercept(req: Request<()>) -> Result, Status> { +fn intercept(mut req: Request<()>) -> Result, Status> { println!("Intercepting request: {:?}", req); + + // Set an extension that can be retrieved by `say_hello` + req.extensions_mut().insert(MyExtension { + some_piece_of_data: "foo".to_string(), + }); + Ok(req) } + +struct MyExtension { + some_piece_of_data: String, +} diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 2e6836862..1b4acef2a 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -17,6 +17,9 @@ bytes = "1.0" [dev-dependencies] tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] } tokio-stream = { version = "0.1.5", features = ["net"] } +tower-service = "0.3" +hyper = "0.14" +futures = "0.3" [build-dependencies] tonic-build = { path = "../../tonic-build" } diff --git a/tests/integration_tests/tests/extensions.rs b/tests/integration_tests/tests/extensions.rs new file mode 100644 index 000000000..c60c98414 --- /dev/null +++ b/tests/integration_tests/tests/extensions.rs @@ -0,0 +1,144 @@ +use futures_util::FutureExt; +use hyper::{Body, Request as HyperRequest, Response as HyperResponse}; +use integration_tests::pb::{test_client, test_server, Input, Output}; +use std::{ + task::{Context, Poll}, + time::Duration, +}; +use tokio::sync::oneshot; +use tonic::{ + body::BoxBody, + transport::{Endpoint, NamedService, Server}, + Request, Response, Status, +}; +use tower_service::Service; + +struct ExtensionValue(i32); + +#[tokio::test] +async fn setting_extension_from_interceptor() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + let value = req.extensions().get::().unwrap(); + assert_eq!(value.0, 42); + + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::with_interceptor(Svc, |mut req: Request<()>| { + req.extensions_mut().insert(ExtensionValue(42)); + Ok(req) + }); + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1323".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1323") + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + client.unary_call(Input {}).await.unwrap(); + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} + +#[tokio::test] +async fn setting_extension_from_tower() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + let value = req.extensions().get::().unwrap(); + assert_eq!(value.0, 42); + + Ok(Response::new(Output {})) + } + } + + let svc = InterceptedService { + inner: test_server::TestServer::new(Svc), + }; + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1324".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1324") + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + client.unary_call(Input {}).await.unwrap(); + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} + +#[derive(Debug, Clone)] +struct InterceptedService { + inner: S, +} + +impl Service> for InterceptedService +where + S: Service, Response = HyperResponse> + + NamedService + + Clone + + Send + + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = futures::future::BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: HyperRequest) -> Self::Future { + let clone = self.inner.clone(); + let mut inner = std::mem::replace(&mut self.inner, clone); + + req.extensions_mut().insert(ExtensionValue(42)); + + Box::pin(async move { + let response = inner.call(req).await?; + Ok(response) + }) + } +} + +impl NamedService for InterceptedService { + const NAME: &'static str = S::NAME; +} diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 2e3c2de85..4142be8d0 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -97,7 +97,8 @@ impl Grpc { M1: Send + Sync + 'static, M2: Send + Sync + 'static, { - let (mut parts, body) = self.streaming(request, path, codec).await?.into_parts(); + let (mut parts, body, extensions) = + self.streaming(request, path, codec).await?.into_parts(); futures_util::pin_mut!(body); @@ -114,7 +115,7 @@ impl Grpc { parts.merge(trailers); } - Ok(Response::from_parts(parts, message)) + Ok(Response::from_parts(parts, message, extensions)) } /// Send a server side streaming gRPC request. diff --git a/tonic/src/extensions.rs b/tonic/src/extensions.rs new file mode 100644 index 000000000..a42a4c276 --- /dev/null +++ b/tonic/src/extensions.rs @@ -0,0 +1,71 @@ +use std::fmt; + +/// A type map of protocol extensions. +/// +/// `Extensions` can be used by [`Interceptor`] and [`Request`] to store extra data derived from +/// the underlying protocol. +/// +/// [`Interceptor`]: crate::Interceptor +/// [`Request`]: crate::Request +pub struct Extensions { + inner: http::Extensions, +} + +impl Extensions { + pub(crate) fn new() -> Self { + Self { + inner: http::Extensions::new(), + } + } + + /// Insert a type into this `Extensions`. + /// + /// If a extension of this type already existed, it will + /// be returned. + #[inline] + pub fn insert(&mut self, val: T) -> Option { + self.inner.insert(val) + } + + /// Get a reference to a type previously inserted on this `Extensions`. + #[inline] + pub fn get(&self) -> Option<&T> { + self.inner.get() + } + + /// Get a mutable reference to a type previously inserted on this `Extensions`. + #[inline] + pub fn get_mut(&mut self) -> Option<&mut T> { + self.inner.get_mut() + } + + /// Remove a type from this `Extensions`. + /// + /// If a extension of this type existed, it will be returned. + #[inline] + pub fn remove(&mut self) -> Option { + self.inner.remove() + } + + /// Clear the `Extensions` of all inserted extensions. + #[inline] + pub fn clear(&mut self) { + self.inner.clear() + } + + #[inline] + pub(crate) fn from_http(http: http::Extensions) -> Self { + Self { inner: http } + } + + #[inline] + pub(crate) fn into_http(self) -> http::Extensions { + self.inner + } +} + +impl fmt::Debug for Extensions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Extensions").finish() + } +} diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index 266a45c50..d5eb61ec3 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -62,6 +62,7 @@ //! [`transport`]: transport/index.html #![recursion_limit = "256"] +#![allow(clippy::inconsistent_struct_constructor)] #![warn( missing_debug_implementations, missing_docs, @@ -87,6 +88,7 @@ pub mod server; #[cfg_attr(docsrs, doc(cfg(feature = "transport")))] pub mod transport; +mod extensions; mod interceptor; mod macros; mod request; @@ -100,6 +102,7 @@ pub use async_trait::async_trait; #[doc(inline)] pub use codec::Streaming; +pub use extensions::Extensions; pub use interceptor::Interceptor; pub use request::{IntoRequest, IntoStreamingRequest, Request}; pub use response::Response; diff --git a/tonic/src/request.rs b/tonic/src/request.rs index 7d8f80260..e0a033ef6 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -1,8 +1,8 @@ use crate::metadata::{MetadataMap, MetadataValue}; #[cfg(feature = "transport")] use crate::transport::Certificate; +use crate::Extensions; use futures_core::Stream; -use http::Extensions; #[cfg(feature = "transport")] use std::sync::Arc; use std::{net::SocketAddr, time::Duration}; @@ -116,7 +116,7 @@ impl Request { Request { metadata: MetadataMap::new(), message, - extensions: Extensions::default(), + extensions: Extensions::new(), } } @@ -161,7 +161,7 @@ impl Request { Request { metadata: MetadataMap::from_headers(parts.headers), message, - extensions: parts.extensions, + extensions: Extensions::from_http(parts.extensions), } } @@ -178,7 +178,7 @@ impl Request { *request.method_mut() = http::Method::POST; *request.uri_mut() = uri; *request.headers_mut() = self.metadata.into_sanitized_headers(); - *request.extensions_mut() = self.extensions; + *request.extensions_mut() = self.extensions.into_http(); request } @@ -193,7 +193,7 @@ impl Request { Request { metadata: self.metadata, message, - extensions: Extensions::default(), + extensions: Extensions::new(), } } @@ -254,6 +254,60 @@ impl Request { self.metadata_mut() .insert(crate::metadata::GRPC_TIMEOUT_HEADER, value); } + + /// Returns a reference to the associated extensions. + pub fn extensions(&self) -> &Extensions { + &self.extensions + } + + /// Returns a mutable reference to the associated extensions. + /// + /// # Example + /// + /// Extensions can be set in interceptors: + /// + /// ```no_run + /// use tonic::{Request, Interceptor}; + /// + /// struct MyExtension { + /// some_piece_of_data: String, + /// } + /// + /// Interceptor::new(|mut request: Request<()>| { + /// request.extensions_mut().insert(MyExtension { + /// some_piece_of_data: "foo".to_string(), + /// }); + /// + /// Ok(request) + /// }); + /// ``` + /// + /// And picked up by RPCs: + /// + /// ```no_run + /// use tonic::{async_trait, Status, Request, Response}; + /// # + /// # struct Output {} + /// # struct Input; + /// # struct MyService; + /// # struct MyExtension; + /// # #[async_trait] + /// # trait TestService { + /// # async fn handler(&self, req: Request) -> Result, Status>; + /// # } + /// + /// #[async_trait] + /// impl TestService for MyService { + /// async fn handler(&self, req: Request) -> Result, Status> { + /// let value: &MyExtension = req.extensions().get::().unwrap(); + /// + /// Ok(Response::new(Output {})) + /// } + /// } + /// ``` + pub fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } } impl IntoRequest for T { diff --git a/tonic/src/response.rs b/tonic/src/response.rs index ab6eaa66a..87f59b4e4 100644 --- a/tonic/src/response.rs +++ b/tonic/src/response.rs @@ -1,10 +1,11 @@ -use crate::metadata::MetadataMap; +use crate::{metadata::MetadataMap, Extensions}; /// A gRPC response and metadata from an RPC call. #[derive(Debug)] pub struct Response { metadata: MetadataMap, message: T, + extensions: Extensions, } impl Response { @@ -24,6 +25,7 @@ impl Response { Response { metadata: MetadataMap::new(), message, + extensions: Extensions::new(), } } @@ -52,12 +54,16 @@ impl Response { self.message } - pub(crate) fn into_parts(self) -> (MetadataMap, T) { - (self.metadata, self.message) + pub(crate) fn into_parts(self) -> (MetadataMap, T, Extensions) { + (self.metadata, self.message, self.extensions) } - pub(crate) fn from_parts(metadata: MetadataMap, message: T) -> Self { - Self { metadata, message } + pub(crate) fn from_parts(metadata: MetadataMap, message: T, extensions: Extensions) -> Self { + Self { + metadata, + message, + extensions, + } } pub(crate) fn from_http(res: http::Response) -> Self { @@ -65,6 +71,7 @@ impl Response { Response { metadata: MetadataMap::from_headers(head.headers), message, + extensions: Extensions::from_http(head.extensions), } } @@ -73,6 +80,7 @@ impl Response { *res.version_mut() = http::Version::HTTP_2; *res.headers_mut() = self.metadata.into_sanitized_headers(); + *res.extensions_mut() = self.extensions.into_http(); res } @@ -86,8 +94,19 @@ impl Response { Response { metadata: self.metadata, message, + extensions: self.extensions, } } + + /// Returns a reference to the associated extensions. + pub fn extensions(&self) -> &Extensions { + &self.extensions + } + + /// Returns a mutable reference to the associated extensions. + pub fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } } #[cfg(test)]