From e951ed1510214d09794168f1b385289359b76b1c Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Thu, 27 Jul 2023 16:24:03 +0200 Subject: [PATCH] fix(derive): `ArrayType` derive for unit structs --- narrow-derive/src/struct/mod.rs | 24 +- narrow-derive/src/struct/unit.rs | 246 +++++++++++------- narrow-derive/src/{util/mod.rs => util.rs} | 16 +- narrow-derive/src/util/bounds.rs | 1 - .../struct/unit/const_generic.expanded.rs | 58 ++++- .../unit/const_generic_default.expanded.rs | 58 ++++- .../expand/struct/unit/simple.expanded.rs | 50 +++- .../struct/unit/where_clause.expanded.rs | 80 +++++- src/array/mod.rs | 5 +- src/array/struct.rs | 13 + tests/derive.rs | 41 +-- 11 files changed, 430 insertions(+), 162 deletions(-) rename narrow-derive/src/{util/mod.rs => util.rs} (90%) delete mode 100644 narrow-derive/src/util/bounds.rs diff --git a/narrow-derive/src/struct/mod.rs b/narrow-derive/src/struct/mod.rs index ffdd40f4..45acb728 100644 --- a/narrow-derive/src/struct/mod.rs +++ b/narrow-derive/src/struct/mod.rs @@ -1,5 +1,7 @@ +use crate::util; use proc_macro2::TokenStream; -use syn::{DeriveInput, Fields}; +use quote::{format_ident, quote}; +use syn::{DeriveInput, Fields, Generics, Ident}; mod unit; @@ -9,3 +11,23 @@ pub(super) fn derive(input: &DeriveInput, fields: &Fields) -> TokenStream { _ => todo!("non unit structs derive"), } } + +fn array_type_ident(ident: &Ident) -> Ident { + format_ident!("{}Array", ident) +} + +fn array_type_impl(ident: &Ident, generics: &Generics) -> TokenStream { + let narrow = util::narrow(); + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + quote! { + impl #impl_generics #narrow::array::ArrayType for #ident #ty_generics #where_clause { + type Array = #narrow::array::StructArray<#ident #ty_generics, false, Buffer>; + } + + impl #impl_generics #narrow::array::ArrayType<#ident #ty_generics> for ::std::option::Option<#ident #ty_generics> #where_clause { + type Array = #narrow::array::StructArray<#ident #ty_generics, true, Buffer>; + } + } +} diff --git a/narrow-derive/src/struct/unit.rs b/narrow-derive/src/struct/unit.rs index e8fbc8ee..f179d25c 100644 --- a/narrow-derive/src/struct/unit.rs +++ b/narrow-derive/src/struct/unit.rs @@ -1,8 +1,12 @@ +use crate::{ + r#struct::{array_type_ident, array_type_impl}, + util, +}; use proc_macro2::TokenStream; -use quote::{format_ident, quote}; -use syn::{parse_quote, visit_mut::VisitMut, DeriveInput, Generics, Ident, Visibility}; - -use crate::util; +use quote::quote; +use syn::{ + parse_quote, DeriveInput, Generics, Ident, ImplGenerics, TypeGenerics, Visibility, WhereClause, +}; pub(super) fn derive(input: &DeriveInput) -> TokenStream { let DeriveInput { @@ -12,117 +16,179 @@ pub(super) fn derive(input: &DeriveInput) -> TokenStream { .. } = input; - // let narrow = util::narrow(); - - // Construct the raw array wrapper. - // let raw_array_def = quote! { - // #[doc = #raw_array_doc] - // #vis struct #raw_array_ident #array_generics - // (#narrow::array::null::NullArray<#ident #ty_generics>) #array_where_clause; - // }; - - // Add NullArray generics. - // let mut array_generics = generics.clone(); - // AddNullableConstGeneric.visit_generics_mut(&mut array_generics); - - // // let generics = SelfReplace::generics(&ident, &generics); - // let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - // // Replace `Self` with ident in where clauses. - // let mut where_clause = where_clause.cloned(); - // where_clause - // .as_mut() - // .map(|where_clause| - // SelfReplace(&ident).visit_where_clause_mut(where_clause)); - - // let array_generics = SelfReplace::generics(ident, generics); - // let (array_impl_generics, array_ty_generics, array_where_clause) = - // array_generics.split_for_impl(); - - // // Add a type definition for the array of this type. - // let alias_array_ident = util::alias_array_ident(&ident); - // let alias_array = quote!( - // #[doc = #raw_array_doc] - // #[automatically_derived] - // #vis type #alias_array_ident #generics #where_clause = - // #narrow::array::r#struct::StructArray<#ident #ty_generics>; ); - - // Implement ArrayType for this type. - // let array_type_impl = quote! { - // #[automatically_derived] - // impl #impl_generics #narrow::array::ArrayType for #ident #ty_generics - // #where_clause { type Array = - // #raw_array_ident #array_generics; } - // }; - - let raw_array = raw_array_def(vis, ident, generics); - - // todo impl unit for this type + // Generate the ArrayType implementation. + let array_type_impl = array_type_impl(ident, generics); + + // Generate the StructArrayType implementation. + let struct_array_type_impl = struct_array_type_impl(ident, generics); + + // Generate the Unit implementation. + let unit_impl = unit_impl(ident, generics); + + // Generate the array type definition. + let array_type_def = array_type_def(vis, ident, generics); + + // Generate the length impl. + let length_impl = length_impl(ident, generics); + + // Generate the FromIterator impl. + let from_iterator_impl = from_iterator_impl(ident, generics); + + // Generate the Extend impl. + let extend_impl = extend_impl(ident, generics); + + // Generate the Default impl. + let default_impl = default_impl(ident, generics); + + quote! { + #array_type_impl + + #struct_array_type_impl + + #unit_impl + + #array_type_def + + #from_iterator_impl + + #length_impl + + #extend_impl + + #default_impl + } +} + +fn unit_impl(ident: &Ident, generics: &Generics) -> TokenStream { + let narrow = util::narrow(); + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); quote! { - #raw_array + /// Safety: + /// - This is a unit struct. + unsafe impl #impl_generics #narrow::array::Unit for #ident #ty_generics #where_clause {} + } +} + +fn struct_array_type_impl(ident: &Ident, generics: &Generics) -> TokenStream { + let narrow = util::narrow(); + + let array_type_ident = array_type_ident(ident); - // #array_type_impl + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let array_generics = NullArrayGenerics::new(generics); + let (_, ty_generics_with_buffer, _) = array_generics.split_for_impl(); + + quote! { + impl #impl_generics #narrow::array::StructArrayType for #ident #ty_generics #where_clause { + type Array = #array_type_ident #ty_generics_with_buffer; + } } - // #alias_array - // } } -pub(super) fn raw_array_def(vis: &Visibility, ident: &Ident, generics: &Generics) -> TokenStream { +fn array_type_def(vis: &Visibility, ident: &Ident, generics: &Generics) -> TokenStream { let narrow = util::narrow(); - // Get the ident and doc for the raw array struct. - let (raw_array_ident, raw_array_doc) = util::raw_array(ident); + let array_type_ident = array_type_ident(ident); - // Get the ty_generics of the inner type. let (_, ty_generics, where_clause) = generics.split_for_impl(); - // Add the const generic trait bound for nullability. - let generics = NullArrayGenerics::generics(generics); - let nullarray_nullable = NullArrayGenerics::nullable_generic(); - let nullarray_validity_bitmap_buffer = NullArrayGenerics::validity_bitmap_buffer_generic(); - // let (_, _, _) = generics.split_for_impl(); + let array_generics = NullArrayGenerics::new(generics); + let (impl_generics_with_buffer, _, _) = array_generics.split_for_impl(); quote! { - #[doc = #raw_array_doc] - #vis struct #raw_array_ident #generics( - #narrow::array::null::NullArray< - #ident #ty_generics, - #nullarray_nullable, - #nullarray_validity_bitmap_buffer - > - ) #where_clause; + #vis struct #array_type_ident #impl_generics_with_buffer( + #narrow::array::NullArray<#ident #ty_generics, false, Buffer> + ) #where_clause; } } -struct NullArrayGenerics; +fn from_iterator_impl(ident: &Ident, generics: &Generics) -> TokenStream { + let array_type_ident = array_type_ident(ident); -impl NullArrayGenerics { - fn generics(generics: &Generics) -> Generics { - let mut generics = generics.clone(); - Self.visit_generics_mut(&mut generics); - generics + let (_, ty_generics, where_clause) = generics.split_for_impl(); + + let array_generics = NullArrayGenerics::new(generics); + let (impl_generics_with_buffer, ty_generics_with_buffer, _) = array_generics.split_for_impl(); + + quote! { + impl #impl_generics_with_buffer ::std::iter::FromIterator<#ident #ty_generics> for #array_type_ident #ty_generics_with_buffer #where_clause { + fn from_iter<_I: ::std::iter::IntoIterator>(iter: _I) -> Self { + Self(iter.into_iter().collect()) + } + } } +} + +fn length_impl(ident: &Ident, generics: &Generics) -> TokenStream { + let narrow = util::narrow(); - fn nullable_generic() -> Ident { - format_ident!("_NARROW_NULLABLE") + let array_type_ident = array_type_ident(ident); + + let (_, _ty_generics, where_clause) = generics.split_for_impl(); + + let array_generics = NullArrayGenerics::new(generics); + let (impl_generics_with_buffer, ty_generics_with_buffer, _) = array_generics.split_for_impl(); + + quote! { + impl #impl_generics_with_buffer #narrow::Length for #array_type_ident #ty_generics_with_buffer #where_clause { + #[inline] + fn len(&self) -> usize { + self.0.len() + } + } } +} + +fn extend_impl(ident: &Ident, generics: &Generics) -> TokenStream { + let array_type_ident = array_type_ident(ident); - fn validity_bitmap_buffer_generic() -> Ident { - format_ident!("_NARROW_VALIDITY_BITMAP_BUFFER") + let (_, ty_generics, where_clause) = generics.split_for_impl(); + + let array_generics = NullArrayGenerics::new(generics); + let (impl_generics_with_buffer, ty_generics_with_buffer, _) = array_generics.split_for_impl(); + + quote! { + impl #impl_generics_with_buffer ::std::iter::Extend<#ident #ty_generics> for #array_type_ident #ty_generics_with_buffer #where_clause { + fn extend<_I: ::std::iter::IntoIterator>(&mut self, iter: _I) { + self.0.extend(iter) + } + } } } -impl VisitMut for NullArrayGenerics { - fn visit_generics_mut(&mut self, generics: &mut Generics) { - let nullable = Self::nullable_generic(); - generics - .params - .push(parse_quote!(const #nullable: bool = false)); +fn default_impl(ident: &Ident, generics: &Generics) -> TokenStream { + let array_type_ident = array_type_ident(ident); + + let (_, _, where_clause) = generics.split_for_impl(); + + let array_generics = NullArrayGenerics::new(generics); + let (impl_generics_with_buffer, ty_generics_with_buffer, _) = array_generics.split_for_impl(); + + quote! { + impl #impl_generics_with_buffer ::std::default::Default for #array_type_ident #ty_generics_with_buffer #where_clause { + fn default() -> Self { + Self(::std::default::Default::default()) + } + } + } +} + +struct NullArrayGenerics(Generics); - let validity_bitmap_buffer = Self::validity_bitmap_buffer_generic(); - generics +impl NullArrayGenerics { + fn new(generics: &Generics) -> Self { + let narrow = util::narrow(); + + let mut generics_with_buffer = generics.clone(); + generics_with_buffer .params - .push(parse_quote!(#validity_bitmap_buffer = Vec)); + .push(parse_quote!(Buffer: #narrow::buffer::BufferType = #narrow::buffer::VecBuffer)); + Self(generics_with_buffer) + } + fn split_for_impl(&self) -> (ImplGenerics, TypeGenerics, Option<&WhereClause>) { + self.0.split_for_impl() } } diff --git a/narrow-derive/src/util/mod.rs b/narrow-derive/src/util.rs similarity index 90% rename from narrow-derive/src/util/mod.rs rename to narrow-derive/src/util.rs index bab78fe6..c2481eb8 100644 --- a/narrow-derive/src/util/mod.rs +++ b/narrow-derive/src/util.rs @@ -1,13 +1,10 @@ +use crate::NARROW; use proc_macro2::TokenStream; use quote::{format_ident, quote}; use syn::{ parse_quote, visit::Visit, visit_mut::VisitMut, Generics, Ident, Type, TypePath, WherePredicate, }; -use crate::NARROW; - -mod bounds; - /// Returns the name of the `narrow` crate. Panics when the `narrow` crate is /// not found. pub(super) fn narrow() -> TokenStream { @@ -15,17 +12,6 @@ pub(super) fn narrow() -> TokenStream { quote!(#ident) } -pub(super) fn raw_array(ident: &Ident) -> (Ident, String) { - ( - format_ident!("Raw{}Array", ident), - format!(" Array with [{ident}] values."), - ) -} - -// pub(super) fn alias_array_ident(ident: &Ident) -> Ident { -// format_ident!("{}Array", ident) -// } - /// Replace Self with ident in where clauses. pub(super) struct SelfReplace<'a> { ident: &'a Ident, diff --git a/narrow-derive/src/util/bounds.rs b/narrow-derive/src/util/bounds.rs deleted file mode 100644 index 8b137891..00000000 --- a/narrow-derive/src/util/bounds.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/narrow-derive/tests/expand/struct/unit/const_generic.expanded.rs b/narrow-derive/tests/expand/struct/unit/const_generic.expanded.rs index ee2642f7..741f7f40 100644 --- a/narrow-derive/tests/expand/struct/unit/const_generic.expanded.rs +++ b/narrow-derive/tests/expand/struct/unit/const_generic.expanded.rs @@ -1,13 +1,51 @@ pub struct Foo; -/// Array with [Foo] values. -pub struct RawFooArray< - const N: usize, - const _NARROW_NULLABLE: bool = false, - _NARROW_VALIDITY_BITMAP_BUFFER = Vec, ->( - narrow::array::null::NullArray< +impl narrow::array::ArrayType for Foo { + type Array = narrow::array::StructArray< + Foo, + false, + Buffer, + >; +} +impl narrow::array::ArrayType> for ::std::option::Option> { + type Array = narrow::array::StructArray< Foo, - _NARROW_NULLABLE, - _NARROW_VALIDITY_BITMAP_BUFFER, - >, + true, + Buffer, + >; +} +impl narrow::array::StructArrayType for Foo { + type Array = FooArray; +} +/// Safety: +/// - This is a unit struct. +unsafe impl narrow::array::Unit for Foo {} +pub struct FooArray( + narrow::array::NullArray, false, Buffer>, ); +impl< + const N: usize, + Buffer: narrow::buffer::BufferType, +> ::std::iter::FromIterator> for FooArray { + fn from_iter<_I: ::std::iter::IntoIterator>>(iter: _I) -> Self { + Self(iter.into_iter().collect()) + } +} +impl narrow::Length +for FooArray { + #[inline] + fn len(&self) -> usize { + self.0.len() + } +} +impl ::std::iter::Extend> +for FooArray { + fn extend<_I: ::std::iter::IntoIterator>>(&mut self, iter: _I) { + self.0.extend(iter) + } +} +impl ::std::default::Default +for FooArray { + fn default() -> Self { + Self(::std::default::Default::default()) + } +} diff --git a/narrow-derive/tests/expand/struct/unit/const_generic_default.expanded.rs b/narrow-derive/tests/expand/struct/unit/const_generic_default.expanded.rs index 050fba07..fdd405f4 100644 --- a/narrow-derive/tests/expand/struct/unit/const_generic_default.expanded.rs +++ b/narrow-derive/tests/expand/struct/unit/const_generic_default.expanded.rs @@ -1,13 +1,51 @@ pub struct Foo; -/// Array with [Foo] values. -pub struct RawFooArray< - const N: usize = 42, - const _NARROW_NULLABLE: bool = false, - _NARROW_VALIDITY_BITMAP_BUFFER = Vec, ->( - narrow::array::null::NullArray< +impl narrow::array::ArrayType for Foo { + type Array = narrow::array::StructArray< Foo, - _NARROW_NULLABLE, - _NARROW_VALIDITY_BITMAP_BUFFER, - >, + false, + Buffer, + >; +} +impl narrow::array::ArrayType> for ::std::option::Option> { + type Array = narrow::array::StructArray< + Foo, + true, + Buffer, + >; +} +impl narrow::array::StructArrayType for Foo { + type Array = FooArray; +} +/// Safety: +/// - This is a unit struct. +unsafe impl narrow::array::Unit for Foo {} +pub struct FooArray( + narrow::array::NullArray, false, Buffer>, ); +impl< + const N: usize, + Buffer: narrow::buffer::BufferType, +> ::std::iter::FromIterator> for FooArray { + fn from_iter<_I: ::std::iter::IntoIterator>>(iter: _I) -> Self { + Self(iter.into_iter().collect()) + } +} +impl narrow::Length +for FooArray { + #[inline] + fn len(&self) -> usize { + self.0.len() + } +} +impl ::std::iter::Extend> +for FooArray { + fn extend<_I: ::std::iter::IntoIterator>>(&mut self, iter: _I) { + self.0.extend(iter) + } +} +impl ::std::default::Default +for FooArray { + fn default() -> Self { + Self(::std::default::Default::default()) + } +} diff --git a/narrow-derive/tests/expand/struct/unit/simple.expanded.rs b/narrow-derive/tests/expand/struct/unit/simple.expanded.rs index eb4cf2f8..ced55eb1 100644 --- a/narrow-derive/tests/expand/struct/unit/simple.expanded.rs +++ b/narrow-derive/tests/expand/struct/unit/simple.expanded.rs @@ -1,8 +1,46 @@ struct Foo; -/// Array with [Foo] values. -struct RawFooArray< - const _NARROW_NULLABLE: bool = false, - _NARROW_VALIDITY_BITMAP_BUFFER = Vec, ->( - narrow::array::null::NullArray, +impl narrow::array::ArrayType for Foo { + type Array = narrow::array::StructArray< + Foo, + false, + Buffer, + >; +} +impl narrow::array::ArrayType for ::std::option::Option { + type Array = narrow::array::StructArray< + Foo, + true, + Buffer, + >; +} +impl narrow::array::StructArrayType for Foo { + type Array = FooArray; +} +/// Safety: +/// - This is a unit struct. +unsafe impl narrow::array::Unit for Foo {} +struct FooArray( + narrow::array::NullArray, ); +impl ::std::iter::FromIterator +for FooArray { + fn from_iter<_I: ::std::iter::IntoIterator>(iter: _I) -> Self { + Self(iter.into_iter().collect()) + } +} +impl narrow::Length for FooArray { + #[inline] + fn len(&self) -> usize { + self.0.len() + } +} +impl ::std::iter::Extend for FooArray { + fn extend<_I: ::std::iter::IntoIterator>(&mut self, iter: _I) { + self.0.extend(iter) + } +} +impl ::std::default::Default for FooArray { + fn default() -> Self { + Self(::std::default::Default::default()) + } +} diff --git a/narrow-derive/tests/expand/struct/unit/where_clause.expanded.rs b/narrow-derive/tests/expand/struct/unit/where_clause.expanded.rs index 701d481d..5e5267dd 100644 --- a/narrow-derive/tests/expand/struct/unit/where_clause.expanded.rs +++ b/narrow-derive/tests/expand/struct/unit/where_clause.expanded.rs @@ -1,17 +1,77 @@ pub(super) struct Foo where Self: Sized; -/// Array with [Foo] values. -pub(super) struct RawFooArray< - const N: bool = false, - const _NARROW_NULLABLE: bool = false, - _NARROW_VALIDITY_BITMAP_BUFFER = Vec, ->( - narrow::array::null::NullArray< +impl narrow::array::ArrayType for Foo +where + Self: Sized, +{ + type Array = narrow::array::StructArray< + Foo, + false, + Buffer, + >; +} +impl narrow::array::ArrayType> for ::std::option::Option> +where + Self: Sized, +{ + type Array = narrow::array::StructArray< Foo, - _NARROW_NULLABLE, - _NARROW_VALIDITY_BITMAP_BUFFER, - >, + true, + Buffer, + >; +} +impl narrow::array::StructArrayType for Foo +where + Self: Sized, +{ + type Array = FooArray; +} +/// Safety: +/// - This is a unit struct. +unsafe impl narrow::array::Unit for Foo +where + Self: Sized, +{} +pub(super) struct FooArray( + narrow::array::NullArray, false, Buffer>, ) where Self: Sized; +impl ::std::iter::FromIterator> +for FooArray +where + Self: Sized, +{ + fn from_iter<_I: ::std::iter::IntoIterator>>(iter: _I) -> Self { + Self(iter.into_iter().collect()) + } +} +impl narrow::Length +for FooArray +where + Self: Sized, +{ + #[inline] + fn len(&self) -> usize { + self.0.len() + } +} +impl ::std::iter::Extend> +for FooArray +where + Self: Sized, +{ + fn extend<_I: ::std::iter::IntoIterator>>(&mut self, iter: _I) { + self.0.extend(iter) + } +} +impl ::std::default::Default +for FooArray +where + Self: Sized, +{ + fn default() -> Self { + Self(::std::default::Default::default()) + } +} diff --git a/src/array/mod.rs b/src/array/mod.rs index 60af690b..4aeb76f9 100644 --- a/src/array/mod.rs +++ b/src/array/mod.rs @@ -26,7 +26,10 @@ pub use variable_size_list::*; pub trait Array {} -pub trait ArrayType { +// Note: the default generics is required for to allow impls outside on foreign wrappers. +// See https://rust-lang.github.io/rfcs/2451-re-rebalancing-coherence.html +// The resulting behavior when using anything other than the default `Self` is undefined. +pub trait ArrayType { type Array: Array; } diff --git a/src/array/struct.rs b/src/array/struct.rs index abe45322..0eb6bf22 100644 --- a/src/array/struct.rs +++ b/src/array/struct.rs @@ -86,6 +86,9 @@ mod tests { impl<'a> ArrayType for Foo<'a> { type Array = StructArray, false, Buffer>; } + impl<'a> ArrayType for Option> { + type Array = StructArray, true, Buffer>; + } struct FooArray<'a, Buffer: BufferType> { a: ::Array, @@ -193,6 +196,15 @@ mod tests { type Array = FooArray<'a, Buffer>; } + impl<'a, Buffer: BufferType> Length for FooArray<'a, Buffer> + where + ::Array: Length, + { + fn len(&self) -> usize { + self.a.len() + } + } + // And then: let input = [ Foo { @@ -233,6 +245,7 @@ mod tests { }, ]; let array = input.into_iter().collect::>(); + assert_eq!(array.len(), 2); assert_eq!(array.0.a.into_iter().collect::>(), &[1, 2, 3, 4]); assert_eq!( array.0.b.into_iter().collect::>(), diff --git a/tests/derive.rs b/tests/derive.rs index 074f0284..d100422d 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -1,25 +1,30 @@ #[cfg(feature = "derive")] mod tests { - // use narrow::Array; + use narrow::{array::StructArray, ArrayType, Length}; - // #[test] - // fn unit_struct() { - // #[derive(Array)] - // pub struct Foo; - // } + #[test] + fn unit_struct() { + #[derive(ArrayType, Copy, Clone, Default)] + pub struct Foo; - // #[test] - // fn unit_struct_generic() { - // #[derive(Array)] - // pub struct Bar - // where - // Self: Sized; + let input = [Foo; 5]; + let array = input.into_iter().collect::>(); + assert_eq!(array.len(), 5); - // // [Bar::, Bar::] - // // .iter() - // // .collect::>(); - // } + let input = [Some(Foo); 5]; + let array = input.into_iter().collect::>(); + assert_eq!(array.len(), 5); + } - // #[test] - // fn unit_struct_lifetime() {} + #[test] + fn unit_struct_generic() { + #[derive(ArrayType, Copy, Clone, Default)] + pub struct Bar + where + Self: Sized; + + let input = [Bar, Bar]; + let array = input.into_iter().collect::>(); + assert_eq!(array.len(), 2); + } }