diff --git a/CHANGELOG.md b/CHANGELOG.md index 7935bc6a..f5307b7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ Changelog for `odra`. +## [0.9.1] - 2024-04-xx +### Changed +- `#[odra::odra_type]` attribute can be applied to all flavors of enums (before applicable to unit-only enums). + ## [0.9.0] - 2024-04-02 ### Added - `Maybe` - a type that represents an entrypoint arg that may or may not be present. diff --git a/examples/Odra.toml b/examples/Odra.toml index c9516eac..ec6c6b1a 100644 --- a/examples/Odra.toml +++ b/examples/Odra.toml @@ -63,3 +63,6 @@ fqn = "features::livenet::LivenetContract" [[contracts]] fqn = "features::optional_args::Token" + +[[contracts]] +fqn = "features::custom_types::MyContract" diff --git a/examples/src/features/custom_types.rs b/examples/src/features/custom_types.rs new file mode 100644 index 00000000..f0e473b6 --- /dev/null +++ b/examples/src/features/custom_types.rs @@ -0,0 +1,115 @@ +#![allow(missing_docs)] +use odra::{prelude::*, Var}; + +type IPv4 = [u8; 4]; +type IPv6 = [u8; 16]; + +/// An enum representing an IP address. +#[odra::odra_type] +#[derive(Default)] +pub enum IP { + /// No data. + #[default] + Unknown, + /// Single unnamed element. + IPv4(IPv4), + /// multiple unnamed elements. + IPv4WithDescription(IPv4, String), + /// single named element. + IPv6 { ip: IPv6 }, + /// multiple named elements. + IPv6WithDescription { ip: IPv6, description: String } +} + +#[odra::odra_type] +pub enum Fieldless { + /// Tuple variant. + Tuple(), + /// Struct variant. + Struct {}, + /// Unit variant. + Unit +} + +/// Unit-only enum. +#[odra::odra_type] +#[derive(Default)] +pub enum Unit { + #[default] + A = 10, + B = 20, + C +} + +/// A struct with named elements. +#[odra::odra_type] +pub struct MyStruct { + a: u32, + b: u32 +} + +// A struct with unnamed elements cannot be an Odra type. +// #[odra::odra_type] +// pub struct TupleStruct(u32, u32); + +#[odra::module] +pub struct MyContract { + ip: Var, + fieldless: Var, + unit: Var, + my_struct: Var +} + +#[odra::odra_error] +pub enum Errors { + NotFound = 1 +} + +#[odra::module] +impl MyContract { + pub fn init(&mut self, ip: IP, fieldless: Fieldless, unit: Unit, my_struct: MyStruct) { + self.ip.set(ip); + self.fieldless.set(fieldless); + self.unit.set(unit); + self.my_struct.set(my_struct); + } + + pub fn get_ip(&self) -> IP { + self.ip.get_or_default() + } + + pub fn get_fieldless(&self) -> Fieldless { + self.fieldless.get_or_revert_with(Errors::NotFound) + } + + pub fn get_unit(&self) -> Unit { + self.unit.get_or_default() + } + + pub fn get_struct(&self) -> MyStruct { + self.my_struct.get_or_revert_with(Errors::NotFound) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use odra::host::Deployer; + + #[test] + fn test_contract() { + let test_env = odra_test::env(); + let init_args = MyContractInitArgs { + ip: IP::IPv4([192, 168, 0, 1]), + fieldless: Fieldless::Tuple(), + unit: Unit::C, + my_struct: MyStruct { a: 10, b: 20 } + }; + let contract = MyContractHostRef::deploy(&test_env, init_args); + + assert_eq!(contract.get_ip(), IP::IPv4([192, 168, 0, 1])); + assert_eq!(contract.get_fieldless(), Fieldless::Tuple()); + assert_eq!(contract.get_unit(), Unit::C); + assert_eq!(contract.get_struct(), MyStruct { a: 10, b: 20 }); + } +} diff --git a/examples/src/features/mod.rs b/examples/src/features/mod.rs index 44370e6f..9731a0bc 100644 --- a/examples/src/features/mod.rs +++ b/examples/src/features/mod.rs @@ -2,6 +2,7 @@ pub mod access_control; pub mod collecting_events; pub mod cross_calls; +pub mod custom_types; pub mod events; pub mod handling_errors; pub mod host_functions; diff --git a/odra-macros/src/ast/odra_type_item.rs b/odra-macros/src/ast/odra_type_item.rs index a5f9c772..eace2936 100644 --- a/odra-macros/src/ast/odra_type_item.rs +++ b/odra-macros/src/ast/odra_type_item.rs @@ -1,11 +1,13 @@ use crate::ast::events_item::HasEventsImplItem; use crate::ast::fn_utils::{FnItem, SingleArgFnItem}; -use crate::ast::utils::{ImplItem, Named}; -use crate::ir::TypeIR; +use crate::ast::utils::ImplItem; +use crate::ir::{TypeIR, TypeKind}; use crate::utils; use crate::utils::misc::AsBlock; use derive_try_from_ref::TryFromRef; -use syn::parse_quote; +use quote::{format_ident, ToTokens}; +use syn::{parse_quote, Token}; +use syn::punctuated::Punctuated; use crate::ast::schema::SchemaCustomTypeItem; macro_rules! impl_from_ir { @@ -14,13 +16,10 @@ macro_rules! impl_from_ir { type Error = syn::Error; fn try_from(ir: &TypeIR) -> Result { - match ir { - x if x.is_enum() => Self::from_enum(ir), - x if x.is_struct() => Self::from_struct(ir), - _ => Err(syn::Error::new_spanned( - ir.self_code(), - "Only support enum or struct" - )) + match ir.kind()? { + TypeKind::UnitEnum { variants } => Self::from_unit_enum(variants), + TypeKind::Enum { variants } => Self::from_enum(variants), + TypeKind::Struct { fields } => Self::from_struct(fields), } } } @@ -117,9 +116,10 @@ impl TryFrom<&'_ TypeIR> for CLTypedItem { type Error = syn::Error; fn try_from(ir: &TypeIR) -> Result { - let ret_ty_cl_type_any = match ir.is_enum() { - true => utils::ty::cl_type_u8(), - false => utils::ty::cl_type_any() + let ret_ty_cl_type_any = match ir.kind()? { + TypeKind::UnitEnum { variants: _ } => utils::ty::cl_type_u8(), + TypeKind::Enum { variants: _ } => utils::ty::cl_type_any(), + TypeKind::Struct { fields: _ } => utils::ty::cl_type_any(), } .as_block(); let ty_cl_type = utils::ty::cl_type(); @@ -144,8 +144,7 @@ struct FromBytesFnItem { impl_from_ir!(FromBytesFnItem); impl FromBytesFnItem { - fn from_enum(ir: &TypeIR) -> syn::Result { - let ident = ir.name()?; + fn from_enum(variants: Vec) -> syn::Result { let ident_bytes = utils::ident::bytes(); let ident_from_bytes = utils::ident::from_bytes(); let ident_result = utils::ident::result(); @@ -154,9 +153,71 @@ impl FromBytesFnItem { let read_stmt: syn::Stmt = parse_quote!(let (#ident_result, #ident_bytes): (#ty_u8, _) = #from_bytes_expr;); - let deser = ir.map_fields( - |i| quote::quote!(x if x == #ident::#i as #ty_u8 => Ok((#ident::#i, #ident_bytes))) - )?; + let arms = variants + .iter() + .enumerate() + .map(|(v_idx, v)| { + let v_idx: u8 = v_idx as u8; + let ident = &v.ident; + let fields = variant_ident_vec(v); + let deser = fields.iter() + .map(|f| quote::quote!(let (#f, bytes) = odra::casper_types::bytesrepr::FromBytes::from_bytes(bytes)?;)) + .collect::>(); + let code = match &v.fields { + syn::Fields::Unit => { + quote::quote!(Ok((Self::#ident, bytes))) + }, + syn::Fields::Named(_) => { + quote::quote!( + #(#deser)* + Ok((Self::#ident { #(#fields,)* }, bytes)) + ) + }, + syn::Fields::Unnamed(_) => { + quote::quote!( + #(#deser)* + Ok((Self::#ident(#(#fields,)*), bytes)) + ) + } + }; + quote::quote!(#v_idx => { #code }) + }) + .collect::>(); + + let arg = Self::arg(); + let ret_ty = Self::ret_ty(); + let block = parse_quote!({ + #read_stmt + match #ident_result { + #arms + _ => Err(odra::casper_types::bytesrepr::Error::Formatting), + } + }); + Ok(Self { + fn_item: SingleArgFnItem::new(&ident_from_bytes, arg, ret_ty, block) + }) + } + + fn from_unit_enum(variants: Vec) -> syn::Result { + let ident_bytes = utils::ident::bytes(); + let ident_from_bytes = utils::ident::from_bytes(); + let ident_result = utils::ident::result(); + let ty_u8 = utils::ty::u8(); + let from_bytes_expr = utils::expr::failable_from_bytes(&ident_bytes); + + let read_stmt: syn::Stmt = + parse_quote!(let (#ident_result, #ident_bytes): (#ty_u8, _) = #from_bytes_expr;); + let deser = variants.iter() + .map(|v| { + let i = &v.ident; + let self_ty = match &v.fields { + syn::Fields::Unit => quote::quote!(Self::#i), + syn::Fields::Named(_) => quote::quote!(Self::#i { }), + syn::Fields::Unnamed(_) => quote::quote!(Self::#i()) + }; + quote::quote!(x if x == #self_ty as #ty_u8 => Ok((#self_ty, #ident_bytes))) + }) + .collect::>(); let arg = Self::arg(); let ret_ty = Self::ret_ty(); let block = parse_quote!({ @@ -171,16 +232,18 @@ impl FromBytesFnItem { }) } - fn from_struct(ir: &TypeIR) -> syn::Result { + fn from_struct(fields: Vec<(syn::Ident, syn::Type)>) -> syn::Result { let ident_bytes = utils::ident::bytes(); let ident_from_bytes = utils::ident::from_bytes(); let from_bytes_expr = utils::expr::failable_from_bytes(&ident_bytes); - let fields = ir - .fields()? + let fields = fields .into_iter() + .map(|(i, _)| i) .collect::>(); - let deser = ir.map_fields(|i| quote::quote!(let (#i, #ident_bytes) = #from_bytes_expr;))?; + let deser = fields.iter() + .map(|i| quote::quote!(let (#i, #ident_bytes) = #from_bytes_expr;)) + .collect::>(); let arg = Self::arg(); let ret_ty = Self::ret_ty(); let block = parse_quote!({ @@ -216,7 +279,7 @@ struct ToBytesFnItem { impl_from_ir!(ToBytesFnItem); impl ToBytesFnItem { - fn from_struct(ir: &TypeIR) -> syn::Result { + fn from_struct(fields: Vec<(syn::Ident, syn::Type)>) -> syn::Result { let ty_bytes_vec = utils::ty::bytes_vec(); let ty_ret = utils::ty::bytes_result(&ty_bytes_vec); let ty_self = utils::ty::_self(); @@ -227,11 +290,11 @@ impl ToBytesFnItem { let init_vec_stmt = utils::stmt::new_mut_vec_with_capacity(&ident_result, &serialized_length_expr); - let serialize = ir.map_fields(|i| { + let serialize = fields.iter().map(|(i, _)| { let member = utils::member::_self(i); let expr_to_bytes = utils::expr::failable_to_bytes(&member); quote::quote!(#ident_result.extend(#expr_to_bytes);) - })?; + }).collect::>(); let name = utils::ident::to_bytes(); let ret_ty = utils::misc::ret_ty(&ty_ret); @@ -245,7 +308,37 @@ impl ToBytesFnItem { }) } - fn from_enum(_ir: &TypeIR) -> syn::Result { + fn from_enum(variants: Vec) -> syn::Result { + let ty_bytes_vec = utils::ty::bytes_vec(); + let ty_ret = utils::ty::bytes_result(&ty_bytes_vec); + let name = utils::ident::to_bytes(); + let ret_ty = utils::misc::ret_ty(&ty_ret); + let ident_result = utils::ident::result(); + + let arms = variants.iter() + .enumerate() + .map(|(idx, v)| { + let idx = idx as u8; + let ident = &v.ident; + let fields = variant_ident_vec(v); + let left = match &v.fields { + syn::Fields::Unit => quote::quote!(Self::#ident), + syn::Fields::Named(_) => quote::quote!(Self::#ident { #(#fields),* }), + syn::Fields::Unnamed(_) => quote::quote!(Self::#ident( #(#fields),* )) + }; + quote::quote!(#left => { + let mut #ident_result = odra::prelude::vec![#idx]; + #(#ident_result.extend_from_slice(&#fields.to_bytes()?);)* + Ok(#ident_result) + }) + }) + .collect::>(); + Ok(Self { + fn_item: FnItem::new(&name, vec![], ret_ty, match_self_expr(arms).as_block()).instanced() + }) + } + + fn from_unit_enum(_variants: Vec) -> syn::Result { let ty_bytes_vec = utils::ty::bytes_vec(); let ty_ret = utils::ty::bytes_result(&ty_bytes_vec); let name = utils::ident::to_bytes(); @@ -267,16 +360,16 @@ struct SerializedLengthFnItem { impl_from_ir!(SerializedLengthFnItem); impl SerializedLengthFnItem { - fn from_struct(ir: &TypeIR) -> syn::Result { + fn from_struct(fields: Vec<(syn::Ident, syn::Type)>) -> syn::Result { let ty_usize = utils::ty::usize(); let ident_result = utils::ident::result(); - let stmts = ir.map_fields(|i| { + let stmts = fields.iter().map(|(i, _)| { let member = utils::member::_self(i); let expr = utils::expr::serialized_length(&member); let stmt: syn::Stmt = parse_quote!(#ident_result += #expr;); stmt - })?; + }).collect::>(); let name = utils::ident::serialized_length(); let ret_ty = utils::misc::ret_ty(&ty_usize); @@ -290,7 +383,30 @@ impl SerializedLengthFnItem { }) } - fn from_enum(_ir: &TypeIR) -> syn::Result { + fn from_enum(variants: Vec) -> syn::Result { + let ty_usize = utils::ty::usize(); + let name = utils::ident::serialized_length(); + let ret_ty = utils::misc::ret_ty(&ty_usize); + let expr_u8_serialized_len = utils::expr::u8_serialized_len(); + + let arms = variants.iter() + .map(|v| { + let ident = &v.ident; + let fields = variant_ident_vec(v); + let left = match &v.fields { + syn::Fields::Unit => quote::quote!(Self::#ident), + syn::Fields::Named(_) => quote::quote!(Self::#ident { #(#fields),* }), + syn::Fields::Unnamed(_) => quote::quote!(Self::#ident( #(#fields),* )) + }; + quote::quote!(#left => #expr_u8_serialized_len #(+ #fields.serialized_length())* ) + }) + .collect::>(); + Ok(Self { + fn_item: FnItem::new(&name, vec![], ret_ty, match_self_expr(arms).as_block()).instanced() + }) + } + + fn from_unit_enum(_variants: Vec) -> syn::Result { let ty_usize = utils::ty::usize(); let name = utils::ident::serialized_length(); let ret_ty = utils::misc::ret_ty(&ty_usize); @@ -301,6 +417,22 @@ impl SerializedLengthFnItem { } } +fn variant_ident_vec(variant: &syn::Variant) -> Vec { + variant.fields + .clone() + .iter() + .enumerate() + .map(|(idx, i)| match &i.ident { + Some(ident) => ident.clone(), + None => format_ident!("f{}", idx) + }) + .collect::>() +} + +fn match_self_expr(arms: T) -> syn::Expr { + parse_quote!(match self { #arms }) +} + #[cfg(test)] mod tests { use super::*; @@ -368,7 +500,7 @@ mod tests { } #[test] - fn test_enum() { + fn test_unit_enum() { let ir = test_utils::mock::custom_enum(); let item = OdraTypeItem::try_from(&ir).unwrap(); let expected = quote!( @@ -384,8 +516,8 @@ mod tests { fn from_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), odra::casper_types::bytesrepr::Error> { let (result, bytes): (u8, _) = odra::casper_types::bytesrepr::FromBytes::from_bytes(bytes)?; match result { - x if x == MyType::A as u8 => Ok((MyType::A, bytes)), - x if x == MyType::B as u8 => Ok((MyType::B, bytes)), + x if x == Self::A as u8 => Ok((Self::A, bytes)), + x if x == Self::B as u8 => Ok((Self::B, bytes)), _ => Err(odra::casper_types::bytesrepr::Error::Formatting), } } @@ -421,4 +553,106 @@ mod tests { test_utils::assert_eq(item, expected); } + + #[test] + fn test_complex_enum() { + let ir = test_utils::mock::custom_complex_enum(); + let item = OdraTypeItem::try_from(&ir).unwrap(); + let expected = quote!( + #[derive(Clone, PartialEq, Eq, Debug)] + enum MyType { + /// Description of A + A { a: String, b: u32 }, + /// Description of B + B(u32, String), + /// Description of C + C(), + /// Description of D + D {}, + } + + impl odra::casper_types::bytesrepr::FromBytes for MyType { + fn from_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), odra::casper_types::bytesrepr::Error> { + let (result, bytes): (u8, _) = odra::casper_types::bytesrepr::FromBytes::from_bytes(bytes)?; + match result { + 0u8 => { + let (a, bytes) = odra::casper_types::bytesrepr::FromBytes::from_bytes(bytes)?; + let (b, bytes) = odra::casper_types::bytesrepr::FromBytes::from_bytes(bytes)?; + Ok((Self::A { a, b }, bytes)) + }, + 1u8 => { + let (f0, bytes) = odra::casper_types::bytesrepr::FromBytes::from_bytes(bytes)?; + let (f1, bytes) = odra::casper_types::bytesrepr::FromBytes::from_bytes(bytes)?; + Ok((Self::B(f0, f1), bytes)) + }, + 2u8 => Ok((Self::C(), bytes)), + 3u8 => Ok((Self::D {}, bytes)), + _ => Err(odra::casper_types::bytesrepr::Error::Formatting), + } + } + } + + impl odra::casper_types::bytesrepr::ToBytes for MyType { + fn to_bytes(&self) -> Result, odra::casper_types::bytesrepr::Error> { + match self { + Self::A { a, b } => { + let mut result = odra::prelude::vec![0u8]; + result.extend_from_slice(&a.to_bytes()?); + result.extend_from_slice(&b.to_bytes()?); + Ok(result) + }, + Self::B(f0, f1) => { + let mut result = odra::prelude::vec![1u8]; + result.extend_from_slice(&f0.to_bytes()?); + result.extend_from_slice(&f1.to_bytes()?); + Ok(result) + }, + Self::C() => { + let mut result = odra::prelude::vec![2u8]; + Ok(result) + }, + Self::D {} => { + let mut result = odra::prelude::vec![3u8]; + Ok(result) + } + } + } + + fn serialized_length(&self) -> usize { + match self { + Self::A { a, b } => odra::casper_types::bytesrepr::U8_SERIALIZED_LENGTH + a.serialized_length() + b.serialized_length(), + Self::B(f0, f1) => odra::casper_types::bytesrepr::U8_SERIALIZED_LENGTH + f0.serialized_length() + f1.serialized_length(), + Self::C() => odra::casper_types::bytesrepr::U8_SERIALIZED_LENGTH, + Self::D {} => odra::casper_types::bytesrepr::U8_SERIALIZED_LENGTH, + } + } + } + + impl odra::casper_types::CLTyped for MyType { + fn cl_type() -> odra::casper_types::CLType { + odra::casper_types::CLType::Any + } + } + + impl odra::contract_def::HasEvents for MyType { + fn events() -> odra::prelude::vec::Vec { + odra::prelude::vec::Vec::new() + } + + #[cfg(target_arch = "wasm32")] + fn event_schemas() -> odra::prelude::BTreeMap { + odra::prelude::BTreeMap::new() + } + } + ); + + test_utils::assert_eq(item, expected); + } + + #[test] + fn test_union() { + let ir = test_utils::mock::custom_union(); + let item = OdraTypeItem::try_from(&ir); + assert!(item.is_err()); + } } diff --git a/odra-macros/src/ast/schema/custom_item.rs b/odra-macros/src/ast/schema/custom_item.rs index beee40fd..f2f2d0f6 100644 --- a/odra-macros/src/ast/schema/custom_item.rs +++ b/odra-macros/src/ast/schema/custom_item.rs @@ -1,11 +1,10 @@ use quote::ToTokens; -use crate::{ast::utils::Named, ir::TypeIR, utils}; +use syn::Fields; +use crate::{ast::utils::Named, ir::{TypeIR, TypeKind}, utils}; pub struct SchemaCustomTypeItem { ty_ident: syn::Ident, - is_enum: bool, - fields: Vec, - variants: Vec + kind: TypeKind, } impl ToTokens for SchemaCustomTypeItem { @@ -13,18 +12,40 @@ impl ToTokens for SchemaCustomTypeItem { let name = &self.ty_ident.to_string(); let ident = &self.ty_ident; - let custom_item = match self.is_enum { - true => custom_enum(name, &self.variants), - false => custom_struct(name, &self.fields) + let custom_item = match &self.kind { + TypeKind::UnitEnum { variants } => custom_enum(name, variants), + TypeKind::Enum { variants } => custom_complex_enum(name, variants), + TypeKind::Struct { fields } => custom_struct(name, fields), }; - let sub_types = self.fields - .iter() - .map(|f| { - let ty = &f.ty; - quote::quote!(.chain(<#ty as odra::schema::SchemaCustomTypes>::schema_types())) - }) - .collect::>(); + let sub_types = match &self.kind { + TypeKind::Struct { fields } => fields + .iter() + .map(|(_, ty)| { + quote::quote!(.chain(<#ty as odra::schema::SchemaCustomTypes>::schema_types())) + }) + .collect::>(), + _ => Vec::new(), + }; + + let enum_sub_types = match &self.kind { + TypeKind::Enum { variants } => variants.iter().filter_map(|v| { + match &v.fields { + Fields::Named(f) => { + let fields = f.named.iter().map(|f| { + let name = f.ident.as_ref().unwrap().to_string(); + let ty = &f.ty; + quote::quote!(odra::schema::struct_member::<#ty>(#name)) + }).collect::>(); + let ty_name = format!("{}::{}", name, v.ident); + Some(quote::quote!(odra::schema::custom_struct(#ty_name, odra::prelude::vec![#(#fields),*]))) + } + Fields::Unnamed(_) => None, + Fields::Unit => None + } + }).collect::>(), + _ => Vec::new() + }; let item = quote::quote! { #[automatically_derived] @@ -34,6 +55,7 @@ impl ToTokens for SchemaCustomTypeItem { odra::prelude::BTreeSet::>::new() .into_iter() .chain(odra::prelude::vec![Some(#custom_item)]) + .chain(odra::prelude::vec![#(Some(#enum_sub_types)),*]) #(#sub_types)* .collect::>() } @@ -57,55 +79,67 @@ impl ToTokens for SchemaCustomTypeItem { } fn custom_enum(name: &str, variants: &[syn::Variant]) -> proc_macro2::TokenStream { - let variants = utils::syn::transform_variants(variants, |name, discriminant, _| { + let variants = utils::syn::transform_variants(variants, |name, _, discriminant, _| { quote::quote!(odra::schema::enum_variant(#name, #discriminant),) }); quote::quote!(odra::schema::custom_enum(#name, #variants)) } -fn custom_struct(name: &str, fields: &[syn::Field]) -> proc_macro2::TokenStream { - let members = fields.iter().map(|f| { - let name = f.ident.as_ref().unwrap().to_string(); - let ty = &f.ty; - quote::quote! { - odra::schema::struct_member::<#ty>(#name), +fn custom_complex_enum(enum_name: &str, variants: &[syn::Variant]) -> proc_macro2::TokenStream { + let variants = utils::syn::transform_variants(variants, |name, fields, discriminant, _| match fields { + Fields::Named(_) => { + match fields.len() { + 0 => quote::quote!(odra::schema::enum_variant(#name, #discriminant),), + _ => { + let ty_name = format!("{}::{}", enum_name, name); + quote::quote!(odra::schema::enum_custom_type_variant(#name, #discriminant, #ty_name),) + } + } } + Fields::Unnamed(_) => { + match fields.len() { + 0 => quote::quote!(odra::schema::enum_variant(#name, #discriminant),), + 1 => { + let ty = fields.iter().next().unwrap().ty.clone(); + let ty = quote::quote!(#ty); + quote::quote!(odra::schema::enum_typed_variant::<#ty>(#name, #discriminant),) + } + _ => { + let mut ty = proc_macro2::TokenStream::new(); + syn::token::Paren::default().surround(&mut ty, |tokens| { + fields.iter().for_each(|f| { + let ty = &f.ty; + tokens.extend(quote::quote!(#ty,)) + }); + }); + quote::quote!(odra::schema::enum_typed_variant::<#ty>(#name, #discriminant),) + } + } + } + Fields::Unit => quote::quote!(odra::schema::enum_variant(#name, #discriminant),), }); + quote::quote!(odra::schema::custom_enum(#enum_name, #variants)) +} +fn custom_struct(name: &str, fields: &[(syn::Ident, syn::Type)]) -> proc_macro2::TokenStream { + let members = fields + .iter() + .map(|(ident, ty)| { + let name = ident.to_string(); + quote::quote!(odra::schema::struct_member::<#ty>(#name)) + }); - quote::quote!(odra::schema::custom_struct(#name, odra::prelude::vec![#(#members)*])) + quote::quote!(odra::schema::custom_struct(#name, odra::prelude::vec![#(#members,)*])) } impl TryFrom<&TypeIR> for SchemaCustomTypeItem { type Error = syn::Error; fn try_from(ir: &TypeIR) -> Result { - let item = ir.self_code(); - if matches!(item, syn::Item::Struct(_) | syn::Item::Enum(_)) { - let fields = if let syn::Item::Struct(s) = item { - utils::syn::extract_named_field(s)? - } else { - vec![] - }; - - let variants = if let syn::Item::Enum(e) = item { - utils::syn::extract_unit_variants(e)? - } else { - vec![] - }; - - Ok(Self { - ty_ident: ir.name()?, - is_enum: ir.is_enum(), - fields, - variants - }) - } else { - Err(syn::Error::new_spanned( - item, - "Struct with named fields or a unit variants enum expected" - )) - } + Ok(Self { + ty_ident: ir.name()?, + kind: ir.kind()?, + }) } } @@ -135,6 +169,7 @@ mod tests { ] )) ]) + .chain(odra::prelude::vec![]) .chain(::schema_types()) .chain(::schema_types()) .collect::>() @@ -160,7 +195,7 @@ mod tests { } #[test] - fn test_enum() { + fn test_unit_enum() { let ir = test_utils::mock::custom_enum(); let item = SchemaCustomTypeItem::try_from(&ir).unwrap(); let expected = quote!( @@ -179,6 +214,7 @@ mod tests { ] )) ]) + .chain(odra::prelude::vec![]) .collect::>() } } @@ -200,4 +236,65 @@ mod tests { test_utils::assert_eq(item, expected); } + + #[test] + fn test_complex_enum() { + let ir = test_utils::mock::custom_complex_enum(); + let item = SchemaCustomTypeItem::try_from(&ir).unwrap(); + let expected = quote!( + #[automatically_derived] + #[cfg(not(target_arch = "wasm32"))] + impl odra::schema::SchemaCustomTypes for MyType { + fn schema_types() -> odra::prelude::vec::Vec> { + odra::prelude::BTreeSet::>::new() + .into_iter() + .chain(odra::prelude::vec![ + Some(odra::schema::custom_enum( + "MyType", + odra::prelude::vec![ + odra::schema::enum_custom_type_variant("A", 0u16, "MyType::A"), + odra::schema::enum_typed_variant::<(u32, String,)>("B", 1u16), + odra::schema::enum_variant("C", 2u16), + odra::schema::enum_variant("D", 3u16), + ] + )) + ]) + .chain(odra::prelude::vec![ + Some(odra::schema::custom_struct( + "MyType::A", + odra::prelude::vec![ + odra::schema::struct_member::("a"), + odra::schema::struct_member::("b") + ] + )), + Some(odra::schema::custom_struct("MyType::D", odra::prelude::vec![])) + ]) + .collect::>() + } + } + + #[automatically_derived] + #[cfg(not(target_arch = "wasm32"))] + impl odra::schema::NamedCLTyped for MyType { + fn ty() -> odra::schema::casper_contract_schema::NamedCLType { + odra::schema::casper_contract_schema::NamedCLType::Custom(String::from( + "MyType" + )) + } + } + + #[automatically_derived] + #[cfg(not(target_arch = "wasm32"))] + impl odra::schema::SchemaCustomElement for MyType {} + ); + + test_utils::assert_eq(item, expected); + } + + #[test] + fn test_union() { + let ir = test_utils::mock::custom_union(); + let item = SchemaCustomTypeItem::try_from(&ir); + assert!(item.is_err()); + } } diff --git a/odra-macros/src/ast/schema/errors.rs b/odra-macros/src/ast/schema/errors.rs index 6bbccb1d..8ae5212f 100644 --- a/odra-macros/src/ast/schema/errors.rs +++ b/odra-macros/src/ast/schema/errors.rs @@ -93,7 +93,7 @@ impl TryFrom<&TypeIR> for SchemaErrorItem { } fn enum_variants(variants: &[syn::Variant]) -> proc_macro2::TokenStream { - utils::syn::transform_variants(variants, |name, discriminant, docs| { + utils::syn::transform_variants(variants, |name, _, discriminant, docs| { let description = docs.first().cloned().unwrap_or_default().trim().to_string(); quote::quote!(odra::schema::error(#name, #description, #discriminant),) }) diff --git a/odra-macros/src/ir/mod.rs b/odra-macros/src/ir/mod.rs index 494cc5d7..b4d17d65 100644 --- a/odra-macros/src/ir/mod.rs +++ b/odra-macros/src/ir/mod.rs @@ -5,7 +5,7 @@ use crate::utils; use config::ConfigItem; use proc_macro2::Ident; use quote::{format_ident, ToTokens}; -use syn::{parse_quote, ImplItem}; +use syn::{parse_quote, spanned::Spanned, ImplItem}; use self::attr::OdraAttribute; @@ -653,22 +653,45 @@ impl TypeIR { &self.code } - pub fn fields(&self) -> syn::Result> { - utils::syn::derive_item_variants(&self.code) - } - - pub fn map_fields(&self, func: F) -> syn::Result> - where - F: FnMut(&syn::Ident) -> R - { - Ok(self.fields()?.iter().map(func).collect::>()) - } - - pub fn is_enum(&self) -> bool { - matches!(self.code, syn::Item::Enum(_)) + pub fn kind(&self) -> syn::Result { + match &self.code { + syn::Item::Enum(e) => { + let is_unit = e.variants.iter().all(|v| v.fields.is_empty()); + let variants = e.variants.iter().cloned().collect(); + if is_unit { + Ok(TypeKind::UnitEnum { variants }) + } else { + Ok(TypeKind::Enum { variants }) + } + } + syn::Item::Struct(syn::ItemStruct { fields, .. }) => { + let fields = fields + .iter() + .map(|f| { + f.ident + .clone() + .map(|i| (i, f.ty.clone())) + .ok_or(syn::Error::new(f.span(), "Unnamed field")) + }) + .collect::, _>>()?; + Ok(TypeKind::Struct { fields }) + } + _ => Err(syn::Error::new_spanned( + &self.code, + "Invalid type. Only enums and structs are supported" + )) + } } +} - pub fn is_struct(&self) -> bool { - matches!(self.code, syn::Item::Struct(_)) +pub enum TypeKind { + UnitEnum { + variants: Vec + }, + Enum { + variants: Vec + }, + Struct { + fields: Vec<(syn::Ident, syn::Type)> } } diff --git a/odra-macros/src/test_utils.rs b/odra-macros/src/test_utils.rs index f03fa923..deaffb61 100644 --- a/odra-macros/src/test_utils.rs +++ b/odra-macros/src/test_utils.rs @@ -150,6 +150,32 @@ pub mod mock { TypeIR::try_from(&ty).unwrap() } + pub fn custom_complex_enum() -> TypeIR { + let ty = quote!( + enum MyType { + /// Description of A + A { a: String, b: u32 }, + /// Description of B + B(u32, String), + /// Description of C + C(), + /// Description of D + D {} + } + ); + TypeIR::try_from(&ty).unwrap() + } + + pub fn custom_union() -> TypeIR { + let ty = quote!( + union MyUnion { + f1: u32, + f2: f32, + } + ); + TypeIR::try_from(&ty).unwrap() + } + pub fn ext_contract() -> ModuleImplIR { let ext = quote!( pub trait Token { diff --git a/odra-macros/src/utils/syn.rs b/odra-macros/src/utils/syn.rs index 8938e005..0701023e 100644 --- a/odra-macros/src/utils/syn.rs +++ b/odra-macros/src/utils/syn.rs @@ -1,5 +1,5 @@ use proc_macro2::TokenStream; -use syn::{parse_quote, spanned::Spanned}; +use syn::{parse_quote, spanned::Spanned, Fields}; pub fn ident_from_impl(impl_code: &syn::ItemImpl) -> syn::Result { last_segment_ident(&impl_code.self_ty) @@ -100,36 +100,6 @@ fn map_fields syn::Result>( } } -pub fn derive_item_variants(item: &syn::Item) -> syn::Result> { - match &item { - syn::Item::Struct(syn::ItemStruct { fields, .. }) => fields - .iter() - .map(|f| { - f.ident - .clone() - .ok_or(syn::Error::new(f.span(), "Unnamed field")) - }) - .collect::, _>>(), - syn::Item::Enum(syn::ItemEnum { variants, .. }) => { - let is_valid = variants - .iter() - .all(|v| matches!(v.fields, syn::Fields::Unit)); - if is_valid { - Ok(variants.iter().map(|v| v.ident.clone()).collect::>()) - } else { - Err(syn::Error::new_spanned( - variants, - "Expected a unit enum variant." - )) - } - } - _ => Err(syn::Error::new_spanned( - item, - "Struct with named fields expected" - )) - } -} - pub fn visibility_pub() -> syn::Visibility { parse_quote!(pub) } @@ -211,20 +181,6 @@ pub fn is_ref(ty: &syn::Type) -> bool { matches!(ty, syn::Type::Reference(_)) } -pub fn extract_named_field(input: &syn::ItemStruct) -> syn::Result> { - let fields = &input.fields; - fields - .iter() - .map(|f| { - if f.ident.is_none() { - Err(syn::Error::new(f.span(), "Unnamed field")) - } else { - Ok(f.clone()) - } - }) - .collect() -} - pub fn extract_unit_variants(input: &syn::ItemEnum) -> syn::Result> { let variants = &input.variants; let is_valid = variants @@ -240,7 +196,7 @@ pub fn extract_unit_variants(input: &syn::ItemEnum) -> syn::Result) -> TokenStream>( +pub fn transform_variants) -> TokenStream>( variants: &[syn::Variant], f: F ) -> TokenStream { @@ -248,12 +204,14 @@ pub fn transform_variants) -> TokenStream>( let variants = variants.iter().map(|v| { let docs = string_docs(&v.attrs); let name = v.ident.to_string(); + let fields = v.fields.clone(); + if let Some((_, syn::Expr::Lit(lit))) = &v.discriminant { if let syn::Lit::Int(int) = &lit.lit { discriminant = int.base10_parse().unwrap(); } }; - let result = f(name, discriminant, docs); + let result = f(name, fields, discriminant, docs); discriminant += 1; result }); diff --git a/odra-schema/src/lib.rs b/odra-schema/src/lib.rs index 549633f7..593e9e5d 100644 --- a/odra-schema/src/lib.rs +++ b/odra-schema/src/lib.rs @@ -112,11 +112,16 @@ pub fn enum_typed_variant(name: &str, discriminant: u16) -> Enu /// Creates a new enum variant of type [NamedCLType::Unit]. pub fn enum_variant(name: &str, discriminant: u16) -> EnumVariant { + enum_typed_variant::<()>(name, discriminant) +} + +/// Creates a new enum variant of type [NamedCLType::Custom]. +pub fn enum_custom_type_variant(name: &str, discriminant: u16, custom_type: &str) -> EnumVariant { EnumVariant { name: name.to_string(), description: None, discriminant, - ty: NamedCLType::Unit.into() + ty: NamedCLType::Custom(custom_type.into()).into() } } @@ -323,12 +328,17 @@ mod test { #[test] fn test_custom_enum() { - let variant = super::enum_variant("variant1", 1); - let custom_enum = super::custom_enum("enum1", vec![variant]); + let variant1 = super::enum_variant("variant1", 1); + let variant2 = super::enum_typed_variant::("v2", 2); + let variant3 = super::enum_custom_type_variant("v3", 3, "Type1"); + let custom_enum = super::custom_enum("enum1", vec![variant1, variant2, variant3]); match custom_enum { casper_contract_schema::CustomType::Enum { name, variants, .. } => { assert_eq!(name, "enum1".into()); - assert_eq!(variants.len(), 1); + assert_eq!(variants.len(), 3); + assert_eq!(variants[0].ty, NamedCLType::Unit.into()); + assert_eq!(variants[1].ty, NamedCLType::String.into()); + assert_eq!(variants[2].ty, NamedCLType::Custom("Type1".into()).into()); } _ => panic!("Expected CustomType::Enum") }