diff --git a/src/lib.rs b/src/lib.rs index 2c7aad6dea5..0fb86ac7c31 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -254,7 +254,7 @@ pub use crate::wrappers::*; #[cfg(any(feature = "derive", test))] #[cfg_attr(doc_cfg, doc(cfg(feature = "derive")))] -pub use zerocopy_derive::Unaligned; +pub use zerocopy_derive::{TryFromBytes, Unaligned}; // `pub use` separately here so that we can mark it `#[doc(hidden)]`. // @@ -1149,6 +1149,13 @@ pub use zerocopy_derive::FromZeroes; // TODO(#5): Update `try_from_ref` doc link once it exists #[doc(hidden)] pub unsafe trait TryFromBytes { + // The `Self: Sized` bound makes it so that `TryFromBytes` is still object + // safe. + #[doc(hidden)] + fn only_derive_is_allowed_to_implement_this_trait() + where + Self: Sized; + /// Does a given memory range contain a valid instance of `Self`? /// /// # Safety diff --git a/src/macros.rs b/src/macros.rs index 2da78af7df8..833ae387a7e 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -133,6 +133,9 @@ macro_rules! unsafe_impl { }; (@method TryFromBytes ; |$candidate:ident: &$repr:ty| $is_bit_valid:expr) => { + #[allow(clippy::missing_inline_in_public_items)] + fn only_derive_is_allowed_to_implement_this_trait() {} + #[inline] unsafe fn is_bit_valid(candidate: Ptr<'_, Self>) -> bool { // SAFETY: @@ -160,6 +163,9 @@ macro_rules! unsafe_impl { } }; (@method TryFromBytes ; |$candidate:ident: Ptr<$repr:ty>| $is_bit_valid:expr) => { + #[allow(clippy::missing_inline_in_public_items)] + fn only_derive_is_allowed_to_implement_this_trait() {} + #[inline] unsafe fn is_bit_valid(candidate: Ptr<'_, Self>) -> bool { // SAFETY: @@ -174,7 +180,11 @@ macro_rules! unsafe_impl { $is_bit_valid } }; - (@method TryFromBytes) => { #[inline(always)] unsafe fn is_bit_valid(_: Ptr<'_, Self>) -> bool { true } }; + (@method TryFromBytes) => { + #[allow(clippy::missing_inline_in_public_items)] + fn only_derive_is_allowed_to_implement_this_trait() {} + #[inline(always)] unsafe fn is_bit_valid(_: Ptr<'_, Self>) -> bool { true } + }; (@method $trait:ident) => { #[allow(clippy::missing_inline_in_public_items)] fn only_derive_is_allowed_to_implement_this_trait() {} diff --git a/src/util.rs b/src/util.rs index f018e876d7d..5b91a48e966 100644 --- a/src/util.rs +++ b/src/util.rs @@ -139,6 +139,23 @@ pub(crate) mod ptr { // - `U: 'a` Ptr { ptr, _lifetime: PhantomData } } + + /// TODO + /// + /// # Safety + /// + /// TODO + #[doc(hidden)] + pub unsafe fn project( + self, + project: impl FnOnce(*mut T) -> *mut U, + ) -> Ptr<'a, U> { + let field = project(self.ptr.as_ptr()); + // SAFETY: TODO + let field = unsafe { NonNull::new_unchecked(field) }; + // SAFETY: TODO + Ptr { ptr: field, _lifetime: PhantomData } + } } impl<'a> Ptr<'a, [u8]> { diff --git a/zerocopy-derive/src/ext.rs b/zerocopy-derive/src/ext.rs index 87cf838f883..def7cf12407 100644 --- a/zerocopy-derive/src/ext.rs +++ b/zerocopy-derive/src/ext.rs @@ -6,48 +6,74 @@ // This file may not be copied, modified, or distributed except according to // those terms. -use syn::{Data, DataEnum, DataStruct, DataUnion, Type}; +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::{Data, DataEnum, DataStruct, DataUnion, Field, Index, Type}; pub trait DataExt { - /// Extract the types of all fields. For enums, extract the types of fields - /// from each variant. - fn field_types(&self) -> Vec<&Type>; + /// Extract the names and types of all fields. For enums, extract the names + /// and types of fields from each variant. For tuple structs, the names are + /// the indices used to index into the struct (ie, `0`, `1`, etc). + /// + /// TODO: Extracting field names for enums doesn't really make sense. Types + /// makes sense because we don't care about where they live - we just care + /// about transitive ownership. But for field names, we'd only use them when + /// generating is_bit_valid, which cares about where they live. + fn fields(&self) -> Vec<(TokenStream, &Type)>; } impl DataExt for Data { - fn field_types(&self) -> Vec<&Type> { + fn fields(&self) -> Vec<(TokenStream, &Type)> { match self { - Data::Struct(strc) => strc.field_types(), - Data::Enum(enm) => enm.field_types(), - Data::Union(un) => un.field_types(), + Data::Struct(strc) => strc.fields(), + Data::Enum(enm) => enm.fields(), + Data::Union(un) => un.fields(), } } } impl DataExt for DataStruct { - fn field_types(&self) -> Vec<&Type> { - self.fields.iter().map(|f| &f.ty).collect() + fn fields(&self) -> Vec<(TokenStream, &Type)> { + map_fields(&self.fields) } } impl DataExt for DataEnum { - fn field_types(&self) -> Vec<&Type> { - self.variants.iter().flat_map(|var| &var.fields).map(|f| &f.ty).collect() + fn fields(&self) -> Vec<(TokenStream, &Type)> { + map_fields(self.variants.iter().flat_map(|var| &var.fields)) } } impl DataExt for DataUnion { - fn field_types(&self) -> Vec<&Type> { - self.fields.named.iter().map(|f| &f.ty).collect() + fn fields(&self) -> Vec<(TokenStream, &Type)> { + map_fields(&self.fields.named) } } +fn map_fields<'a>( + fields: impl 'a + IntoIterator, +) -> Vec<(TokenStream, &'a Type)> { + fields + .into_iter() + .enumerate() + .map(|(idx, f)| { + ( + f.ident + .as_ref() + .map(ToTokens::to_token_stream) + .unwrap_or_else(|| Index::from(idx).to_token_stream()), + &f.ty, + ) + }) + .collect() +} + pub trait EnumExt { fn is_c_like(&self) -> bool; } impl EnumExt for DataEnum { fn is_c_like(&self) -> bool { - self.field_types().is_empty() + self.fields().is_empty() } } diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index 9af8a28a06a..6841f3ecaf6 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -86,13 +86,16 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre Data::Enum(..) | Data::Union(..) => None, }; - let fields = ast.data.field_types(); + let fields = ast.data.fields(); let (require_self_sized, extras) = if let ( Some(reprs), Some((trailing_field, leading_fields)), ) = (is_repr_c_struct, fields.split_last()) { + let (_name, trailing_field_ty) = trailing_field; + let leading_fields_tys = leading_fields.iter().map(|(_name, ty)| ty); + let repr_align = reprs .iter() .find_map( @@ -142,8 +145,8 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre let repr_packed = #repr_packed; DstLayout::new_zst(repr_align) - #(.extend(DstLayout::for_type::<#leading_fields>(), repr_packed))* - .extend(<#trailing_field as KnownLayout>::LAYOUT, repr_packed) + #(.extend(DstLayout::for_type::<#leading_fields_tys>(), repr_packed))* + .extend(<#trailing_field_ty as KnownLayout>::LAYOUT, repr_packed) .pad_to_align() }; @@ -157,7 +160,7 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre elems: usize, ) -> ::zerocopy::macro_util::core_reexport::ptr::NonNull { use ::zerocopy::{KnownLayout}; - let trailing = <#trailing_field as KnownLayout>::raw_from_ptr_len(bytes, elems); + let trailing = <#trailing_field_ty as KnownLayout>::raw_from_ptr_len(bytes, elems); let slf = trailing.as_ptr() as *mut Self; // SAFETY: Constructed from `trailing`, which is non-null. unsafe { ::zerocopy::macro_util::core_reexport::ptr::NonNull::new_unchecked(slf) } @@ -243,6 +246,21 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre .into() } +#[proc_macro_derive(TryFromBytes)] +pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { + let ast = syn::parse_macro_input!(ts as DeriveInput); + match &ast.data { + Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct), + Data::Enum(_) => { + Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error() + } + Data::Union(_) => { + Error::new_spanned(&ast, "TryFromBytes not supported on union types").to_compile_error() + } + } + .into() +} + #[proc_macro_derive(FromZeroes)] pub fn derive_from_zeroes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { let ast = syn::parse_macro_input!(ts as DeriveInput); @@ -287,6 +305,74 @@ pub fn derive_unaligned(ts: proc_macro::TokenStream) -> proc_macro::TokenStream .into() } +// A struct is `TryFromBytes` if: +// - all fields are `TryFromBytes` + +fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { + let reprs = try_or_print!(repr::reprs::(&ast.attrs)); + + for (meta, repr) in reprs { + if matches!(repr, Repr::Packed | Repr::PackedN(_)) { + try_or_print!(Err(vec![Error::new_spanned( + meta, + "cannot derive TryFromBytes with repr(packed)", + )])); + } + } + + let extras = Some({ + let fields = strct.fields(); + let field_names = fields.iter().map(|(name, _ty)| name); + let field_tys = fields.iter().map(|(_name, ty)| ty); + quote!( + // SAFETY: We use `is_bit_valid` to validate that each field is + // bit-valid, and only return `true` if all of them are. The bit + // validity of a struct is just the composition of the bit + // validities of its fields, so this is a sound implementation of + // `is_bit_valid`. + unsafe fn is_bit_valid(_candidate: zerocopy::Ptr) -> bool { + true #(&& { + let project = |slf: *mut Self| ::core::ptr::addr_of_mut!((*slf).#field_names); + // SAFETY: TODO + let field_candidate = unsafe { _candidate.project(project) }; + // SAFETY: TODO + // + // Old safety comment: + // + // SAFETY: + // - `f` is properly aligned for `#field_tys` because + // `candidate` is properly aligned for `Self`. + // - `f` is valid for reads because `candidate` is. + // - Total length encoded by `f` doesn't overflow `isize` + // because it's no greater than the size encoded by + // `candidate`, whose size doesn't overflow `isize`. + // - `f` addresses a range which falls inside a single + // allocation because that range is a subset of the range + // addressed by `candidate`, and that latter range falls + // inside a single allocation. + // - The bit validity property of `is_bit_valid` is + // trivially compositional for structs. In particular, in + // a struct, there is no data dependency between bit + // validity in any two byte offsets (this is notably not + // true of enums). Since we know that the bit validity + // property holds for all of `candidate`, we also know + // that it holds for `f` regardless of the contents of any + // other region of `candidate`. + // + // Note that it's possible that this call will panic - + // `is_bit_valid` does not promise that it doesn't panic, + // and in practice, we support user-defined validators, + // which could panic. This is sound because we haven't + // violated any safety invariants which we would need to fix + // before returning. + <#field_tys as zerocopy::TryFromBytes>::is_bit_valid(field_candidate) + })* + } + ) + }); + impl_block(ast, strct, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras) +} + const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ &[StructRepr::C], &[StructRepr::Transparent], @@ -637,6 +723,7 @@ impl PaddingCheck { #[derive(Debug, Eq, PartialEq)] enum Trait { KnownLayout, + TryFromBytes, FromZeroes, FromBytes, AsBytes, @@ -734,19 +821,19 @@ fn impl_block( let type_ident = &input.ident; let trait_ident = trt.ident(); - let field_types = data.field_types(); + let fields = data.fields(); let bound_tt = |ty| parse_quote!(#ty: ::zerocopy::#trait_ident); - let field_type_bounds: Vec<_> = match (require_trait_bound_on_field_types, &field_types[..]) { - (RequireBoundedFields::Yes, _) => field_types.iter().map(bound_tt).collect(), + let field_type_bounds: Vec<_> = match (require_trait_bound_on_field_types, &fields[..]) { + (RequireBoundedFields::Yes, _) => fields.iter().map(|(_name, ty)| bound_tt(ty)).collect(), (RequireBoundedFields::No, _) | (RequireBoundedFields::Trailing, []) => vec![], - (RequireBoundedFields::Trailing, [.., last]) => vec![bound_tt(last)], + (RequireBoundedFields::Trailing, [.., last]) => vec![bound_tt(&last.1)], }; // Don't bother emitting a padding check if there are no fields. #[allow(unstable_name_collisions)] // See `BoolExt` below - let padding_check_bound = padding_check.and_then(|check| (!field_types.is_empty()).then_some(check)).map(|check| { - let fields = field_types.iter(); + let padding_check_bound = padding_check.and_then(|check| (!fields.is_empty()).then_some(check)).map(|check| { + let fields = fields.iter().map(|(_name, ty)| ty); let validator_macro = check.validator_macro_ident(); parse_quote!( ::zerocopy::macro_util::HasPadding<#type_ident, {::zerocopy::#validator_macro!(#type_ident, #(#fields),*)}>: