From feb8d0e1840d2f297de53e0aaa3587ab6d7c55d6 Mon Sep 17 00:00:00 2001 From: Alex Vitkov <44268717+alexvitkov@users.noreply.github.com> Date: Tue, 29 Aug 2023 14:51:24 +0000 Subject: [PATCH] feat: Standard library functions can now be called with closure args (#2471) --- .../compile_success_empty/option/src/main.nr | 11 ++++ .../execution_success/generators/Nargo.toml | 7 +++ .../execution_success/generators/src/main.nr | 57 +++++++++++++++++++ .../higher_order_functions/src/main.nr | 10 +++- noir_stdlib/src/array.nr | 12 ++-- noir_stdlib/src/option.nr | 14 ++--- 6 files changed, 97 insertions(+), 14 deletions(-) create mode 100644 crates/nargo_cli/tests/execution_success/generators/Nargo.toml create mode 100644 crates/nargo_cli/tests/execution_success/generators/src/main.nr diff --git a/crates/nargo_cli/tests/compile_success_empty/option/src/main.nr b/crates/nargo_cli/tests/compile_success_empty/option/src/main.nr index 0a41b9a629c..22229014eef 100644 --- a/crates/nargo_cli/tests/compile_success_empty/option/src/main.nr +++ b/crates/nargo_cli/tests/compile_success_empty/option/src/main.nr @@ -1,6 +1,8 @@ use dep::std::option::Option; fn main() { + let ten = 10; // giving this a name, to ensure that the Option functions work with closures + let none = Option::none(); let some = Option::some(3); @@ -14,15 +16,22 @@ fn main() { assert(none.unwrap_or_else(|| 5) == 5); assert(some.unwrap_or_else(|| 5) == 3); + assert(none.unwrap_or_else(|| ten + 5) == 15); + assert(some.unwrap_or_else(|| ten + 5) == 3); assert(none.map(|x| x * 2).is_none()); assert(some.map(|x| x * 2).unwrap() == 6); + assert(some.map(|x| x * ten).unwrap() == 30); assert(none.map_or(0, |x| x * 2) == 0); assert(some.map_or(0, |x| x * 2) == 6); + assert(none.map_or(0, |x| x * ten) == 0); + assert(some.map_or(0, |x| x * ten) == 30); assert(none.map_or_else(|| 0, |x| x * 2) == 0); assert(some.map_or_else(|| 0, |x| x * 2) == 6); + assert(none.map_or_else(|| 0, |x| x * ten) == 0); + assert(some.map_or_else(|| ten, |x| x * 2) == 6); assert(none.and(none).is_none()); assert(none.and(some).is_none()); @@ -35,6 +44,7 @@ fn main() { assert(none.and_then(add1_u64).is_none()); assert(some.and_then(|_value| Option::none()).is_none()); assert(some.and_then(add1_u64).unwrap() == 4); + assert(some.and_then(|x| Option::some(x + ten)).unwrap() == 13); assert(none.or(none).is_none()); assert(none.or(some).is_some()); @@ -45,6 +55,7 @@ fn main() { assert(none.or_else(|| Option::some(5)).is_some()); assert(some.or_else(|| Option::none()).is_some()); assert(some.or_else(|| Option::some(5)).is_some()); + assert(some.or_else(|| Option::some(ten)).is_some()); assert(none.xor(none).is_none()); assert(none.xor(some).is_some()); diff --git a/crates/nargo_cli/tests/execution_success/generators/Nargo.toml b/crates/nargo_cli/tests/execution_success/generators/Nargo.toml new file mode 100644 index 00000000000..0f05b6e5759 --- /dev/null +++ b/crates/nargo_cli/tests/execution_success/generators/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "generators" +type = "bin" +authors = [""] +compiler_version = "0.10.3" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/execution_success/generators/src/main.nr b/crates/nargo_cli/tests/execution_success/generators/src/main.nr new file mode 100644 index 00000000000..2f6f90a8c57 --- /dev/null +++ b/crates/nargo_cli/tests/execution_success/generators/src/main.nr @@ -0,0 +1,57 @@ +// TODO? +// the syntax for these return types is very difficult to get right :/ +// for arguments this can be handled with a generic Env (or with Fn traits when we add them) +// but for return types neither fo these will help, you need to type out the exact type +fn make_counter() -> fn[(&mut Field,)]() -> Field { + let mut x = &mut 0; + + || { + *x = *x + 1; + *x + } +} + +fn fibonacci_generator() -> fn[(&mut Field, &mut Field)]() -> Field { + let mut x = &mut 1; + let mut y = &mut 2; + + || { + let old_x = *x; + let old_y = *y; + + *y = *x + *y; + *x = old_y; + + old_x + } +} + +// we'll be able to un-hardcode the array length if we have the ::<> syntax proposed in https://github.com/noir-lang/noir/issues/2458 +fn get_some(generator: fn[Env]() -> Field) -> [Field; 5] { + [0,0,0,0,0].map(|_| generator()) +} + +fn test_fib() { + let fib = fibonacci_generator(); + + assert(fib() == 1); + assert(fib() == 2); + assert(fib() == 3); + assert(fib() == 5); + + assert(get_some(fib) == [8, 13, 21, 34, 55]); +} + +fn test_counter() { + let counter = make_counter(); + assert(counter() == 1); + assert(counter() == 2); + assert(counter() == 3); + + assert(get_some(counter) == [4, 5, 6, 7, 8]); +} + +fn main() { + test_fib(); + test_counter(); +} diff --git a/crates/nargo_cli/tests/execution_success/higher_order_functions/src/main.nr b/crates/nargo_cli/tests/execution_success/higher_order_functions/src/main.nr index 0216b4070fb..ce61a4d572d 100644 --- a/crates/nargo_cli/tests/execution_success/higher_order_functions/src/main.nr +++ b/crates/nargo_cli/tests/execution_success/higher_order_functions/src/main.nr @@ -44,14 +44,21 @@ fn main() -> pub Field { /// Test the array functions in std::array fn test_array_functions() { + let two = 2; // giving this a name, to ensure that the Option functions work with closures + let myarray: [i32; 3] = [1, 2, 3]; assert(myarray.any(|n| n > 2)); + assert(myarray.any(|n| n > two)); + + let evens: [i32; 3] = myarray.map(|n| n * two); // [2, 4, 6] - let evens: [i32; 3] = [2, 4, 6]; assert(evens.all(|n| n > 1)); + assert(evens.all(|n| n >= two)); assert(evens.fold(0, |a, b| a + b) == 12); + assert(evens.fold(0, |a, b| a + b + two) == 18); assert(evens.reduce(|a, b| a + b) == 12); + assert(evens.reduce(|a, b| a + b + two) == 16); // TODO: is this a sort_via issue with the new backend, // or something more general? @@ -68,6 +75,7 @@ fn test_array_functions() { // assert(descending == [3, 2, 1]); assert(evens.map(|n| n / 2) == myarray); + assert(evens.map(|n| n / two) == myarray); } fn foo() -> [u32; 2] { diff --git a/noir_stdlib/src/array.nr b/noir_stdlib/src/array.nr index 9082e161e91..c1e7cfdcfe6 100644 --- a/noir_stdlib/src/array.nr +++ b/noir_stdlib/src/array.nr @@ -9,7 +9,7 @@ impl [T; N] { fn sort(_self: Self) -> Self {} // Sort with a custom sorting function. - fn sort_via(mut a: Self, ordering: fn(T, T) -> bool) -> Self { + fn sort_via(mut a: Self, ordering: fn[Env](T, T) -> bool) -> Self { for i in 1 .. a.len() { for j in 0..i { if ordering(a[i], a[j]) { @@ -33,7 +33,7 @@ impl [T; N] { // Apply a function to each element of an array, returning a new array // containing the mapped elements. - fn map(self, f: fn(T) -> U) -> [U; N] { + fn map(self, f: fn[Env](T) -> U) -> [U; N] { let first_elem = f(self[0]); let mut ret = [first_elem; N]; @@ -47,7 +47,7 @@ impl [T; N] { // Apply a function to each element of the array and an accumulator value, // returning the final accumulated value. This function is also sometimes // called `foldl`, `fold_left`, `reduce`, or `inject`. - fn fold(self, mut accumulator: U, f: fn(U, T) -> U) -> U { + fn fold(self, mut accumulator: U, f: fn[Env](U, T) -> U) -> U { for elem in self { accumulator = f(accumulator, elem); } @@ -57,7 +57,7 @@ impl [T; N] { // Apply a function to each element of the array and an accumulator value, // returning the final accumulated value. Unlike fold, reduce uses the first // element of the given array as its starting accumulator value. - fn reduce(self, f: fn(T, T) -> T) -> T { + fn reduce(self, f: fn[Env](T, T) -> T) -> T { let mut accumulator = self[0]; for i in 1 .. self.len() { accumulator = f(accumulator, self[i]); @@ -66,7 +66,7 @@ impl [T; N] { } // Returns true if all elements in the array satisfy the predicate - fn all(self, predicate: fn(T) -> bool) -> bool { + fn all(self, predicate: fn[Env](T) -> bool) -> bool { let mut ret = true; for elem in self { ret &= predicate(elem); @@ -75,7 +75,7 @@ impl [T; N] { } // Returns true if any element in the array satisfies the predicate - fn any(self, predicate: fn(T) -> bool) -> bool { + fn any(self, predicate: fn[Env](T) -> bool) -> bool { let mut ret = false; for elem in self { ret |= predicate(elem); diff --git a/noir_stdlib/src/option.nr b/noir_stdlib/src/option.nr index 919c40fd9e0..11a632011b0 100644 --- a/noir_stdlib/src/option.nr +++ b/noir_stdlib/src/option.nr @@ -48,7 +48,7 @@ impl Option { /// Returns the wrapped value if `self.is_some()`. Otherwise, calls the given function to return /// a default value. - fn unwrap_or_else(self, default: fn() -> T) -> T { + fn unwrap_or_else(self, default: fn[Env]() -> T) -> T { if self._is_some { self._value } else { @@ -57,7 +57,7 @@ impl Option { } /// If self is `Some(x)`, this returns `Some(f(x))`. Otherwise, this returns `None`. - fn map(self, f: fn(T) -> U) -> Option { + fn map(self, f: fn[Env](T) -> U) -> Option { if self._is_some { Option::some(f(self._value)) } else { @@ -66,7 +66,7 @@ impl Option { } /// If self is `Some(x)`, this returns `f(x)`. Otherwise, this returns the given default value. - fn map_or(self, default: U, f: fn(T) -> U) -> U { + fn map_or(self, default: U, f: fn[Env](T) -> U) -> U { if self._is_some { f(self._value) } else { @@ -75,7 +75,7 @@ impl Option { } /// If self is `Some(x)`, this returns `f(x)`. Otherwise, this returns `default()`. - fn map_or_else(self, default: fn() -> U, f: fn(T) -> U) -> U { + fn map_or_else(self, default: fn[Env1]() -> U, f: fn[Env2](T) -> U) -> U { if self._is_some { f(self._value) } else { @@ -96,7 +96,7 @@ impl Option { /// with the Some value contained within self, and returns the result of that call. /// /// In some languages this function is called `flat_map` or `bind`. - fn and_then(self, f: fn(T) -> Option) -> Option { + fn and_then(self, f: fn[Env](T) -> Option) -> Option { if self._is_some { f(self._value) } else { @@ -114,7 +114,7 @@ impl Option { } /// If self is Some, return self. Otherwise, return `default()`. - fn or_else(self, default: fn() -> Self) -> Self { + fn or_else(self, default: fn[Env]() -> Self) -> Self { if self._is_some { self } else { @@ -140,7 +140,7 @@ impl Option { /// Returns `Some(x)` if self is `Some(x)` and `predicate(x)` is true. /// Otherwise, this returns `None` - fn filter(self, predicate: fn(T) -> bool) -> Self { + fn filter(self, predicate: fn[Env](T) -> bool) -> Self { if self._is_some { if predicate(self._value) { self