Skip to content

Commit

Permalink
http server: propagate extensions in http response (#1514)
Browse files Browse the repository at this point in the history
* http server: propagate extensions in http response

* fix tests

* remove needless clone

* fix tests

* Update server/src/tests/http.rs
  • Loading branch information
niklasad1 authored Jan 15, 2025
1 parent 98ea4a6 commit e75d3fc
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 46 deletions.
42 changes: 26 additions & 16 deletions core/src/server/method_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ impl MethodResponse {
}

/// Consume the method response and extract the parts.
pub fn into_parts(self) -> (String, Option<MethodResponseNotifyTx>) {
(self.result, self.on_close)
pub fn into_parts(self) -> (String, Option<MethodResponseNotifyTx>, Extensions) {
(self.result, self.on_close, self.extensions)
}

/// Get the error code
Expand All @@ -120,11 +120,11 @@ impl MethodResponse {
/// Create a method response from [`BatchResponse`].
pub fn from_batch(batch: BatchResponse) -> Self {
Self {
result: batch.0,
result: batch.result,
success_or_error: MethodResponseResult::Success,
kind: ResponseKind::Batch,
on_close: None,
extensions: Extensions::new(),
extensions: batch.extensions,
}
}

Expand Down Expand Up @@ -280,6 +280,8 @@ pub struct BatchResponseBuilder {
result: String,
/// Max limit for the batch
max_response_size: usize,
/// Extensions for the batch response.
extensions: Extensions,
}

impl BatchResponseBuilder {
Expand All @@ -288,17 +290,18 @@ impl BatchResponseBuilder {
let mut initial = String::with_capacity(2048);
initial.push('[');

Self { result: initial, max_response_size: limit }
Self { result: initial, max_response_size: limit, extensions: Extensions::new() }
}

/// Append a result from an individual method to the batch response.
///
/// Fails if the max limit is exceeded and returns to error response to
/// return early in order to not process method call responses which are thrown away anyway.
pub fn append(&mut self, response: &MethodResponse) -> Result<(), MethodResponse> {
pub fn append(&mut self, response: MethodResponse) -> Result<(), MethodResponse> {
// `,` will occupy one extra byte for each entry
// on the last item the `,` is replaced by `]`.
let len = response.result.len() + self.result.len() + 1;
self.extensions.extend(response.extensions);

if len > self.max_response_size {
Err(MethodResponse::error(Id::Null, reject_too_big_batch_response(self.max_response_size)))
Expand All @@ -317,18 +320,24 @@ impl BatchResponseBuilder {
/// Finish the batch response
pub fn finish(mut self) -> BatchResponse {
if self.result.len() == 1 {
BatchResponse(batch_response_error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest)))
BatchResponse {
result: batch_response_error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest)),
extensions: self.extensions,
}
} else {
self.result.pop();
self.result.push(']');
BatchResponse(self.result)
BatchResponse { result: self.result, extensions: self.extensions }
}
}
}

/// Serialized batch response.
#[derive(Debug, Clone)]
pub struct BatchResponse(String);
pub struct BatchResponse {
result: String,
extensions: Extensions,
}

/// Create a JSON-RPC error response.
pub fn batch_response_error(id: Id, err: impl Into<ErrorObject<'static>>) -> String {
Expand Down Expand Up @@ -473,42 +482,43 @@ mod tests {

// Recall a batch appends two bytes for the `[]`.
let mut builder = BatchResponseBuilder::new_with_limit(39);
builder.append(&method).unwrap();
builder.append(method).unwrap();
let batch = builder.finish();

assert_eq!(batch.0, r#"[{"jsonrpc":"2.0","id":1,"result":"a"}]"#)
assert_eq!(batch.result, r#"[{"jsonrpc":"2.0","id":1,"result":"a"}]"#)
}

#[test]
fn batch_with_multiple_works() {
let m1 = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a"), usize::MAX);
let m11 = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a"), usize::MAX);
assert_eq!(m1.result.len(), 37);

// Recall a batch appends two bytes for the `[]` and one byte for `,` to append a method call.
// so it should be 2 + (37 * n) + (n-1)
let limit = 2 + (37 * 2) + 1;
let mut builder = BatchResponseBuilder::new_with_limit(limit);
builder.append(&m1).unwrap();
builder.append(&m1).unwrap();
builder.append(m1).unwrap();
builder.append(m11).unwrap();
let batch = builder.finish();

assert_eq!(batch.0, r#"[{"jsonrpc":"2.0","id":1,"result":"a"},{"jsonrpc":"2.0","id":1,"result":"a"}]"#)
assert_eq!(batch.result, r#"[{"jsonrpc":"2.0","id":1,"result":"a"},{"jsonrpc":"2.0","id":1,"result":"a"}]"#)
}

#[test]
fn batch_empty_err() {
let batch = BatchResponseBuilder::new_with_limit(1024).finish();

let exp_err = r#"{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"Invalid request"}}"#;
assert_eq!(batch.0, exp_err);
assert_eq!(batch.result, exp_err);
}

#[test]
fn batch_too_big() {
let method = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a".repeat(28)), 128);
assert_eq!(method.result.len(), 64);

let batch = BatchResponseBuilder::new_with_limit(63).append(&method).unwrap_err();
let batch = BatchResponseBuilder::new_with_limit(63).append(method).unwrap_err();

let exp_err = r#"{"jsonrpc":"2.0","id":null,"error":{"code":-32011,"message":"The batch response was too large","data":"Exceeded max limit of 63"}}"#;
assert_eq!(batch.result, exp_err);
Expand Down
2 changes: 1 addition & 1 deletion core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ impl Methods {
};

let is_success = response.is_success();
let (rp, notif) = response.into_parts();
let (rp, notif, _) = response.into_parts();

if let Some(n) = notif {
n.notify(is_success);
Expand Down
41 changes: 26 additions & 15 deletions server/src/middleware/http/proxy_get_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

//! Middleware that proxies requests at a specified URI to internal
//! RPC method calls.
//! Middleware that proxies HTTP GET requests at a specified URI to an internal RPC method call.
use crate::transport::http;
use crate::{HttpBody, HttpRequest, HttpResponse};
Expand All @@ -34,7 +33,7 @@ use http_body_util::BodyExt;
use hyper::body::Bytes;
use hyper::header::{ACCEPT, CONTENT_TYPE};
use hyper::http::HeaderValue;
use hyper::{Method, Uri};
use hyper::{Method, StatusCode, Uri};
use jsonrpsee_core::BoxError;
use jsonrpsee_types::{ErrorCode, ErrorObject, Id, RequestSer};
use std::collections::HashMap;
Expand Down Expand Up @@ -171,7 +170,8 @@ where
async move {
let res = fut.await.map_err(Into::into)?;

let mut body = http_body_util::BodyStream::new(res.into_body());
let (parts, body) = res.into_parts();
let mut body = http_body_util::BodyStream::new(body);
let mut bytes = Vec::new();

while let Some(frame) = body.frame().await {
Expand All @@ -185,21 +185,14 @@ where
result: &'a serde_json::value::RawValue,
}

#[derive(serde::Deserialize)]
struct ErrorResponse<'a> {
#[serde(borrow)]
error: ErrorObject<'a>,
}

let response = if let Ok(payload) = serde_json::from_slice::<SuccessResponse>(&bytes) {
let mut response = if let Ok(payload) = serde_json::from_slice::<SuccessResponse>(&bytes) {
http::response::ok_response(payload.result.to_string())
} else {
let error = serde_json::from_slice::<ErrorResponse>(&bytes)
.map(|payload| payload.error)
.unwrap_or_else(|_| ErrorObject::from(ErrorCode::InternalError));
http::response::error_response(error)
internal_proxy_error(&bytes)
};

response.extensions_mut().extend(parts.extensions);

Ok(response)
}
.boxed()
Expand All @@ -212,3 +205,21 @@ where
}
}
}

fn internal_proxy_error(bytes: &[u8]) -> HttpResponse {
#[derive(serde::Deserialize)]
struct ErrorResponse<'a> {
#[serde(borrow)]
error: ErrorObject<'a>,
}

let error = serde_json::from_slice::<ErrorResponse>(bytes)
.map(|payload| payload.error)
.unwrap_or_else(|_| ErrorObject::from(ErrorCode::InternalError));

http::response::from_template(
StatusCode::INTERNAL_SERVER_ERROR,
serde_json::to_string(&error).expect("JSON serialization infallible; qed"),
"application/json; charset=utf-8",
)
}
4 changes: 2 additions & 2 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ where
if let Ok(req) = deserialize::from_str_with_extensions(call.get(), extensions.clone()) {
let rp = rpc_service.call(req).await;

if let Err(too_large) = batch_response.append(&rp) {
if let Err(too_large) = batch_response.append(rp) {
return Some(too_large);
}
} else if let Ok(_notif) = serde_json::from_str::<Notif>(call.get()) {
Expand All @@ -1288,7 +1288,7 @@ where
};

if let Err(too_large) =
batch_response.append(&MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest)))
batch_response.append(MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest)))
{
return Some(too_large);
}
Expand Down
134 changes: 132 additions & 2 deletions server/src/tests/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,23 @@

use std::net::SocketAddr;

use crate::{BatchRequestConfig, RegisterMethodError, RpcModule, ServerBuilder, ServerConfig, ServerHandle};
use jsonrpsee_core::RpcResult;
use crate::middleware::rpc::{RpcServiceBuilder, RpcServiceT};
use crate::types::Request;
use crate::{
BatchRequestConfig, HttpBody, HttpRequest, HttpResponse, MethodResponse, RegisterMethodError, RpcModule,
ServerBuilder, ServerConfig, ServerHandle,
};
use futures_util::future::{BoxFuture, Future, FutureExt};
use hyper::body::Bytes;
use jsonrpsee_core::{BoxError, RpcResult};
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::mocks::{Id, StatusCode};
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::ErrorObjectOwned;
use serde_json::Value as JsonValue;
use std::pin::Pin;
use std::task::{Context, Poll};
use tower::Service;

use super::helpers::{MyAppError, TestContext};

Expand All @@ -42,6 +52,65 @@ fn init_logger() {
.try_init();
}

#[derive(Clone)]
struct InjectExt<S> {
service: S,
}

impl<'a, S> RpcServiceT<'a> for InjectExt<S>
where
S: Send + Sync + RpcServiceT<'a> + Clone + 'static,
{
type Future = BoxFuture<'a, MethodResponse>;

fn call(&self, mut req: Request<'a>) -> Self::Future {
if req.method_name().contains("err") {
req.extensions_mut().insert(StatusCode::IM_A_TEAPOT);
} else {
req.extensions_mut().insert(StatusCode::OK);
}

self.service.call(req).boxed()
}
}

#[derive(Debug, Clone)]
struct ModifyHttpStatus<S> {
service: S,
}

impl<S, B> Service<HttpRequest<B>> for ModifyHttpStatus<S>
where
S: Service<HttpRequest<B>, Response = HttpResponse<HttpBody>>,
S::Response: 'static,
S::Error: Into<BoxError> + Send + 'static,
S::Future: Send + 'static,
B: http_body::Body<Data = Bytes> + Send + std::fmt::Debug + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
type Response = S::Response;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx).map_err(Into::into)
}

fn call(&mut self, request: HttpRequest<B>) -> Self::Future {
let fut = self.service.call(request);
async move {
let mut rp = fut.await.map_err(Into::into)?;
let status_code = rp.extensions().get::<StatusCode>().copied().unwrap();

*rp.status_mut() = status_code;

Ok(rp)
}
.boxed()
}
}

async fn server() -> (SocketAddr, ServerHandle) {
let server = ServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let ctx = TestContext;
Expand Down Expand Up @@ -579,3 +648,64 @@ async fn http2_method_call_works() {
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, ok_response(JsonValue::Number(3.into()), Id::Num(1)));
}

#[tokio::test]
async fn http_extensions_from_rpc_response_propagated() {
init_logger();

let server = ServerBuilder::default()
.set_rpc_middleware(RpcServiceBuilder::new().layer_fn(|service| InjectExt { service }))
.set_http_middleware(tower::ServiceBuilder::new().layer_fn(|service| ModifyHttpStatus { service }))
.build("127.0.0.1:0")
.await
.unwrap();
let mut module = RpcModule::new(());
module.register_method("err", |_, _ctx, _| "lo").unwrap();
let addr = server.local_addr().unwrap();
let uri = to_http_uri(addr);
let handle = server.start(module);

let req = r#"{"jsonrpc":"2.0","method":"err","id":1}"#;
let response = http_request(req.into(), uri).with_default_timeout().await.unwrap().unwrap();
assert_eq!(response.status, StatusCode::IM_A_TEAPOT);

handle.stop().unwrap();
handle.stopped().await;
}

#[tokio::test]
async fn http_extensions_from_rpc_batch_response_overwrite() {
init_logger();

let server = ServerBuilder::default()
.set_rpc_middleware(RpcServiceBuilder::new().layer_fn(|service| InjectExt { service }))
.set_http_middleware(tower::ServiceBuilder::new().layer_fn(|service| ModifyHttpStatus { service }))
.build("127.0.0.1:0")
.await
.unwrap();
let mut module = RpcModule::new(());
module.register_method("say_hello", |_, _ctx, _| "lo").unwrap();
module.register_method("err", |_, _ctx, _| "e").unwrap();
let addr = server.local_addr().unwrap();
let uri = to_http_uri(addr);
let handle = server.start(module);

// Send a batch which will overwrite the status code Teapot with OK.
let req = r#"[
{"jsonrpc":"2.0","method":"err", "params":[],"id":2},
{"jsonrpc":"2.0","method":"say_hello", "params":[],"id":3}
]"#;
let response = http_request(req.into(), uri.clone()).with_default_timeout().await.unwrap().unwrap();
assert_eq!(response.status, StatusCode::OK);

// Send a batch which will overwrite the status code OK with TEAPOT.
let req = r#"[
{"jsonrpc":"2.0","method":"say_hello", "params":[],"id":2},
{"jsonrpc":"2.0","method":"err", "params":[],"id":3}
]"#;
let response = http_request(req.into(), uri).with_default_timeout().await.unwrap().unwrap();
assert_eq!(response.status, StatusCode::IM_A_TEAPOT);

handle.stop().unwrap();
handle.stopped().await;
}
Loading

0 comments on commit e75d3fc

Please sign in to comment.