diff --git a/core/src/server/method_response.rs b/core/src/server/method_response.rs index 30b3e84ea8..d0ba38b05b 100644 --- a/core/src/server/method_response.rs +++ b/core/src/server/method_response.rs @@ -101,8 +101,8 @@ impl MethodResponse { } /// Consume the method response and extract the parts. - pub fn into_parts(self) -> (String, Option) { - (self.result, self.on_close) + pub fn into_parts(self) -> (String, Option, Extensions) { + (self.result, self.on_close, self.extensions) } /// Get the error code @@ -120,11 +120,11 @@ impl MethodResponse { /// Create a method response from [`BatchResponse`]. pub fn from_batch(batch: BatchResponse) -> Self { Self { - result: batch.0, + result: batch.result, success_or_error: MethodResponseResult::Success, kind: ResponseKind::Batch, on_close: None, - extensions: Extensions::new(), + extensions: batch.extensions, } } @@ -280,6 +280,8 @@ pub struct BatchResponseBuilder { result: String, /// Max limit for the batch max_response_size: usize, + /// Extensions for the batch response. + extensions: Extensions, } impl BatchResponseBuilder { @@ -288,17 +290,18 @@ impl BatchResponseBuilder { let mut initial = String::with_capacity(2048); initial.push('['); - Self { result: initial, max_response_size: limit } + Self { result: initial, max_response_size: limit, extensions: Extensions::new() } } /// Append a result from an individual method to the batch response. /// /// Fails if the max limit is exceeded and returns to error response to /// return early in order to not process method call responses which are thrown away anyway. - pub fn append(&mut self, response: &MethodResponse) -> Result<(), MethodResponse> { + pub fn append(&mut self, response: MethodResponse) -> Result<(), MethodResponse> { // `,` will occupy one extra byte for each entry // on the last item the `,` is replaced by `]`. let len = response.result.len() + self.result.len() + 1; + self.extensions.extend(response.extensions); if len > self.max_response_size { Err(MethodResponse::error(Id::Null, reject_too_big_batch_response(self.max_response_size))) @@ -317,18 +320,24 @@ impl BatchResponseBuilder { /// Finish the batch response pub fn finish(mut self) -> BatchResponse { if self.result.len() == 1 { - BatchResponse(batch_response_error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest))) + BatchResponse { + result: batch_response_error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest)), + extensions: self.extensions, + } } else { self.result.pop(); self.result.push(']'); - BatchResponse(self.result) + BatchResponse { result: self.result, extensions: self.extensions } } } } /// Serialized batch response. #[derive(Debug, Clone)] -pub struct BatchResponse(String); +pub struct BatchResponse { + result: String, + extensions: Extensions, +} /// Create a JSON-RPC error response. pub fn batch_response_error(id: Id, err: impl Into>) -> String { @@ -473,26 +482,27 @@ mod tests { // Recall a batch appends two bytes for the `[]`. let mut builder = BatchResponseBuilder::new_with_limit(39); - builder.append(&method).unwrap(); + builder.append(method).unwrap(); let batch = builder.finish(); - assert_eq!(batch.0, r#"[{"jsonrpc":"2.0","id":1,"result":"a"}]"#) + assert_eq!(batch.result, r#"[{"jsonrpc":"2.0","id":1,"result":"a"}]"#) } #[test] fn batch_with_multiple_works() { let m1 = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a"), usize::MAX); + let m11 = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a"), usize::MAX); assert_eq!(m1.result.len(), 37); // Recall a batch appends two bytes for the `[]` and one byte for `,` to append a method call. // so it should be 2 + (37 * n) + (n-1) let limit = 2 + (37 * 2) + 1; let mut builder = BatchResponseBuilder::new_with_limit(limit); - builder.append(&m1).unwrap(); - builder.append(&m1).unwrap(); + builder.append(m1).unwrap(); + builder.append(m11).unwrap(); let batch = builder.finish(); - assert_eq!(batch.0, r#"[{"jsonrpc":"2.0","id":1,"result":"a"},{"jsonrpc":"2.0","id":1,"result":"a"}]"#) + assert_eq!(batch.result, r#"[{"jsonrpc":"2.0","id":1,"result":"a"},{"jsonrpc":"2.0","id":1,"result":"a"}]"#) } #[test] @@ -500,7 +510,7 @@ mod tests { let batch = BatchResponseBuilder::new_with_limit(1024).finish(); let exp_err = r#"{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"Invalid request"}}"#; - assert_eq!(batch.0, exp_err); + assert_eq!(batch.result, exp_err); } #[test] @@ -508,7 +518,7 @@ mod tests { let method = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a".repeat(28)), 128); assert_eq!(method.result.len(), 64); - let batch = BatchResponseBuilder::new_with_limit(63).append(&method).unwrap_err(); + let batch = BatchResponseBuilder::new_with_limit(63).append(method).unwrap_err(); let exp_err = r#"{"jsonrpc":"2.0","id":null,"error":{"code":-32011,"message":"The batch response was too large","data":"Exceeded max limit of 63"}}"#; assert_eq!(batch.result, exp_err); diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index 83990f7a97..4fcd5aab32 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -399,7 +399,7 @@ impl Methods { }; let is_success = response.is_success(); - let (rp, notif) = response.into_parts(); + let (rp, notif, _) = response.into_parts(); if let Some(n) = notif { n.notify(is_success); diff --git a/server/src/middleware/http/proxy_get_request.rs b/server/src/middleware/http/proxy_get_request.rs index 0c48000501..0ae473ff0a 100644 --- a/server/src/middleware/http/proxy_get_request.rs +++ b/server/src/middleware/http/proxy_get_request.rs @@ -24,8 +24,7 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -//! Middleware that proxies requests at a specified URI to internal -//! RPC method calls. +//! Middleware that proxies HTTP GET requests at a specified URI to an internal RPC method call. use crate::transport::http; use crate::{HttpBody, HttpRequest, HttpResponse}; @@ -34,7 +33,7 @@ use http_body_util::BodyExt; use hyper::body::Bytes; use hyper::header::{ACCEPT, CONTENT_TYPE}; use hyper::http::HeaderValue; -use hyper::{Method, Uri}; +use hyper::{Method, StatusCode, Uri}; use jsonrpsee_core::BoxError; use jsonrpsee_types::{ErrorCode, ErrorObject, Id, RequestSer}; use std::collections::HashMap; @@ -171,7 +170,8 @@ where async move { let res = fut.await.map_err(Into::into)?; - let mut body = http_body_util::BodyStream::new(res.into_body()); + let (parts, body) = res.into_parts(); + let mut body = http_body_util::BodyStream::new(body); let mut bytes = Vec::new(); while let Some(frame) = body.frame().await { @@ -185,21 +185,14 @@ where result: &'a serde_json::value::RawValue, } - #[derive(serde::Deserialize)] - struct ErrorResponse<'a> { - #[serde(borrow)] - error: ErrorObject<'a>, - } - - let response = if let Ok(payload) = serde_json::from_slice::(&bytes) { + let mut response = if let Ok(payload) = serde_json::from_slice::(&bytes) { http::response::ok_response(payload.result.to_string()) } else { - let error = serde_json::from_slice::(&bytes) - .map(|payload| payload.error) - .unwrap_or_else(|_| ErrorObject::from(ErrorCode::InternalError)); - http::response::error_response(error) + internal_proxy_error(&bytes) }; + response.extensions_mut().extend(parts.extensions); + Ok(response) } .boxed() @@ -212,3 +205,21 @@ where } } } + +fn internal_proxy_error(bytes: &[u8]) -> HttpResponse { + #[derive(serde::Deserialize)] + struct ErrorResponse<'a> { + #[serde(borrow)] + error: ErrorObject<'a>, + } + + let error = serde_json::from_slice::(bytes) + .map(|payload| payload.error) + .unwrap_or_else(|_| ErrorObject::from(ErrorCode::InternalError)); + + http::response::from_template( + StatusCode::INTERNAL_SERVER_ERROR, + serde_json::to_string(&error).expect("JSON serialization infallible; qed"), + "application/json; charset=utf-8", + ) +} diff --git a/server/src/server.rs b/server/src/server.rs index e06e837fd3..e554169215 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -1274,7 +1274,7 @@ where if let Ok(req) = deserialize::from_str_with_extensions(call.get(), extensions.clone()) { let rp = rpc_service.call(req).await; - if let Err(too_large) = batch_response.append(&rp) { + if let Err(too_large) = batch_response.append(rp) { return Some(too_large); } } else if let Ok(_notif) = serde_json::from_str::(call.get()) { @@ -1288,7 +1288,7 @@ where }; if let Err(too_large) = - batch_response.append(&MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest))) + batch_response.append(MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest))) { return Some(too_large); } diff --git a/server/src/tests/http.rs b/server/src/tests/http.rs index 13ad53af61..7d185c77b2 100644 --- a/server/src/tests/http.rs +++ b/server/src/tests/http.rs @@ -26,13 +26,23 @@ use std::net::SocketAddr; -use crate::{BatchRequestConfig, RegisterMethodError, RpcModule, ServerBuilder, ServerConfig, ServerHandle}; -use jsonrpsee_core::RpcResult; +use crate::middleware::rpc::{RpcServiceBuilder, RpcServiceT}; +use crate::types::Request; +use crate::{ + BatchRequestConfig, HttpBody, HttpRequest, HttpResponse, MethodResponse, RegisterMethodError, RpcModule, + ServerBuilder, ServerConfig, ServerHandle, +}; +use futures_util::future::{BoxFuture, Future, FutureExt}; +use hyper::body::Bytes; +use jsonrpsee_core::{BoxError, RpcResult}; use jsonrpsee_test_utils::helpers::*; use jsonrpsee_test_utils::mocks::{Id, StatusCode}; use jsonrpsee_test_utils::TimeoutFutureExt; use jsonrpsee_types::ErrorObjectOwned; use serde_json::Value as JsonValue; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tower::Service; use super::helpers::{MyAppError, TestContext}; @@ -42,6 +52,65 @@ fn init_logger() { .try_init(); } +#[derive(Clone)] +struct InjectExt { + service: S, +} + +impl<'a, S> RpcServiceT<'a> for InjectExt +where + S: Send + Sync + RpcServiceT<'a> + Clone + 'static, +{ + type Future = BoxFuture<'a, MethodResponse>; + + fn call(&self, mut req: Request<'a>) -> Self::Future { + if req.method_name().contains("err") { + req.extensions_mut().insert(StatusCode::IM_A_TEAPOT); + } else { + req.extensions_mut().insert(StatusCode::OK); + } + + self.service.call(req).boxed() + } +} + +#[derive(Debug, Clone)] +struct ModifyHttpStatus { + service: S, +} + +impl Service> for ModifyHttpStatus +where + S: Service, Response = HttpResponse>, + S::Response: 'static, + S::Error: Into + Send + 'static, + S::Future: Send + 'static, + B: http_body::Body + Send + std::fmt::Debug + 'static, + B::Data: Send, + B::Error: Into, +{ + type Response = S::Response; + type Error = BoxError; + type Future = Pin> + Send + 'static>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, request: HttpRequest) -> Self::Future { + let fut = self.service.call(request); + async move { + let mut rp = fut.await.map_err(Into::into)?; + let status_code = rp.extensions().get::().copied().unwrap(); + + *rp.status_mut() = status_code; + + Ok(rp) + } + .boxed() + } +} + async fn server() -> (SocketAddr, ServerHandle) { let server = ServerBuilder::default().build("127.0.0.1:0").await.unwrap(); let ctx = TestContext; @@ -579,3 +648,64 @@ async fn http2_method_call_works() { assert_eq!(response.status, StatusCode::OK); assert_eq!(response.body, ok_response(JsonValue::Number(3.into()), Id::Num(1))); } + +#[tokio::test] +async fn http_extensions_from_rpc_response_propagated() { + init_logger(); + + let server = ServerBuilder::default() + .set_rpc_middleware(RpcServiceBuilder::new().layer_fn(|service| InjectExt { service })) + .set_http_middleware(tower::ServiceBuilder::new().layer_fn(|service| ModifyHttpStatus { service })) + .build("127.0.0.1:0") + .await + .unwrap(); + let mut module = RpcModule::new(()); + module.register_method("err", |_, _ctx, _| "lo").unwrap(); + let addr = server.local_addr().unwrap(); + let uri = to_http_uri(addr); + let handle = server.start(module); + + let req = r#"{"jsonrpc":"2.0","method":"err","id":1}"#; + let response = http_request(req.into(), uri).with_default_timeout().await.unwrap().unwrap(); + assert_eq!(response.status, StatusCode::IM_A_TEAPOT); + + handle.stop().unwrap(); + handle.stopped().await; +} + +#[tokio::test] +async fn http_extensions_from_rpc_batch_response_overwrite() { + init_logger(); + + let server = ServerBuilder::default() + .set_rpc_middleware(RpcServiceBuilder::new().layer_fn(|service| InjectExt { service })) + .set_http_middleware(tower::ServiceBuilder::new().layer_fn(|service| ModifyHttpStatus { service })) + .build("127.0.0.1:0") + .await + .unwrap(); + let mut module = RpcModule::new(()); + module.register_method("say_hello", |_, _ctx, _| "lo").unwrap(); + module.register_method("err", |_, _ctx, _| "e").unwrap(); + let addr = server.local_addr().unwrap(); + let uri = to_http_uri(addr); + let handle = server.start(module); + + // Send a batch which will overwrite the status code Teapot with OK. + let req = r#"[ + {"jsonrpc":"2.0","method":"err", "params":[],"id":2}, + {"jsonrpc":"2.0","method":"say_hello", "params":[],"id":3} + ]"#; + let response = http_request(req.into(), uri.clone()).with_default_timeout().await.unwrap().unwrap(); + assert_eq!(response.status, StatusCode::OK); + + // Send a batch which will overwrite the status code OK with TEAPOT. + let req = r#"[ + {"jsonrpc":"2.0","method":"say_hello", "params":[],"id":2}, + {"jsonrpc":"2.0","method":"err", "params":[],"id":3} + ]"#; + let response = http_request(req.into(), uri).with_default_timeout().await.unwrap().unwrap(); + assert_eq!(response.status, StatusCode::IM_A_TEAPOT); + + handle.stop().unwrap(); + handle.stopped().await; +} diff --git a/server/src/transport/http.rs b/server/src/transport/http.rs index 1d0b7d5cd7..c60e506899 100644 --- a/server/src/transport/http.rs +++ b/server/src/transport/http.rs @@ -95,12 +95,15 @@ where } }; - let rp = handle_rpc_call(&body, is_single, batch_config, max_response_size, &rpc_service, parts.extensions) - .await; - - // If the response is empty it means that it was a notification or empty batch. - // For HTTP these are just ACK:ed with a empty body. - response::ok_response(rp.map_or(String::new(), |r| r.into_result())) + if let Some(rp) = + handle_rpc_call(&body, is_single, batch_config, max_response_size, &rpc_service, parts.extensions).await + { + response::from_method_response(rp) + } else { + // If the response is empty it means that it was a notification or empty batch. + // For HTTP these are just ACK:ed with a empty body. + response::ok_response("") + } } // Error scenarios: Method::POST => response::unsupported_content_type(), @@ -110,6 +113,7 @@ where /// HTTP response helpers. pub mod response { + use jsonrpsee_core::server::MethodResponse; use jsonrpsee_types::error::{reject_too_big_request, ErrorCode}; use jsonrpsee_types::{ErrorObject, ErrorObjectOwned, Id, Response, ResponsePayload}; @@ -165,7 +169,11 @@ pub mod response { } /// Create a response body. - fn from_template(status: hyper::StatusCode, body: impl Into, content_type: &'static str) -> HttpResponse { + pub(crate) fn from_template( + status: hyper::StatusCode, + body: impl Into, + content_type: &'static str, + ) -> HttpResponse { HttpResponse::builder() .status(status) .header("content-type", hyper::header::HeaderValue::from_static(content_type)) @@ -180,6 +188,17 @@ pub mod response { from_template(hyper::StatusCode::OK, body, JSON) } + /// Create a response from a method response. + /// + /// This will include the body and extensions from the method response. + pub fn from_method_response(rp: MethodResponse) -> HttpResponse { + let (body, _, extensions) = rp.into_parts(); + + let mut rp = from_template(hyper::StatusCode::OK, body, JSON); + rp.extensions_mut().extend(extensions); + rp + } + /// Create a response for unsupported content type. pub fn unsupported_content_type() -> HttpResponse { from_template( diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 8bccb14e25..d5a7721b26 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -169,7 +169,7 @@ where { if !rp.is_subscription() { let is_success = rp.is_success(); - let (serialized_rp, mut on_close) = rp.into_parts(); + let (serialized_rp, mut on_close, _) = rp.into_parts(); // The connection is closed, just quit. if sink.send(serialized_rp).await.is_err() { diff --git a/test-utils/src/mocks.rs b/test-utils/src/mocks.rs index c131e3222e..a8aecf4cef 100644 --- a/test-utils/src/mocks.rs +++ b/test-utils/src/mocks.rs @@ -36,8 +36,7 @@ use futures_util::io::{BufReader, BufWriter}; use futures_util::sink::SinkExt; use futures_util::stream::{self, StreamExt}; use futures_util::{pin_mut, select}; -use hyper_util::rt::TokioExecutor; -use hyper_util::rt::TokioIo; +use hyper_util::rt::{TokioExecutor, TokioIo}; use serde::{Deserialize, Serialize}; use soketto::handshake::{self, http::is_upgrade_request, server::Response, Error as SokettoError, Server}; use tokio::net::TcpStream;