Skip to content

Commit

Permalink
Merge pull request #149 from ralexstokes/proc-macro-fixes
Browse files Browse the repository at this point in the history
Fixes and extensions for the derive proc macro
  • Loading branch information
ralexstokes authored Apr 2, 2024
2 parents 9a238ce + 016236a commit 71f84ff
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 22 deletions.
27 changes: 18 additions & 9 deletions ssz-rs-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
//! Provides a set of macros to derive implementations for the core SSZ traits given in the `ssz_rs`
//! crate. Suppports native Rust structs and enums, subject to conditions compatible with SSZ
//! containers and unions.
//!
//! Refer to the `examples` in the `ssz_rs` crate for a better idea on how to use this derive macro.
//!
//! This proc macro supports one attribute `ssz(transparent)` to pass through calls on a wrapping
//! Rust enum to the underlying data. Refers to this crate's tests for example usage.
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned};
use syn::{
Expand Down Expand Up @@ -619,14 +623,21 @@ fn derive_prove_impl(data: &Data, name: &Ident, generics: &Generics) -> TokenStr
(chunks_impl, prove_element_impl, None)
}
Fields::Unnamed(..) => {
// NOTE: new type pattern, proxy to wrapped type...
let chunks_impl = quote! {
self.0.chunks()
};

let prove_element_impl = quote! {
self.0.prove_element(index, prover)
};
(chunks_impl, prove_element_impl, None)

let decoration_impl = quote! {
fn decoration(&self) -> Option<usize> {
self.0.decoration()
}
};
(chunks_impl, prove_element_impl, Some(decoration_impl))
}
Fields::Unit => unreachable!("validated to exclude this type"),
},
Expand Down Expand Up @@ -663,7 +674,7 @@ fn derive_prove_impl(data: &Data, name: &Ident, generics: &Generics) -> TokenStr
},
)
}
_ => unreachable!(),
_ => unreachable!("other variants validated to not exist"),
}
});
let (impl_by_variant, decoration_by_variant): (Vec<_>, Vec<_>) =
Expand Down Expand Up @@ -731,7 +742,7 @@ fn filter_ssz_attrs<'a>(
fn validate_no_attrs<'a>(fields: impl Iterator<Item = &'a Field>) {
let mut ssz_attrs = fields.flat_map(|field| filter_ssz_attrs(field.attrs.iter()));
if ssz_attrs.next().is_some() {
panic!("macro attribute `{SSZ_HELPER_ATTRIBUTE}` is only allowed at struct or enum level")
panic!("macro attribute `{SSZ_HELPER_ATTRIBUTE}` is only allowed at enum level")
}
}

Expand Down Expand Up @@ -981,21 +992,19 @@ pub fn derive_prove(input: proc_macro::TokenStream) -> proc_macro::TokenStream {

/// Derive `SimpleSerialize` for the attached item, including the relevant additional traits
/// required by the trait bound. Most common macro used from this crate.
#[proc_macro_derive(SimpleSerialize, attributes(ssz))]
#[proc_macro_derive(SimpleSerialize)]
pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);

let data = &input.data;
let helper_attrs = extract_helper_attrs(&input);
validate_derive_input(data, &helper_attrs);
let helper_attr = helper_attrs.first();
validate_derive_input(data, &[]);

let name = &input.ident;
let generics = &input.generics;

let serializable_impl = derive_serializable_impl(data, name, generics, helper_attr);
let serializable_impl = derive_serializable_impl(data, name, generics, None);

let merkleization_impl = derive_merkleization_impl(data, name, generics, helper_attr);
let merkleization_impl = derive_merkleization_impl(data, name, generics, None);

let generalized_indexable_impl = derive_generalized_indexable_impl(data, name, generics);

Expand Down
201 changes: 189 additions & 12 deletions ssz-rs-derive/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,210 @@
use ssz_rs::prelude::*;
use ssz_rs::{prelude::*, proofs::ProofAndWitness};
use ssz_rs_derive::SimpleSerialize;
use std::fmt;

#[derive(Debug, SimpleSerialize, PartialEq, Eq)]
#[derive(Debug, Clone, SimpleSerialize, PartialEq, Eq)]
struct Foo {
a: u8,
b: u32,
c: List<usize, 45>,
d: U256,
}

#[derive(Debug, SimpleSerialize, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Serializable, HashTreeRoot)]
#[ssz(transparent)]
enum Bar {
A(u8),
B(Foo),
}

#[derive(Debug, SimpleSerialize)]
fn generalized_index_for_bar(
object: &mut Bar,
path: Path,
) -> Result<GeneralizedIndex, MerkleizationError> {
match object {
Bar::A(_) => u8::generalized_index(path),
Bar::B(_) => Foo::generalized_index(path),
}
}

fn prove_for_bar(object: &mut Bar, path: Path) -> Result<ProofAndWitness, MerkleizationError> {
match object {
Bar::A(value) => value.prove(path),
Bar::B(value) => value.prove(path),
}
}

#[derive(Debug, PartialEq, Eq, SimpleSerialize)]
struct Wrapper(Foo);

#[derive(Debug, PartialEq, Eq, SimpleSerialize)]
struct WrappedList(List<u8, 23>);

fn can_serde<T: Serializable + Eq + fmt::Debug>(data: &mut T) {
let mut buf = vec![];
let _ = data.serialize(&mut buf).unwrap();
let recovered = T::deserialize(&buf).unwrap();
assert_eq!(data, &recovered);
}

#[test]
fn test_transparent_helper() {
let mut f = Foo { a: 23, b: 445 };
let f_root = f.hash_tree_root().unwrap();
let mut bar = Bar::B(f);
// derive traits for "regular" types
let mut container = Foo {
a: 23,
b: 445,
c: List::<usize, 45>::try_from(vec![9, 8, 7, 6, 5, 4]).unwrap(),
d: U256::from(234234),
};
can_serde(&mut container);

let mut buf = vec![];
let _ = bar.serialize(&mut buf).unwrap();
let recovered_bar = Bar::deserialize(&buf).unwrap();
assert_eq!(bar, recovered_bar);
let container_root = container.hash_tree_root().unwrap();

let mut container_indices = vec![];
let container_paths = [
(vec!["a".into()], 4),
(vec!["b".into()], 5),
(vec!["c".into()], 6),
(vec!["c".into(), 1.into()], 192),
(vec!["c".into(), 43.into()], 202),
(vec!["d".into()], 7),
];
for (path, expected) in &container_paths {
let index = Foo::generalized_index(path).unwrap();
assert_eq!(index, *expected);
container_indices.push(index);
}

let mut container_proofs = vec![];
for pair in &container_paths {
let path = &pair.0;
let (proof, witness) = container.prove(path).unwrap();
assert_eq!(witness, container_root);
assert!(proof.verify(witness).is_ok());
container_proofs.push((proof, witness));
}

// derive traits in "transparent" mode
let mut inner = 22;
let mut bar = Bar::A(inner);
can_serde(&mut bar);

let inner_root = inner.hash_tree_root().unwrap();
let bar_root = bar.hash_tree_root().unwrap();
assert_eq!(inner_root, bar_root);

// `bar` just wraps a primitive type, so `path` is empty.
let index = generalized_index_for_bar(&mut bar, &[]).unwrap();
assert_eq!(index, 1);
let result = generalized_index_for_bar(&mut bar, &["a".into()]);
assert!(result.is_err());

let path = &[];
let (proof, witness) = prove_for_bar(&mut bar, path).unwrap();
assert_eq!(witness, inner_root);
assert_eq!(witness, bar_root);
assert!(proof.verify(witness).is_ok());

// repeat transparent with other variant
let mut inner = container.clone();
let inner_root = inner.hash_tree_root().unwrap();
let mut bar = Bar::B(inner);
can_serde(&mut bar);

let bar_root = bar.hash_tree_root().unwrap();
assert_eq!(f_root, bar_root);
assert_eq!(inner_root, bar_root);

for (i, (path, _)) in container_paths.iter().enumerate() {
let index = generalized_index_for_bar(&mut bar, path).unwrap();
assert_eq!(index, container_indices[i]);
}

for (i, pair) in container_paths.iter().enumerate() {
let path = &pair.0;
let (proof, witness) = prove_for_bar(&mut bar, path).unwrap();
assert_eq!(witness, container_root);
assert!(proof.verify(witness).is_ok());
assert_eq!((proof, witness), container_proofs[i]);
}

// derive traits for "new type" pattern
// for a wrapped type without "decoration"
let mut buf = vec![];
let container_serialization = container.serialize(&mut buf).unwrap();
let mut wrapped = Wrapper(container);
can_serde(&mut wrapped);
buf.clear();
let wrapped_serialization = wrapped.serialize(&mut buf).unwrap();
assert_eq!(container_serialization, wrapped_serialization);

let wrapped_root = wrapped.hash_tree_root().unwrap();
assert_eq!(wrapped_root, container_root);

let wrapped_paths = [
(vec!["a".into()], 4),
(vec!["b".into()], 5),
(vec!["c".into()], 6),
(vec!["c".into(), 1.into()], 192),
(vec!["c".into(), 43.into()], 202),
(vec!["d".into()], 7),
];
for (i, (path, expected)) in container_paths.iter().enumerate() {
let index = Wrapper::generalized_index(path).unwrap();
assert_eq!(index, *expected);
assert_eq!(index, container_indices[i]);
}

for (i, pair) in wrapped_paths.iter().enumerate() {
let path = &pair.0;
let (proof, witness) = wrapped.prove(path).unwrap();
assert_eq!(witness, container_root);
assert!(proof.verify(witness).is_ok());
assert_eq!((proof, witness), container_proofs[i]);
}

// for a wrapped type with "decoration"
let mut buf = vec![];
let mut inner = List::<u8, 23>::try_from(vec![10, 11, 12]).unwrap();
let inner_serialization = inner.serialize(&mut buf).unwrap();
let inner_root = inner.hash_tree_root().unwrap();
let inner_paths =
[(vec![0.into()], 2), (vec![1.into()], 2), (vec![21.into()], 2), (vec![22.into()], 2)];
let mut inner_indices = vec![];
for (path, expected) in &inner_paths {
let index = List::<u8, 23>::generalized_index(path).unwrap();
assert_eq!(index, *expected);
inner_indices.push(index);
}
let mut inner_proofs = vec![];
for (path, _) in &inner_paths {
let (proof, witness) = inner.prove(path).unwrap();
assert_eq!(witness, inner_root);
assert!(proof.verify(witness).is_ok());
inner_proofs.push((proof, witness));
}

let mut wrapped = WrappedList(inner);
can_serde(&mut wrapped);
buf.clear();
let wrapped_serialization = wrapped.serialize(&mut buf).unwrap();
assert_eq!(inner_serialization, wrapped_serialization);

let wrapped_root = wrapped.hash_tree_root().unwrap();
assert_eq!(wrapped_root, inner_root);

let wrapped_paths =
[(vec![0.into()], 2), (vec![3.into()], 2), (vec![21.into()], 2), (vec![22.into()], 2)];
for (i, (path, expected)) in wrapped_paths.iter().enumerate() {
let index = WrappedList::generalized_index(path).unwrap();
assert_eq!(index, *expected);
assert_eq!(index, inner_indices[i]);
}

for (i, pair) in wrapped_paths.iter().enumerate() {
let path = &pair.0;
let (proof, witness) = wrapped.prove(path).unwrap();
assert_eq!(witness, inner_root);
assert!(proof.verify(witness).is_ok());
assert_eq!((proof, witness), inner_proofs[i]);
}
}
2 changes: 1 addition & 1 deletion ssz-rs/src/merkleization/generalized_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
};

/// Describes part of a `GeneralizedIndexable` type.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PathElement {
// Refers to either an element in a SSZ collection
// or a particular variant of a SSZ union.
Expand Down

0 comments on commit 71f84ff

Please sign in to comment.