diff --git a/doc/display.md b/doc/display.md index d4e970c5..3c5c30eb 100644 --- a/doc/display.md +++ b/doc/display.md @@ -33,6 +33,68 @@ i.e. `_0`, `_1`, `_2`, etc. The syntax does not change, but the name of the attribute is the snake case version of the trait. E.g. `Octal` -> `octal`, `Pointer` -> `pointer`, `UpperHex` -> `upper_hex`. +# Generic data types + +When deriving `Display` (or other formatting trait) for a generic struct/enum, all generic type +arguments used during formatting are bound by respective formatting trait. + +E.g., for a structure `Foo` defined like this: +```rust +# #[macro_use] extern crate derive_more; +# trait Trait { type Type; } + +#[derive(Display)] +#[display(fmt = "{} {} {:?} {:p}", a, b, c, d)] +struct Foo<'a, T1, T2: Trait, T3> { + a: T1, + b: ::Type, + c: Vec, + d: &'a T1, +} +``` + +The following where clauses would be generated: +* `T1: Display + Pointer` +* `::Type: Debug` +* `Bar: Display` + +## Custom trait bounds + +Sometimes you may want to specify additional trait bounds on your generic type parameters, so that they +could be used during formatting. This can be done with a `#[display(bound = "...")]` attribute. + +`#[display(bound = "...")]` accepts a single string argument in a format similar to the format +used in angle bracket list: `T: MyTrait, U: Trait1 + Trait2`. + +Only type parameters defined on a struct allowed to appear in bound-string and they can only be bound +by traits, i.e. no lifetime parameters or lifetime bounds allowed in bound-string. + +As double-quote `fmt` arguments are parsed as an arbitrary Rust expression and passed to generated +`write!` as-is, it's impossible to meaningfully infer any kind of trait bounds for generic type parameters +used this way. That means that you'll **have to** explicitly specify all trait bound used. Either in the +struct/enum definition, or via `#[display(bound = "...")]` attribute. + +Note how we have to bound `U` and `V` by `Display` in the following example, as no bound is inferred. +Not even `Display`. + +Also note, that `"c"` case is just a curious example. Bound inference works as expected if you simply +write `c` without double-quotes. + +```rust +# #[macro_use] extern crate derive_more; +# use std::fmt::Display; +# trait MyTrait { fn my_function(&self) -> i32; } + +#[derive(Display)] +#[display(bound = "T: MyTrait, U: Display, V: Display")] +#[display(fmt = "{} {} {}", "a.my_function()", "b.to_string().len()", "c")] +struct MyStruct { + a: T, + b: U, + c: V, +} +``` + # Example usage ```rust diff --git a/src/display.rs b/src/display.rs index 809973c4..a5090554 100644 --- a/src/display.rs +++ b/src/display.rs @@ -1,35 +1,30 @@ use std::{ collections::{HashMap, HashSet}, fmt::Display, - ops::Deref, + iter::FromIterator as _, + ops::Deref as _, + str::FromStr as _, }; use crate::utils::add_extra_where_clauses; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, quote_spanned}; use syn::{ - parse::{Error, Result}, - spanned::Spanned, - Attribute, Data, DeriveInput, Fields, GenericArgument, Lit, Meta, MetaNameValue, NestedMeta, Path, PathArguments, Type, TypeReference + parse::{ + Error, + Result, + Parser as _, + }, + punctuated::Punctuated, + spanned::Spanned as _, }; /// Provides the hook to expand `#[derive(Display)]` into an implementation of `From` -pub fn expand(input: &DeriveInput, trait_name: &str) -> Result { +pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> Result { let trait_name = trait_name.trim_end_matches("Custom"); - let trait_ident = Ident::new(trait_name, Span::call_site()); + let trait_ident = syn::Ident::new(trait_name, Span::call_site()); let trait_path = "e!(::core::fmt::#trait_ident); - let trait_attr = match trait_name { - "Display" => "display", - "Binary" => "binary", - "Octal" => "octal", - "LowerHex" => "lower_hex", - "UpperHex" => "upper_hex", - "LowerExp" => "lower_exp", - "UpperExp" => "upper_exp", - "Pointer" => "pointer", - "Debug" => "debug", - _ => unimplemented!(), - }; + let trait_attr = trait_name_to_attribute_name(trait_name); let type_params = input .generics .type_params() @@ -50,9 +45,8 @@ pub fn expand(input: &DeriveInput, trait_name: &str) -> Result { .map(|(ty, trait_names)| { let bounds: Vec<_> = trait_names .into_iter() - .map(|trait_name| { - let trait_ident = Ident::new(trait_name, Span::call_site()); - quote!(::core::fmt::#trait_ident) + .map(|bound| { + quote!(#bound) }) .collect(); quote!(#ty: #(#bounds)+*) @@ -99,10 +93,54 @@ pub fn expand(input: &DeriveInput, trait_name: &str) -> Result { }) } +fn trait_name_to_attribute_name(trait_name: &str) -> &'static str { + match trait_name { + "Display" => "display", + "Binary" => "binary", + "Octal" => "octal", + "LowerHex" => "lower_hex", + "UpperHex" => "upper_hex", + "LowerExp" => "lower_exp", + "UpperExp" => "upper_exp", + "Pointer" => "pointer", + "Debug" => "debug", + _ => unimplemented!(), + } +} + +fn attribute_name_to_trait_name(attribute_name: &str) -> &'static str { + match attribute_name { + "display" => "Display", + "binary" => "Binary", + "octal" => "Octal", + "lower_hex" => "LowerHex", + "upper_hex" => "UpperHex", + "lower_exp" => "LowerExp", + "upper_exp" => "UpperExp", + "pointer" => "Pointer", + _ => unreachable!(), + } +} + +fn trait_name_to_trait_bound(trait_name: &str) -> syn::TraitBound { + let path_segments_iterator = vec!["core", "fmt", trait_name].into_iter() + .map(|segment| syn::PathSegment::from(Ident::new(segment, Span::call_site()))); + + syn::TraitBound { + lifetimes: None, + modifier: syn::TraitBoundModifier::None, + paren_token: None, + path: syn::Path { + leading_colon: Some(syn::Token![::](Span::call_site())), + segments: syn::punctuated::Punctuated::from_iter(path_segments_iterator), + }, + } +} + struct State<'a, 'b> { trait_path: &'b TokenStream, trait_attr: &'static str, - input: &'a DeriveInput, + input: &'a syn::DeriveInput, type_params: HashSet, } @@ -113,11 +151,14 @@ impl<'a, 'b> State<'a, 'b> { self.trait_attr ) } + fn get_proper_bound_syntax(&self) -> impl Display { + format!("Proper syntax: #[{}(bound = \"T, U: Trait1 + Trait2, V: Trait3\")]", self.trait_attr) + } - fn get_matcher(&self, fields: &Fields) -> TokenStream { + fn get_matcher(&self, fields: &syn::Fields) -> TokenStream { match fields { - Fields::Unit => TokenStream::new(), - Fields::Unnamed(fields) => { + syn::Fields::Unit => TokenStream::new(), + syn::Fields::Unnamed(fields) => { let fields: TokenStream = (0..fields.unnamed.len()) .map(|n| { let i = Ident::new(&format!("_{}", n), Span::call_site()); @@ -126,7 +167,7 @@ impl<'a, 'b> State<'a, 'b> { .collect(); quote!((#fields)) } - Fields::Named(fields) => { + syn::Fields::Named(fields) => { let fields: TokenStream = fields .named .iter() @@ -139,37 +180,106 @@ impl<'a, 'b> State<'a, 'b> { } } } - fn find_meta(&self, attrs: &[Attribute]) -> Result> { - let mut it = attrs - .iter() - .filter_map(|m| m.parse_meta().ok()) - .filter(|m| { - if let Some(ident) = m.path().segments.first().map(|p| &p.ident) { - ident == self.trait_attr - } else { - false + fn find_meta(&self, attrs: &[syn::Attribute], meta_key: &str) -> Result> { + let mut iterator = attrs.iter() + .filter_map(|attr| attr.parse_meta().ok()) + .filter(|meta| { + let meta = match meta { + syn::Meta::List(meta) => meta, + _ => return false, + }; + + if !meta.path.is_ident(self.trait_attr) || meta.nested.is_empty() { + return false; } + + let meta = match &meta.nested[0] { + syn::NestedMeta::Meta(meta) => meta, + _ => return false, + }; + + let meta = match meta { + syn::Meta::NameValue(meta) => meta, + _ => return false, + }; + + meta.path.is_ident(meta_key) }); - let meta = it.next(); - if it.next().is_some() { - Err(Error::new(meta.span(), "Too many formats given")) - } else { + let meta = iterator.next(); + if iterator.next().is_none() { Ok(meta) + } else { + Err(Error::new(meta.span(), "Too many attributes specified")) + } + } + fn parse_meta_bounds(&self, bounds: &syn::LitStr) -> Result>> { + let span = bounds.span(); + + let input = bounds.value(); + let tokens = TokenStream::from_str(&input)?; + let parser = Punctuated::::parse_terminated; + + let generic_params = parser.parse2(tokens) + .map_err(|error| Error::new(span, error.to_string()))?; + + if generic_params.is_empty() { + return Err(Error::new(span, "No bounds specified")); + } + + let mut bounds = HashMap::new(); + + for generic_param in generic_params { + let type_param = match generic_param { + syn::GenericParam::Type(type_param) => type_param, + _ => return Err(Error::new(span, "Only trait bounds allowed")), + }; + + if !self.type_params.contains(&type_param.ident) { + return Err(Error::new(span, "Unknown generic type argument specified")); + } else if !type_param.attrs.is_empty() { + return Err(Error::new(span, "Attributes aren't allowed")); + } else if type_param.eq_token.is_some() || type_param.default.is_some() { + return Err(Error::new(span, "Default type parameters aren't allowed")); + } + + let ident = type_param.ident.to_string(); + + let ty = syn::Type::Path(syn::TypePath { qself: None, path: type_param.ident.into() }); + let bounds = bounds.entry(ty).or_insert_with(HashSet::new); + + for bound in type_param.bounds { + let bound = match bound { + syn::TypeParamBound::Trait(bound) => bound, + _ => return Err(Error::new(span, "Only trait bounds allowed")), + }; + + if bound.lifetimes.is_some() { + return Err(Error::new(span, "Higher-rank trait bounds aren't allowed")); + } + + bounds.insert(bound); + } + + if bounds.is_empty() { + return Err(Error::new(span, format!("No bounds specified for type parameter {}", ident))); + } } + + Ok(bounds) } - fn get_meta_fmt(&self, meta: &Meta, outer_enum: bool) -> Result<(TokenStream, bool)> { + fn parse_meta_fmt(&self, meta: &syn::Meta, outer_enum: bool) -> Result<(TokenStream, bool)> { let list = match meta { - Meta::List(list) => list, + syn::Meta::List(list) => list, _ => { return Err(Error::new(meta.span(), self.get_proper_fmt_syntax())); } }; match &list.nested[0] { - NestedMeta::Meta(Meta::NameValue(MetaNameValue { + syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue { path, - lit: Lit::Str(fmt), + lit: syn::Lit::Str(fmt), .. })) => match path { op if op.segments.first().expect("path shouldn't be empty").ident == "fmt" => { @@ -182,9 +292,9 @@ impl<'a, 'b> State<'a, 'b> { } // TODO: Check for a single `Display` group? let fmt_string = match &list.nested[0] { - NestedMeta::Meta(Meta::NameValue(MetaNameValue { + syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue { path, - lit: Lit::Str(s), + lit: syn::Lit::Str(s), .. })) if path .segments @@ -215,8 +325,8 @@ impl<'a, 'b> State<'a, 'b> { .skip(1) // skip fmt = "..." .try_fold(TokenStream::new(), |args, arg| { let arg = match arg { - NestedMeta::Lit(Lit::Str(s)) => s, - NestedMeta::Meta(Meta::Path(i)) => { + syn::NestedMeta::Lit(syn::Lit::Str(s)) => s, + syn::NestedMeta::Meta(syn::Meta::Path(i)) => { return Ok(quote_spanned!(list.span()=> #args #i,)); } _ => { @@ -247,11 +357,11 @@ impl<'a, 'b> State<'a, 'b> { )), } } - fn infer_fmt(&self, fields: &Fields, name: &Ident) -> Result { + fn infer_fmt(&self, fields: &syn::Fields, name: &Ident) -> Result { let fields = match fields { - Fields::Unit => return Ok(quote!(stringify!(#name))), - Fields::Named(fields) => &fields.named, - Fields::Unnamed(fields) => &fields.unnamed, + syn::Fields::Unit => return Ok(quote!(stringify!(#name))), + syn::Fields::Named(fields) => &fields.named, + syn::Fields::Unnamed(fields) => &fields.unnamed, }; if fields.is_empty() { return Ok(quote!(stringify!(#name))); @@ -271,16 +381,16 @@ impl<'a, 'b> State<'a, 'b> { } fn get_match_arms_and_extra_bounds( &self, - ) -> Result<(TokenStream, HashMap>)> { - match &self.input.data { - Data::Enum(e) => { + ) -> Result<(TokenStream, HashMap>)> { + let result: Result<_> = match &self.input.data { + syn::Data::Enum(e) => { match self - .find_meta(&self.input.attrs) - .and_then(|m| m.map(|m| self.get_meta_fmt(&m, true)).transpose())? + .find_meta(&self.input.attrs, "fmt") + .and_then(|m| m.map(|m| self.parse_meta_fmt(&m, true)).transpose())? { Some((fmt, false)) => { e.variants.iter().try_for_each(|v| { - if let Some(meta) = self.find_meta(&v.attrs)? { + if let Some(meta) = self.find_meta(&v.attrs, "fmt")? { Err(Error::new( meta.span(), "`fmt` cannot be used on variant when the whole enum has a format string without a placeholder, maybe you want to add a placeholder?", @@ -298,8 +408,8 @@ impl<'a, 'b> State<'a, 'b> { Some((outer_fmt, true)) => { let fmt: Result = e.variants.iter().try_fold(TokenStream::new(), |arms, v| { let matcher = self.get_matcher(&v.fields); - let fmt = if let Some(meta) = self.find_meta(&v.attrs)? { - self.get_meta_fmt(&meta, false)?.0 + let fmt = if let Some(meta) = self.find_meta(&v.attrs, "fmt")? { + self.parse_meta_fmt(&meta, false)?.0 } else { self.infer_fmt(&v.fields, &v.ident)? }; @@ -320,8 +430,8 @@ impl<'a, 'b> State<'a, 'b> { let fmt: TokenStream; let bounds: HashMap<_, _>; - if let Some(meta) = self.find_meta(&v.attrs)? { - fmt = self.get_meta_fmt(&meta, false)?.0; + if let Some(meta) = self.find_meta(&v.attrs, "fmt")? { + fmt = self.parse_meta_fmt(&meta, false)?.0; bounds = self.get_used_type_params_bounds(&v.fields, &meta); } else { fmt = self.infer_fmt(&v.fields, v_name)?; @@ -340,14 +450,14 @@ impl<'a, 'b> State<'a, 'b> { }), } } - Data::Struct(s) => { + syn::Data::Struct(s) => { let matcher = self.get_matcher(&s.fields); let name = &self.input.ident; let fmt: TokenStream; let bounds: HashMap<_, _>; - if let Some(meta) = self.find_meta(&self.input.attrs)? { - fmt = self.get_meta_fmt(&meta, false)?.0; + if let Some(meta) = self.find_meta(&self.input.attrs, "fmt")? { + fmt = self.parse_meta_fmt(&meta, false)?.0; bounds = self.get_used_type_params_bounds(&s.fields, &meta); } else { fmt = self.infer_fmt(&s.fields, name)?; @@ -359,32 +469,68 @@ impl<'a, 'b> State<'a, 'b> { bounds, )) } - Data::Union(_) => { - let meta = self.find_meta(&self.input.attrs)?.ok_or_else(|| { + syn::Data::Union(_) => { + let meta = self.find_meta(&self.input.attrs, "fmt")?.ok_or_else(|| { Error::new( self.input.span(), "Can not automatically infer format for unions", ) })?; - let fmt = self.get_meta_fmt(&meta, false)?.0; + let fmt = self.parse_meta_fmt(&meta, false)?.0; Ok(( quote_spanned!(self.input.span()=> _ => write!(_derive_more_Display_formatter, "{}", #fmt),), HashMap::new(), )) } + }; + + let (arms, mut bounds) = result?; + + let meta = match self.find_meta(&self.input.attrs, "bound")? { + Some(meta) => meta, + _ => return Ok((arms, bounds)), + }; + + let span = meta.span(); + + let meta = match meta { + syn::Meta::List(meta) => meta.nested, + _ => return Err(Error::new(span, self.get_proper_bound_syntax())), + }; + + if meta.len() != 1 { + return Err(Error::new(span, self.get_proper_bound_syntax())); } + + let meta = match &meta[0] { + syn::NestedMeta::Meta(syn::Meta::NameValue(meta)) => meta, + _ => return Err(Error::new(span, self.get_proper_bound_syntax())), + }; + + let extra_bounds = match &meta.lit { + syn::Lit::Str(extra_bounds) => extra_bounds, + _ => return Err(Error::new(span, self.get_proper_bound_syntax())), + }; + + let extra_bounds = self.parse_meta_bounds(extra_bounds)?; + + for (ty, extra_bounds) in extra_bounds { + bounds.entry(ty).or_insert_with(HashSet::new).extend(extra_bounds); + } + + Ok((arms, bounds)) } fn get_used_type_params_bounds( &self, - fields: &Fields, - meta: &Meta, - ) -> HashMap> { + fields: &syn::Fields, + meta: &syn::Meta, + ) -> HashMap> { if self.type_params.is_empty() { return HashMap::new(); } - let fields_type_params: HashMap = fields + let fields_type_params: HashMap = fields .iter() .enumerate() .filter_map(|(i, field)| { @@ -405,7 +551,7 @@ impl<'a, 'b> State<'a, 'b> { } let list = match meta { - Meta::List(list) => list, + syn::Meta::List(list) => list, // This one has been checked already in get_meta_fmt() method. _ => unreachable!(), }; @@ -415,10 +561,10 @@ impl<'a, 'b> State<'a, 'b> { .skip(1) // skip fmt = "..." .enumerate() .filter_map(|(i, arg)| match arg { - NestedMeta::Lit(Lit::Str(ref s)) => { + syn::NestedMeta::Lit(syn::Lit::Str(ref s)) => { syn::parse_str(&s.value()).ok().map(|id| (i, id)) } - NestedMeta::Meta(Meta::Path(ref id)) => Some((i, id.clone())), + syn::NestedMeta::Meta(syn::Meta::Path(ref id)) => Some((i, id.clone())), // This one has been checked already in get_meta_fmt() method. _ => unreachable!(), }) @@ -427,9 +573,9 @@ impl<'a, 'b> State<'a, 'b> { return HashMap::new(); } let fmt_string = match &list.nested[0] { - NestedMeta::Meta(Meta::NameValue(MetaNameValue { + syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue { path, - lit: Lit::Str(s), + lit: syn::Lit::Str(s), .. })) if path .segments @@ -452,18 +598,18 @@ impl<'a, 'b> State<'a, 'b> { bounds .entry(fields_type_params[arg].clone()) .or_insert_with(HashSet::new) - .insert(pl.trait_name); + .insert(trait_name_to_trait_bound(pl.trait_name)); } } bounds }, ) } - fn infer_type_params_bounds(&self, fields: &Fields) -> HashMap> { + fn infer_type_params_bounds(&self, fields: &syn::Fields) -> HashMap> { if self.type_params.is_empty() { return HashMap::new(); } - if let Fields::Unit = fields { + if let syn::Fields::Unit = fields { return HashMap::new(); } // infer_fmt() uses only first field. @@ -474,17 +620,7 @@ impl<'a, 'b> State<'a, 'b> { self.get_type_param(&field.ty).map(|ty| { ( ty, - [match self.trait_attr { - "display" => "Display", - "binary" => "Binary", - "octal" => "Octal", - "lower_hex" => "LowerHex", - "upper_hex" => "UpperHex", - "lower_exp" => "LowerExp", - "upper_exp" => "UpperExp", - "pointer" => "Pointer", - _ => unreachable!(), - }] + [trait_name_to_trait_bound(attribute_name_to_trait_name(self.trait_attr))] .iter() .cloned() .collect() @@ -493,9 +629,19 @@ impl<'a, 'b> State<'a, 'b> { }) .collect() } + fn get_type_param(&self, ty: &syn::Type) -> Option { + if self.has_type_param_in(ty) { + match ty { + syn::Type::Reference(syn::TypeReference { elem: ty, .. }) => Some(ty.deref().clone()), + ty => Some(ty.clone()) + } + } else { + None + } + } fn has_type_param_in(&self, ty: &syn::Type) -> bool { match ty { - Type::Path(ty) => { + syn::Type::Path(ty) => { if let Some(qself) = &ty.qself { if self.has_type_param_in(&qself.ty) { return true; @@ -510,13 +656,13 @@ impl<'a, 'b> State<'a, 'b> { ty.path.segments.iter() .any(|segment| { - if let PathArguments::AngleBracketed(arguments) = &segment.arguments { + if let syn::PathArguments::AngleBracketed(arguments) = &segment.arguments { arguments.args.iter().any(|argument| { match argument { - GenericArgument::Type(ty) => { + syn::GenericArgument::Type(ty) => { self.has_type_param_in(ty) }, - GenericArgument::Constraint(constraint) => { + syn::GenericArgument::Constraint(constraint) => { self.type_params.contains(&constraint.ident) }, _ => false, @@ -528,23 +674,13 @@ impl<'a, 'b> State<'a, 'b> { }) }, - Type::Reference(ty) => { + syn::Type::Reference(ty) => { self.has_type_param_in(&ty.elem) }, _ => false, } } - fn get_type_param(&self, ty: &syn::Type) -> Option { - if self.has_type_param_in(ty) { - match ty { - Type::Reference(TypeReference { elem: ty, .. }) => Some(ty.deref().clone()), - ty => Some(ty.clone()) - } - } else { - None - } - } } /// Representation of formatting placeholder. diff --git a/tests/display.rs b/tests/display.rs index aa58d320..2085d02f 100644 --- a/tests/display.rs +++ b/tests/display.rs @@ -339,4 +339,60 @@ mod generic { assert_eq!(s.to_string(), "10"); } } + + mod bound { + use super::*; + + #[test] + fn simple() { + #[derive(Display)] + #[display(fmt = "{} {}", _0, _1)] + struct Struct(T1, T2); + + let s = Struct(10, 20); + assert_eq!(s.to_string(), "10 20"); + } + + #[test] + fn redundant() { + #[derive(Display)] + #[display(bound = "T1: ::std::fmt::Display, T2: ::std::fmt::Display")] + #[display(fmt = "{} {}", _0, _1)] + struct Struct(T1, T2); + + let s = Struct(10, 20); + assert_eq!(s.to_string(), "10 20"); + } + + #[test] + fn complex() { + trait Trait1 { + fn function1(&self) -> &'static str; + } + + trait Trait2 { + fn function2(&self) -> &'static str; + } + + impl Trait1 for i32 { + fn function1(&self) -> &'static str { + "WHAT" + } + } + + impl Trait2 for i32 { + fn function2(&self) -> &'static str { + "EVER" + } + } + + #[derive(Display)] + #[display(bound = "T1: Trait1 + Trait2, T2: Trait1 + Trait2")] + #[display(fmt = "{} {} {} {}", "_0.function1()", _0, "_1.function2()", _1)] + struct Struct(T1, T2); + + let s = Struct(10, 20); + assert_eq!(s.to_string(), "WHAT 10 EVER 20"); + } + } }