diff --git a/Cargo.toml b/Cargo.toml index 0704d71..767c9f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,8 +20,18 @@ version = "^1.0.0" default-features = false optional = true +[dependencies.zerocopy] +version = "0.8.9" +default-features = false +optional = true +[dependencies.zerocopy-derive] +version = "0.8.9" +default-features = false +optional = true + [features] std = [] +zerocopy = ["dep:zerocopy", "dep:zerocopy-derive"] [workspace] members = [ diff --git a/src/lib.rs b/src/lib.rs index 077d65a..a3cedc9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,6 +46,8 @@ //! //! - [`serde`](https://serde.rs/) implements `Serialize` and `Deserialize` //! for `BitFlags`. +//! - [`zerocopy`](https://github.com/google/zerocopy/) implements `Immutable`, `IntoBytes`, +//! `FromZeros`, `TryFromBytes`, and `KnownLayout` for all `BitFlags` and `Unaligned` if the value type is unaligned. //! - `std` implements `std::error::Error` for `FromBitsError`. //! //! ## `const fn`-compatible APIs @@ -253,7 +255,7 @@ pub trait BitFlag: Copy + Clone + 'static + _internal::RawBitFlags { /// /// All bits set in `val` must correspond to a value of the enum. /// - /// # Example + /// # Example /// /// This is a convenience reexport of [`BitFlags::from_bits_unchecked`]. It can be /// called with `MyFlag::from_bits_unchecked(bits)`, thus bypassing the need for @@ -322,8 +324,8 @@ pub mod _internal { } use ::core::fmt; - use ::core::ops::{BitAnd, BitOr, BitXor, Not, Sub}; use ::core::hash::Hash; + use ::core::ops::{BitAnd, BitOr, BitXor, Not, Sub}; pub trait BitFlagNum: Default @@ -523,6 +525,14 @@ pub use crate::const_api::ConstToken; /// `BitFlags` value where that isn't the case is only possible with /// incorrect unsafe code. #[derive(Copy, Clone)] +#[cfg_attr( + feature = "zerocopy", + derive( + zerocopy_derive::Immutable, + zerocopy_derive::KnownLayout, + zerocopy_derive::IntoBytes, + ) +)] #[repr(transparent)] pub struct BitFlags::Numeric> { val: N, @@ -657,6 +667,33 @@ where unsafe { BitFlags::from_bits_unchecked(bits & T::ALL_BITS) } } + /// Validate if an underlying bitwise value can safely be converted to `BitFlags`. + /// Returns false if any invalid bits are set. + /// + /// ``` + /// # use enumflags2::{bitflags, BitFlags}; + /// #[bitflags] + /// #[repr(u8)] + /// #[derive(Clone, Copy, PartialEq, Eq)] + /// enum MyFlag { + /// One = 0b0001, + /// Two = 0b0010, + /// Three = 0b1000, + /// } + /// + /// assert_eq!(BitFlags::::validate_bits(0b1011), true); + /// assert_eq!(BitFlags::::validate_bits(0b0000), true); + /// assert_eq!(BitFlags::::validate_bits(0b0100), false); + /// assert_eq!(BitFlags::::validate_bits(0b1111), false); + /// ``` + #[must_use] + #[inline(always)] + pub fn validate_bits(bits: T::Numeric) -> bool { + // SAFETY: We're truncating out all the invalid bits so it will + // only be different if there are invalid bits set. + (bits & T::ALL_BITS) == bits + } + /// Create a new BitFlags unsafely, without checking if the bits form /// a valid bit pattern for the type. /// @@ -1032,3 +1069,84 @@ mod impl_serde { } } } + +#[cfg(feature = "zerocopy")] +mod impl_zerocopy { + use super::{BitFlag, BitFlags}; + use zerocopy::{FromZeros, Immutable, TryFromBytes, Unaligned}; + + // All zeros is always valid + unsafe impl FromZeros for BitFlags + where + T: BitFlag, + T::Numeric: Immutable, + T::Numeric: FromZeros, + { + // We are actually allowed to implement this trait. The scary name is just meant + // to convey that "this is dangerous and you'd better know what you're doing and + // be sure that you need to do this and can't just use the derives". (https://github.com/google/zerocopy/issues/287) + // We can not use the derives for this, because they dont support validation. + fn only_derive_is_allowed_to_implement_this_trait() {} + } + + // Mark all BitFlags as Unaligned if the underlying number type is unaligned + unsafe impl Unaligned for BitFlags + where + T: BitFlag, + T::Numeric: Unaligned, + { + // We are actually allowed to implement this trait. The scary name is just meant + // to convey that "this is dangerous and you'd better know what you're doing and + // be sure that you need to do this and can't just use the derives". (https://github.com/google/zerocopy/issues/287) + // We can not use the derives for this, because they dont support validation. + fn only_derive_is_allowed_to_implement_this_trait() {} + } + + // Assert that there are no invalid bytes set + unsafe impl TryFromBytes for BitFlags + where + T: BitFlag, + T::Numeric: Immutable, + T::Numeric: TryFromBytes, + { + // We are actually allowed to implement this trait. The scary name is just meant + // to convey that "this is dangerous and you'd better know what you're doing and + // be sure that you need to do this and can't just use the derives". (https://github.com/google/zerocopy/issues/287) + // We can not use the derives for this, because they dont support validation. + fn only_derive_is_allowed_to_implement_this_trait() + where + Self: Sized, + { + } + + #[inline] + fn is_bit_valid< + ZerocopyAliasing: zerocopy::pointer::invariant::Aliasing + + zerocopy::pointer::invariant::AtLeast, + >( + candidate: zerocopy::Maybe<'_, Self, ZerocopyAliasing>, + ) -> bool { + // SAFETY: + // - The cast preserves address. The caller has promised that the + // cast results in an object of equal or lesser size, and so the + // cast returns a pointer which references a subset of the bytes + // of `p`. + // - The cast preserves provenance. + // - The caller has promised that the destination type has + // `UnsafeCell`s at the same byte ranges as the source type. + let candidate = unsafe { candidate.cast_unsized::(|p| p as *mut _) }; + + // SAFETY: The caller has promised that the referenced memory region + // will contain a valid `$repr`. + let my_candidate = + unsafe { candidate.assume_validity::() }; + { + // TODO: Currently this assumes that the candidate is aligned. We actually need to check this beforehand + // Dereference the pointer to the candidate + let candidate = + my_candidate.read_unaligned::(); + return BitFlags::::validate_bits(candidate); + } + } + } +} diff --git a/test_suite/Cargo.toml b/test_suite/Cargo.toml index 72f3526..c085f81 100644 --- a/test_suite/Cargo.toml +++ b/test_suite/Cargo.toml @@ -6,12 +6,16 @@ edition = "2018" [dependencies.enumflags2] path = "../" -features = ["serde"] +features = ["serde", "zerocopy"] [dependencies.serde] version = "1" features = ["derive"] +[dependencies.zerocopy] +version = "0.8.9" +features = ["derive"] + [dev-dependencies] trybuild = "1.0" glob = "0.3" @@ -65,3 +69,8 @@ edition = "2018" name = "not_literal" path = "tests/not_literal.rs" edition = "2018" + +[[test]] +name = "zerocopy" +path = "tests/zerocopy.rs" +edition = "2018" diff --git a/test_suite/tests/zerocopy.rs b/test_suite/tests/zerocopy.rs new file mode 100644 index 0000000..3392715 --- /dev/null +++ b/test_suite/tests/zerocopy.rs @@ -0,0 +1,32 @@ +use enumflags2::{bitflags, BitFlags}; +use zerocopy::{Immutable, IntoBytes, KnownLayout, TryFromBytes}; + +#[test] +fn zerocopy_compile() { + #[bitflags] + #[derive(Copy, Clone, Debug, KnownLayout)] + #[repr(u8)] + enum TestU8 { + A, + B, + C, + D, + } + + #[bitflags] + #[derive(Copy, Clone, Debug, KnownLayout)] + #[repr(u16)] + enum TestU16 { + A, + B, + C, + D, + } + + #[derive(Clone, Debug, Immutable, TryFromBytes, IntoBytes, KnownLayout)] + #[repr(packed)] + struct Other { + flags2: BitFlags, + flags: BitFlags, + } +}