diff --git a/README.md b/README.md index 491c24f2..b3d77901 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ Strum has implemented the following macros: | [IntoStaticStr] | Implements `From for &'static str` on an enum | | [EnumVariantNames] | Adds an associated `VARIANTS` constant which is an array of discriminant names | | [EnumIter] | Creates a new type that iterates of the variants of an enum. | +| [EnumMap] | Creates a new type that stores an item of a specified type for each variant of the enum. | | [EnumProperty] | Add custom properties to enum variants. | | [EnumMessage] | Add a verbose message to an enum variant. | | [EnumDiscriminants] | Generate a new type with only the discriminant names. | diff --git a/strum_macros/src/helpers/mod.rs b/strum_macros/src/helpers/mod.rs index 142ea0b8..724fce2f 100644 --- a/strum_macros/src/helpers/mod.rs +++ b/strum_macros/src/helpers/mod.rs @@ -1,4 +1,4 @@ -pub use self::case_style::{CaseStyleHelpers, snakify}; +pub use self::case_style::{snakify, CaseStyleHelpers}; pub use self::type_props::HasTypeProperties; pub use self::variant_props::HasStrumVariantProperties; diff --git a/strum_macros/src/lib.rs b/strum_macros/src/lib.rs index 82db12ad..cf6a837a 100644 --- a/strum_macros/src/lib.rs +++ b/strum_macros/src/lib.rs @@ -444,6 +444,45 @@ pub fn enum_try_as(input: proc_macro::TokenStream) -> proc_macro::TokenStream { toks.into() } +/// Creates a new type that maps all the variants of an enum to another generic value. +/// +/// This macro does not support any additional data on your variants. +/// The macro creates a new type called `YourEnumTable`. +/// The table has a field of type `T` for each variant of `YourEnum`. The table automatically implements `Index` and `IndexMut`. +/// ``` +/// use strum_macros::EnumTable; +/// +/// #[derive(EnumTable)] +/// enum Color { +/// Red, +/// Yellow, +/// Green, +/// Blue, +/// } +/// +/// assert_eq!(ColorTable::default(), ColorTable::new(0, 0, 0, 0)); +/// assert_eq!(ColorTable::filled(2), ColorTable::new(2, 2, 2, 2)); +/// assert_eq!(ColorTable::from_closure(|_| 3), ColorTable::new(3, 3, 3, 3)); +/// assert_eq!(ColorTable::default().transform(|_, val| val + 2), ColorTable::new(2, 2, 2, 2)); +/// +/// let mut complex_map = ColorTable::from_closure(|color| match color { +/// Color::Red => 0, +/// _ => 3 +/// }); +/// complex_map[Color::Green] = complex_map[Color::Red]; +/// assert_eq!(complex_map, ColorTable::new(0, 3, 0, 3)); +/// +/// ``` +#[proc_macro_derive(EnumTable, attributes(strum))] +pub fn enum_table(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let ast = syn::parse_macro_input!(input as DeriveInput); + + let toks = + macros::enum_table::enum_table_inner(&ast).unwrap_or_else(|err| err.to_compile_error()); + debug_print_generated(&ast, &toks); + toks.into() +} + /// Add a function to enum that allows accessing variants by its discriminant /// /// This macro adds a standalone function to obtain an enum variant by its discriminant. The macro adds diff --git a/strum_macros/src/macros/enum_table.rs b/strum_macros/src/macros/enum_table.rs new file mode 100644 index 00000000..f9d4e81d --- /dev/null +++ b/strum_macros/src/macros/enum_table.rs @@ -0,0 +1,204 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote}; +use syn::{spanned::Spanned, Data, DeriveInput, Fields}; + +use crate::helpers::{non_enum_error, snakify, HasStrumVariantProperties}; + +pub fn enum_table_inner(ast: &DeriveInput) -> syn::Result { + let name = &ast.ident; + let gen = &ast.generics; + let vis = &ast.vis; + let mut doc_comment = format!("A map over the variants of `{}`", name); + + if gen.lifetimes().count() > 0 { + return Err(syn::Error::new( + Span::call_site(), + "`EnumTable` doesn't support enums with lifetimes.", + )); + } + + let variants = match &ast.data { + Data::Enum(v) => &v.variants, + _ => return Err(non_enum_error()), + }; + + let table_name = format_ident!("{}Table", name); + + // the identifiers of each variant, in PascalCase + let mut pascal_idents = Vec::new(); + // the identifiers of each struct field, in snake_case + let mut snake_idents = Vec::new(); + // match arms in the form `MyEnumTable::Variant => &self.variant,` + let mut get_matches = Vec::new(); + // match arms in the form `MyEnumTable::Variant => &mut self.variant,` + let mut get_matches_mut = Vec::new(); + // match arms in the form `MyEnumTable::Variant => self.variant = new_value` + let mut set_matches = Vec::new(); + // struct fields of the form `variant: func(MyEnum::Variant),* + let mut closure_fields = Vec::new(); + // struct fields of the form `variant: func(MyEnum::Variant, self.variant),` + let mut transform_fields = Vec::new(); + + // identifiers for disabled variants + let mut disabled_variants = Vec::new(); + // match arms for disabled variants + let mut disabled_matches = Vec::new(); + + for variant in variants { + // skip disabled variants + if variant.get_variant_properties()?.disabled.is_some() { + let disabled_ident = &variant.ident; + let panic_message = format!( + "Can't use `{}` with `{}` - variant is disabled for Strum features", + disabled_ident, table_name + ); + disabled_variants.push(disabled_ident); + disabled_matches.push(quote!(#name::#disabled_ident => panic!(#panic_message),)); + continue; + } + + // Error on variants with data + if variant.fields != Fields::Unit { + return Err(syn::Error::new( + variant.fields.span(), + "`EnumTable` doesn't support enums with non-unit variants", + )); + }; + + let pascal_case = &variant.ident; + let snake_case = format_ident!("_{}", snakify(&pascal_case.to_string())); + + get_matches.push(quote! {#name::#pascal_case => &self.#snake_case,}); + get_matches_mut.push(quote! {#name::#pascal_case => &mut self.#snake_case,}); + set_matches.push(quote! {#name::#pascal_case => self.#snake_case = new_value,}); + closure_fields.push(quote! {#snake_case: func(#name::#pascal_case),}); + transform_fields.push(quote! {#snake_case: func(#name::#pascal_case, &self.#snake_case),}); + pascal_idents.push(pascal_case); + snake_idents.push(snake_case); + } + + // Error on empty enums + if pascal_idents.is_empty() { + return Err(syn::Error::new( + variants.span(), + "`EnumTable` requires at least one non-disabled variant", + )); + } + + // if the index operation can panic, add that to the documentation + if !disabled_variants.is_empty() { + doc_comment.push_str(&format!( + "\n# Panics\nIndexing `{}` with any of the following variants will cause a panic:", + table_name + )); + for variant in disabled_variants { + doc_comment.push_str(&format!("\n\n- `{}::{}`", name, variant)); + } + } + + let doc_new = format!( + "Create a new {} with a value for each variant of {}", + table_name, name + ); + let doc_closure = format!( + "Create a new {} by running a function on each variant of `{}`", + table_name, name + ); + let doc_transform = format!("Create a new `{}` by running a function on each variant of `{}` and the corresponding value in the current `{0}`", table_name, name); + let doc_filled = format!( + "Create a new `{}` with the same value in each field.", + table_name + ); + let doc_option_all = format!("Converts `{}>` into `Option<{0}>`. Returns `Some` if all fields are `Some`, otherwise returns `None`.", table_name); + let doc_result_all_ok = format!("Converts `{}>` into `Result<{0}, E>`. Returns `Ok` if all fields are `Ok`, otherwise returns `Err`.", table_name); + + Ok(quote! { + #[doc = #doc_comment] + #[allow( + missing_copy_implementations, + )] + #[derive(Debug, Clone, Default, PartialEq, Eq, Hash)] + #vis struct #table_name { + #(#snake_idents: T,)* + } + + impl #table_name { + #[doc = #doc_filled] + #vis fn filled(value: T) -> #table_name { + #table_name { + #(#snake_idents: value.clone(),)* + } + } + } + + impl #table_name { + #[doc = #doc_new] + #vis fn new( + #(#snake_idents: T,)* + ) -> #table_name { + #table_name { + #(#snake_idents,)* + } + } + + #[doc = #doc_closure] + #vis fn from_closureT>(func: F) -> #table_name { + #table_name { + #(#closure_fields)* + } + } + + #[doc = #doc_transform] + #vis fn transformU>(&self, func: F) -> #table_name { + #table_name { + #(#transform_fields)* + } + } + + } + + impl ::core::ops::Index<#name> for #table_name { + type Output = T; + + fn index(&self, idx: #name) -> &T { + match idx { + #(#get_matches)* + #(#disabled_matches)* + } + } + } + + impl ::core::ops::IndexMut<#name> for #table_name { + fn index_mut(&mut self, idx: #name) -> &mut T { + match idx { + #(#get_matches_mut)* + #(#disabled_matches)* + } + } + } + + impl #table_name<::core::option::Option> { + #[doc = #doc_option_all] + #vis fn all(self) -> ::core::option::Option<#table_name> { + if let #table_name { + #(#snake_idents: ::core::option::Option::Some(#snake_idents),)* + } = self { + ::core::option::Option::Some(#table_name { + #(#snake_idents,)* + }) + } else { + ::core::option::Option::None + } + } + } + + impl #table_name<::core::result::Result> { + #[doc = #doc_result_all_ok] + #vis fn all_ok(self) -> ::core::result::Result<#table_name, E> { + ::core::result::Result::Ok(#table_name { + #(#snake_idents: self.#snake_idents?,)* + }) + } + } + }) +} diff --git a/strum_macros/src/macros/mod.rs b/strum_macros/src/macros/mod.rs index 8df8cd6d..70a20750 100644 --- a/strum_macros/src/macros/mod.rs +++ b/strum_macros/src/macros/mod.rs @@ -2,6 +2,7 @@ pub mod enum_count; pub mod enum_discriminants; pub mod enum_is; pub mod enum_iter; +pub mod enum_table; pub mod enum_messages; pub mod enum_properties; pub mod enum_try_as; diff --git a/strum_tests/tests/enum_table.rs b/strum_tests/tests/enum_table.rs new file mode 100644 index 00000000..25e854fd --- /dev/null +++ b/strum_tests/tests/enum_table.rs @@ -0,0 +1,101 @@ +use strum::EnumTable; + +#[derive(EnumTable)] +enum Color { + Red, + Yellow, + Green, + #[strum(disabled)] + Teal, + Blue, + #[strum(disabled)] + Indigo, +} + +// even though this isn't used, it needs to be a test +// because if it doesn't compile, enum variants that conflict with keywords won't work +#[derive(EnumTable)] +enum Keyword { + Const, +} + +#[test] +fn default() { + assert_eq!(ColorTable::default(), ColorTable::new(0, 0, 0, 0)); +} + +#[test] +#[should_panic] +fn disabled() { + let _ = ColorTable::::default()[Color::Indigo]; +} + +#[test] +fn filled() { + assert_eq!(ColorTable::filled(42), ColorTable::new(42, 42, 42, 42)); +} + +#[test] +fn from_closure() { + assert_eq!( + ColorTable::from_closure(|color| match color { + Color::Red => 1, + _ => 2, + }), + ColorTable::new(1, 2, 2, 2) + ); +} + +#[test] +fn clone() { + let cm = ColorTable::filled(String::from("Some Text Data")); + assert_eq!(cm.clone(), cm); +} + +#[test] +fn index() { + let map = ColorTable::new(18, 25, 7, 2); + assert_eq!(map[Color::Red], 18); + assert_eq!(map[Color::Yellow], 25); + assert_eq!(map[Color::Green], 7); + assert_eq!(map[Color::Blue], 2); +} + +#[test] +fn index_mut() { + let mut map = ColorTable::new(18, 25, 7, 2); + map[Color::Green] = 5; + map[Color::Red] *= 4; + assert_eq!(map[Color::Green], 5); + assert_eq!(map[Color::Red], 72); +} + +#[test] +fn option_all() { + let mut map: ColorTable> = ColorTable::filled(None); + map[Color::Red] = Some(64); + map[Color::Green] = Some(32); + map[Color::Blue] = Some(16); + + assert_eq!(map.clone().all(), None); + + map[Color::Yellow] = Some(8); + assert_eq!(map.all(), Some(ColorTable::new(64, 8, 32, 16))); +} + +#[test] +fn result_all_ok() { + let mut map: ColorTable> = ColorTable::filled(Ok(4)); + assert_eq!(map.clone().all_ok(), Ok(ColorTable::filled(4))); + map[Color::Red] = Err(22); + map[Color::Yellow] = Err(100); + assert_eq!(map.clone().all_ok(), Err(22)); + map[Color::Red] = Ok(1); + assert_eq!(map.all_ok(), Err(100)); +} + +#[test] +fn transform() { + let all_two = ColorTable::filled(2); + assert_eq!(all_two.transform(|_, n| *n * 2), ColorTable::filled(4)); +}