Skip to content

Commit

Permalink
feat(tonic): add Request extensions
Browse files Browse the repository at this point in the history
Adds `tonic::Extensions` which is a newtype around `http::Extensions`.

Extensions can be set by interceptors with `Request::extensions_mut` and
retrieved from RPCs with `Request::extensions`. Extensions can also be
set in tower middleware and will be carried through to the RPC.

Fixes #255
  • Loading branch information
davidpdrsn committed May 13, 2021
1 parent f33316d commit b937f78
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 6 deletions.
15 changes: 14 additions & 1 deletion examples/src/interceptor/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ impl Greeter for MyGreeter {
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
let extension = request.extensions().get::<MyExtension>().unwrap();
println!("extension data = {}", extension.some_piece_of_data);

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
};
Expand All @@ -40,7 +43,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// This function will get called on each inbound request, if a `Status`
/// is returned, it will cancel the request and return that status to the
/// client.
fn intercept(req: Request<()>) -> Result<Request<()>, Status> {
fn intercept(mut req: Request<()>) -> Result<Request<()>, Status> {
println!("Intercepting request: {:?}", req);

// Set an extension that can be retrieved by `say_hello`
req.extensions_mut().insert(MyExtension {
some_piece_of_data: "foo".to_string(),
});

Ok(req)
}

struct MyExtension {
some_piece_of_data: String,
}
3 changes: 3 additions & 0 deletions tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ bytes = "1.0"
[dev-dependencies]
tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] }
tokio-stream = { version = "0.1.5", features = ["net"] }
tower-service = "0.3"
hyper = "0.14"
futures = "0.3"

[build-dependencies]
tonic-build = { path = "../../tonic-build" }
150 changes: 150 additions & 0 deletions tests/integration_tests/tests/extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use futures_util::FutureExt;
use hyper::{Body, Request as HyperRequest, Response as HyperResponse};
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::{
task::{Context, Poll},
time::Duration,
};
use tokio::sync::oneshot;
use tonic::{
body::BoxBody,
transport::{Endpoint, NamedService, Server},
Request, Response, Status,
};
use tower_service::Service;

struct ExtensionValue(i32);

#[tokio::test]
async fn setting_extension_from_interceptor() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let value = req.extensions().get::<ExtensionValue>().unwrap();
assert_eq!(value.0, 42);

Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::with_interceptor(Svc, |mut req: Request<()>| {
req.extensions_mut().insert(ExtensionValue(42));
Ok(req)
});

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1323".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1323")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

match client.unary_call(Input {}).await {
Ok(_) => {}
Err(status) => panic!("{}", status.message()),
}

tx.send(()).unwrap();

jh.await.unwrap();
}

#[tokio::test]
async fn setting_extension_from_tower() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let value = req.extensions().get::<ExtensionValue>().unwrap();
assert_eq!(value.0, 42);

Ok(Response::new(Output {}))
}
}

let svc = InterceptedService {
inner: test_server::TestServer::new(Svc),
};

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1324".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1324")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

match client.unary_call(Input {}).await {
Ok(_) => {}
Err(status) => panic!("{}", status.message()),
}

tx.send(()).unwrap();

jh.await.unwrap();
}

#[derive(Debug, Clone)]
struct InterceptedService<S> {
inner: S,
}

impl<S> Service<HyperRequest<Body>> for InterceptedService<S>
where
S: Service<HyperRequest<Body>, Response = HyperResponse<BoxBody>>
+ NamedService
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: HyperRequest<Body>) -> Self::Future {
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);

req.extensions_mut().insert(ExtensionValue(42));

Box::pin(async move {
let response = inner.call(req).await?;
Ok(response)
})
}
}

impl<S: NamedService> NamedService for InterceptedService<S> {
const NAME: &'static str = S::NAME;
}
67 changes: 67 additions & 0 deletions tonic/src/extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::fmt;

/// A type map of protocol extensions.
///
/// `Extensions` can be used by [`Interceptor`] and [`Request`] to store extra data derived from
/// the underlying protocol.
///
/// [`Interceptor`]: crate::Interceptor
/// [`Request`]: crate::Request
pub struct Extensions(http::Extensions);

impl Extensions {
pub(crate) fn new() -> Self {
Self(http::Extensions::new())
}

/// Insert a type into this `Extensions`.
///
/// If a extension of this type already existed, it will
/// be returned.
#[inline]
pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
self.0.insert(val)
}

/// Get a reference to a type previously inserted on this `Extensions`.
#[inline]
pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.0.get()
}

/// Get a mutable reference to a type previously inserted on this `Extensions`.
#[inline]
pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
self.0.get_mut()
}

/// Remove a type from this `Extensions`.
///
/// If a extension of this type existed, it will be returned.
#[inline]
pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
self.0.remove()
}

/// Clear the `Extensions` of all inserted extensions.
#[inline]
pub fn clear(&mut self) {
self.0.clear()
}

#[inline]
pub(crate) fn from_http(http: http::Extensions) -> Self {
Self(http)
}

#[inline]
pub(crate) fn into_http(self) -> http::Extensions {
self.0
}
}

impl fmt::Debug for Extensions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Extensions").finish()
}
}
3 changes: 3 additions & 0 deletions tonic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
//! [`transport`]: transport/index.html
#![recursion_limit = "256"]
#![allow(clippy::inconsistent_struct_constructor)]
#![warn(
missing_debug_implementations,
missing_docs,
Expand All @@ -87,6 +88,7 @@ pub mod server;
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
pub mod transport;

mod extensions;
mod interceptor;
mod macros;
mod request;
Expand All @@ -100,6 +102,7 @@ pub use async_trait::async_trait;

#[doc(inline)]
pub use codec::Streaming;
pub use extensions::Extensions;
pub use interceptor::Interceptor;
pub use request::{IntoRequest, IntoStreamingRequest, Request};
pub use response::Response;
Expand Down
64 changes: 59 additions & 5 deletions tonic/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::metadata::{MetadataMap, MetadataValue};
#[cfg(feature = "transport")]
use crate::transport::Certificate;
use crate::Extensions;
use futures_core::Stream;
use http::Extensions;
#[cfg(feature = "transport")]
use std::sync::Arc;
use std::{net::SocketAddr, time::Duration};
Expand Down Expand Up @@ -116,7 +116,7 @@ impl<T> Request<T> {
Request {
metadata: MetadataMap::new(),
message,
extensions: Extensions::default(),
extensions: Extensions::new(),
}
}

Expand Down Expand Up @@ -161,7 +161,7 @@ impl<T> Request<T> {
Request {
metadata: MetadataMap::from_headers(parts.headers),
message,
extensions: parts.extensions,
extensions: Extensions::from_http(parts.extensions),
}
}

Expand All @@ -178,7 +178,7 @@ impl<T> Request<T> {
*request.method_mut() = http::Method::POST;
*request.uri_mut() = uri;
*request.headers_mut() = self.metadata.into_sanitized_headers();
*request.extensions_mut() = self.extensions;
*request.extensions_mut() = self.extensions.into_http();

request
}
Expand All @@ -193,7 +193,7 @@ impl<T> Request<T> {
Request {
metadata: self.metadata,
message,
extensions: Extensions::default(),
extensions: Extensions::new(),
}
}

Expand Down Expand Up @@ -254,6 +254,60 @@ impl<T> Request<T> {
self.metadata_mut()
.insert(crate::metadata::GRPC_TIMEOUT_HEADER, value);
}

/// Returns a reference to the associated extensions.
pub fn extensions(&self) -> &Extensions {
&self.extensions
}

/// Returns a mutable reference to the associated extensions.
///
/// # Example
///
/// Extensions can be set in interceptors:
///
/// ```no_run
/// use tonic::{Request, Interceptor};
///
/// struct MyExtension {
/// some_piece_of_data: String,
/// }
///
/// Interceptor::new(|mut request: Request<()>| {
/// request.extensions_mut().insert(MyExtension {
/// some_piece_of_data: "foo".to_string(),
/// });
///
/// Ok(request)
/// });
/// ```
///
/// And picked up by RPCs:
///
/// ```no_run
/// use tonic::{async_trait, Status, Request, Response};
/// #
/// # struct Output {}
/// # struct Input;
/// # struct MyService;
/// # struct MyExtension;
/// # #[async_trait]
/// # trait TestService {
/// # async fn handler(&self, req: Request<Input>) -> Result<Response<Output>, Status>;
/// # }
///
/// #[async_trait]
/// impl TestService for MyService {
/// async fn handler(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
/// let value: &MyExtension = req.extensions().get::<MyExtension>().unwrap();
///
/// Ok(Response::new(Output {}))
/// }
/// }
/// ```
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
}

impl<T> IntoRequest<T> for T {
Expand Down

0 comments on commit b937f78

Please sign in to comment.