Skip to content

Commit

Permalink
Interpolate unnamed enum variant fields in to_string attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
gin-ahirsch committed Mar 13, 2024
1 parent 89da3cc commit 71001b4
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 35 deletions.
4 changes: 3 additions & 1 deletion strum_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,16 @@ pub fn to_string(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
/// 3. The name of the variant will be used if there are no `serialize` or `to_string` attributes.
/// 4. If the enum has a `strum(prefix = "some_value_")`, every variant will have that prefix prepended
/// to the serialization.
/// 5. Enums with named fields support named field interpolation. The value will be interpolated into the output string.
/// 5. Enums with fields support string interpolation.
/// Note this means the variant will not "round trip" if you then deserialize the string.
///
/// ```rust
/// #[derive(strum_macros::Display)]
/// pub enum Color {
/// #[strum(to_string = "saturation is {sat}")]
/// Red { sat: usize },
/// #[strum(to_string = "hue is {1}, saturation is {0}")]
/// Blue(usize, usize),
/// }
/// ```
///
Expand Down
113 changes: 79 additions & 34 deletions strum_macros/src/macros/strings/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,20 @@ pub fn display_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {

let params = match variant.fields {
Fields::Unit => quote! {},
Fields::Unnamed(..) => quote! { (..) },
Fields::Unnamed(ref unnamed_fields) => {
// Transform unnamed params '(String, u8)' to '(ref field0, ref field1)'
let names: Punctuated<_, Token!(,)> = unnamed_fields
.unnamed
.iter()
.enumerate()
.map(|(index, field)| {
assert!(field.ident.is_none());
let ident = syn::parse_str::<Ident>(format!("field{}", index).as_str()).unwrap();
quote! { ref #ident }
})
.collect();
quote! { (#names) }
}
Fields::Named(ref field_names) => {
// Transform named params '{ name: String, age: u8 }' to '{ ref name, ref age }'
let names: Punctuated<TokenStream, Token!(,)> = field_names
Expand Down Expand Up @@ -58,33 +71,60 @@ pub fn display_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}
}
} else {
let arm = if let Fields::Named(ref field_names) = variant.fields {
let used_vars = capture_format_string_idents(&output)?;
if used_vars.is_empty() {
quote! { #name::#ident #params => ::core::fmt::Display::fmt(#output, f) }
} else {
// Create args like 'name = name, age = age' for format macro
let args: Punctuated<_, Token!(,)> = field_names
.named
.iter()
.filter_map(|field| {
let ident = field.ident.as_ref().unwrap();
// Only contain variables that are used in format string
if !used_vars.contains(ident) {
None
} else {
Some(quote! { #ident = #ident })
}
})
.collect();

quote! {
#[allow(unused_variables)]
#name::#ident #params => ::core::fmt::Display::fmt(&format!(#output, #args), f)
let arm = match variant.fields {
Fields::Named(ref field_names) => {
let used_vars = capture_format_string_idents(&output)?;
if used_vars.is_empty() {
quote! { #name::#ident #params => ::core::fmt::Display::fmt(#output, f) }
} else {
// Create args like 'name = name, age = age' for format macro
let args: Punctuated<_, Token!(,)> = field_names
.named
.iter()
.filter_map(|field| {
let ident = field.ident.as_ref().unwrap();
// Only contain variables that are used in format string
if !used_vars.contains(ident) {
None
} else {
Some(quote! { #ident = #ident })
}
})
.collect();

quote! {
#[allow(unused_variables)]
#name::#ident #params => ::core::fmt::Display::fmt(&format!(#output, #args), f)
}
}
},
Fields::Unnamed(ref unnamed_fields) => {
let used_vars = capture_format_strings(&output)?;
if used_vars.iter().any(String::is_empty) {
return Err(syn::Error::new_spanned(
&output,
"Empty {} is not allowed; Use manual numbering ({0})",
))
}
if used_vars.is_empty() {
quote! { #name::#ident #params => ::core::fmt::Display::fmt(#output, f) }
} else {
let args: Punctuated<_, Token!(,)> = unnamed_fields
.unnamed
.iter()
.enumerate()
.map(|(index, field)| {
assert!(field.ident.is_none());
syn::parse_str::<Ident>(format!("field{}", index).as_str()).unwrap()
})
.collect();
quote! {
#[allow(unused_variables)]
#name::#ident #params => ::core::fmt::Display::fmt(&format!(#output, #args), f)
}
}
}
} else {
quote! { #name::#ident #params => ::core::fmt::Display::fmt(#output, f) }
Fields::Unit => quote! { #name::#ident #params => ::core::fmt::Display::fmt(#output, f) }
};

arms.push(arm);
Expand All @@ -107,11 +147,22 @@ pub fn display_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}

fn capture_format_string_idents(string_literal: &LitStr) -> syn::Result<Vec<Ident>> {
capture_format_strings(string_literal)?.into_iter().map(|ident| {
syn::parse_str::<Ident>(ident.as_str()).map_err(|_| {
syn::Error::new_spanned(
string_literal,
"Invalid identifier inside format string bracket",
)
})
}).collect()
}

fn capture_format_strings(string_literal: &LitStr) -> syn::Result<Vec<String>> {
// Remove escaped brackets
let format_str = string_literal.value().replace("{{", "").replace("}}", "");

let mut new_var_start_index: Option<usize> = None;
let mut var_used: Vec<Ident> = Vec::new();
let mut var_used = Vec::new();

for (i, chr) in format_str.bytes().enumerate() {
if chr == b'{' {
Expand All @@ -133,13 +184,7 @@ fn capture_format_string_idents(string_literal: &LitStr) -> syn::Result<Vec<Iden

let inside_brackets = &format_str[start_index + 1..i];
let ident_str = inside_brackets.split(":").next().unwrap().trim_end();
let ident = syn::parse_str::<Ident>(ident_str).map_err(|_| {
syn::Error::new_spanned(
string_literal,
"Invalid identifier inside format string bracket",
)
})?;
var_used.push(ident);
var_used.push(ident_str.to_owned());
}
}

Expand Down
10 changes: 10 additions & 0 deletions strum_tests/tests/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ enum Color {
Purple { sat: usize },
#[strum(default)]
Green(String),
#[strum(to_string = "Orange({0})")]
Orange(usize),
}

#[test]
Expand Down Expand Up @@ -64,6 +66,14 @@ fn to_green_string() {
);
}

#[test]
fn to_orange_string() {
assert_eq!(
String::from("Orange(10)"),
Color::Orange(10).to_string().as_ref()
);
}

#[derive(Debug, Eq, PartialEq, EnumString, strum::Display)]
enum ColorWithDefaultAndToString {
#[strum(default, to_string = "GreenGreen")]
Expand Down

0 comments on commit 71001b4

Please sign in to comment.