diff --git a/derive/examples/extension-dispatch.rs b/derive/examples/extension-dispatch.rs index bcc4862c13b..13bd4d9d735 100644 --- a/derive/examples/extension-dispatch.rs +++ b/derive/examples/extension-dispatch.rs @@ -1,7 +1,9 @@ use trussed::Error; mod backends { - use super::extensions::{TestExtension, TestReply, TestRequest}; + use super::extensions::{ + SampleExtension, SampleReply, SampleRequest, TestExtension, TestReply, TestRequest, + }; use trussed::{ backend::Backend, platform::Platform, serde_extensions::ExtensionImpl, service::ServiceResources, types::CoreContext, Error, @@ -26,6 +28,18 @@ mod backends { } } + impl ExtensionImpl for ABackend { + fn extension_request( + &mut self, + _core_ctx: &mut CoreContext, + _backend_ctx: &mut Self::Context, + _request: &SampleRequest, + _resources: &mut ServiceResources

, + ) -> Result { + Ok(SampleReply) + } + } + #[derive(Default)] pub struct BBackend; @@ -88,6 +102,7 @@ mod extensions { enum Backend { A, + ASample, B, } @@ -104,9 +119,19 @@ enum Extension { Sample = "extensions::SampleExtension" )] struct Dispatch { + #[dispatch(no_core)] #[extensions("Test")] a: backends::ABackend, + + #[dispatch(delegate_to = "a")] + #[extensions("Sample")] + a_sample: (), + b: backends::BBackend, + + #[allow(unused)] + #[dispatch(skip)] + other: String, } fn main() { @@ -135,5 +160,9 @@ fn main() { &[BackendId::Custom(Backend::B)], Some(Error::RequestNotAvailable), ); + run( + &[BackendId::Custom(Backend::ASample)], + Some(Error::RequestNotAvailable), + ); run(&[BackendId::Custom(Backend::A)], None); } diff --git a/derive/src/extension_dispatch.rs b/derive/src/extension_dispatch.rs index 54d0e2a03af..47f1a0abf30 100644 --- a/derive/src/extension_dispatch.rs +++ b/derive/src/extension_dispatch.rs @@ -15,6 +15,7 @@ pub struct ExtensionDispatch { dispatch_attrs: DispatchAttrs, extension_attrs: ExtensionAttrs, backends: Vec, + delegated_backends: Vec, } impl ExtensionDispatch { @@ -27,11 +28,30 @@ impl ExtensionDispatch { }; let dispatch_attrs = DispatchAttrs::new(&input)?; let extension_attrs = ExtensionAttrs::new(&input)?; - let backends = data_struct - .fields - .iter() - .enumerate() - .map(|(i, field)| Backend::new(i, field, &extension_attrs.extensions)) + let mut raw_backends = Vec::new(); + for field in &data_struct.fields { + if let Some(raw_backend) = RawBackend::new(field)? { + raw_backends.push(raw_backend); + } + } + let mut backends = Vec::new(); + let mut delegated_backends = Vec::new(); + for raw_backend in raw_backends { + if let Some(delegate_to) = raw_backend.delegate_to.clone() { + delegated_backends.push((raw_backend, delegate_to)); + } else { + backends.push(Backend::new( + backends.len(), + raw_backend, + &extension_attrs.extensions, + )?); + } + } + let delegated_backends = delegated_backends + .into_iter() + .map(|(raw, delegate_to)| { + DelegatedBackend::new(raw, delegate_to, &backends, &extension_attrs.extensions) + }) .collect::>()?; Ok(Self { name: input.ident, @@ -39,6 +59,7 @@ impl ExtensionDispatch { dispatch_attrs, extension_attrs, backends, + delegated_backends, }) } @@ -49,7 +70,15 @@ impl ExtensionDispatch { let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); let context = self.backends.iter().map(Backend::context); let requests = self.backends.iter().map(Backend::request); + let delegated_requests = self + .delegated_backends + .iter() + .map(DelegatedBackend::request); let extension_requests = self.backends.iter().map(Backend::extension_request); + let delegated_extension_requests = self + .delegated_backends + .iter() + .map(DelegatedBackend::extension_request); let extension_impls = self .extension_attrs .extensions @@ -71,6 +100,7 @@ impl ExtensionDispatch { ) -> ::core::result::Result<::trussed::api::Reply, ::trussed::error::Error> { match backend { #(#requests)* + #(#delegated_requests)* } } @@ -84,6 +114,7 @@ impl ExtensionDispatch { ) -> ::core::result::Result<::trussed::api::reply::SerdeExtension, ::trussed::error::Error> { match backend { #(#extension_requests)* + #(#delegated_extension_requests)* } } } @@ -165,16 +196,40 @@ impl ExtensionAttrs { } } -struct Backend { +struct RawBackend { id: Ident, field: Ident, ty: Type, - index: Index, - extensions: Vec, + no_core: bool, + delegate_to: Option, + extensions: Vec, } -impl Backend { - fn new(i: usize, field: &Field, extension_types: &HashMap) -> Result { +impl RawBackend { + fn new(field: &Field) -> Result> { + let mut delegate_to = None; + let mut no_core = false; + let mut skip = false; + for attr in util::get_attrs(&field.attrs, "dispatch") { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("delegate_to") { + let s: LitStr = meta.value()?.parse()?; + delegate_to = Some(s.parse()?); + Ok(()) + } else if meta.path.is_ident("no_core") { + no_core = true; + Ok(()) + } else if meta.path.is_ident("skip") { + skip = true; + Ok(()) + } else { + Err(meta.error("unsupported dispatch attribute")) + } + })?; + } + if skip { + return Ok(None); + } let ident = field.ident.clone().ok_or_else(|| { Error::new_spanned( field, @@ -184,14 +239,43 @@ impl Backend { let mut extensions = Vec::new(); for attr in util::get_attrs(&field.attrs, "extensions") { for s in attr.parse_args_with(Punctuated::::parse_terminated)? { - extensions.push(Extension::new(&s, extension_types)?); + extensions.push(s.parse()?); } } - Ok(Self { + Ok(Some(Self { id: util::to_camelcase(&ident), field: ident, ty: field.ty.clone(), + no_core, + delegate_to, + extensions, + })) + } +} + +#[derive(Clone)] +struct Backend { + id: Ident, + field: Ident, + ty: Type, + index: Index, + no_core: bool, + extensions: Vec, +} + +impl Backend { + fn new(i: usize, raw: RawBackend, extensions: &HashMap) -> Result { + let extensions = raw + .extensions + .into_iter() + .map(|i| Extension::new(i, extensions)) + .collect::>()?; + Ok(Self { + id: raw.id, + field: raw.field, + ty: raw.ty, index: Index::from(i), + no_core: raw.no_core, extensions, }) } @@ -202,13 +286,23 @@ impl Backend { } fn request(&self) -> TokenStream { - let Self { - index, id, field, .. - } = self; + let id = &self.id; + let request = if self.no_core { + quote! { + Err(::trussed::Error::RequestNotAvailable) + } + } else { + let Self { index, field, .. } = self; + quote! { + ::trussed::backend::Backend::request( + &mut self.#field, &mut ctx.core, &mut ctx.backends.#index, request, resources, + ) + } + }; quote! { - Self::BackendId::#id => ::trussed::backend::Backend::request( - &mut self.#field, &mut ctx.core, &mut ctx.backends.#index, request, resources, - ), + Self::BackendId::#id => { + #request + } } } @@ -224,17 +318,109 @@ impl Backend { } } +struct DelegatedBackend { + id: Ident, + field: Ident, + backend: Backend, + no_core: bool, + extensions: Vec, +} + +impl DelegatedBackend { + fn new( + raw: RawBackend, + delegate_to: Ident, + backends: &[Backend], + extensions: &HashMap, + ) -> Result { + match raw.ty { + Type::Tuple(tuple) if tuple.elems.is_empty() => (), + _ => { + return Err(Error::new_spanned( + &raw.ty, + "delegated backends must use the unit type ()", + )); + } + } + + let extensions = raw + .extensions + .into_iter() + .map(|i| Extension::new(i, extensions)) + .collect::>()?; + let backend = backends + .iter() + .find(|backend| backend.field == delegate_to) + .ok_or_else(|| Error::new_spanned(delegate_to, "unknown backend"))? + .clone(); + Ok(Self { + id: raw.id, + field: raw.field, + backend, + no_core: raw.no_core, + extensions, + }) + } + + fn request(&self) -> TokenStream { + let id = &self.id; + let request = if self.no_core { + quote! { + Err(::trussed::Error::RequestNotAvailable) + } + } else { + let Self { backend, field, .. } = self; + let Backend { + field: delegated_field, + index: delegated_index, + .. + } = backend; + quote! { + let _ = self.#field; + ::trussed::backend::Backend::request( + &mut self.#delegated_field, &mut ctx.core, &mut ctx.backends.#delegated_index, request, resources, + ) + } + }; + quote! { + Self::BackendId::#id => { + #request + } + } + } + + fn extension_request(&self) -> TokenStream { + let Self { + id, + extensions, + backend, + field, + .. + } = self; + let extension_requests = extensions.iter().map(|e| e.extension_request(backend)); + quote! { + Self::BackendId::#id => { + let _ = self.#field; + match extension { + #(#extension_requests)* + _ => Err(::trussed::error::Error::RequestNotAvailable), + } + } + } + } +} + +#[derive(Clone)] struct Extension { id: Ident, ty: Path, } impl Extension { - fn new(s: &LitStr, extensions: &HashMap) -> Result { - let id = s.parse()?; + fn new(id: Ident, extensions: &HashMap) -> Result { let ty = extensions .get(&id) - .ok_or_else(|| Error::new_spanned(s, "unknown extension ID"))? + .ok_or_else(|| Error::new_spanned(&id, "unknown extension ID"))? .clone(); Ok(Self { id, ty }) }