From d838a87673480690bb2e9027bd73a2c840d37655 Mon Sep 17 00:00:00 2001 From: Joshua Liebow-Feeser Date: Sat, 9 Sep 2023 11:59:15 -0700 Subject: [PATCH] [derive] Support TryFromBytes for structs TODO: - Should DataExt::fields have a different shape? Extracting field names for enums doesn't really make sense - Add SAFETY comments in emitted `is_bit_valid` impl - Do we need to update the TryFromBytes doc comment at all? - Lots and lots of tests Makes progress on #5 --- src/byteorder.rs | 8 +++-- src/derive_util.rs | 12 +++++++ src/lib.rs | 29 +++++++++++---- src/macros.rs | 22 ++++++++---- src/wrappers.rs | 7 +++- zerocopy-derive/src/ext.rs | 56 +++++++++++++++++++++-------- zerocopy-derive/src/lib.rs | 74 +++++++++++++++++++++++++++++--------- 7 files changed, 161 insertions(+), 47 deletions(-) diff --git a/src/byteorder.rs b/src/byteorder.rs index ecee7a07c9f..639b70662fe 100644 --- a/src/byteorder.rs +++ b/src/byteorder.rs @@ -176,16 +176,20 @@ example of how it can be used for parsing UDP packets. [`AsBytes`]: crate::AsBytes [`Unaligned`]: crate::Unaligned"), #[derive(Copy, Clone, Eq, PartialEq, Hash)] - #[cfg_attr(any(feature = "derive", test), derive(FromZeroes, FromBytes, AsBytes, Unaligned))] + #[cfg_attr(any(feature = "derive", test), derive(TryFromBytes, FromZeroes, FromBytes, AsBytes, Unaligned))] #[repr(transparent)] pub struct $name([u8; $bytes], PhantomData); } + impl_known_layout!(O => $name); + safety_comment! { /// SAFETY: /// `$name` is `repr(transparent)`, and so it has the same layout /// as its only non-zero field, which is a `u8` array. `u8` arrays - /// are `FromZeroes`, `FromBytes`, `AsBytes`, and `Unaligned`. + /// are `TryFromBytes`, `FromZeroes`, `FromBytes`, `AsBytes`, and + /// `Unaligned`. + impl_or_verify!(O => TryFromBytes for $name); impl_or_verify!(O => FromZeroes for $name); impl_or_verify!(O => FromBytes for $name); impl_or_verify!(O => AsBytes for $name); diff --git a/src/derive_util.rs b/src/derive_util.rs index edf88e3bd7b..4fd76b15caf 100644 --- a/src/derive_util.rs +++ b/src/derive_util.rs @@ -66,6 +66,7 @@ macro_rules! union_has_padding { #[cfg(test)] mod tests { use crate::util::testutil::*; + use crate::*; #[test] fn test_struct_has_padding() { @@ -124,4 +125,15 @@ mod tests { // anyway. test!(#[repr(C)] #[repr(packed)] {a: u8, b: u64} => true); } + + #[test] + fn foo() { + #[derive(TryFromBytes)] + struct Foo { + f: u8, + b: bool, + } + + impl_known_layout!(Foo); + } } diff --git a/src/lib.rs b/src/lib.rs index d3555af16d2..db5455d23b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -994,6 +994,11 @@ pub unsafe trait FromBytes: FromZeroes { /// [github-repo]: https://github.com/google/zerocopy /// [`try_from_ref`]: TryFromBytes::try_from_ref pub unsafe trait TryFromBytes: KnownLayout { + #[doc(hidden)] + fn only_derive_is_allowed_to_implement_this_trait() + where + Self: Sized; + /// Does a given memory range contain a valid instance of `Self`? /// /// # Safety @@ -3613,10 +3618,16 @@ mod tests { // // This is used to test the custom derives of our traits. The `[u8]` type // gets a hand-rolled impl, so it doesn't exercise our custom derives. - #[derive(Debug, Eq, PartialEq, FromZeroes, FromBytes, AsBytes, Unaligned)] + #[derive(Debug, Eq, PartialEq, TryFromBytes, FromZeroes, FromBytes, AsBytes, Unaligned)] #[repr(transparent)] struct Unsized([u8]); + // SAFETY: `Unsized` is a `#[repr(transparent)]` wrapper around `[u8]`. + unsafe_impl_known_layout!( + #[repr([u8])] + Unsized + ); + impl Unsized { fn from_mut_slice(slc: &mut [u8]) -> &mut Unsized { // SAFETY: This *probably* sound - since the layouts of `[u8]` and @@ -4550,7 +4561,7 @@ mod tests { assert_eq!(too_many_bytes[0], 123); } - #[derive(Debug, Eq, PartialEq, FromZeroes, FromBytes, AsBytes)] + #[derive(Debug, Eq, PartialEq, TryFromBytes, FromZeroes, FromBytes, AsBytes)] #[repr(C)] struct Foo { a: u32, @@ -4558,6 +4569,8 @@ mod tests { c: Option, } + impl_known_layout!(Foo); + let expected_bytes: Vec = if cfg!(target_endian = "little") { vec![1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0] } else { @@ -4579,12 +4592,14 @@ mod tests { #[test] fn test_array() { - #[derive(FromZeroes, FromBytes, AsBytes)] + #[derive(TryFromBytes, FromZeroes, FromBytes, AsBytes)] #[repr(C)] struct Foo { a: [u16; 33], } + impl_known_layout!(Foo); + let foo = Foo { a: [0xFFFF; 33] }; let expected = [0xFFu8; 66]; assert_eq!(foo.as_bytes(), &expected[..]); @@ -4643,23 +4658,25 @@ mod tests { #[test] fn test_transparent_packed_generic_struct() { - #[derive(AsBytes, FromZeroes, FromBytes, Unaligned)] + #[derive(AsBytes, TryFromBytes, FromZeroes, FromBytes, Unaligned)] #[repr(transparent)] struct Foo { _t: T, _phantom: PhantomData<()>, } + impl_known_layout!(T => Foo); assert_impl_all!(Foo: FromZeroes, FromBytes, AsBytes); assert_impl_all!(Foo: Unaligned); - #[derive(AsBytes, FromZeroes, FromBytes, Unaligned)] + #[derive(AsBytes, TryFromBytes, FromZeroes, FromBytes, Unaligned)] #[repr(packed)] struct Bar { _t: T, _u: U, } + impl_known_layout!(T, U => Bar); assert_impl_all!(Bar: FromZeroes, FromBytes, AsBytes, Unaligned); } @@ -4923,7 +4940,7 @@ mod tests { assert_impls!(Wrapping: TryFromBytes, FromZeroes, AsBytes, Unaligned, !FromBytes); assert_impls!(Wrapping: !TryFromBytes, !FromZeroes, !FromBytes, !AsBytes, !Unaligned); - assert_impls!(Unalign: FromZeroes, FromBytes, AsBytes, Unaligned, !TryFromBytes); + assert_impls!(Unalign: TryFromBytes, FromZeroes, FromBytes, AsBytes, Unaligned, TryFromBytes); assert_impls!(Unalign: Unaligned, !TryFromBytes, !FromZeroes, !FromBytes, !AsBytes); assert_impls!([u8]: TryFromBytes, FromZeroes, FromBytes, AsBytes, Unaligned); diff --git a/src/macros.rs b/src/macros.rs index 3360899b3b9..b4a5ccd9c30 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -130,6 +130,9 @@ macro_rules! unsafe_impl { let $candidate = unsafe { &*(candidate.as_ptr() as *const $repr) }; $is_bit_valid } + + #[allow(clippy::missing_inline_in_public_items)] + fn only_derive_is_allowed_to_implement_this_trait() {} }; (@method TryFromBytes ; |$candidate:ident: NonNull<$repr:ty>| $is_bit_valid:expr) => { #[inline] @@ -141,8 +144,15 @@ macro_rules! unsafe_impl { let $candidate = unsafe { NonNull::new_unchecked(candidate.as_ptr() as *mut $repr) }; $is_bit_valid } + + #[allow(clippy::missing_inline_in_public_items)] + fn only_derive_is_allowed_to_implement_this_trait() {} + }; + (@method TryFromBytes) => { + #[inline(always)] unsafe fn is_bit_valid(_: NonNull) -> bool { true } + #[allow(clippy::missing_inline_in_public_items)] + fn only_derive_is_allowed_to_implement_this_trait() {} }; - (@method TryFromBytes) => { #[inline(always)] unsafe fn is_bit_valid(_: NonNull) -> bool { true } }; (@method $trait:ident) => { #[allow(clippy::missing_inline_in_public_items)] fn only_derive_is_allowed_to_implement_this_trait() {} @@ -253,17 +263,17 @@ macro_rules! impl_known_layout { ($(const $constvar:ident : $constty:ty, $tyvar:ident $(: ?$optbound:ident)? => $ty:ty),* $(,)?) => { $(impl_known_layout!(@inner const $constvar: $constty, $tyvar $(: ?$optbound)? => $ty);)* }; - ($($tyvar:ident $(: ?$optbound:ident)? => $ty:ty),* $(,)?) => { - $(impl_known_layout!(@inner , $tyvar $(: ?$optbound)? => $ty);)* + ($($($tyvar:ident $(: ?$optbound:ident)?),* => $ty:ty),* $(,)?) => { + $(impl_known_layout!(@inner , $($tyvar $(: ?$optbound)?),* => $ty);)* }; ($($ty:ty),*) => { $(impl_known_layout!(@inner , => $ty);)* }; - (@inner $(const $constvar:ident : $constty:ty)? , $($tyvar:ident $(: ?$optbound:ident)?)? => $ty:ty) => { + (@inner $(const $constvar:ident : $constty:ty)? , $($tyvar:ident $(: ?$optbound:ident)?),* => $ty:ty) => { const _: () = { use core::ptr::NonNull; - impl<$(const $constvar : $constty,)? $($tyvar $(: ?$optbound)?)?> sealed::KnownLayoutSealed for $ty {} + impl<$(const $constvar : $constty,)? $($tyvar $(: ?$optbound)?),*> sealed::KnownLayoutSealed for $ty {} // SAFETY: Delegates safety to `DstLayout::for_type`. - unsafe impl<$(const $constvar : $constty,)? $($tyvar $(: ?$optbound)?)?> KnownLayout for $ty { + unsafe impl<$(const $constvar : $constty,)? $($tyvar $(: ?$optbound)?),*> KnownLayout for $ty { const LAYOUT: DstLayout = DstLayout::for_type::<$ty>(); // SAFETY: `.cast` preserves address and provenance. diff --git a/src/wrappers.rs b/src/wrappers.rs index a0e6ac70edc..dfecf462748 100644 --- a/src/wrappers.rs +++ b/src/wrappers.rs @@ -54,10 +54,15 @@ use super::*; // [3] https://github.com/google/zerocopy/issues/209 #[allow(missing_debug_implementations)] #[derive(Default, Copy)] -#[cfg_attr(any(feature = "derive", test), derive(FromZeroes, FromBytes, AsBytes, Unaligned))] +#[cfg_attr( + any(feature = "derive", test), + derive(TryFromBytes, FromZeroes, FromBytes, AsBytes, Unaligned) +)] #[repr(C, packed)] pub struct Unalign(T); +impl_known_layout!(T => Unalign); + safety_comment! { /// SAFETY: /// - `Unalign` is `repr(packed)`, so it is unaligned regardless of the diff --git a/zerocopy-derive/src/ext.rs b/zerocopy-derive/src/ext.rs index ff8a3d6596a..8c482a84997 100644 --- a/zerocopy-derive/src/ext.rs +++ b/zerocopy-derive/src/ext.rs @@ -2,48 +2,74 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use syn::{Data, DataEnum, DataStruct, DataUnion, Type}; +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::{Data, DataEnum, DataStruct, DataUnion, Field, Index, Type}; pub trait DataExt { - /// Extract the types of all fields. For enums, extract the types of fields - /// from each variant. - fn field_types(&self) -> Vec<&Type>; + /// Extract the names and types of all fields. For enums, extract the names + /// and types of fields from each variant. For tuple structs, the names are + /// the indices used to index into the struct (ie, `0`, `1`, etc). + /// + /// TODO: Extracting field names for enums doesn't really make sense. Types + /// makes sense because we don't care about where they live - we just care + /// about transitive ownership. But for field names, we'd only use them when + /// generating is_bit_valid, which cares about where they live. + fn fields(&self) -> Vec<(TokenStream, &Type)>; } impl DataExt for Data { - fn field_types(&self) -> Vec<&Type> { + fn fields(&self) -> Vec<(TokenStream, &Type)> { match self { - Data::Struct(strc) => strc.field_types(), - Data::Enum(enm) => enm.field_types(), - Data::Union(un) => un.field_types(), + Data::Struct(strc) => strc.fields(), + Data::Enum(enm) => enm.fields(), + Data::Union(un) => un.fields(), } } } impl DataExt for DataStruct { - fn field_types(&self) -> Vec<&Type> { - self.fields.iter().map(|f| &f.ty).collect() + fn fields(&self) -> Vec<(TokenStream, &Type)> { + map_fields(&self.fields) } } impl DataExt for DataEnum { - fn field_types(&self) -> Vec<&Type> { - self.variants.iter().flat_map(|var| &var.fields).map(|f| &f.ty).collect() + fn fields(&self) -> Vec<(TokenStream, &Type)> { + map_fields(self.variants.iter().flat_map(|var| &var.fields)) } } impl DataExt for DataUnion { - fn field_types(&self) -> Vec<&Type> { - self.fields.named.iter().map(|f| &f.ty).collect() + fn fields(&self) -> Vec<(TokenStream, &Type)> { + map_fields(&self.fields.named) } } +fn map_fields<'a>( + fields: impl 'a + IntoIterator, +) -> Vec<(TokenStream, &'a Type)> { + fields + .into_iter() + .enumerate() + .map(|(idx, f)| { + ( + f.ident + .as_ref() + .map(ToTokens::to_token_stream) + .unwrap_or_else(|| Index::from(idx).to_token_stream()), + &f.ty, + ) + }) + .collect() +} + pub trait EnumExt { fn is_c_like(&self) -> bool; } impl EnumExt for DataEnum { fn is_c_like(&self) -> bool { - self.field_types().is_empty() + self.fields().is_empty() } } diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index 6b9e3a40f24..8b03dd4dd1d 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -49,12 +49,28 @@ use {crate::ext::*, crate::repr::*}; // help: required by the derive of FromBytes // // Instead, we have more verbose error messages like "unsupported representation -// for deriving FromZeroes, FromBytes, AsBytes, or Unaligned on an enum" +// for deriving TryFromBytes, FromZeroes, FromBytes, AsBytes, or Unaligned on an +// enum" // // This will probably require Span::error // (https://doc.rust-lang.org/nightly/proc_macro/struct.Span.html#method.error), // which is currently unstable. Revisit this once it's stable. +#[proc_macro_derive(TryFromBytes)] +pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { + let ast = syn::parse_macro_input!(ts as DeriveInput); + match &ast.data { + Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct), + Data::Enum(_) => { + Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error() + } + Data::Union(_) => { + Error::new_spanned(&ast, "TryFromBytes not supported on union types").to_compile_error() + } + } + .into() +} + #[proc_macro_derive(FromZeroes)] pub fn derive_from_zeroes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { let ast = syn::parse_macro_input!(ts as DeriveInput); @@ -110,6 +126,10 @@ macro_rules! try_or_print { }; } +fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { + impl_block(ast, strct, "TryFromBytes", true, None, true) +} + const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ &[StructRepr::C], &[StructRepr::Transparent], @@ -121,7 +141,7 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ // - all fields are `FromZeroes` fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromZeroes", true, None) + impl_block(ast, strct, "FromZeroes", true, None, false) } // An enum is `FromZeroes` if: @@ -155,21 +175,21 @@ fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::To .to_compile_error(); } - impl_block(ast, enm, "FromZeroes", true, None) + impl_block(ast, enm, "FromZeroes", true, None, false) } // Like structs, unions are `FromZeroes` if // - all fields are `FromZeroes` fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromZeroes", true, None) + impl_block(ast, unn, "FromZeroes", true, None, false) } // A struct is `FromBytes` if: // - all fields are `FromBytes` fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromBytes", true, None) + impl_block(ast, strct, "FromBytes", true, None, false) } // An enum is `FromBytes` if: @@ -212,7 +232,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok .to_compile_error(); } - impl_block(ast, enm, "FromBytes", true, None) + impl_block(ast, enm, "FromBytes", true, None, false) } #[rustfmt::skip] @@ -243,7 +263,7 @@ const ENUM_FROM_BYTES_CFG: Config = { // - all fields are `FromBytes` fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromBytes", true, None) + impl_block(ast, unn, "FromBytes", true, None, false) } // A struct is `AsBytes` if: @@ -277,7 +297,7 @@ fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2: // any padding bytes would need to come from the fields, all of which // we require to be `AsBytes` (meaning they don't have any padding). let padding_check = if is_transparent || is_packed { None } else { Some(PaddingCheck::Struct) }; - impl_block(ast, strct, "AsBytes", true, padding_check) + impl_block(ast, strct, "AsBytes", true, padding_check, false) } const STRUCT_UNION_AS_BYTES_CFG: Config = Config { @@ -300,7 +320,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Token // We don't care what the repr is; we only care that it is one of the // allowed ones. let _: Vec = try_or_print!(ENUM_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, enm, "AsBytes", false, None) + impl_block(ast, enm, "AsBytes", false, None, false) } #[rustfmt::skip] @@ -342,7 +362,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union)) + impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union), false) } // A struct is `Unaligned` if: @@ -355,7 +375,7 @@ fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2 let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, strct, "Unaligned", require_trait_bound, None) + impl_block(ast, strct, "Unaligned", require_trait_bound, None, false) } const STRUCT_UNION_UNALIGNED_CFG: Config = Config { @@ -386,7 +406,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Toke // for `require_trait_bounds` doesn't really do anything. But it's // marginally more future-proof in case that restriction is lifted in the // future. - impl_block(ast, enm, "Unaligned", true, None) + impl_block(ast, enm, "Unaligned", true, None, false) } #[rustfmt::skip] @@ -424,7 +444,7 @@ fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::To let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, unn, "Unaligned", require_trait_bound, None) + impl_block(ast, unn, "Unaligned", require_trait_bound, None, false) } // This enum describes what kind of padding check needs to be generated for the @@ -455,6 +475,7 @@ fn impl_block( trait_name: &str, require_trait_bound: bool, padding_check: Option, + emit_is_bit_valid: bool, ) -> proc_macro2::TokenStream { // In this documentation, we will refer to this hypothetical struct: // @@ -516,18 +537,18 @@ fn impl_block( let type_ident = &input.ident; let trait_ident = Ident::new(trait_name, Span::call_site()); - let field_types = data.field_types(); + let fields = data.fields(); let field_type_bounds = require_trait_bound - .then(|| field_types.iter().map(|ty| parse_quote!(#ty: zerocopy::#trait_ident))) + .then(|| fields.iter().map(|(_name, ty)| parse_quote!(#ty: zerocopy::#trait_ident))) .into_iter() .flatten() .collect::>(); // Don't bother emitting a padding check if there are no fields. #[allow(unstable_name_collisions)] // See `BoolExt` below - let padding_check_bound = padding_check.and_then(|check| (!field_types.is_empty()).then_some(check)).map(|check| { - let fields = field_types.iter(); + let padding_check_bound = padding_check.and_then(|check| (!fields.is_empty()).then_some(check)).map(|check| { + let fields = fields.iter().map(|(_name, ty)| ty); let validator_macro = check.validator_macro_ident(); parse_quote!( zerocopy::derive_util::HasPadding<#type_ident, {zerocopy::#validator_macro!(#type_ident, #(#fields),*)}>: @@ -564,12 +585,31 @@ fn impl_block( GenericParam::Const(cnst) => quote!(#cnst), }); + let is_bit_valid = emit_is_bit_valid.then(|| { + let field_names = fields.iter().map(|(name, _ty)| name); + let field_tys = fields.iter().map(|(_name, ty)| ty); + quote!( + unsafe fn is_bit_valid(candidate: ::core::ptr::NonNull) -> bool { + let _c = candidate.as_ptr(); + true #(&& { + let field_candidate = ::core::ptr::addr_of_mut!((*_c).#field_names); + // SAFETY: TODO + let f = unsafe { ::core::ptr::NonNull::new_unchecked(field_candidate) }; + // SAFETY: TODO + <#field_tys as zerocopy::TryFromBytes>::is_bit_valid(f) + })* + } + ) + }); + quote! { unsafe impl < #(#params),* > zerocopy::#trait_ident for #type_ident < #(#param_idents),* > where #(#bounds,)* { fn only_derive_is_allowed_to_implement_this_trait() {} + + #is_bit_valid } } }