Skip to content

Commit

Permalink
[derive] Support TryFromBytes for structs/unions
Browse files Browse the repository at this point in the history
TODO:
- Should DataExt::fields have a different shape? Extracting field names
  for enums doesn't really make sense
- Add SAFETY comments in emitted `is_bit_valid` impl
- Do we need to update the TryFromBytes doc comment at all?
- Union bit validity should be OR of fields rather than AND
- Lots and lots of tests

Makes progress on #5
  • Loading branch information
joshlf committed Sep 11, 2023
1 parent e82b746 commit 0129c61
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 33 deletions.
12 changes: 12 additions & 0 deletions src/derive_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ macro_rules! union_has_padding {
#[cfg(test)]
mod tests {
use crate::util::testutil::*;
use crate::*;

#[test]
fn test_struct_has_padding() {
Expand Down Expand Up @@ -124,4 +125,15 @@ mod tests {
// anyway.
test!(#[repr(C)] #[repr(packed)] {a: u8, b: u64} => true);
}

#[test]
fn foo() {
#[derive(TryFromBytes)]
struct Foo {
f: u8,
b: bool,
}

impl_known_layout!(Foo);
}
}
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,11 @@ pub unsafe trait FromBytes: FromZeroes {
///
/// [`try_from_ref`]: TryFromBytes::try_from_ref
pub unsafe trait TryFromBytes: KnownLayout {
#[doc(hidden)]
fn only_derive_is_allowed_to_implement_this_trait()
where
Self: Sized;

/// Does a given memory range contain a valid instance of `Self`?
///
/// # Safety
Expand Down
12 changes: 11 additions & 1 deletion src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ macro_rules! unsafe_impl {
let $candidate = unsafe { &*(candidate.as_ptr() as *const $repr) };
$is_bit_valid
}

#[allow(clippy::missing_inline_in_public_items)]
fn only_derive_is_allowed_to_implement_this_trait() {}
};
(@method TryFromBytes ; |$candidate:ident: NonNull<$repr:ty>| $is_bit_valid:expr) => {
#[inline]
Expand All @@ -141,8 +144,15 @@ macro_rules! unsafe_impl {
let $candidate = unsafe { NonNull::new_unchecked(candidate.as_ptr() as *mut $repr) };
$is_bit_valid
}

#[allow(clippy::missing_inline_in_public_items)]
fn only_derive_is_allowed_to_implement_this_trait() {}
};
(@method TryFromBytes) => {
#[inline(always)] unsafe fn is_bit_valid(_: NonNull<Self>) -> bool { true }
#[allow(clippy::missing_inline_in_public_items)]
fn only_derive_is_allowed_to_implement_this_trait() {}
};
(@method TryFromBytes) => { #[inline(always)] unsafe fn is_bit_valid(_: NonNull<Self>) -> bool { true } };
(@method $trait:ident) => {
#[allow(clippy::missing_inline_in_public_items)]
fn only_derive_is_allowed_to_implement_this_trait() {}
Expand Down
56 changes: 41 additions & 15 deletions zerocopy-derive/src/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,74 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

use syn::{Data, DataEnum, DataStruct, DataUnion, Type};
use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::{Data, DataEnum, DataStruct, DataUnion, Field, Index, Type};

pub trait DataExt {
/// Extract the types of all fields. For enums, extract the types of fields
/// from each variant.
fn field_types(&self) -> Vec<&Type>;
/// Extract the names and types of all fields. For enums, extract the names
/// and types of fields from each variant. For tuple structs, the names are
/// the indices used to index into the struct (ie, `0`, `1`, etc).
///
/// TODO: Extracting field names for enums doesn't really make sense. Types
/// makes sense because we don't care about where they live - we just care
/// about transitive ownership. But for field names, we'd only use them when
/// generating is_bit_valid, which cares about where they live.
fn fields(&self) -> Vec<(TokenStream, &Type)>;
}

impl DataExt for Data {
fn field_types(&self) -> Vec<&Type> {
fn fields(&self) -> Vec<(TokenStream, &Type)> {
match self {
Data::Struct(strc) => strc.field_types(),
Data::Enum(enm) => enm.field_types(),
Data::Union(un) => un.field_types(),
Data::Struct(strc) => strc.fields(),
Data::Enum(enm) => enm.fields(),
Data::Union(un) => un.fields(),
}
}
}

impl DataExt for DataStruct {
fn field_types(&self) -> Vec<&Type> {
self.fields.iter().map(|f| &f.ty).collect()
fn fields(&self) -> Vec<(TokenStream, &Type)> {
map_fields(&self.fields)
}
}

impl DataExt for DataEnum {
fn field_types(&self) -> Vec<&Type> {
self.variants.iter().flat_map(|var| &var.fields).map(|f| &f.ty).collect()
fn fields(&self) -> Vec<(TokenStream, &Type)> {
map_fields(self.variants.iter().flat_map(|var| &var.fields))
}
}

impl DataExt for DataUnion {
fn field_types(&self) -> Vec<&Type> {
self.fields.named.iter().map(|f| &f.ty).collect()
fn fields(&self) -> Vec<(TokenStream, &Type)> {
map_fields(&self.fields.named)
}
}

fn map_fields<'a>(
fields: impl 'a + IntoIterator<Item = &'a Field>,
) -> Vec<(TokenStream, &'a Type)> {
fields
.into_iter()
.enumerate()
.map(|(idx, f)| {
(
f.ident
.as_ref()
.map(ToTokens::to_token_stream)
.unwrap_or_else(|| Index::from(idx).to_token_stream()),
&f.ty,
)
})
.collect()
}

pub trait EnumExt {
fn is_c_like(&self) -> bool;
}

impl EnumExt for DataEnum {
fn is_c_like(&self) -> bool {
self.field_types().is_empty()
self.fields().is_empty()
}
}
76 changes: 59 additions & 17 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,26 @@ use {crate::ext::*, crate::repr::*};
// help: required by the derive of FromBytes
//
// Instead, we have more verbose error messages like "unsupported representation
// for deriving FromZeroes, FromBytes, AsBytes, or Unaligned on an enum"
// for deriving TryFromBytes, FromZeroes, FromBytes, AsBytes, or Unaligned on an
// enum"
//
// This will probably require Span::error
// (https://doc.rust-lang.org/nightly/proc_macro/struct.Span.html#method.error),
// which is currently unstable. Revisit this once it's stable.

#[proc_macro_derive(TryFromBytes)]
pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse_macro_input!(ts as DeriveInput);
match &ast.data {
Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct),
Data::Enum(_) => {
Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error()
}
Data::Union(unn) => derive_try_from_bytes_union(&ast, unn),
}
.into()
}

#[proc_macro_derive(FromZeroes)]
pub fn derive_from_zeroes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse_macro_input!(ts as DeriveInput);
Expand Down Expand Up @@ -110,6 +124,14 @@ macro_rules! try_or_print {
};
}

fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
impl_block(ast, strct, "TryFromBytes", true, None, true)
}

fn derive_try_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
impl_block(ast, unn, "TryFromBytes", true, None, true)
}

const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
&[StructRepr::C],
&[StructRepr::Transparent],
Expand All @@ -121,7 +143,7 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
// - all fields are `FromZeroes`

fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
impl_block(ast, strct, "FromZeroes", true, None)
impl_block(ast, strct, "FromZeroes", true, None, false)
}

// An enum is `FromZeroes` if:
Expand Down Expand Up @@ -155,21 +177,21 @@ fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::To
.to_compile_error();
}

impl_block(ast, enm, "FromZeroes", true, None)
impl_block(ast, enm, "FromZeroes", true, None, false)
}

// Like structs, unions are `FromZeroes` if
// - all fields are `FromZeroes`

fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
impl_block(ast, unn, "FromZeroes", true, None)
impl_block(ast, unn, "FromZeroes", true, None, false)
}

// A struct is `FromBytes` if:
// - all fields are `FromBytes`

fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
impl_block(ast, strct, "FromBytes", true, None)
impl_block(ast, strct, "FromBytes", true, None, false)
}

// An enum is `FromBytes` if:
Expand Down Expand Up @@ -212,7 +234,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok
.to_compile_error();
}

impl_block(ast, enm, "FromBytes", true, None)
impl_block(ast, enm, "FromBytes", true, None, false)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -243,7 +265,7 @@ const ENUM_FROM_BYTES_CFG: Config<EnumRepr> = {
// - all fields are `FromBytes`

fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
impl_block(ast, unn, "FromBytes", true, None)
impl_block(ast, unn, "FromBytes", true, None, false)
}

// A struct is `AsBytes` if:
Expand Down Expand Up @@ -277,7 +299,7 @@ fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2:
// any padding bytes would need to come from the fields, all of which
// we require to be `AsBytes` (meaning they don't have any padding).
let padding_check = if is_transparent || is_packed { None } else { Some(PaddingCheck::Struct) };
impl_block(ast, strct, "AsBytes", true, padding_check)
impl_block(ast, strct, "AsBytes", true, padding_check, false)
}

const STRUCT_UNION_AS_BYTES_CFG: Config<StructRepr> = Config {
Expand All @@ -300,7 +322,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Token
// We don't care what the repr is; we only care that it is one of the
// allowed ones.
let _: Vec<repr::EnumRepr> = try_or_print!(ENUM_AS_BYTES_CFG.validate_reprs(ast));
impl_block(ast, enm, "AsBytes", false, None)
impl_block(ast, enm, "AsBytes", false, None, false)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -342,7 +364,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok

try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast));

impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union))
impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union), false)
}

// A struct is `Unaligned` if:
Expand All @@ -355,7 +377,7 @@ fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2
let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast));
let require_trait_bound = !reprs.contains(&StructRepr::Packed);

impl_block(ast, strct, "Unaligned", require_trait_bound, None)
impl_block(ast, strct, "Unaligned", require_trait_bound, None, false)
}

const STRUCT_UNION_UNALIGNED_CFG: Config<StructRepr> = Config {
Expand Down Expand Up @@ -386,7 +408,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Toke
// for `require_trait_bounds` doesn't really do anything. But it's
// marginally more future-proof in case that restriction is lifted in the
// future.
impl_block(ast, enm, "Unaligned", true, None)
impl_block(ast, enm, "Unaligned", true, None, false)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -424,7 +446,7 @@ fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::To
let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast));
let require_trait_bound = !reprs.contains(&StructRepr::Packed);

impl_block(ast, unn, "Unaligned", require_trait_bound, None)
impl_block(ast, unn, "Unaligned", require_trait_bound, None, false)
}

// This enum describes what kind of padding check needs to be generated for the
Expand Down Expand Up @@ -455,6 +477,7 @@ fn impl_block<D: DataExt>(
trait_name: &str,
require_trait_bound: bool,
padding_check: Option<PaddingCheck>,
emit_is_bit_valid: bool,
) -> proc_macro2::TokenStream {
// In this documentation, we will refer to this hypothetical struct:
//
Expand Down Expand Up @@ -516,18 +539,18 @@ fn impl_block<D: DataExt>(

let type_ident = &input.ident;
let trait_ident = Ident::new(trait_name, Span::call_site());
let field_types = data.field_types();
let fields = data.fields();

let field_type_bounds = require_trait_bound
.then(|| field_types.iter().map(|ty| parse_quote!(#ty: zerocopy::#trait_ident)))
.then(|| fields.iter().map(|(_name, ty)| parse_quote!(#ty: zerocopy::#trait_ident)))
.into_iter()
.flatten()
.collect::<Vec<_>>();

// Don't bother emitting a padding check if there are no fields.
#[allow(unstable_name_collisions)] // See `BoolExt` below
let padding_check_bound = padding_check.and_then(|check| (!field_types.is_empty()).then_some(check)).map(|check| {
let fields = field_types.iter();
let padding_check_bound = padding_check.and_then(|check| (!fields.is_empty()).then_some(check)).map(|check| {
let fields = fields.iter().map(|(_name, ty)| ty);
let validator_macro = check.validator_macro_ident();
parse_quote!(
zerocopy::derive_util::HasPadding<#type_ident, {zerocopy::#validator_macro!(#type_ident, #(#fields),*)}>:
Expand Down Expand Up @@ -564,12 +587,31 @@ fn impl_block<D: DataExt>(
GenericParam::Const(cnst) => quote!(#cnst),
});

let is_bit_valid = emit_is_bit_valid.then(|| {
let field_names = fields.iter().map(|(name, _ty)| name);
let field_tys = fields.iter().map(|(_name, ty)| ty);
quote!(
unsafe fn is_bit_valid(candidate: ::core::ptr::NonNull<Self>) -> bool {
let _c = candidate.as_ptr();
true #(&& {
let field_candidate = ::core::ptr::addr_of_mut!((*_c).#field_names);
// SAFETY: TODO
let f = unsafe { ::core::ptr::NonNull::new_unchecked(field_candidate) };
// SAFETY: TODO
<#field_tys as zerocopy::TryFromBytes>::is_bit_valid(f)
})*
}
)
});

quote! {
unsafe impl < #(#params),* > zerocopy::#trait_ident for #type_ident < #(#param_idents),* >
where
#(#bounds,)*
{
fn only_derive_is_allowed_to_implement_this_trait() {}

#is_bit_valid
}
}
}
Expand Down

0 comments on commit 0129c61

Please sign in to comment.