-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: validate inputs of fn main (#4)
Related noir-lang/noir#7181 Related noir-lang/noir#4218
- Loading branch information
Showing
7 changed files
with
276 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
pub comptime fn validate_inputs(f: FunctionDefinition) { | ||
let validated_inputs = f | ||
.parameters() | ||
.map(|(name, _typ): (Quoted, Type)| quote {{ nodash::ValidateInput::validate($name); }}) | ||
.join(quote {;}); | ||
let checks_body = quote {{ $validated_inputs }}.as_expr().expect( | ||
f"failed to parse ValidateInput checks code", | ||
); // should never fail | ||
|
||
let old_body = f.body(); | ||
let checked_body = quote {{ | ||
$checks_body; | ||
$old_body | ||
}}; | ||
f.set_body(checked_body.as_expr().expect(f"failed to concatenate body with checks")); | ||
} | ||
|
||
#[derive_via(derive_validate_input)] | ||
pub trait ValidateInput { | ||
fn validate(self); | ||
} | ||
|
||
comptime fn derive_validate_input(s: StructDefinition) -> Quoted { | ||
let name = quote { nodash::ValidateInput }; | ||
let signature = quote { fn validate(self) }; | ||
let for_each_field = |name| quote { nodash::ValidateInput::validate(self.$name); }; | ||
let body = |fields| quote { $fields }; | ||
std::meta::make_trait_impl(s, name, signature, for_each_field, quote { , }, body) | ||
} | ||
|
||
impl ValidateInput for u8 { | ||
fn validate(self) {} | ||
} | ||
|
||
impl ValidateInput for u16 { | ||
fn validate(self) {} | ||
} | ||
|
||
impl ValidateInput for u32 { | ||
fn validate(self) {} | ||
} | ||
|
||
impl ValidateInput for u64 { | ||
fn validate(self) {} | ||
} | ||
|
||
impl ValidateInput for i8 { | ||
fn validate(self) {} | ||
} | ||
impl ValidateInput for i16 { | ||
fn validate(self) {} | ||
} | ||
impl ValidateInput for i32 { | ||
fn validate(self) {} | ||
} | ||
|
||
impl ValidateInput for i64 { | ||
fn validate(self) {} | ||
} | ||
|
||
impl ValidateInput for Field { | ||
fn validate(self) {} | ||
} | ||
|
||
impl ValidateInput for bool { | ||
fn validate(self) {} | ||
} | ||
|
||
impl ValidateInput for U128 { | ||
fn validate(self) {} | ||
} | ||
|
||
impl<let N: u32> ValidateInput for str<N> { | ||
fn validate(self) {} | ||
} | ||
|
||
impl<T, let N: u32> ValidateInput for [T; N] | ||
where | ||
T: ValidateInput, | ||
{ | ||
fn validate(mut self) { | ||
for i in 0..N { | ||
self[i].validate(); | ||
} | ||
} | ||
} | ||
|
||
impl<T, let MaxLen: u32> ValidateInput for BoundedVec<T, MaxLen> | ||
where | ||
T: ValidateInput, | ||
{ | ||
fn validate(mut self) { | ||
for i in 0..MaxLen { | ||
if i < self.len() { | ||
self.get_unchecked(i).validate() | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[package] | ||
name = "tests" | ||
type = "lib" | ||
|
||
[dependencies] | ||
nodash = { path = "../" } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
mod validate_inputs; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#[nodash::validate_inputs] | ||
fn my_main(a: Field, b: u64) -> Field { | ||
a + b as Field | ||
} | ||
|
||
#[test] | ||
fn test_validate_inputs() { | ||
let result = my_main(1, 2); | ||
assert(result == 3); | ||
} | ||
|
||
#[nodash::validate_inputs] | ||
fn main_collections(a: [U120; 1], b: BoundedVec<U120, 2>) -> Field { | ||
a[0].inner + b.get(0).inner | ||
} | ||
|
||
#[test] | ||
fn test_validate_collections() { | ||
let result = main_collections( | ||
[U120::new(1)], | ||
BoundedVec::from_parts([U120::new(2), U120 { inner: 2.pow_32(120) }], 1), | ||
); | ||
assert(result == 3); | ||
} | ||
|
||
#[test(should_fail_with = "call to assert_max_bit_size")] | ||
fn test_validate_array_fail() { | ||
let _ = main_collections([U120 { inner: 2.pow_32(120) }], BoundedVec::new()); | ||
} | ||
|
||
#[test(should_fail_with = "call to assert_max_bit_size")] | ||
fn test_validate_bounded_vec_fail() { | ||
let _ = main_collections( | ||
[U120::new(1)], | ||
BoundedVec::from_parts([U120::new(2), U120 { inner: 2.pow_32(120) }], 2), | ||
); | ||
} | ||
|
||
#[nodash::validate_inputs] | ||
fn main_u120(a: U120) -> Field { | ||
a.inner | ||
} | ||
|
||
#[test] | ||
fn test_validate_u120() { | ||
let inner = 2.pow_32(120) - 1; | ||
let result = main_u120(U120 { inner }); | ||
assert(result == inner); | ||
} | ||
|
||
#[test(should_fail_with = "call to assert_max_bit_size")] | ||
fn test_validate_u120_fail() { | ||
let inner = 2.pow_32(120); | ||
let _ = main_u120(U120 { inner }); | ||
} | ||
|
||
#[nodash::validate_inputs] | ||
fn main_struct_derive(a: NestedStruct) -> Field { | ||
a.value.inner | ||
} | ||
|
||
#[test] | ||
fn test_validate_struct_derive() { | ||
let inner = 2.pow_32(120) - 1; | ||
let result = main_struct_derive(NestedStruct { value: U120 { inner } }); | ||
assert(result == inner); | ||
} | ||
|
||
#[test(should_fail_with = "call to assert_max_bit_size")] | ||
fn test_validate_struct_derive_fail() { | ||
let inner = 2.pow_32(120); | ||
let _ = main_struct_derive(NestedStruct { value: U120 { inner } }); | ||
} | ||
|
||
struct U120 { | ||
inner: Field, | ||
} | ||
|
||
impl U120 { | ||
fn new(inner: Field) -> Self { | ||
inner.assert_max_bit_size::<120>(); | ||
Self { inner } | ||
} | ||
} | ||
|
||
impl nodash::ValidateInput for U120 { | ||
fn validate(self) { | ||
let _ = U120::new(self.inner); | ||
} | ||
} | ||
|
||
#[derive(nodash::ValidateInput)] | ||
struct NestedStruct { | ||
value: U120, | ||
} |