Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve derive-Display trait bounds inference (#93) #95

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 82 additions & 38 deletions src/display.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
collections::{HashMap, HashSet},
fmt::Display,
ops::Deref,
};

use crate::utils::add_extra_where_clauses;
Expand All @@ -9,7 +10,7 @@ use quote::{quote, quote_spanned};
use syn::{
parse::{Error, Result},
spanned::Spanned,
Attribute, Data, DeriveInput, Fields, Lit, Meta, MetaNameValue, NestedMeta, Path, Type,
Attribute, Data, DeriveInput, Fields, GenericArgument, Lit, Meta, MetaNameValue, NestedMeta, Path, PathArguments, Type, TypeReference
};

/// Provides the hook to expand `#[derive(Display)]` into an implementation of `From`
Expand Down Expand Up @@ -383,19 +384,20 @@ impl<'a, 'b> State<'a, 'b> {
return HashMap::new();
}

let fields_type_params: HashMap<_, _> = fields
let fields_type_params: HashMap<Path, _> = fields
.iter()
.enumerate()
.filter_map(|(i, field)| {
if !self.has_type_param_in(field) {
return None;
}
let path: Path = field
.ident
.clone()
.unwrap_or_else(|| Ident::new(&format!("_{}", i), Span::call_site()))
.into();
Some((path, field.ty.clone()))
self.get_type_param(&field.ty).map(|ty| {
(
field
.ident
.clone()
.unwrap_or_else(|| Ident::new(&format!("_{}", i), Span::call_site()))
.into(),
ty
)
})
})
.collect();
if fields_type_params.is_empty() {
Expand Down Expand Up @@ -469,37 +471,79 @@ impl<'a, 'b> State<'a, 'b> {
.iter()
.take(1)
.filter_map(|field| {
if !self.has_type_param_in(field) {
return None;
}
Some((
field.ty.clone(),
[match self.trait_attr {
"display" => "Display",
"binary" => "Binary",
"octal" => "Octal",
"lower_hex" => "LowerHex",
"upper_hex" => "UpperHex",
"lower_exp" => "LowerExp",
"upper_exp" => "UpperExp",
"pointer" => "Pointer",
_ => unreachable!(),
}]
.iter()
.cloned()
.collect(),
))
self.get_type_param(&field.ty).map(|ty| {
(
ty,
[match self.trait_attr {
"display" => "Display",
"binary" => "Binary",
"octal" => "Octal",
"lower_hex" => "LowerHex",
"upper_hex" => "UpperHex",
"lower_exp" => "LowerExp",
"upper_exp" => "UpperExp",
"pointer" => "Pointer",
_ => unreachable!(),
}]
.iter()
.cloned()
.collect()
)
})
})
.collect()
}
fn has_type_param_in(&self, field: &syn::Field) -> bool {
if let Type::Path(ref ty) = field.ty {
return match ty.path.segments.first() {
Some(t) => self.type_params.contains(&t.ident),
_ => false,
};
fn has_type_param_in(&self, ty: &syn::Type) -> bool {
match ty {
Type::Path(ty) => {
if let Some(qself) = &ty.qself {
if self.has_type_param_in(&qself.ty) {
return true;
}
}

if let Some(segment) = ty.path.segments.first() {
if self.type_params.contains(&segment.ident) {
return true;
}
}

ty.path.segments.iter()
.any(|segment| {
if let PathArguments::AngleBracketed(arguments) = &segment.arguments {
arguments.args.iter().any(|argument| {
match argument {
GenericArgument::Type(ty) => {
self.has_type_param_in(ty)
},
GenericArgument::Constraint(constraint) => {
self.type_params.contains(&constraint.ident)
},
_ => false,
}
})
} else {
false
}
})
},

Type::Reference(ty) => {
self.has_type_param_in(&ty.elem)
},

_ => false,
}
}
fn get_type_param(&self, ty: &syn::Type) -> Option<syn::Type> {
if self.has_type_param_in(ty) {
match ty {
Type::Reference(TypeReference { elem: ty, .. }) => Some(ty.deref().clone()),
ty => Some(ty.clone())
}
} else {
None
}
false
}
}

Expand Down
122 changes: 122 additions & 0 deletions tests/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,126 @@ mod generic {
let s = UnusedGenericStruct(());
assert_eq!(s.to_string(), "12");
}

mod associated_type_field_enumerator {
use super::*;

trait Trait {
type Type;
}

struct Struct;

impl Trait for Struct {
type Type = i32;
}

#[test]
fn auto_generic_named_struct_associated() {
#[derive(Display)]
struct AutoGenericNamedStructAssociated<T: Trait> {
field: <T as Trait>::Type,
}

let s = AutoGenericNamedStructAssociated::<Struct>{ field: 10 };
assert_eq!(s.to_string(), "10");
}

#[test]
fn auto_generic_unnamed_struct_associated() {
#[derive(Display)]
struct AutoGenericUnnamedStructAssociated<T: Trait>(<T as Trait>::Type);

let s = AutoGenericUnnamedStructAssociated::<Struct>(10);
assert_eq!(s.to_string(), "10");
}

#[test]
fn auto_generic_enum_associated() {
#[derive(Display)]
enum AutoGenericEnumAssociated<T: Trait> {
Enumerator(<T as Trait>::Type),
}

let e = AutoGenericEnumAssociated::<Struct>::Enumerator(10);
assert_eq!(e.to_string(), "10");
}
}

mod complex_type_field_enumerator {
use super::*;

#[derive(Display)]
struct Struct<T>(T);

#[test]
fn auto_generic_named_struct_complex() {
#[derive(Display)]
struct AutoGenericNamedStructComplex<T> {
field: Struct<T>,
}

let s = AutoGenericNamedStructComplex { field: Struct(10) };
assert_eq!(s.to_string(), "10");
}

#[test]
fn auto_generic_unnamed_struct_complex() {
#[derive(Display)]
struct AutoGenericUnnamedStructComplex<T>(Struct<T>);

let s = AutoGenericUnnamedStructComplex(Struct(10));
assert_eq!(s.to_string(), "10");
}

#[test]
fn auto_generic_enum_complex() {
#[derive(Display)]
enum AutoGenericEnumComplex<T> {
Enumerator(Struct<T>),
}

let e = AutoGenericEnumComplex::Enumerator(Struct(10));
assert_eq!(e.to_string(), "10")
}
}

mod reference {
use super::*;

#[test]
fn auto_generic_reference() {
#[derive(Display)]
struct AutoGenericReference<'a, T>(&'a T);

let s = AutoGenericReference(&10);
assert_eq!(s.to_string(), "10");
}

#[test]
fn auto_generic_static_reference() {
#[derive(Display)]
struct AutoGenericStaticReference<T: 'static>(&'static T);

let s = AutoGenericStaticReference(&10);
assert_eq!(s.to_string(), "10");
}
}

mod indirect {
use super::*;

#[derive(Display)]
struct Struct<T>(T);

#[test]
fn auto_generic_indirect() {
#[derive(Display)]
struct AutoGenericIndirect<T: 'static>(Struct<&'static T>);

const V: i32 = 10;
let s = AutoGenericIndirect(Struct(&V));
assert_eq!(s.to_string(), "10");
}
}
}