diff --git a/benches/helpers.rs b/benches/helpers.rs index 53ccef87f3..7935650d77 100644 --- a/benches/helpers.rs +++ b/benches/helpers.rs @@ -163,11 +163,17 @@ fn gen_rpc_module() -> jsonrpsee::RpcModule<()> { let mut module = jsonrpsee::RpcModule::new(()); module.register_method(SYNC_FAST_CALL, |_, _| Ok("lo")).unwrap(); - module.register_async_method(ASYNC_FAST_CALL, |_, _| async { Ok("lo") }).unwrap(); + module + .register_async_method(ASYNC_FAST_CALL, |_, _| async { Result::<_, jsonrpsee::core::Error>::Ok("lo") }) + .unwrap(); module.register_method(SYNC_MEM_CALL, |_, _| Ok("A".repeat(MIB))).unwrap(); - module.register_async_method(ASYNC_MEM_CALL, |_, _| async move { Ok("A".repeat(MIB)) }).unwrap(); + module + .register_async_method(ASYNC_MEM_CALL, |_, _| async move { + Result::<_, jsonrpsee::core::Error>::Ok("A".repeat(MIB)) + }) + .unwrap(); module .register_method(SYNC_SLOW_CALL, |_, _| { @@ -179,7 +185,7 @@ fn gen_rpc_module() -> jsonrpsee::RpcModule<()> { module .register_async_method(ASYNC_SLOW_CALL, |_, _| async move { tokio::time::sleep(SLOW_CALL).await; - Ok("slow call async") + Result::<_, jsonrpsee::core::Error>::Ok("slow call async") }) .unwrap(); diff --git a/core/src/server/resource_limiting.rs b/core/src/server/resource_limiting.rs index 685712caab..e50398b82c 100644 --- a/core/src/server/resource_limiting.rs +++ b/core/src/server/resource_limiting.rs @@ -75,7 +75,7 @@ //! module //! .register_async_method("my_expensive_method", |_, _| async move { //! // Do work -//! Ok("hello") +//! Result::<_, jsonrpsee::core::Error>::Ok("hello") //! })? //! .resource("cpu", 5)? //! .resource("mem", 2)?; diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index 8a3a4d5ae6..8a0fff850d 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -569,14 +569,15 @@ impl RpcModule { } /// Register a new asynchronous RPC method, which computes the response with the given callback. - pub fn register_async_method( + pub fn register_async_method( &mut self, method_name: &'static str, callback: Fun, ) -> Result where R: Serialize + Send + Sync + 'static, - Fut: Future> + Send, + E: Into, + Fut: Future> + Send, Fun: (Fn(Params<'static>, Arc) -> Fut) + Clone + Send + Sync + 'static, { let ctx = self.ctx.clone(); @@ -589,7 +590,7 @@ impl RpcModule { let future = async move { let result = match callback(params, ctx).await { Ok(res) => MethodResponse::response(id, res, max_response_size), - Err(err) => MethodResponse::error(id, err), + Err(err) => MethodResponse::error(id, err.into()), }; // Release claimed resources @@ -606,7 +607,7 @@ impl RpcModule { /// Register a new **blocking** synchronous RPC method, which computes the response with the given callback. /// Unlike the regular [`register_method`](RpcModule::register_method), this method can block its thread and perform expensive computations. - pub fn register_blocking_method( + pub fn register_blocking_method( &mut self, method_name: &'static str, callback: F, @@ -614,7 +615,8 @@ impl RpcModule { where Context: Send + Sync + 'static, R: Serialize, - F: Fn(Params, Arc) -> Result + Clone + Send + Sync + 'static, + E: Into, + F: Fn(Params, Arc) -> Result + Clone + Send + Sync + 'static, { let ctx = self.ctx.clone(); let callback = self.methods.verify_and_insert( @@ -626,7 +628,7 @@ impl RpcModule { tokio::task::spawn_blocking(move || { let result = match callback(params, ctx) { Ok(result) => MethodResponse::response(id, result, max_response_size), - Err(err) => MethodResponse::error(id, err), + Err(err) => MethodResponse::error(id, err.into()), }; // Release claimed resources diff --git a/examples/examples/tokio_console.rs b/examples/examples/tokio_console.rs index 03310c1d5f..7d826fb059 100644 --- a/examples/examples/tokio_console.rs +++ b/examples/examples/tokio_console.rs @@ -36,6 +36,7 @@ use std::net::SocketAddr; +use jsonrpsee::core::Error; use jsonrpsee::server::ServerBuilder; use jsonrpsee::RpcModule; @@ -55,7 +56,7 @@ async fn run_server() -> anyhow::Result { module.register_method("memory_call", |_, _| Ok("A".repeat(1024 * 1024)))?; module.register_async_method("sleep", |_, _| async { tokio::time::sleep(std::time::Duration::from_millis(100)).await; - Ok("lo") + Result::<_, Error>::Ok("lo") })?; let addr = server.local_addr()?; diff --git a/proc-macros/Cargo.toml b/proc-macros/Cargo.toml index 091b0805d1..79b5fc4c86 100644 --- a/proc-macros/Cargo.toml +++ b/proc-macros/Cargo.toml @@ -25,3 +25,4 @@ trybuild = "1.0" tokio = { version = "1.16", features = ["rt", "macros"] } futures-channel = { version = "0.3.14", default-features = false } futures-util = { version = "0.3.14", default-features = false } +serde_json = "1" diff --git a/proc-macros/src/render_client.rs b/proc-macros/src/render_client.rs index 652f3486ff..08826767da 100644 --- a/proc-macros/src/render_client.rs +++ b/proc-macros/src/render_client.rs @@ -27,8 +27,9 @@ use crate::attributes::ParamKind; use crate::helpers::generate_where_clause; use crate::rpc_macro::{RpcDescription, RpcMethod, RpcSubscription}; use proc_macro2::TokenStream as TokenStream2; -use quote::quote; -use syn::{FnArg, Pat, PatIdent, PatType, TypeParam}; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; +use syn::{AngleBracketedGenericArguments, FnArg, Pat, PatIdent, PatType, PathArguments, TypeParam}; impl RpcDescription { pub(super) fn render_client(&self) -> Result { @@ -68,6 +69,46 @@ impl RpcDescription { Ok(trait_impl) } + /// Verify and rewrite the return type (for methods). + fn return_result_type(&self, mut ty: syn::Type) -> TokenStream2 { + // We expect a valid type path. + let syn::Type::Path(ref mut type_path) = ty else { + return quote_spanned!(ty.span() => compile_error!("Expecting something like 'Result' here. (1)")); + }; + + // The path (eg std::result::Result) should have a final segment like 'Result'. + let Some(type_name) = type_path.path.segments.last_mut() else { + return quote_spanned!(ty.span() => compile_error!("Expecting this path to end in something like 'Result'")); + }; + + // Get the generic args eg the in Result. + let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) = &mut type_name.arguments else { + return quote_spanned!(ty.span() => compile_error!("Expecting something like 'Result' here, but got no generic args (eg no '').")); + }; + + if type_name.ident == "Result" { + // Result should have 2 generic args. + if args.len() != 2 { + return quote_spanned!(args.span() => compile_error!("Result must be have two arguments")); + } + + // Force the last argument to be `jsonrpsee::core::Error`: + let error_arg = args.last_mut().unwrap(); + *error_arg = syn::GenericArgument::Type(syn::Type::Verbatim(self.jrps_client_item(quote! { core::Error }))); + + quote!(#ty) + } else if type_name.ident == "RpcResult" { + // RpcResult (an alias we export) should have 1 generic arg. + if args.len() != 1 { + return quote_spanned!(args.span() => compile_error!("RpcResult must have one argument")); + } + quote!(#ty) + } else { + // Any other type name isn't allowed. + quote_spanned!(type_name.span() => compile_error!("The return type must be Result or RpcResult")) + } + } + fn render_method(&self, method: &RpcMethod) -> Result { // `jsonrpsee::Error` let jrps_error = self.jrps_client_item(quote! { core::Error }); @@ -83,6 +124,7 @@ impl RpcDescription { // `returns` represent the return type of the *rust method* (`Result< <..>, jsonrpsee::core::Error`). let (called_method, returns) = if let Some(returns) = &method.returns { let called_method = quote::format_ident!("request"); + let returns = self.return_result_type(returns.clone()); let returns = quote! { #returns }; (called_method, returns) diff --git a/proc-macros/tests/ui/correct/errors.rs b/proc-macros/tests/ui/correct/errors.rs new file mode 100644 index 0000000000..da5e57af8a --- /dev/null +++ b/proc-macros/tests/ui/correct/errors.rs @@ -0,0 +1,88 @@ +//! Example of using custom errors. + +use std::net::SocketAddr; + +use jsonrpsee::core::async_trait; +use jsonrpsee::proc_macros::rpc; +use jsonrpsee::server::ServerBuilder; +use jsonrpsee::ws_client::*; + +pub enum CustomError { + One, + Two { custom_data: u32 }, +} + +impl From for jsonrpsee::core::Error { + fn from(err: CustomError) -> Self { + let code = match &err { + CustomError::One => 101, + CustomError::Two { .. } => 102, + }; + let data = match &err { + CustomError::One => None, + CustomError::Two { custom_data } => Some(serde_json::json!({ "customData": custom_data })), + }; + + let data = data.map(|val| serde_json::value::to_raw_value(&val).unwrap()); + + let error_object = jsonrpsee::types::ErrorObjectOwned::owned(code, "custom_error", data); + + Self::Call(jsonrpsee::types::error::CallError::Custom(error_object)) + } +} + +#[rpc(client, server, namespace = "foo")] +pub trait Rpc { + #[method(name = "method1")] + async fn method1(&self) -> Result; + + #[method(name = "method2")] + async fn method2(&self) -> Result; +} + +pub struct RpcServerImpl; + +#[async_trait] +impl RpcServer for RpcServerImpl { + async fn method1(&self) -> Result { + Err(CustomError::One) + } + + async fn method2(&self) -> Result { + Err(CustomError::Two { custom_data: 123 }) + } +} + +pub async fn server() -> SocketAddr { + let server = ServerBuilder::default().build("127.0.0.1:0").await.unwrap(); + let addr = server.local_addr().unwrap(); + let server_handle = server.start(RpcServerImpl.into_rpc()).unwrap(); + + tokio::spawn(server_handle.stopped()); + + addr +} + +#[tokio::main] +async fn main() { + let server_addr = server().await; + let server_url = format!("ws://{}", server_addr); + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + + let get_error_object = |err| match err { + jsonrpsee::core::Error::Call(jsonrpsee::types::error::CallError::Custom(object)) => object, + _ => panic!("wrong error kind: {:?}", err), + }; + + let error = client.method1().await.unwrap_err(); + let error_object = get_error_object(error); + assert_eq!(error_object.code(), 101); + assert_eq!(error_object.message(), "custom_error"); + assert!(error_object.data().is_none()); + + let error = client.method2().await.unwrap_err(); + let error_object = get_error_object(error); + assert_eq!(error_object.code(), 102); + assert_eq!(error_object.message(), "custom_error"); + assert_eq!(error_object.data().unwrap().get(), r#"{"customData":123}"#); +} diff --git a/proc-macros/tests/ui/incorrect/method/method_non_result_return_type.rs b/proc-macros/tests/ui/incorrect/method/method_non_result_return_type.rs new file mode 100644 index 0000000000..7ce91455bb --- /dev/null +++ b/proc-macros/tests/ui/incorrect/method/method_non_result_return_type.rs @@ -0,0 +1,9 @@ +use jsonrpsee::proc_macros::rpc; + +#[rpc(client)] +pub trait NonResultReturnType { + #[method(name = "a")] + async fn a(&self) -> u16; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/method/method_non_result_return_type.stderr b/proc-macros/tests/ui/incorrect/method/method_non_result_return_type.stderr new file mode 100644 index 0000000000..c8d42bffe8 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/method/method_non_result_return_type.stderr @@ -0,0 +1,5 @@ +error: Expecting something like 'Result' here, but got no generic args (eg no ''). + --> tests/ui/incorrect/method/method_non_result_return_type.rs:6:23 + | +6 | async fn a(&self) -> u16; + | ^^^ diff --git a/server/src/tests/helpers.rs b/server/src/tests/helpers.rs index 1330a0b819..6259d6697b 100644 --- a/server/src/tests/helpers.rs +++ b/server/src/tests/helpers.rs @@ -59,7 +59,7 @@ pub(crate) async fn server_with_handles() -> (SocketAddr, ServerHandle) { tracing::debug!("server respond to hello"); // Call some async function inside. futures_util::future::ready(()).await; - Ok("hello") + Result::<_, Error>::Ok("hello") } }) .unwrap(); @@ -67,7 +67,7 @@ pub(crate) async fn server_with_handles() -> (SocketAddr, ServerHandle) { .register_async_method("add_async", |params, _| async move { let params: Vec = params.parse()?; let sum: u64 = params.into_iter().sum(); - Ok(sum) + Result::<_, Error>::Ok(sum) }) .unwrap(); module @@ -111,7 +111,7 @@ pub(crate) async fn server_with_handles() -> (SocketAddr, ServerHandle) { module .register_async_method("should_ok_async", |_p, ctx| async move { ctx.ok().map_err(CallError::Failed)?; - Ok("ok") + Result::<_, Error>::Ok("ok") }) .unwrap(); @@ -146,7 +146,7 @@ pub(crate) async fn server_with_context() -> SocketAddr { .register_async_method("should_ok_async", |_p, ctx| async move { ctx.ok().map_err(CallError::Failed)?; // Call some async function inside. - Ok(futures_util::future::ready("ok!").await) + Result::<_, Error>::Ok(futures_util::future::ready("ok!").await) }) .unwrap(); @@ -154,7 +154,7 @@ pub(crate) async fn server_with_context() -> SocketAddr { .register_async_method("err_async", |_p, ctx| async move { ctx.ok().map_err(CallError::Failed)?; // Async work that returns an error - futures_util::future::err::<(), _>(anyhow!("nah").into()).await + futures_util::future::err::<(), Error>(anyhow!("nah").into()).await }) .unwrap(); diff --git a/server/src/tests/http.rs b/server/src/tests/http.rs index 1a3bfdbd53..de33fa5d8a 100644 --- a/server/src/tests/http.rs +++ b/server/src/tests/http.rs @@ -46,7 +46,7 @@ async fn server() -> (SocketAddr, ServerHandle) { let mut module = RpcModule::new(ctx); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| Ok("lo")).unwrap(); - module.register_async_method("say_hello_async", |_, _| async move { Ok("lo") }).unwrap(); + module.register_async_method("say_hello_async", |_, _| async move { Result::<_, Error>::Ok("lo") }).unwrap(); module .register_method("add", |params, _| { let params: Vec = params.parse()?; @@ -78,7 +78,7 @@ async fn server() -> (SocketAddr, ServerHandle) { module .register_async_method("should_ok_async", |_p, ctx| async move { ctx.ok().map_err(CallError::Failed)?; - Ok("ok") + Result::<_, Error>::Ok("ok") }) .unwrap(); diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index f81176b529..2009167a1f 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -196,7 +196,7 @@ pub async fn server() -> SocketAddr { module .register_async_method("slow_hello", |_, _| async { tokio::time::sleep(std::time::Duration::from_secs(1)).await; - Ok("hello") + Result::<_, Error>::Ok("hello") }) .unwrap(); diff --git a/tests/tests/resource_limiting.rs b/tests/tests/resource_limiting.rs index 27650490f5..14bfd11a2b 100644 --- a/tests/tests/resource_limiting.rs +++ b/tests/tests/resource_limiting.rs @@ -46,20 +46,20 @@ fn module_manual() -> Result, Error> { module.register_async_method("say_hello", |_, _| async move { sleep(Duration::from_millis(50)).await; - Ok("hello") + Result::<_, Error>::Ok("hello") })?; module .register_async_method("expensive_call", |_, _| async move { sleep(Duration::from_millis(50)).await; - Ok("hello expensive call") + Result::<_, Error>::Ok("hello expensive call") })? .resource("CPU", 3)?; module .register_async_method("memory_hog", |_, _| async move { sleep(Duration::from_millis(50)).await; - Ok("hello memory hog") + Result::<_, Error>::Ok("hello memory hog") })? .resource("CPU", 0)? .resource("MEM", 8)?; diff --git a/tests/tests/rpc_module.rs b/tests/tests/rpc_module.rs index 9382de1248..28b6b38653 100644 --- a/tests/tests/rpc_module.rs +++ b/tests/tests/rpc_module.rs @@ -126,7 +126,7 @@ async fn calling_method_without_server() { module .register_async_method("roo", |params, ctx| { let ns: Vec = params.parse().expect("valid params please"); - async move { Ok(ctx.roo(ns)) } + async move { Result::<_, Error>::Ok(ctx.roo(ns)) } }) .unwrap(); let res: u64 = module.call("roo", [12, 13]).await.unwrap();