Skip to content

Commit

Permalink
[derive] Support derive(TryFromBytes) for structs
Browse files Browse the repository at this point in the history
Supersedes #370.

Makes progress on #5.

Co-authored-by: Joshua Liebow-Feeser <hello@joshlf.com>
  • Loading branch information
jswrenn and joshlf committed Dec 1, 2023
1 parent 839e210 commit 149bdcd
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 27 deletions.
9 changes: 8 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ pub use crate::wrappers::*;

#[cfg(any(feature = "derive", test))]
#[cfg_attr(doc_cfg, doc(cfg(feature = "derive")))]
pub use zerocopy_derive::Unaligned;
pub use zerocopy_derive::{TryFromBytes, Unaligned};

// `pub use` separately here so that we can mark it `#[doc(hidden)]`.
//
Expand Down Expand Up @@ -1149,6 +1149,13 @@ pub use zerocopy_derive::FromZeroes;
// TODO(#5): Update `try_from_ref` doc link once it exists
#[doc(hidden)]
pub unsafe trait TryFromBytes {
// The `Self: Sized` bound makes it so that `TryFromBytes` is still object
// safe.
#[doc(hidden)]
fn only_derive_is_allowed_to_implement_this_trait()
where
Self: Sized;

/// Does a given memory range contain a valid instance of `Self`?
///
/// # Safety
Expand Down
12 changes: 11 additions & 1 deletion src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ macro_rules! unsafe_impl {
};

(@method TryFromBytes ; |$candidate:ident: &$repr:ty| $is_bit_valid:expr) => {
#[allow(clippy::missing_inline_in_public_items)]
fn only_derive_is_allowed_to_implement_this_trait() {}

#[inline]
unsafe fn is_bit_valid(candidate: Ptr<'_, Self>) -> bool {
// SAFETY:
Expand Down Expand Up @@ -160,6 +163,9 @@ macro_rules! unsafe_impl {
}
};
(@method TryFromBytes ; |$candidate:ident: Ptr<$repr:ty>| $is_bit_valid:expr) => {
#[allow(clippy::missing_inline_in_public_items)]
fn only_derive_is_allowed_to_implement_this_trait() {}

#[inline]
unsafe fn is_bit_valid(candidate: Ptr<'_, Self>) -> bool {
// SAFETY:
Expand All @@ -174,7 +180,11 @@ macro_rules! unsafe_impl {
$is_bit_valid
}
};
(@method TryFromBytes) => { #[inline(always)] unsafe fn is_bit_valid(_: Ptr<'_, Self>) -> bool { true } };
(@method TryFromBytes) => {
#[allow(clippy::missing_inline_in_public_items)]
fn only_derive_is_allowed_to_implement_this_trait() {}
#[inline(always)] unsafe fn is_bit_valid(_: Ptr<'_, Self>) -> bool { true }
};
(@method $trait:ident) => {
#[allow(clippy::missing_inline_in_public_items)]
fn only_derive_is_allowed_to_implement_this_trait() {}
Expand Down
17 changes: 17 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,23 @@ pub(crate) mod ptr {
// - `U: 'a`
Ptr { ptr, _lifetime: PhantomData }
}

/// TODO
///
/// # Safety
///
/// TODO
#[doc(hidden)]
pub unsafe fn project<U: 'a + ?Sized>(
self,
project: impl FnOnce(*mut T) -> *mut U,
) -> Ptr<'a, U> {
let field = project(self.ptr.as_ptr());
// SAFETY: TODO
let field = unsafe { NonNull::new_unchecked(field) };
// SAFETY: TODO
Ptr { ptr: field, _lifetime: PhantomData }
}
}

impl<'a> Ptr<'a, [u8]> {
Expand Down
56 changes: 41 additions & 15 deletions zerocopy-derive/src/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,74 @@
// This file may not be copied, modified, or distributed except according to
// those terms.

use syn::{Data, DataEnum, DataStruct, DataUnion, Type};
use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::{Data, DataEnum, DataStruct, DataUnion, Field, Index, Type};

pub trait DataExt {
/// Extract the types of all fields. For enums, extract the types of fields
/// from each variant.
fn field_types(&self) -> Vec<&Type>;
/// Extract the names and types of all fields. For enums, extract the names
/// and types of fields from each variant. For tuple structs, the names are
/// the indices used to index into the struct (ie, `0`, `1`, etc).
///
/// TODO: Extracting field names for enums doesn't really make sense. Types
/// makes sense because we don't care about where they live - we just care
/// about transitive ownership. But for field names, we'd only use them when
/// generating is_bit_valid, which cares about where they live.
fn fields(&self) -> Vec<(TokenStream, &Type)>;
}

impl DataExt for Data {
fn field_types(&self) -> Vec<&Type> {
fn fields(&self) -> Vec<(TokenStream, &Type)> {
match self {
Data::Struct(strc) => strc.field_types(),
Data::Enum(enm) => enm.field_types(),
Data::Union(un) => un.field_types(),
Data::Struct(strc) => strc.fields(),
Data::Enum(enm) => enm.fields(),
Data::Union(un) => un.fields(),
}
}
}

impl DataExt for DataStruct {
fn field_types(&self) -> Vec<&Type> {
self.fields.iter().map(|f| &f.ty).collect()
fn fields(&self) -> Vec<(TokenStream, &Type)> {
map_fields(&self.fields)
}
}

impl DataExt for DataEnum {
fn field_types(&self) -> Vec<&Type> {
self.variants.iter().flat_map(|var| &var.fields).map(|f| &f.ty).collect()
fn fields(&self) -> Vec<(TokenStream, &Type)> {
map_fields(self.variants.iter().flat_map(|var| &var.fields))
}
}

impl DataExt for DataUnion {
fn field_types(&self) -> Vec<&Type> {
self.fields.named.iter().map(|f| &f.ty).collect()
fn fields(&self) -> Vec<(TokenStream, &Type)> {
map_fields(&self.fields.named)
}
}

fn map_fields<'a>(
fields: impl 'a + IntoIterator<Item = &'a Field>,
) -> Vec<(TokenStream, &'a Type)> {
fields
.into_iter()
.enumerate()
.map(|(idx, f)| {
(
f.ident
.as_ref()
.map(ToTokens::to_token_stream)
.unwrap_or_else(|| Index::from(idx).to_token_stream()),
&f.ty,
)
})
.collect()
}

pub trait EnumExt {
fn is_c_like(&self) -> bool;
}

impl EnumExt for DataEnum {
fn is_c_like(&self) -> bool {
self.field_types().is_empty()
self.fields().is_empty()
}
}
107 changes: 97 additions & 10 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre
Data::Enum(..) | Data::Union(..) => None,
};

let fields = ast.data.field_types();
let fields = ast.data.fields();

let (require_self_sized, extras) = if let (
Some(reprs),
Some((trailing_field, leading_fields)),
) = (is_repr_c_struct, fields.split_last())
{
let (_name, trailing_field_ty) = trailing_field;
let leading_fields_tys = leading_fields.iter().map(|(_name, ty)| ty);

let repr_align = reprs
.iter()
.find_map(
Expand Down Expand Up @@ -142,8 +145,8 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre
let repr_packed = #repr_packed;

DstLayout::new_zst(repr_align)
#(.extend(DstLayout::for_type::<#leading_fields>(), repr_packed))*
.extend(<#trailing_field as KnownLayout>::LAYOUT, repr_packed)
#(.extend(DstLayout::for_type::<#leading_fields_tys>(), repr_packed))*
.extend(<#trailing_field_ty as KnownLayout>::LAYOUT, repr_packed)
.pad_to_align()
};

Expand All @@ -157,7 +160,7 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre
elems: usize,
) -> ::zerocopy::macro_util::core_reexport::ptr::NonNull<Self> {
use ::zerocopy::{KnownLayout};
let trailing = <#trailing_field as KnownLayout>::raw_from_ptr_len(bytes, elems);
let trailing = <#trailing_field_ty as KnownLayout>::raw_from_ptr_len(bytes, elems);
let slf = trailing.as_ptr() as *mut Self;
// SAFETY: Constructed from `trailing`, which is non-null.
unsafe { ::zerocopy::macro_util::core_reexport::ptr::NonNull::new_unchecked(slf) }
Expand Down Expand Up @@ -243,6 +246,21 @@ pub fn derive_known_layout(ts: proc_macro::TokenStream) -> proc_macro::TokenStre
.into()
}

#[proc_macro_derive(TryFromBytes)]
pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse_macro_input!(ts as DeriveInput);
match &ast.data {
Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct),
Data::Enum(_) => {
Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error()
}
Data::Union(_) => {
Error::new_spanned(&ast, "TryFromBytes not supported on union types").to_compile_error()
}
}
.into()
}

#[proc_macro_derive(FromZeroes)]
pub fn derive_from_zeroes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse_macro_input!(ts as DeriveInput);
Expand Down Expand Up @@ -287,6 +305,74 @@ pub fn derive_unaligned(ts: proc_macro::TokenStream) -> proc_macro::TokenStream
.into()
}

// A struct is `TryFromBytes` if:
// - all fields are `TryFromBytes`

fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
let reprs = try_or_print!(repr::reprs::<Repr>(&ast.attrs));

for (meta, repr) in reprs {
if matches!(repr, Repr::Packed | Repr::PackedN(_)) {
try_or_print!(Err(vec![Error::new_spanned(
meta,
"cannot derive TryFromBytes with repr(packed)",
)]));
}
}

let extras = Some({
let fields = strct.fields();
let field_names = fields.iter().map(|(name, _ty)| name);
let field_tys = fields.iter().map(|(_name, ty)| ty);
quote!(
// SAFETY: We use `is_bit_valid` to validate that each field is
// bit-valid, and only return `true` if all of them are. The bit
// validity of a struct is just the composition of the bit
// validities of its fields, so this is a sound implementation of
// `is_bit_valid`.
unsafe fn is_bit_valid(_candidate: zerocopy::Ptr<Self>) -> bool {
true #(&& {
let project = |slf: *mut Self| ::core::ptr::addr_of_mut!((*slf).#field_names);
// SAFETY: TODO
let field_candidate = unsafe { _candidate.project(project) };
// SAFETY: TODO
//
// Old safety comment:
//
// SAFETY:
// - `f` is properly aligned for `#field_tys` because
// `candidate` is properly aligned for `Self`.
// - `f` is valid for reads because `candidate` is.
// - Total length encoded by `f` doesn't overflow `isize`
// because it's no greater than the size encoded by
// `candidate`, whose size doesn't overflow `isize`.
// - `f` addresses a range which falls inside a single
// allocation because that range is a subset of the range
// addressed by `candidate`, and that latter range falls
// inside a single allocation.
// - The bit validity property of `is_bit_valid` is
// trivially compositional for structs. In particular, in
// a struct, there is no data dependency between bit
// validity in any two byte offsets (this is notably not
// true of enums). Since we know that the bit validity
// property holds for all of `candidate`, we also know
// that it holds for `f` regardless of the contents of any
// other region of `candidate`.
//
// Note that it's possible that this call will panic -
// `is_bit_valid` does not promise that it doesn't panic,
// and in practice, we support user-defined validators,
// which could panic. This is sound because we haven't
// violated any safety invariants which we would need to fix
// before returning.
<#field_tys as zerocopy::TryFromBytes>::is_bit_valid(field_candidate)
})*
}
)
});
impl_block(ast, strct, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras)
}

const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
&[StructRepr::C],
&[StructRepr::Transparent],
Expand Down Expand Up @@ -637,6 +723,7 @@ impl PaddingCheck {
#[derive(Debug, Eq, PartialEq)]
enum Trait {
KnownLayout,
TryFromBytes,
FromZeroes,
FromBytes,
AsBytes,
Expand Down Expand Up @@ -734,19 +821,19 @@ fn impl_block<D: DataExt>(

let type_ident = &input.ident;
let trait_ident = trt.ident();
let field_types = data.field_types();
let fields = data.fields();

let bound_tt = |ty| parse_quote!(#ty: ::zerocopy::#trait_ident);
let field_type_bounds: Vec<_> = match (require_trait_bound_on_field_types, &field_types[..]) {
(RequireBoundedFields::Yes, _) => field_types.iter().map(bound_tt).collect(),
let field_type_bounds: Vec<_> = match (require_trait_bound_on_field_types, &fields[..]) {
(RequireBoundedFields::Yes, _) => fields.iter().map(|(_name, ty)| bound_tt(ty)).collect(),
(RequireBoundedFields::No, _) | (RequireBoundedFields::Trailing, []) => vec![],
(RequireBoundedFields::Trailing, [.., last]) => vec![bound_tt(last)],
(RequireBoundedFields::Trailing, [.., last]) => vec![bound_tt(&last.1)],
};

// Don't bother emitting a padding check if there are no fields.
#[allow(unstable_name_collisions)] // See `BoolExt` below
let padding_check_bound = padding_check.and_then(|check| (!field_types.is_empty()).then_some(check)).map(|check| {
let fields = field_types.iter();
let padding_check_bound = padding_check.and_then(|check| (!fields.is_empty()).then_some(check)).map(|check| {
let fields = fields.iter().map(|(_name, ty)| ty);
let validator_macro = check.validator_macro_ident();
parse_quote!(
::zerocopy::macro_util::HasPadding<#type_ident, {::zerocopy::#validator_macro!(#type_ident, #(#fields),*)}>:
Expand Down

0 comments on commit 149bdcd

Please sign in to comment.