diff --git a/src/gen.rs b/src/gen.rs index 2729624..0b7fc92 100644 --- a/src/gen.rs +++ b/src/gen.rs @@ -2,8 +2,9 @@ use crate::proc_macro::Span; use proc_macro2::TokenStream as TokenStream2; use quote::{ToTokens, TokenStreamExt}; use syn::{ - FnArg, Ident, ItemTrait, Lifetime, MethodSig, Pat, PatIdent, TraitItem, TraitItemConst, - TraitItemMethod, TraitItemType, + FnArg, Ident, ItemTrait, Lifetime, MethodSig, Pat, PatIdent, ReturnType, TraitBound, + TraitBoundModifier, TraitItem, TraitItemConst, TraitItemMethod, TraitItemType, Type, + TypeParamBound, WherePredicate, }; use crate::{ @@ -65,12 +66,84 @@ fn gen_header( // The `'{proxy_lt_param}` in the beginning is only added when the proxy // type is `&` or `&mut`. let impl_generics = { + // Determine whether we can add a `?Sized` relaxation to allow trait + // objects. We can do that as long as there is no method that has a + // `self` by value receiver and no `where Self: Sized` bound. + let sized_required = trait_def.items.iter() + // Only interested in methods + .filter_map(|item| if let TraitItem::Method(m) = item { Some(m) } else { None }) + // We also ignore methods that we will not override. In the case of + // invalid attributes it is save to assume default behavior. + .filter(|m| !should_keep_default_for(m, proxy_type).unwrap_or(false)) + .any(|m| { + // Check if there is a `Self: Sized` bound on the method. + let self_is_bounded_sized = m.sig.decl.generics.where_clause.iter() + .flat_map(|wc| &wc.predicates) + .filter_map(|pred| { + if let WherePredicate::Type(p) = pred { Some(p) } else { None } + }) + .any(|pred| { + // Check if the type is `Self` + match &pred.bounded_ty { + Type::Path(p) if p.path.is_ident("Self") => { + // Check if the bound contains `Sized` + pred.bounds.iter().any(|b| { + match b { + TypeParamBound::Trait(TraitBound { + modifier: TraitBoundModifier::None, + path, + .. + }) => path.is_ident("Sized"), + _ => false, + } + }) + } + _ => false, + } + }); + + // Check if the first parameter is `self` by value. In that + // case, we might require `Self` to be `Sized`. + let self_value_param = match m.sig.decl.inputs.first().map(|p| p.into_value()) { + Some(FnArg::SelfValue(_)) => true, + _ => false, + }; + + // Check if return type is `Self` + let self_value_return = match &m.sig.decl.output { + ReturnType::Type(_, t) => { + if let Type::Path(p) = &**t { + p.path.is_ident("Self") + } else { + false + } + } + _ => false, + }; + + // TODO: check for `Self` parameter in any other argument. + + // If for this method, `Self` is used in a position that + // requires `Self: Sized` or this bound is added explicitly, we + // cannot add the `?Sized` relaxation to the impl body. + self_value_param || self_value_return || self_is_bounded_sized + }); + + let relaxation = if sized_required { + quote! {} + } else { + quote! { + ?::std::marker::Sized } + }; + // Determine if our proxy type needs a lifetime parameter let (mut params, ty_bounds) = match proxy_type { - ProxyType::Ref | ProxyType::RefMut => { - (quote! { #proxy_lt_param, }, quote! { : #proxy_lt_param + #trait_path }) + ProxyType::Ref | ProxyType::RefMut => ( + quote! { #proxy_lt_param, }, + quote! { : #proxy_lt_param + #trait_path #relaxation } + ), + ProxyType::Box | ProxyType::Rc | ProxyType::Arc => { + (quote!{}, quote! { : #trait_path #relaxation }) } - ProxyType::Box | ProxyType::Rc | ProxyType::Arc => (quote!{}, quote! { : #trait_path }), ProxyType::Fn | ProxyType::FnMut | ProxyType::FnOnce => { let fn_bound = gen_fn_type_for_trait(proxy_type, trait_def)?; (quote!{}, quote! { : #fn_bound }) diff --git a/tests/compile-fail/trait_obj_value_self.rs b/tests/compile-fail/trait_obj_value_self.rs new file mode 100644 index 0000000..2e818d0 --- /dev/null +++ b/tests/compile-fail/trait_obj_value_self.rs @@ -0,0 +1,13 @@ +use auto_impl::auto_impl; + + +#[auto_impl(Box)] +trait Trait { + fn foo(self); +} + +fn assert_impl() {} + +fn main() { + assert_impl::>(); +} diff --git a/tests/compile-pass/trait_obj_default_method.rs b/tests/compile-pass/trait_obj_default_method.rs new file mode 100644 index 0000000..9ee6962 --- /dev/null +++ b/tests/compile-pass/trait_obj_default_method.rs @@ -0,0 +1,22 @@ +use auto_impl::auto_impl; + + +#[auto_impl(Box)] +trait Trait { + fn bar(&self); + + #[auto_impl(keep_default_for(Box))] + fn foo(self) where Self: Sized {} +} + +fn assert_impl() {} + +struct Foo {} +impl Trait for Foo { + fn bar(&self) {} +} + +fn main() { + assert_impl::(); + assert_impl::>(); +} diff --git a/tests/compile-pass/trait_obj_immutable_self.rs b/tests/compile-pass/trait_obj_immutable_self.rs new file mode 100644 index 0000000..1ac30c0 --- /dev/null +++ b/tests/compile-pass/trait_obj_immutable_self.rs @@ -0,0 +1,19 @@ +use auto_impl::auto_impl; + + +#[auto_impl(&, &mut, Box, Rc, Arc)] +trait Trait { + fn foo(&self); +} + +fn assert_impl() {} + +fn main() { + use std::{rc::Rc, sync::Arc}; + + assert_impl::<&dyn Trait>(); + assert_impl::<&mut dyn Trait>(); + assert_impl::>(); + assert_impl::>(); + assert_impl::>(); +} diff --git a/tests/compile-pass/trait_obj_value_self.rs b/tests/compile-pass/trait_obj_value_self.rs new file mode 100644 index 0000000..ebf30eb --- /dev/null +++ b/tests/compile-pass/trait_obj_value_self.rs @@ -0,0 +1,7 @@ +use auto_impl::auto_impl; + + +#[auto_impl(Box)] +trait Trait { + fn foo(self); +}