diff --git a/doc/display.md b/doc/display.md index d4e970c5..4d2aff41 100644 --- a/doc/display.md +++ b/doc/display.md @@ -33,6 +33,54 @@ i.e. `_0`, `_1`, `_2`, etc. The syntax does not change, but the name of the attribute is the snake case version of the trait. E.g. `Octal` -> `octal`, `Pointer` -> `pointer`, `UpperHex` -> `upper_hex`. +# Deriving `Display` for generic data types + +When deriving `Display` (or other formatting trait) for a generic struct/enum, all generic type +arguments used during formatting are bound by respective formatting trait. + +E.g., for a structure Foo defined like this: + +```rust +#[derive(Display)] +#[display(fmt = "{} {:?} {} {:p}", a, b, c, d)] +struct Foo<'a, T1, T2: Trait, T3> { + a: T1, + b: ::Type, + c: Bar, + d: &'a T1, +} +``` + +Following where clauses would be generated: + +* `T1: Display + Pointer` +* `::Type: Debug` +* `Bar: Display` + +## Specifying additional trait bounds + +Sometimes you may want to specify additional trait bounds on your generic type parameters, so that they +could be used during formatting. This can be done with a `#[display(bound = "...")]` attribute. + +`#[display(bound = "...")]` accepts single string argument in a format generally similar to a format +used in angle bracket list: `T, U: MyTrait, V: Trait1 + Trait2`. + +Specifying type argument without explicitly specifying trait bounds is a shortcut to bind by formatting +type. + +Only type parameters defined on a struct allowed to appear in bound-string and they can only be bound +by traits, i.e., no lifetime parameters or lifetime bounds allowed in bound-string. + +```rust +#[derive(Display)] +#[display(bound = "T: MyTrait, U")] +#[display(fmt = "{} {}", "a.my_function()", "transform(b.to_string())")] +struct MyStruct { + a: T, + b: U, +} +``` + # Example usage ```rust diff --git a/src/display.rs b/src/display.rs index 809973c4..5659bd70 100644 --- a/src/display.rs +++ b/src/display.rs @@ -1,6 +1,7 @@ use std::{ collections::{HashMap, HashSet}, fmt::Display, + iter::FromIterator, ops::Deref, }; @@ -10,7 +11,25 @@ use quote::{quote, quote_spanned}; use syn::{ parse::{Error, Result}, spanned::Spanned, - Attribute, Data, DeriveInput, Fields, GenericArgument, Lit, Meta, MetaNameValue, NestedMeta, Path, PathArguments, Type, TypeReference + Attribute, + Data, + DeriveInput, + Fields, + GenericArgument, + Lit, + Meta, + MetaList, + MetaNameValue, + NestedMeta, + Path, + PathArguments, + PathSegment, + Token, + TraitBound, + TraitBoundModifier, + Type, + TypePath, + TypeReference, }; /// Provides the hook to expand `#[derive(Display)]` into an implementation of `From` @@ -18,18 +37,7 @@ pub fn expand(input: &DeriveInput, trait_name: &str) -> Result { let trait_name = trait_name.trim_end_matches("Custom"); let trait_ident = Ident::new(trait_name, Span::call_site()); let trait_path = "e!(::core::fmt::#trait_ident); - let trait_attr = match trait_name { - "Display" => "display", - "Binary" => "binary", - "Octal" => "octal", - "LowerHex" => "lower_hex", - "UpperHex" => "upper_hex", - "LowerExp" => "lower_exp", - "UpperExp" => "upper_exp", - "Pointer" => "pointer", - "Debug" => "debug", - _ => unimplemented!(), - }; + let trait_attr = trait_name_to_attribute_name(trait_name); let type_params = input .generics .type_params() @@ -50,9 +58,8 @@ pub fn expand(input: &DeriveInput, trait_name: &str) -> Result { .map(|(ty, trait_names)| { let bounds: Vec<_> = trait_names .into_iter() - .map(|trait_name| { - let trait_ident = Ident::new(trait_name, Span::call_site()); - quote!(::core::fmt::#trait_ident) + .map(|bound| { + quote!(#bound) }) .collect(); quote!(#ty: #(#bounds)+*) @@ -99,6 +106,50 @@ pub fn expand(input: &DeriveInput, trait_name: &str) -> Result { }) } +fn trait_name_to_attribute_name(trait_name: &str) -> &'static str { + match trait_name { + "Display" => "display", + "Binary" => "binary", + "Octal" => "octal", + "LowerHex" => "lower_hex", + "UpperHex" => "upper_hex", + "LowerExp" => "lower_exp", + "UpperExp" => "upper_exp", + "Pointer" => "pointer", + "Debug" => "debug", + _ => unimplemented!(), + } +} + +fn attribute_name_to_trait_name(attribute_name: &str) -> &'static str { + match attribute_name { + "display" => "Display", + "binary" => "Binary", + "octal" => "Octal", + "lower_hex" => "LowerHex", + "upper_hex" => "UpperHex", + "lower_exp" => "LowerExp", + "upper_exp" => "UpperExp", + "pointer" => "Pointer", + _ => unreachable!(), + } +} + +fn trait_name_to_trait_bound(trait_name: &str) -> TraitBound { + let path_segments_iterator = vec!["core", "fmt", trait_name].into_iter() + .map(|segment| PathSegment::from(Ident::new(segment, Span::call_site()))); + + TraitBound { + lifetimes: None, + modifier: TraitBoundModifier::None, + paren_token: None, + path: Path { + leading_colon: Some(Token![::](Span::call_site())), + segments: syn::punctuated::Punctuated::from_iter(path_segments_iterator), + }, + } +} + struct State<'a, 'b> { trait_path: &'b TokenStream, trait_attr: &'static str, @@ -113,6 +164,9 @@ impl<'a, 'b> State<'a, 'b> { self.trait_attr ) } + fn get_proper_bound_syntax(&self) -> impl Display { + format!("Proper syntax: #[{}(bound = \"T, U: Trait1 + Trait2, V: Trait3\")]", self.trait_attr) + } fn get_matcher(&self, fields: &Fields) -> TokenStream { match fields { @@ -139,26 +193,113 @@ impl<'a, 'b> State<'a, 'b> { } } } - fn find_meta(&self, attrs: &[Attribute]) -> Result> { - let mut it = attrs - .iter() - .filter_map(|m| m.parse_meta().ok()) - .filter(|m| { - if let Some(ident) = m.path().segments.first().map(|p| &p.ident) { - ident == self.trait_attr - } else { - false + fn find_meta(&self, attrs: &[Attribute], expected_inner_meta_path: &str) -> Result> { + let mut iterator = attrs.iter() + .filter_map(|attr| attr.parse_meta().ok()) + .filter(|meta| { + match meta { + Meta::List(MetaList { + path: Path { + segments: outer_meta_path, + .. + }, + nested: inner_meta, + .. + }) => { + if outer_meta_path.len() == 1 && outer_meta_path[0].ident == self.trait_attr && !inner_meta.is_empty() { + match &inner_meta[0] { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path: Path { + segments: inner_meta_path, + .. + }, + lit: Lit::Str(meta), + .. + })) => { + if inner_meta_path.len() == 1 && inner_meta_path[0].ident == expected_inner_meta_path { + true + } else { + false + } + } + _ => { + false + }, + } + } else { + false + } + } + _ => { + false + }, } }); - let meta = it.next(); - if it.next().is_some() { - Err(Error::new(meta.span(), "Too many formats given")) - } else { + let meta = iterator.next(); + if iterator.next().is_none() { Ok(meta) + } else { + Err(Error::new(meta.span(), "Too many attributes specified")) } } - fn get_meta_fmt(&self, meta: &Meta, outer_enum: bool) -> Result<(TokenStream, bool)> { + fn parse_meta_bounds(&self, bounds: &syn::LitStr) -> Result>> { + use std::str::FromStr; + use proc_macro2::TokenStream; + use syn::{self, parse::Parser, punctuated::Punctuated, GenericParam, Type, TypeParam, TypeParamBound, Token}; + + let span = bounds.span(); + + let input = bounds.value(); + let tokens = TokenStream::from_str(&input)?; + let parser = Punctuated::::parse_terminated; + let bounds = parser.parse2(tokens).map_err(|error| Error::new(span.clone(), error.to_string()))?; + + if bounds.is_empty() { + return Err(Error::new(span.clone(), "No bounds specified")); + } + + bounds.into_iter().try_fold(HashMap::new(), |mut accumulator, generic_param| { + match generic_param { + GenericParam::Type(type_param) => { + if self.type_params.contains(&type_param.ident) { + if !type_param.attrs.is_empty() { + Err(Error::new(span.clone(), "Attributes aren't allowed")) + } else if type_param.eq_token.is_some() || type_param.default.is_some() { + Err(Error::new(span.clone(), "Default type parameters aren't allowed")) + } else { + let ty = Type::Path(TypePath { qself: None, path: type_param.ident.into() }); + let bounds = accumulator.entry(ty).or_insert_with(|| HashSet::new()); + + let bounds = type_param.bounds.into_iter().try_fold(bounds, |mut accumulator, bound| { + match bound { + TypeParamBound::Trait(bound) => { + if bound.lifetimes.is_some() { + Err(Error::new(span.clone(), "Higher-rank trait bounds aren't allowed")) + } else { + accumulator.insert(bound); + Ok(accumulator) + } + } + _ => Err(Error::new(span.clone(), "Only trait bounds allowed")), + } + })?; + + if bounds.is_empty() { + bounds.insert(trait_name_to_trait_bound(attribute_name_to_trait_name(self.trait_attr))); + } + + Ok(accumulator) + } + } else { + Err(Error::new(span.clone(), "Unknown generic type argument specified")) + } + }, + _ => Err(Error::new(span.clone(), "Only trait bounds allowed")), + } + }) + } + fn parse_meta_fmt(&self, meta: &Meta, outer_enum: bool) -> Result<(TokenStream, bool)> { let list = match meta { Meta::List(list) => list, _ => { @@ -271,16 +412,16 @@ impl<'a, 'b> State<'a, 'b> { } fn get_match_arms_and_extra_bounds( &self, - ) -> Result<(TokenStream, HashMap>)> { - match &self.input.data { + ) -> Result<(TokenStream, HashMap>)> { + let result: Result<_> = match &self.input.data { Data::Enum(e) => { match self - .find_meta(&self.input.attrs) - .and_then(|m| m.map(|m| self.get_meta_fmt(&m, true)).transpose())? + .find_meta(&self.input.attrs, "fmt") + .and_then(|m| m.map(|m| self.parse_meta_fmt(&m, true)).transpose())? { Some((fmt, false)) => { e.variants.iter().try_for_each(|v| { - if let Some(meta) = self.find_meta(&v.attrs)? { + if let Some(meta) = self.find_meta(&v.attrs, "fmt")? { Err(Error::new( meta.span(), "`fmt` cannot be used on variant when the whole enum has a format string without a placeholder, maybe you want to add a placeholder?", @@ -298,8 +439,8 @@ impl<'a, 'b> State<'a, 'b> { Some((outer_fmt, true)) => { let fmt: Result = e.variants.iter().try_fold(TokenStream::new(), |arms, v| { let matcher = self.get_matcher(&v.fields); - let fmt = if let Some(meta) = self.find_meta(&v.attrs)? { - self.get_meta_fmt(&meta, false)?.0 + let fmt = if let Some(meta) = self.find_meta(&v.attrs, "fmt")? { + self.parse_meta_fmt(&meta, false)?.0 } else { self.infer_fmt(&v.fields, &v.ident)? }; @@ -320,8 +461,8 @@ impl<'a, 'b> State<'a, 'b> { let fmt: TokenStream; let bounds: HashMap<_, _>; - if let Some(meta) = self.find_meta(&v.attrs)? { - fmt = self.get_meta_fmt(&meta, false)?.0; + if let Some(meta) = self.find_meta(&v.attrs, "fmt")? { + fmt = self.parse_meta_fmt(&meta, false)?.0; bounds = self.get_used_type_params_bounds(&v.fields, &meta); } else { fmt = self.infer_fmt(&v.fields, v_name)?; @@ -346,8 +487,8 @@ impl<'a, 'b> State<'a, 'b> { let fmt: TokenStream; let bounds: HashMap<_, _>; - if let Some(meta) = self.find_meta(&self.input.attrs)? { - fmt = self.get_meta_fmt(&meta, false)?.0; + if let Some(meta) = self.find_meta(&self.input.attrs, "fmt")? { + fmt = self.parse_meta_fmt(&meta, false)?.0; bounds = self.get_used_type_params_bounds(&s.fields, &meta); } else { fmt = self.infer_fmt(&s.fields, name)?; @@ -360,26 +501,65 @@ impl<'a, 'b> State<'a, 'b> { )) } Data::Union(_) => { - let meta = self.find_meta(&self.input.attrs)?.ok_or_else(|| { + let meta = self.find_meta(&self.input.attrs, "fmt")?.ok_or_else(|| { Error::new( self.input.span(), "Can not automatically infer format for unions", ) })?; - let fmt = self.get_meta_fmt(&meta, false)?.0; + let fmt = self.parse_meta_fmt(&meta, false)?.0; Ok(( quote_spanned!(self.input.span()=> _ => write!(_derive_more_Display_formatter, "{}", #fmt),), HashMap::new(), )) } - } + }; + + let (arms, mut bounds) = result?; + + match self.find_meta(&self.input.attrs, "bound")? { + Some(meta) => { + let span = meta.span(); + + match meta { + Meta::List(MetaList { + nested: meta, + .. + }) => { + if meta.len() != 1 { + Err(Error::new(span, self.get_proper_bound_syntax())) + } else { + match &meta[0] { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + lit: Lit::Str(extra_bounds), + .. + })) => { + let extra_bounds = self.parse_meta_bounds(extra_bounds)?; + + extra_bounds.into_iter().for_each(|(ty, extra_bounds)| { + bounds.entry(ty).or_insert_with(|| HashSet::new()).extend(extra_bounds); + }); + + Ok(()) + } + _ => Err(Error::new(span, self.get_proper_bound_syntax())), + } + } + }, + _ => Err(Error::new(span, self.get_proper_bound_syntax())), + } + }, + None => Ok(()), + }?; + + Ok((arms, bounds)) } fn get_used_type_params_bounds( &self, fields: &Fields, meta: &Meta, - ) -> HashMap> { + ) -> HashMap> { if self.type_params.is_empty() { return HashMap::new(); } @@ -452,14 +632,14 @@ impl<'a, 'b> State<'a, 'b> { bounds .entry(fields_type_params[arg].clone()) .or_insert_with(HashSet::new) - .insert(pl.trait_name); + .insert(trait_name_to_trait_bound(pl.trait_name)); } } bounds }, ) } - fn infer_type_params_bounds(&self, fields: &Fields) -> HashMap> { + fn infer_type_params_bounds(&self, fields: &Fields) -> HashMap> { if self.type_params.is_empty() { return HashMap::new(); } @@ -474,17 +654,7 @@ impl<'a, 'b> State<'a, 'b> { self.get_type_param(&field.ty).map(|ty| { ( ty, - [match self.trait_attr { - "display" => "Display", - "binary" => "Binary", - "octal" => "Octal", - "lower_hex" => "LowerHex", - "upper_hex" => "UpperHex", - "lower_exp" => "LowerExp", - "upper_exp" => "UpperExp", - "pointer" => "Pointer", - _ => unreachable!(), - }] + [trait_name_to_trait_bound(attribute_name_to_trait_name(self.trait_attr))] .iter() .cloned() .collect() @@ -493,6 +663,16 @@ impl<'a, 'b> State<'a, 'b> { }) .collect() } + fn get_type_param(&self, ty: &syn::Type) -> Option { + if self.has_type_param_in(ty) { + match ty { + Type::Reference(TypeReference { elem: ty, .. }) => Some(ty.deref().clone()), + ty => Some(ty.clone()) + } + } else { + None + } + } fn has_type_param_in(&self, ty: &syn::Type) -> bool { match ty { Type::Path(ty) => { @@ -535,16 +715,6 @@ impl<'a, 'b> State<'a, 'b> { _ => false, } } - fn get_type_param(&self, ty: &syn::Type) -> Option { - if self.has_type_param_in(ty) { - match ty { - Type::Reference(TypeReference { elem: ty, .. }) => Some(ty.deref().clone()), - ty => Some(ty.clone()) - } - } else { - None - } - } } /// Representation of formatting placeholder. diff --git a/tests/display.rs b/tests/display.rs index aa58d320..d1cf6789 100644 --- a/tests/display.rs +++ b/tests/display.rs @@ -339,4 +339,50 @@ mod generic { assert_eq!(s.to_string(), "10"); } } + + mod bound { + use super::*; + + #[test] + fn bound_simple() { + #[derive(Display)] + #[display(bound = "T1, T2")] + #[display(fmt = "{} {}", _0, _1)] + struct Struct(T1, T2); + + let s = Struct(10, 20); + assert_eq!(s.to_string(), "10 20"); + } + + #[test] + fn bound_complex() { + trait Trait1 { + fn function1(&self) -> &'static str; + } + + trait Trait2 { + fn function2(&self) -> &'static str; + } + + impl Trait1 for i32 { + fn function1(&self) -> &'static str { + "WHAT" + } + } + + impl Trait2 for i32 { + fn function2(&self) -> &'static str { + "EVER" + } + } + + #[derive(Display)] + #[display(bound = "T1: Trait1 + Trait2, T2: Trait1 + Trait2")] + #[display(fmt = "{} {} {} {}", "_0.function1()", _0, "_1.function2()", _1)] + struct Struct(T1, T2); + + let s = Struct(10, 20); + assert_eq!(s.to_string(), "WHAT 10 EVER 20"); + } + } }