From 771904ed57e08ea730415d5587f01af99f5423d9 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Wed, 5 Jul 2023 22:53:13 +0000 Subject: [PATCH 01/20] feat: defunctionalization pass --- .../7_function/src/main.nr | 10 +- .../brillig_fns_as_values/Nargo.toml | 5 + .../brillig_fns_as_values/Prover.toml | 1 + .../brillig_fns_as_values/src/main.nr | 32 ++ .../src/brillig/brillig_gen/brillig_fn.rs | 13 +- .../src/brillig/brillig_ir/artifact.rs | 5 +- crates/noirc_evaluator/src/ssa_refactor.rs | 6 +- .../src/ssa_refactor/acir_gen/mod.rs | 21 +- .../src/ssa_refactor/ir/dfg.rs | 11 + .../src/ssa_refactor/ir/function.rs | 18 +- .../src/ssa_refactor/ir/map.rs | 5 + .../src/ssa_refactor/ir/printer.rs | 6 +- .../src/ssa_refactor/opt/defunctionalize.rs | 277 ++++++++++++++++++ .../src/ssa_refactor/opt/mod.rs | 1 + .../src/ssa_refactor/ssa_gen/program.rs | 21 ++ 15 files changed, 393 insertions(+), 39 deletions(-) create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Prover.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr create mode 100644 crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr index 5a23b493871..6c10cfca788 100644 --- a/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr @@ -88,12 +88,14 @@ fn test_multiple6(a: my2, b: my_struct, c: (my2, my_struct)) { } -fn foo(a: [Field]) -> [Field] { + +fn foo(a: [Field; N]) -> [Field; N] { a } -fn bar() -> [Field] { - foo([0]) +fn bar(unused: [Field; N]) -> [Field; N] { + foo(unused) } + fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) { let mut ss: my_struct = my_struct { b: x, a: x+2, }; test_multiple4(ss); @@ -115,7 +117,7 @@ fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) { test0(a); test1(a); test2(x as Field, y); - assert(bar()[0] == 0); + assert(bar([0])[0] == 0); let mut b = [0 as u8, 5 as u8, 2 as u8]; let c = test3(b); diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Nargo.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Nargo.toml new file mode 100644 index 00000000000..e0b467ce5da --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Nargo.toml @@ -0,0 +1,5 @@ +[package] +authors = [""] +compiler_version = "0.1" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Prover.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Prover.toml new file mode 100644 index 00000000000..11497a473bc --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Prover.toml @@ -0,0 +1 @@ +x = "0" diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr new file mode 100644 index 00000000000..08addd44b0d --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr @@ -0,0 +1,32 @@ +// Tests a very simple program. +// +// The features being tested are brillig functions as values + +struct MyStruct { + operation: fn (u32) -> u32, +} + +fn main(x: u32) { + assert(wrapper(increment, x) == x + 1); + assert(wrapper(decrement, x) == x - 1); + assert(wrapper_with_struct(MyStruct { operation: increment }, x) == x + 1); + assert(wrapper_with_struct(MyStruct { operation: decrement }, x) == x - 1); +} + +unconstrained fn wrapper(func: fn (u32) -> u32, param: u32) -> u32 { + func(param) +} + +unconstrained fn increment(x: u32) -> u32 { + x + 1 +} + +unconstrained fn decrement(x: u32) -> u32 { + x - 1 +} + +unconstrained fn wrapper_with_struct(my_struct: MyStruct, param: u32) -> u32 { + let func = my_struct.operation; + func(param) +} + diff --git a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs index 819f0ae26c7..a501e9117a2 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs @@ -9,7 +9,6 @@ use crate::{ }, ssa_refactor::ir::{ function::{Function, FunctionId}, - instruction::TerminatorInstruction, types::Type, value::ValueId, }, @@ -71,17 +70,7 @@ impl FunctionContext { /// Collects the return values of a given function pub(crate) fn return_values(func: &Function) -> Vec { - let blocks = func.reachable_blocks(); - let mut function_return_values = None; - for block in blocks { - let terminator = func.dfg[block].terminator(); - if let Some(TerminatorInstruction::Return { return_values }) = terminator { - function_return_values = Some(return_values); - break; - } - } - function_return_values - .expect("Expected a return instruction, as block is finished construction") + func.returns() .iter() .map(|&value_id| { let typ = func.dfg.type_of_value(value_id); diff --git a/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs b/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs index 2eaeee8f636..a8a61d985c6 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs @@ -79,6 +79,7 @@ impl BrilligArtifact { return_parameters: Vec, target_function: Label, ) -> BrilligArtifact { + println!("Creating entry point artifact for function {}", target_function); let mut entry_point_artifact = BrilligArtifact::new(arguments, return_parameters); entry_point_artifact.entry_point_instruction(); @@ -154,6 +155,7 @@ impl BrilligArtifact { /// Brillig artifact (self). pub(crate) fn link_with(&mut self, func_label: Label, obj: &BrilligArtifact) { // Add the unresolved jumps of the linked function to this artifact. + println!("Linking with {}", func_label); self.add_unresolved_jumps_and_calls(obj); let mut byte_code = obj.byte_code.clone(); @@ -169,7 +171,7 @@ impl BrilligArtifact { self.byte_code.append(&mut byte_code); // Remove all resolved external calls and transform them to jumps - let is_resolved = |label: &Label| label == &func_label; + let is_resolved = |label: &Label| self.labels.get(label).is_some(); let resolved_external_calls = self .unresolved_external_call_labels @@ -183,6 +185,7 @@ impl BrilligArtifact { } self.unresolved_external_call_labels.retain(|(_, label)| !is_resolved(label)); + println!("Unresolved external calls: {:?}", self.unresolved_external_call_labels); } /// Adds unresolved jumps & function calls from another artifact offset by the current opcode count in the artifact. diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index 4f11e147809..a2beae3a0d6 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -34,7 +34,11 @@ pub(crate) fn optimize_into_acir( print_ssa_passes: bool, ) -> GeneratedAcir { let abi_distinctness = program.return_distinctness; - let mut ssa = ssa_gen::generate_ssa(program).print(print_ssa_passes, "Initial SSA:"); + let mut ssa = ssa_gen::generate_ssa(program) + .print(print_ssa_passes, "Initial SSA:") + .defunctionalize() + .print(print_ssa_passes, "After Defunctionalization:"); + let brillig = ssa.to_brillig(); if let RuntimeType::Acir = ssa.main().runtime() { ssa = ssa diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs index 307bcce5a35..85e59a5219f 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs @@ -148,9 +148,8 @@ impl Context { self.create_value_from_type(&typ, &mut |this, _| this.acir_context.add_variable()) }); - let outputs: Vec = vecmap(self.get_return_values(main_func), |result_id| { - dfg.type_of_value(result_id).into() - }); + let outputs: Vec = + vecmap(main_func.returns(), |result_id| dfg.type_of_value(*result_id).into()); let code = self.gen_brillig_for(main_func, &brillig); @@ -423,22 +422,6 @@ impl Context { self.define_result(dfg, instruction, AcirValue::Var(result, typ)); } - /// Finds the return values of a given function - fn get_return_values(&self, func: &Function) -> Vec { - let blocks = func.reachable_blocks(); - let mut function_return_values = None; - for block in blocks { - let terminator = func.dfg[block].terminator(); - if let Some(TerminatorInstruction::Return { return_values }) = terminator { - function_return_values = Some(return_values); - break; - } - } - function_return_values - .expect("Expected a return instruction, as block is finished construction") - .clone() - } - /// Converts an SSA terminator's return values into their ACIR representations fn convert_ssa_return(&mut self, terminator: &TerminatorInstruction, dfg: &DataFlowGraph) { let return_values = match terminator { diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 393b85fdd2f..a5bbfc1856a 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -109,6 +109,11 @@ impl DataFlowGraph { self.blocks.iter() } + /// Gets a vec of the value ids of the function + pub(crate) fn value_ids(&self) -> Vec { + self.values.iter().map(|(id, _)| id).collect() + } + /// Returns the parameters of the given block pub(crate) fn block_parameters(&self, block: BasicBlockId) -> &[ValueId] { self.blocks[block].parameters() @@ -384,6 +389,12 @@ impl std::ops::Index for DataFlowGraph { } } +impl std::ops::IndexMut for DataFlowGraph { + fn index_mut(&mut self, id: ValueId) -> &mut Self::Output { + &mut self.values[id] + } +} + impl std::ops::Index for DataFlowGraph { type Output = BasicBlock; fn index(&self, id: BasicBlockId) -> &Self::Output { diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs index 8fe2fe745ff..c01ada12b6b 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs @@ -2,11 +2,12 @@ use std::collections::HashSet; use super::basic_block::BasicBlockId; use super::dfg::DataFlowGraph; +use super::instruction::TerminatorInstruction; use super::map::Id; use super::types::Type; use super::value::ValueId; -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] pub(crate) enum RuntimeType { // A noir function, to be compiled in ACIR and executed by ACVM Acir, @@ -84,6 +85,21 @@ impl Function { self.dfg.block_parameters(self.entry_block) } + /// Returns the return values of this function. + pub(crate) fn returns(&self) -> &[ValueId] { + let blocks = self.reachable_blocks(); + let mut function_return_values = None; + for block in blocks { + let terminator = self.dfg[block].terminator(); + if let Some(TerminatorInstruction::Return { return_values }) = terminator { + function_return_values = Some(return_values); + break; + } + } + function_return_values + .expect("Expected a return instruction, as block is finished construction") + } + /// Collects all the reachable blocks of this function. /// /// Note that self.dfg.basic_blocks_iter() iterates over all blocks, diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs index e7f9d812de3..bb0da6a8558 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs @@ -25,6 +25,11 @@ impl Id { Self { index, _marker: std::marker::PhantomData } } + /// Returns the underlying index of this Id. + pub(crate) fn to_usize(self) -> usize { + self.index + } + /// Creates a test Id with the given index. /// The name of this function makes it apparent it should only /// be used for testing. Obtaining Ids in this way should be avoided diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs index 071f1a16029..fa8cc674d27 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs @@ -6,6 +6,8 @@ use std::{ use iter_extended::vecmap; +use crate::ssa_refactor::ir::function::RuntimeType; + use super::{ basic_block::BasicBlockId, function::Function, @@ -15,7 +17,9 @@ use super::{ /// Helper function for Function's Display impl to pretty-print the function with the given formatter. pub(crate) fn display_function(function: &Function, f: &mut Formatter) -> Result { - writeln!(f, "fn {} {} {{", function.name(), function.id())?; + let runtime: &str = + if let RuntimeType::Brillig = function.runtime() { "brillig" } else { "acir" }; + writeln!(f, "fn {} {} {} {{", runtime, function.name(), function.id())?; display_block_with_successors(function, function.entry_block(), &mut HashSet::new(), f)?; write!(f, "}}") } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs new file mode 100644 index 00000000000..5f4b0c020e7 --- /dev/null +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -0,0 +1,277 @@ +use std::collections::{HashMap, HashSet}; + +use iter_extended::vecmap; + +use crate::ssa_refactor::{ + ir::{ + basic_block::BasicBlockId, + function::{Function, FunctionId, RuntimeType}, + instruction::{BinaryOp, Instruction}, + types::{NumericType, Type}, + value::Value, + }, + ssa_builder::FunctionBuilder, + ssa_gen::Ssa, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct FunctionSignature { + parameters: Vec, + returns: Vec, + runtime: RuntimeType, +} + +impl FunctionSignature { + fn from(function: &Function) -> Self { + let parameters = vecmap(function.parameters(), |param| function.dfg.type_of_value(*param)); + let returns = vecmap(function.returns(), |ret| function.dfg.type_of_value(*ret)); + let runtime = function.runtime(); + Self { parameters, returns, runtime } + } +} + +/// Performs defunctionalization on all functions +/// This is done by changing all functions as value to be a number (FieldElement) +/// And creating apply functions that dispatch to the correct target by runtime comparisons with constants +#[derive(Debug, Clone)] +struct DefunctionalizationContext { + fn_to_runtime: HashMap, + variants: HashMap>, + apply_functions: HashMap, +} + +impl DefunctionalizationContext { + /// Returns the new ssa with the function defunctionalized + pub(crate) fn defunctionalize_ssa(mut ssa: Ssa) -> Ssa { + // Find all functions that share the signature + let variants = find_variants(&ssa); + // Create apply functions + let apply_functions = create_apply_functions(&mut ssa, &variants); + let fn_to_runtime = ssa + .functions + .iter() + .map(|(func_id, func)| (*func_id, func.runtime())) + .collect::>(); + + let context = DefunctionalizationContext { fn_to_runtime, variants, apply_functions }; + + context.defunctionalize_all(ssa) + } + + /// Defunctionalize all functions in the Ssa + fn defunctionalize_all(mut self, mut ssa: Ssa) -> Ssa { + let func_ids = ssa.functions.keys().copied().collect::>(); + for func_id in func_ids { + ssa = self.defunctionalize(func_id, ssa); + } + ssa + } + + /// Defunctionalize a single function + fn defunctionalize(&mut self, func_id: FunctionId, mut ssa: Ssa) -> Ssa { + let func = ssa.get_fn_mut(func_id); + let mut target_function_ids = HashSet::new(); + + for block_id in func.reachable_blocks() { + let block = &func.dfg[block_id]; + let instructions = block.instructions().to_vec(); + + for instruction_id in instructions { + let instruction = func.dfg[instruction_id].clone(); + let mut new_instruction = None; + // Operate on call instructions + if let Instruction::Call { func: target_func_id, arguments } = instruction { + match func.dfg[target_func_id] { + // If the target is a function used as value + Value::Param { .. } | Value::Instruction { .. } => { + // Collect the argument types + let argument_types = + vecmap(arguments.to_owned(), |arg| func.dfg.type_of_value(arg)); + + // Collect the result types + let result_types = vecmap( + func.dfg.instruction_results(instruction_id).to_owned(), + |result| func.dfg.type_of_value(result), + ); + // Find the correct apply function + let apply_function = self.get_apply_function(&FunctionSignature { + parameters: argument_types, + returns: result_types, + runtime: func.runtime(), + }); + target_function_ids.insert(apply_function); + + // Replace the instruction with a call to apply + let apply_function = func.dfg.import_function(apply_function); + let mut new_arguments = vec![target_func_id]; + new_arguments.extend(arguments); + new_instruction = Some(Instruction::Call { + func: apply_function, + arguments: new_arguments, + }); + } + Value::Function(id) => { + target_function_ids.insert(id); + } + _ => {} + } + } + if let Some(new_instruction) = new_instruction { + func.dfg[instruction_id] = new_instruction; + } + } + } + + // Change the type of all the values that are not call targets to NativeField + for value_id in func.dfg.value_ids() { + let value = &mut func.dfg[value_id]; + if let Type::Function = value.get_type() { + // If the value is a static function, transform it to the function id + if let Value::Function(id) = value { + let id = *id; + if !target_function_ids.contains(&id) { + *value = Value::NumericConstant { + constant: (id.to_usize() as u128).into(), + typ: Type::Numeric(NumericType::NativeField), + } + } + } + // If it is a dynamic function, just change the type + if let Value::Instruction { typ, .. } | Value::Param { typ, .. } = value { + *typ = Type::Numeric(NumericType::NativeField); + } + } + } + + ssa + } + + /// Returns the apply function for the given signature + fn get_apply_function(&self, signature: &FunctionSignature) -> FunctionId { + *self.apply_functions.get(signature).expect("Could not find apply function") + } +} + +fn find_variants(ssa: &Ssa) -> HashMap> { + let mut variants: HashMap> = HashMap::new(); + for function in ssa.functions.values() { + if function.id() != ssa.main_id { + let signature = FunctionSignature::from(function); + variants.entry(signature).or_default().push(function.id()); + } + } + variants +} + +fn create_apply_functions( + ssa: &mut Ssa, + variants_map: &HashMap>, +) -> HashMap { + let mut apply_functions = HashMap::new(); + for (signature, variants) in variants_map.iter() { + let apply_function = create_apply_function(ssa, signature, variants); + apply_functions.insert(signature.clone(), apply_function); + } + apply_functions +} + +/// Creates an apply function for the given signature and variants +fn create_apply_function( + ssa: &mut Ssa, + signature: &FunctionSignature, + function_ids: &[FunctionId], +) -> FunctionId { + assert!(!function_ids.is_empty()); + ssa.add_fn(|id| { + let mut function_builder = FunctionBuilder::new("apply".to_string(), id, signature.runtime); + let target_id = function_builder.add_parameter(Type::Numeric(NumericType::NativeField)); + let params_ids = + vecmap(signature.parameters.clone(), |typ| function_builder.add_parameter(typ)); + + let mut previous_target_block = None; + for (index, function_id) in function_ids.iter().enumerate() { + let is_last = index == function_ids.len() - 1; + let mut next_function_block = None; + + // If it's not the last function to dispatch, crate an if statement + if !is_last { + next_function_block = Some(function_builder.insert_block()); + let executor_block = function_builder.insert_block(); + let function_id_constant = function_builder.numeric_constant( + function_id.to_usize() as u128, + Type::Numeric(NumericType::NativeField), + ); + let condition = + function_builder.insert_binary(target_id, BinaryOp::Eq, function_id_constant); + + function_builder.terminate_with_jmpif( + condition, + executor_block, + next_function_block.unwrap(), + ); + function_builder.switch_to_block(executor_block); + } + // Find the target block or build it if necessary + let target_block = match previous_target_block { + Some(block) => { + let current_block = function_builder.current_block(); + build_return_block( + &mut function_builder, + current_block, + signature.returns.clone(), + Some(block), + ) + } + None => { + let current_block = function_builder.current_block(); + build_return_block( + &mut function_builder, + current_block, + signature.returns.clone(), + None, + ) + } + }; + previous_target_block = Some(target_block); + + // Call the function + let target_function_value = function_builder.import_function(*function_id); + let call_results = function_builder + .insert_call(target_function_value, params_ids.clone(), signature.returns.clone()) + .to_vec(); + + // Jump to the target block for returning + function_builder.terminate_with_jmp(target_block, call_results); + + if let Some(next_block) = next_function_block { + // Switch to the next block for the else branch + function_builder.switch_to_block(next_block); + } + } + function_builder.current_function + }) +} + +fn build_return_block( + builder: &mut FunctionBuilder, + previous_block: BasicBlockId, + passed_types: Vec, + target: Option, +) -> BasicBlockId { + let return_block = builder.insert_block(); + builder.switch_to_block(return_block); + + let params = vecmap(passed_types, |typ| builder.add_block_parameter(return_block, typ)); + match target { + None => builder.terminate_with_return(params), + Some(target) => builder.terminate_with_jmp(target, params), + } + builder.switch_to_block(previous_block); + return_block +} + +impl Ssa { + pub(crate) fn defunctionalize(self) -> Ssa { + DefunctionalizationContext::defunctionalize_ssa(self) + } +} diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs index 56c5fa689ad..0d4ad594486 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs @@ -4,6 +4,7 @@ //! simpler form until the IR only has a single function remaining with 1 block within it. //! Generally, these passes are also expected to minimize the final amount of instructions. mod constant_folding; +mod defunctionalize; mod die; mod flatten_cfg; mod inlining; diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs index ba98c658505..9e7cfd7a25b 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs @@ -34,10 +34,31 @@ impl Ssa { &self.functions[&self.main_id] } + /// Returns the function with the given ID + pub(crate) fn get_fn(&self, id: FunctionId) -> &Function { + &self.functions[&id] + } + /// Returns the entry-point function of the program as a mutable reference pub(crate) fn main_mut(&mut self) -> &mut Function { self.functions.get_mut(&self.main_id).expect("ICE: Ssa should have a main function") } + + /// Returns the function with the given ID as a mutable reference + pub(crate) fn get_fn_mut(&mut self, id: FunctionId) -> &mut Function { + self.functions.get_mut(&id).expect("ICE: Could not find function") + } + + /// Adds a new function to the program + pub(crate) fn add_fn( + &mut self, + build_with_id: impl FnOnce(FunctionId) -> Function, + ) -> FunctionId { + let new_id = self.next_id.next(); + let function = build_with_id(new_id); + self.functions.insert(new_id, function); + new_id + } } impl Display for Ssa { From 6180c3da194e6fc58169fbe48eee8a4f92d3e548 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Thu, 6 Jul 2023 00:55:10 +0200 Subject: [PATCH 02/20] Apply suggestions from self code review --- crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs b/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs index a8a61d985c6..cd0502a888e 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs @@ -79,7 +79,6 @@ impl BrilligArtifact { return_parameters: Vec, target_function: Label, ) -> BrilligArtifact { - println!("Creating entry point artifact for function {}", target_function); let mut entry_point_artifact = BrilligArtifact::new(arguments, return_parameters); entry_point_artifact.entry_point_instruction(); @@ -155,7 +154,6 @@ impl BrilligArtifact { /// Brillig artifact (self). pub(crate) fn link_with(&mut self, func_label: Label, obj: &BrilligArtifact) { // Add the unresolved jumps of the linked function to this artifact. - println!("Linking with {}", func_label); self.add_unresolved_jumps_and_calls(obj); let mut byte_code = obj.byte_code.clone(); @@ -185,7 +183,6 @@ impl BrilligArtifact { } self.unresolved_external_call_labels.retain(|(_, label)| !is_resolved(label)); - println!("Unresolved external calls: {:?}", self.unresolved_external_call_labels); } /// Adds unresolved jumps & function calls from another artifact offset by the current opcode count in the artifact. From ad7d53da32503a46b385c31239d8b09742019bc9 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 6 Jul 2023 08:11:06 +0000 Subject: [PATCH 03/20] feat: optimize apply function generation & usage --- .../7_function/src/main.nr | 21 ++++- .../src/ssa_refactor/ir/function.rs | 2 +- .../src/ssa_refactor/opt/defunctionalize.rs | 76 ++++++++++++++++--- 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr index 6c10cfca788..000e202110d 100644 --- a/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr @@ -96,6 +96,15 @@ fn bar(unused: [Field; N]) -> [Field; N] { foo(unused) } +struct FilterWrapper { + skip: u8, + filter: fn ([u32; 2]) -> u32, +} + +fn filterImpl(fields: [u32; 2]) -> u32 { + fields[0] +} + fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) { let mut ss: my_struct = my_struct { b: x, a: x+2, }; test_multiple4(ss); @@ -134,8 +143,18 @@ fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) { //Regression test for issue #628: let result = first(arr_to_field(arr1), arr_to_field(arr2)); assert(result[0] == arr1[0] as Field); -} + // Regression test for issue #1844 + let mut instance = FilterWrapper { + skip: 0, + filter: filterImpl , + }; + + instance.skip = x as u8; + + let filter = instance.filter; + assert(x == filter([x; 2])); +} // Issue #628 fn arr_to_field(arr: [u32; 9]) -> [Field; 9] { diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs index c01ada12b6b..11c9b5adb36 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs @@ -61,7 +61,7 @@ impl Function { /// Runtime type of the function. pub(crate) fn runtime(&self) -> RuntimeType { - self.runtime.clone() + self.runtime } /// Set runtime type of the function. diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 5f4b0c020e7..a6d7ad24c23 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -6,9 +6,9 @@ use crate::ssa_refactor::{ ir::{ basic_block::BasicBlockId, function::{Function, FunctionId, RuntimeType}, - instruction::{BinaryOp, Instruction}, + instruction::{BinaryOp, Instruction, TerminatorInstruction}, types::{NumericType, Type}, - value::Value, + value::{Value, ValueId}, }, ssa_builder::FunctionBuilder, ssa_gen::Ssa, @@ -152,17 +152,63 @@ impl DefunctionalizationContext { } } +/// Collects all functions used as a value by their signatures fn find_variants(ssa: &Ssa) -> HashMap> { let mut variants: HashMap> = HashMap::new(); + let mut functions_used_as_values = HashSet::new(); + for function in ssa.functions.values() { - if function.id() != ssa.main_id { - let signature = FunctionSignature::from(function); - variants.entry(signature).or_default().push(function.id()); - } + functions_used_as_values.extend(functions_as_values(function)); + } + + for function_id in functions_used_as_values { + let function = ssa.get_fn(function_id); + let signature = FunctionSignature::from(function); + variants.entry(signature).or_default().push(function_id); } + variants } +/// Finds all literal functions used as values in the given function +fn functions_as_values(func: &Function) -> HashSet { + let mut functions = HashSet::new(); + + let mut append_functions = |values: &[ValueId]| { + for value in values { + if let Value::Function(id) = func.dfg[*value] { + functions.insert(id); + } + } + }; + + for block_id in func.reachable_blocks() { + let block = &func.dfg[block_id]; + for instruction_id in block.instructions() { + let instruction = &func.dfg[*instruction_id]; + match instruction { + Instruction::Call { arguments, .. } => { + append_functions(arguments); + } + Instruction::Store { value, .. } => { + append_functions(&[*value]); + } + _ => {} + } + } + match block.terminator() { + Some(TerminatorInstruction::Jmp { arguments, .. }) => { + append_functions(arguments); + } + Some(TerminatorInstruction::Return { return_values }) => { + append_functions(return_values); + } + _ => {} + } + } + functions +} + fn create_apply_functions( ssa: &mut Ssa, variants_map: &HashMap>, @@ -193,16 +239,17 @@ fn create_apply_function( let is_last = index == function_ids.len() - 1; let mut next_function_block = None; + let function_id_constant = function_builder.numeric_constant( + function_id.to_usize() as u128, + Type::Numeric(NumericType::NativeField), + ); + let condition = + function_builder.insert_binary(target_id, BinaryOp::Eq, function_id_constant); + // If it's not the last function to dispatch, crate an if statement if !is_last { next_function_block = Some(function_builder.insert_block()); let executor_block = function_builder.insert_block(); - let function_id_constant = function_builder.numeric_constant( - function_id.to_usize() as u128, - Type::Numeric(NumericType::NativeField), - ); - let condition = - function_builder.insert_binary(target_id, BinaryOp::Eq, function_id_constant); function_builder.terminate_with_jmpif( condition, @@ -210,6 +257,9 @@ fn create_apply_function( next_function_block.unwrap(), ); function_builder.switch_to_block(executor_block); + } else { + // Else just constrain the condition + function_builder.insert_constrain(condition); } // Find the target block or build it if necessary let target_block = match previous_target_block { @@ -252,6 +302,8 @@ fn create_apply_function( }) } +/// Crates a return block, if no previous return exists, it will create a final return +/// Else, it will create a bypass return block that points to the previous return block fn build_return_block( builder: &mut FunctionBuilder, previous_block: BasicBlockId, From b3a9fa56bb6b19be3ca4d2f9ccaf2c3827ac20bd Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 6 Jul 2023 08:28:21 +0000 Subject: [PATCH 04/20] feat: avoid unary apply fns --- .../src/ssa_refactor/opt/defunctionalize.rs | 49 ++++++++++++++----- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index a6d7ad24c23..6548cf25ff2 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -30,6 +30,15 @@ impl FunctionSignature { } } +/// Describes an apply function existing in the SSA +#[derive(Debug, Clone, Copy)] +struct ApplyFunction { + // The function id of the apply function + id: FunctionId, + // Whether the apply function dispatches to other functions or not + is_multiple: bool, +} + /// Performs defunctionalization on all functions /// This is done by changing all functions as value to be a number (FieldElement) /// And creating apply functions that dispatch to the correct target by runtime comparisons with constants @@ -37,7 +46,7 @@ impl FunctionSignature { struct DefunctionalizationContext { fn_to_runtime: HashMap, variants: HashMap>, - apply_functions: HashMap, + apply_functions: HashMap, } impl DefunctionalizationContext { @@ -99,16 +108,24 @@ impl DefunctionalizationContext { returns: result_types, runtime: func.runtime(), }); - target_function_ids.insert(apply_function); + target_function_ids.insert(apply_function.id); // Replace the instruction with a call to apply - let apply_function = func.dfg.import_function(apply_function); - let mut new_arguments = vec![target_func_id]; - new_arguments.extend(arguments); - new_instruction = Some(Instruction::Call { - func: apply_function, - arguments: new_arguments, - }); + let apply_function_value_id = + func.dfg.import_function(apply_function.id); + if apply_function.is_multiple { + let mut new_arguments = vec![target_func_id]; + new_arguments.extend(arguments); + new_instruction = Some(Instruction::Call { + func: apply_function_value_id, + arguments: new_arguments, + }); + } else { + new_instruction = Some(Instruction::Call { + func: apply_function_value_id, + arguments, + }); + } } Value::Function(id) => { target_function_ids.insert(id); @@ -147,7 +164,7 @@ impl DefunctionalizationContext { } /// Returns the apply function for the given signature - fn get_apply_function(&self, signature: &FunctionSignature) -> FunctionId { + fn get_apply_function(&self, signature: &FunctionSignature) -> ApplyFunction { *self.apply_functions.get(signature).expect("Could not find apply function") } } @@ -212,11 +229,17 @@ fn functions_as_values(func: &Function) -> HashSet { fn create_apply_functions( ssa: &mut Ssa, variants_map: &HashMap>, -) -> HashMap { +) -> HashMap { let mut apply_functions = HashMap::new(); for (signature, variants) in variants_map.iter() { - let apply_function = create_apply_function(ssa, signature, variants); - apply_functions.insert(signature.clone(), apply_function); + if variants.len() > 1 { + let apply_function = create_apply_function(ssa, signature, variants); + apply_functions + .insert(signature.clone(), ApplyFunction { id: apply_function, is_multiple: true }); + } else { + apply_functions + .insert(signature.clone(), ApplyFunction { id: variants[0], is_multiple: false }); + } } apply_functions } From c7ba2845564a4ec6efb58a5af3b56ed11d3bf623 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 6 Jul 2023 08:34:36 +0000 Subject: [PATCH 05/20] fix: clippy --- crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs | 2 +- crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs b/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs index cd0502a888e..71b06537bd5 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs @@ -152,7 +152,7 @@ impl BrilligArtifact { /// This method will offset the positions in the Brillig artifact to /// account for the fact that it is being appended to the end of this /// Brillig artifact (self). - pub(crate) fn link_with(&mut self, func_label: Label, obj: &BrilligArtifact) { + pub(crate) fn link_with(&mut self, obj: &BrilligArtifact) { // Add the unresolved jumps of the linked function to this artifact. self.add_unresolved_jumps_and_calls(obj); diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs index 85e59a5219f..be7d7afeb82 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs @@ -350,7 +350,7 @@ impl Context { let artifact = &brillig .find_by_function_label(unresolved_fn_label.clone()) .expect("Cannot find linked fn {unresolved_fn_label}"); - entry_point.link_with(unresolved_fn_label, artifact); + entry_point.link_with(artifact); } // Generate the final bytecode entry_point.finish() From 6a44c1924c566700ad8d6ba13fab68e9406a7b5d Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 6 Jul 2023 10:06:35 +0000 Subject: [PATCH 06/20] style: cleanup after peer review --- .../src/ssa_refactor/acir_gen/mod.rs | 2 +- .../src/ssa_refactor/ir/function.rs | 13 +- .../src/ssa_refactor/ir/printer.rs | 6 +- .../src/ssa_refactor/opt/defunctionalize.rs | 161 +++++++++--------- .../src/ssa_refactor/ssa_gen/program.rs | 2 +- cspell.json | 3 + 6 files changed, 96 insertions(+), 91 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs index be7d7afeb82..2b49bae9e80 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs @@ -769,7 +769,7 @@ impl Context { } /// Convert a Vec into a Vec using the given result ids. - /// If the type of a result id is an array, several acirvars are collected into + /// If the type of a result id is an array, several acir vars are collected into /// a single AcirValue::Array of the same length. fn convert_vars_to_values( vars: Vec, diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs index 11c9b5adb36..76395ea74ab 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs @@ -85,7 +85,7 @@ impl Function { self.dfg.block_parameters(self.entry_block) } - /// Returns the return values of this function. + /// Returns the return types of this function. pub(crate) fn returns(&self) -> &[ValueId] { let blocks = self.reachable_blocks(); let mut function_return_values = None; @@ -97,7 +97,7 @@ impl Function { } } function_return_values - .expect("Expected a return instruction, as block is finished construction") + .expect("Expected a return instruction, as function construction is finished") } /// Collects all the reachable blocks of this function. @@ -118,6 +118,15 @@ impl Function { } } +impl std::fmt::Display for RuntimeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RuntimeType::Acir => write!(f, "acir"), + RuntimeType::Brillig => write!(f, "brillig"), + } + } +} + /// FunctionId is a reference for a function /// /// This Id is how each function refers to other functions diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs index fa8cc674d27..84e43ac6d8b 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs @@ -6,8 +6,6 @@ use std::{ use iter_extended::vecmap; -use crate::ssa_refactor::ir::function::RuntimeType; - use super::{ basic_block::BasicBlockId, function::Function, @@ -17,9 +15,7 @@ use super::{ /// Helper function for Function's Display impl to pretty-print the function with the given formatter. pub(crate) fn display_function(function: &Function, f: &mut Formatter) -> Result { - let runtime: &str = - if let RuntimeType::Brillig = function.runtime() { "brillig" } else { "acir" }; - writeln!(f, "fn {} {} {} {{", runtime, function.name(), function.id())?; + writeln!(f, "fn {} {} {} {{", function.runtime(), function.name(), function.id())?; display_block_with_successors(function, function.entry_block(), &mut HashSet::new(), f)?; write!(f, "}}") } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 6548cf25ff2..03d050853ce 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -1,14 +1,15 @@ use std::collections::{HashMap, HashSet}; +use acvm::FieldElement; use iter_extended::vecmap; use crate::ssa_refactor::{ ir::{ basic_block::BasicBlockId, function::{Function, FunctionId, RuntimeType}, - instruction::{BinaryOp, Instruction, TerminatorInstruction}, + instruction::{BinaryOp, Instruction}, types::{NumericType, Type}, - value::{Value, ValueId}, + value::Value, }, ssa_builder::FunctionBuilder, ssa_gen::Ssa, @@ -50,17 +51,14 @@ struct DefunctionalizationContext { } impl DefunctionalizationContext { - /// Returns the new ssa with the function defunctionalized + /// Returns a defunctionalized ssa pub(crate) fn defunctionalize_ssa(mut ssa: Ssa) -> Ssa { - // Find all functions that share the signature + // Find all functions used as value that share the same signature let variants = find_variants(&ssa); // Create apply functions let apply_functions = create_apply_functions(&mut ssa, &variants); - let fn_to_runtime = ssa - .functions - .iter() - .map(|(func_id, func)| (*func_id, func.runtime())) - .collect::>(); + let fn_to_runtime = + ssa.functions.iter().map(|(func_id, func)| (*func_id, func.runtime())).collect(); let context = DefunctionalizationContext { fn_to_runtime, variants, apply_functions }; @@ -69,7 +67,7 @@ impl DefunctionalizationContext { /// Defunctionalize all functions in the Ssa fn defunctionalize_all(mut self, mut ssa: Ssa) -> Ssa { - let func_ids = ssa.functions.keys().copied().collect::>(); + let func_ids: Vec<_> = ssa.functions.keys().copied().collect(); for func_id in func_ids { ssa = self.defunctionalize(func_id, ssa); } @@ -87,53 +85,57 @@ impl DefunctionalizationContext { for instruction_id in instructions { let instruction = func.dfg[instruction_id].clone(); - let mut new_instruction = None; + let mut replacement_instruction = None; // Operate on call instructions - if let Instruction::Call { func: target_func_id, arguments } = instruction { - match func.dfg[target_func_id] { - // If the target is a function used as value - Value::Param { .. } | Value::Instruction { .. } => { - // Collect the argument types - let argument_types = - vecmap(arguments.to_owned(), |arg| func.dfg.type_of_value(arg)); - - // Collect the result types - let result_types = vecmap( - func.dfg.instruction_results(instruction_id).to_owned(), - |result| func.dfg.type_of_value(result), - ); - // Find the correct apply function - let apply_function = self.get_apply_function(&FunctionSignature { - parameters: argument_types, - returns: result_types, - runtime: func.runtime(), + let (target_func_id, arguments) = match instruction { + Instruction::Call { func: target_func_id, arguments } => { + (target_func_id, arguments) + } + _ => continue, + }; + + match func.dfg[target_func_id] { + // If the target is a function used as value + Value::Param { .. } | Value::Instruction { .. } => { + // Collect the argument types + let argument_types = + vecmap(arguments.to_owned(), |arg| func.dfg.type_of_value(arg)); + + // Collect the result types + let result_types = vecmap( + func.dfg.instruction_results(instruction_id).to_owned(), + |result| func.dfg.type_of_value(result), + ); + // Find the correct apply function + let apply_function = self.get_apply_function(&FunctionSignature { + parameters: argument_types, + returns: result_types, + runtime: func.runtime(), + }); + target_function_ids.insert(apply_function.id); + + // Replace the instruction with a call to apply + let apply_function_value_id = func.dfg.import_function(apply_function.id); + if apply_function.is_multiple { + let mut new_arguments = vec![target_func_id]; + new_arguments.extend(arguments); + replacement_instruction = Some(Instruction::Call { + func: apply_function_value_id, + arguments: new_arguments, + }); + } else { + replacement_instruction = Some(Instruction::Call { + func: apply_function_value_id, + arguments, }); - target_function_ids.insert(apply_function.id); - - // Replace the instruction with a call to apply - let apply_function_value_id = - func.dfg.import_function(apply_function.id); - if apply_function.is_multiple { - let mut new_arguments = vec![target_func_id]; - new_arguments.extend(arguments); - new_instruction = Some(Instruction::Call { - func: apply_function_value_id, - arguments: new_arguments, - }); - } else { - new_instruction = Some(Instruction::Call { - func: apply_function_value_id, - arguments, - }); - } - } - Value::Function(id) => { - target_function_ids.insert(id); } - _ => {} } + Value::Function(id) => { + target_function_ids.insert(id); + } + _ => {} } - if let Some(new_instruction) = new_instruction { + if let Some(new_instruction) = replacement_instruction { func.dfg[instruction_id] = new_instruction; } } @@ -148,7 +150,7 @@ impl DefunctionalizationContext { let id = *id; if !target_function_ids.contains(&id) { *value = Value::NumericConstant { - constant: (id.to_usize() as u128).into(), + constant: function_id_to_field(id), typ: Type::Numeric(NumericType::NativeField), } } @@ -189,41 +191,32 @@ fn find_variants(ssa: &Ssa) -> HashMap> { /// Finds all literal functions used as values in the given function fn functions_as_values(func: &Function) -> HashSet { - let mut functions = HashSet::new(); - - let mut append_functions = |values: &[ValueId]| { - for value in values { - if let Value::Function(id) = func.dfg[*value] { - functions.insert(id); - } - } - }; + let mut literal_functions: HashSet<_> = func + .dfg + .value_ids() + .iter() + .filter_map(|id| match func.dfg[*id] { + Value::Function(id) => Some(id), + _ => None, + }) + .collect(); for block_id in func.reachable_blocks() { let block = &func.dfg[block_id]; for instruction_id in block.instructions() { let instruction = &func.dfg[*instruction_id]; - match instruction { - Instruction::Call { arguments, .. } => { - append_functions(arguments); - } - Instruction::Store { value, .. } => { - append_functions(&[*value]); - } - _ => {} - } - } - match block.terminator() { - Some(TerminatorInstruction::Jmp { arguments, .. }) => { - append_functions(arguments); - } - Some(TerminatorInstruction::Return { return_values }) => { - append_functions(return_values); - } - _ => {} + let target_value = match instruction { + Instruction::Call { func, .. } => func, + _ => continue, + }; + let target_id = match func.dfg[*target_value] { + Value::Function(id) => id, + _ => continue, + }; + literal_functions.remove(&target_id); } } - functions + literal_functions } fn create_apply_functions( @@ -244,6 +237,10 @@ fn create_apply_functions( apply_functions } +fn function_id_to_field(function_id: FunctionId) -> FieldElement { + (function_id.to_usize() as u128).into() +} + /// Creates an apply function for the given signature and variants fn create_apply_function( ssa: &mut Ssa, @@ -263,7 +260,7 @@ fn create_apply_function( let mut next_function_block = None; let function_id_constant = function_builder.numeric_constant( - function_id.to_usize() as u128, + function_id_to_field(*function_id), Type::Numeric(NumericType::NativeField), ); let condition = diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs index 9e7cfd7a25b..0c1d337cc1e 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs @@ -36,7 +36,7 @@ impl Ssa { /// Returns the function with the given ID pub(crate) fn get_fn(&self, id: FunctionId) -> &Function { - &self.functions[&id] + self.functions.get(&id).expect("ICE: Could not find function") } /// Returns the entry-point function of the program as a mutable reference diff --git a/cspell.json b/cspell.json index e10e700cdb6..c067e3880e7 100644 --- a/cspell.json +++ b/cspell.json @@ -15,6 +15,9 @@ "combinators", "comptime", "cranelift", + "defunctionalize", + "defunctionalized", + "defunctionalization", "desugared", "endianness", "forall", From 64081c86c09a56bc7dc00a532c7b1ca3226f00f2 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 6 Jul 2023 10:12:17 +0000 Subject: [PATCH 07/20] docs: updated comments on defunctionalize --- .../noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 03d050853ce..9a379a1d4c3 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -1,3 +1,9 @@ +//! This module defines the defunctionalization pass for the SSA IR. +//! The purpose of this pass is to transforms all functions used as values into +//! constant numbers (fields) that represent the function id. That way all calls +//! with a non-literal target can be replaced with a call to an apply function. +//! The apply function is a dispatch function that takes the function id as a parameter +//! and dispatches to the correct target. use std::collections::{HashMap, HashSet}; use acvm::FieldElement; From ecf51147640e542c6385e88c85c629a0aca835b7 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 6 Jul 2023 12:22:11 +0000 Subject: [PATCH 08/20] style: apply suggestions from peer review --- .../7_function/src/main.nr | 25 +----- .../brillig_fns_as_values/src/main.nr | 4 - .../src/ssa_refactor/ir/dfg.rs | 10 +-- .../src/ssa_refactor/ir/printer.rs | 2 +- .../src/ssa_refactor/opt/defunctionalize.rs | 76 +++++++++++-------- 5 files changed, 51 insertions(+), 66 deletions(-) diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr index 000e202110d..26ecf6dda28 100644 --- a/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr @@ -92,17 +92,9 @@ fn test_multiple6(a: my2, b: my_struct, c: (my2, my_struct)) { fn foo(a: [Field; N]) -> [Field; N] { a } -fn bar(unused: [Field; N]) -> [Field; N] { - foo(unused) -} - -struct FilterWrapper { - skip: u8, - filter: fn ([u32; 2]) -> u32, -} -fn filterImpl(fields: [u32; 2]) -> u32 { - fields[0] +fn bar() -> [Field; 1] { + foo([0]) } fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) { @@ -126,7 +118,7 @@ fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) { test0(a); test1(a); test2(x as Field, y); - assert(bar([0])[0] == 0); + assert(bar()[0] == 0); let mut b = [0 as u8, 5 as u8, 2 as u8]; let c = test3(b); @@ -143,17 +135,6 @@ fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) { //Regression test for issue #628: let result = first(arr_to_field(arr1), arr_to_field(arr2)); assert(result[0] == arr1[0] as Field); - - // Regression test for issue #1844 - let mut instance = FilterWrapper { - skip: 0, - filter: filterImpl , - }; - - instance.skip = x as u8; - - let filter = instance.filter; - assert(x == filter([x; 2])); } // Issue #628 diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr index 08addd44b0d..5af542301ec 100644 --- a/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr @@ -1,7 +1,3 @@ -// Tests a very simple program. -// -// The features being tested are brillig functions as values - struct MyStruct { operation: fn (u32) -> u32, } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index a5bbfc1856a..56fd918c910 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -110,8 +110,8 @@ impl DataFlowGraph { } /// Gets a vec of the value ids of the function - pub(crate) fn value_ids(&self) -> Vec { - self.values.iter().map(|(id, _)| id).collect() + pub(crate) fn values_iter(&self) -> impl ExactSizeIterator { + self.values.iter() } /// Returns the parameters of the given block @@ -389,12 +389,6 @@ impl std::ops::Index for DataFlowGraph { } } -impl std::ops::IndexMut for DataFlowGraph { - fn index_mut(&mut self, id: ValueId) -> &mut Self::Output { - &mut self.values[id] - } -} - impl std::ops::Index for DataFlowGraph { type Output = BasicBlock; fn index(&self, id: BasicBlockId) -> &Self::Output { diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs index 84e43ac6d8b..f2fb90b3464 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs @@ -15,7 +15,7 @@ use super::{ /// Helper function for Function's Display impl to pretty-print the function with the given formatter. pub(crate) fn display_function(function: &Function, f: &mut Formatter) -> Result { - writeln!(f, "fn {} {} {} {{", function.runtime(), function.name(), function.id())?; + writeln!(f, "{} fn {} {} {{", function.runtime(), function.name(), function.id())?; display_block_with_successors(function, function.entry_block(), &mut HashSet::new(), f)?; write!(f, "}}") } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 9a379a1d4c3..18a4faa0bdb 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -15,7 +15,7 @@ use crate::ssa_refactor::{ function::{Function, FunctionId, RuntimeType}, instruction::{BinaryOp, Instruction}, types::{NumericType, Type}, - value::Value, + value::{Value, ValueId}, }, ssa_builder::FunctionBuilder, ssa_gen::Ssa, @@ -68,21 +68,19 @@ impl DefunctionalizationContext { let context = DefunctionalizationContext { fn_to_runtime, variants, apply_functions }; - context.defunctionalize_all(ssa) + context.defunctionalize_all(&mut ssa); + ssa } /// Defunctionalize all functions in the Ssa - fn defunctionalize_all(mut self, mut ssa: Ssa) -> Ssa { - let func_ids: Vec<_> = ssa.functions.keys().copied().collect(); - for func_id in func_ids { - ssa = self.defunctionalize(func_id, ssa); + fn defunctionalize_all(mut self, ssa: &mut Ssa) { + for function in ssa.functions.values_mut() { + self.defunctionalize(function); } - ssa } /// Defunctionalize a single function - fn defunctionalize(&mut self, func_id: FunctionId, mut ssa: Ssa) -> Ssa { - let func = ssa.get_fn_mut(func_id); + fn defunctionalize(&mut self, func: &mut Function) { let mut target_function_ids = HashSet::new(); for block_id in func.reachable_blocks() { @@ -104,14 +102,14 @@ impl DefunctionalizationContext { // If the target is a function used as value Value::Param { .. } | Value::Instruction { .. } => { // Collect the argument types - let argument_types = - vecmap(arguments.to_owned(), |arg| func.dfg.type_of_value(arg)); + let argument_types: Vec = + arguments.iter().map(|arg| func.dfg.type_of_value(*arg)).collect(); // Collect the result types - let result_types = vecmap( - func.dfg.instruction_results(instruction_id).to_owned(), - |result| func.dfg.type_of_value(result), - ); + let result_types = + vecmap(func.dfg.instruction_results(instruction_id), |result| { + func.dfg.type_of_value(*result) + }); // Find the correct apply function let apply_function = self.get_apply_function(&FunctionSignature { parameters: argument_types, @@ -148,27 +146,44 @@ impl DefunctionalizationContext { } // Change the type of all the values that are not call targets to NativeField - for value_id in func.dfg.value_ids() { - let value = &mut func.dfg[value_id]; + let value_ids: Vec = func.dfg.values_iter().map(|(id, _)| id).collect(); + for value_id in value_ids { + let value = &func.dfg[value_id]; if let Type::Function = value.get_type() { // If the value is a static function, transform it to the function id - if let Value::Function(id) = value { - let id = *id; - if !target_function_ids.contains(&id) { - *value = Value::NumericConstant { - constant: function_id_to_field(id), - typ: Type::Numeric(NumericType::NativeField), + let mut replacement_value_id = None; + + match value { + Value::Function(id) => { + if !target_function_ids.contains(id) { + replacement_value_id = Some(func.dfg.make_constant( + function_id_to_field(*id), + Type::Numeric(NumericType::NativeField), + )); } } + Value::Instruction { instruction, position, .. } => { + replacement_value_id = Some(func.dfg.make_value(Value::Instruction { + instruction: *instruction, + position: *position, + typ: Type::Numeric(NumericType::NativeField), + })); + } + Value::Param { block, position, .. } => { + replacement_value_id = Some(func.dfg.make_value(Value::Param { + block: *block, + position: *position, + typ: Type::Numeric(NumericType::NativeField), + })); + } + _ => {} } - // If it is a dynamic function, just change the type - if let Value::Instruction { typ, .. } | Value::Param { typ, .. } = value { - *typ = Type::Numeric(NumericType::NativeField); + + if let Some(new_value_id) = replacement_value_id { + func.dfg.set_value_from_id(value_id, new_value_id); } } } - - ssa } /// Returns the apply function for the given signature @@ -199,9 +214,8 @@ fn find_variants(ssa: &Ssa) -> HashMap> { fn functions_as_values(func: &Function) -> HashSet { let mut literal_functions: HashSet<_> = func .dfg - .value_ids() - .iter() - .filter_map(|id| match func.dfg[*id] { + .values_iter() + .filter_map(|(id, _)| match func.dfg[id] { Value::Function(id) => Some(id), _ => None, }) From c6f9d9fa61d115eb63d7535542a1c6221ae8c4a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Thu, 6 Jul 2023 16:28:30 +0200 Subject: [PATCH 09/20] Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher --- .../src/ssa_refactor/opt/defunctionalize.rs | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 18a4faa0bdb..c50a3bd3846 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -56,21 +56,23 @@ struct DefunctionalizationContext { apply_functions: HashMap, } -impl DefunctionalizationContext { - /// Returns a defunctionalized ssa - pub(crate) fn defunctionalize_ssa(mut ssa: Ssa) -> Ssa { +impl Ssa { + pub(crate) fn defunctionalize(mut self) -> Ssa { // Find all functions used as value that share the same signature - let variants = find_variants(&ssa); - // Create apply functions - let apply_functions = create_apply_functions(&mut ssa, &variants); + let variants = find_variants(&self); + + let apply_functions = create_apply_functions(&mut self, &variants); let fn_to_runtime = - ssa.functions.iter().map(|(func_id, func)| (*func_id, func.runtime())).collect(); + self.functions.iter().map(|(func_id, func)| (*func_id, func.runtime())).collect(); let context = DefunctionalizationContext { fn_to_runtime, variants, apply_functions }; - context.defunctionalize_all(&mut ssa); - ssa + context.defunctionalize_all(&mut self); + self } +} + +impl DefunctionalizationContext { /// Defunctionalize all functions in the Ssa fn defunctionalize_all(mut self, ssa: &mut Ssa) { From dcfa9e852251e8c6becd49c340dd72dfc4934e0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Thu, 6 Jul 2023 16:28:51 +0200 Subject: [PATCH 10/20] Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher --- .../noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index c50a3bd3846..5cdb2c6ca92 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -363,9 +363,3 @@ fn build_return_block( builder.switch_to_block(previous_block); return_block } - -impl Ssa { - pub(crate) fn defunctionalize(self) -> Ssa { - DefunctionalizationContext::defunctionalize_ssa(self) - } -} From cf86a38282bc78ebaaa78c890a639062262458b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Thu, 6 Jul 2023 16:29:00 +0200 Subject: [PATCH 11/20] Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher --- crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 5cdb2c6ca92..1a7d4825089 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -40,7 +40,6 @@ impl FunctionSignature { /// Describes an apply function existing in the SSA #[derive(Debug, Clone, Copy)] struct ApplyFunction { - // The function id of the apply function id: FunctionId, // Whether the apply function dispatches to other functions or not is_multiple: bool, From 949f60be1df505bbe5540bea5e93ca89650af162 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Thu, 6 Jul 2023 16:32:55 +0200 Subject: [PATCH 12/20] Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher --- .../src/ssa_refactor/opt/defunctionalize.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 1a7d4825089..03fef05a9a6 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -37,7 +37,21 @@ impl FunctionSignature { } } -/// Describes an apply function existing in the SSA +/// Represents an 'apply' function created by this pass to dispatch higher order functions to. +/// Pseudocode of an `apply` function is given below: +/// ``` +/// fn apply(function_id: Field, arg1: Field, arg2: Field) -> Field { +/// match function_id { +/// 0 -> function0(arg1, arg2), +/// 1 -> function0(arg1, arg2), +/// ... +/// N -> functionN(arg1, arg2), +/// } +/// } +/// ``` +/// Apply functions generally take the function to apply as their first parameter. This is a Field value +/// obtained by converting the FunctionId into a Field. The remaining parameters of apply are the +/// arguments to forward to this function when calling it internally. #[derive(Debug, Clone, Copy)] struct ApplyFunction { id: FunctionId, From 66513e68036169b4e7087b1298b740dbb9b091c8 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 6 Jul 2023 14:34:15 +0000 Subject: [PATCH 13/20] style: rename field --- .../src/ssa_refactor/opt/defunctionalize.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 1a7d4825089..3158dcaf5ef 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -41,8 +41,7 @@ impl FunctionSignature { #[derive(Debug, Clone, Copy)] struct ApplyFunction { id: FunctionId, - // Whether the apply function dispatches to other functions or not - is_multiple: bool, + dispatches_to_multiple_functions: bool, } /// Performs defunctionalization on all functions @@ -72,7 +71,6 @@ impl Ssa { } impl DefunctionalizationContext { - /// Defunctionalize all functions in the Ssa fn defunctionalize_all(mut self, ssa: &mut Ssa) { for function in ssa.functions.values_mut() { @@ -121,7 +119,7 @@ impl DefunctionalizationContext { // Replace the instruction with a call to apply let apply_function_value_id = func.dfg.import_function(apply_function.id); - if apply_function.is_multiple { + if apply_function.dispatches_to_multiple_functions { let mut new_arguments = vec![target_func_id]; new_arguments.extend(arguments); replacement_instruction = Some(Instruction::Call { @@ -248,11 +246,15 @@ fn create_apply_functions( for (signature, variants) in variants_map.iter() { if variants.len() > 1 { let apply_function = create_apply_function(ssa, signature, variants); - apply_functions - .insert(signature.clone(), ApplyFunction { id: apply_function, is_multiple: true }); + apply_functions.insert( + signature.clone(), + ApplyFunction { id: apply_function, dispatches_to_multiple_functions: true }, + ); } else { - apply_functions - .insert(signature.clone(), ApplyFunction { id: variants[0], is_multiple: false }); + apply_functions.insert( + signature.clone(), + ApplyFunction { id: variants[0], dispatches_to_multiple_functions: false }, + ); } } apply_functions From 7728c65f0e0d3660d9cac1efbf173fe338d69ce5 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 6 Jul 2023 14:50:06 +0000 Subject: [PATCH 14/20] docs: fixed doc to avoid doctest --- crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs | 2 +- cspell.json | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index c327a2e7f46..a2e352ba7a1 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -39,7 +39,7 @@ impl FunctionSignature { /// Represents an 'apply' function created by this pass to dispatch higher order functions to. /// Pseudocode of an `apply` function is given below: -/// ``` +/// ```text /// fn apply(function_id: Field, arg1: Field, arg2: Field) -> Field { /// match function_id { /// 0 -> function0(arg1, arg2), diff --git a/cspell.json b/cspell.json index c067e3880e7..92c3154f2b3 100644 --- a/cspell.json +++ b/cspell.json @@ -48,6 +48,7 @@ "pedersen", "peekable", "preprocess", + "pseudocode", "schnorr", "sdiv", "signedness", From a6f683e6c9dbf3910a948dd4ddd9fcef02235f7a Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 6 Jul 2023 15:07:28 +0000 Subject: [PATCH 15/20] style: addressed pr comments --- .../src/ssa_refactor/opt/defunctionalize.rs | 50 +++++++------------ 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index a2e352ba7a1..9232d46a930 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -258,18 +258,14 @@ fn create_apply_functions( ) -> HashMap { let mut apply_functions = HashMap::new(); for (signature, variants) in variants_map.iter() { - if variants.len() > 1 { - let apply_function = create_apply_function(ssa, signature, variants); - apply_functions.insert( - signature.clone(), - ApplyFunction { id: apply_function, dispatches_to_multiple_functions: true }, - ); + let dispatches_to_multiple_functions = variants.len() > 1; + let id = if dispatches_to_multiple_functions { + create_apply_function(ssa, signature, variants) } else { - apply_functions.insert( - signature.clone(), - ApplyFunction { id: variants[0], dispatches_to_multiple_functions: false }, - ); - } + variants[0] + }; + apply_functions + .insert(signature.clone(), ApplyFunction { id, dispatches_to_multiple_functions }); } apply_functions } @@ -287,7 +283,7 @@ fn create_apply_function( assert!(!function_ids.is_empty()); ssa.add_fn(|id| { let mut function_builder = FunctionBuilder::new("apply".to_string(), id, signature.runtime); - let target_id = function_builder.add_parameter(Type::Numeric(NumericType::NativeField)); + let target_id = function_builder.add_parameter(Type::field()); let params_ids = vecmap(signature.parameters.clone(), |typ| function_builder.add_parameter(typ)); @@ -303,7 +299,7 @@ fn create_apply_function( let condition = function_builder.insert_binary(target_id, BinaryOp::Eq, function_id_constant); - // If it's not the last function to dispatch, crate an if statement + // If it's not the last function to dispatch, create an if statement if !is_last { next_function_block = Some(function_builder.insert_block()); let executor_block = function_builder.insert_block(); @@ -319,26 +315,14 @@ fn create_apply_function( function_builder.insert_constrain(condition); } // Find the target block or build it if necessary - let target_block = match previous_target_block { - Some(block) => { - let current_block = function_builder.current_block(); - build_return_block( - &mut function_builder, - current_block, - signature.returns.clone(), - Some(block), - ) - } - None => { - let current_block = function_builder.current_block(); - build_return_block( - &mut function_builder, - current_block, - signature.returns.clone(), - None, - ) - } - }; + let current_block = function_builder.current_block(); + + let target_block = build_return_block( + &mut function_builder, + current_block, + signature.returns.clone(), + previous_target_block, + ); previous_target_block = Some(target_block); // Call the function From 6f014bdc5ba526727c63d79af280997d87480f9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Thu, 6 Jul 2023 17:57:55 +0200 Subject: [PATCH 16/20] Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher --- crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 9232d46a930..35ad88e7f2e 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -115,8 +115,7 @@ impl DefunctionalizationContext { // If the target is a function used as value Value::Param { .. } | Value::Instruction { .. } => { // Collect the argument types - let argument_types: Vec = - arguments.iter().map(|arg| func.dfg.type_of_value(*arg)).collect(); + let argument_types = vecmap(&arguments, |arg| func.dfg.type_of_value(*arg)); // Collect the result types let result_types = From 6be4cebd71b760fc3a107a11cc9f58792b4ff5f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Thu, 6 Jul 2023 17:58:23 +0200 Subject: [PATCH 17/20] Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher --- .../src/ssa_refactor/opt/defunctionalize.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 35ad88e7f2e..02481ea2ac7 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -133,18 +133,10 @@ impl DefunctionalizationContext { // Replace the instruction with a call to apply let apply_function_value_id = func.dfg.import_function(apply_function.id); if apply_function.dispatches_to_multiple_functions { - let mut new_arguments = vec![target_func_id]; - new_arguments.extend(arguments); - replacement_instruction = Some(Instruction::Call { - func: apply_function_value_id, - arguments: new_arguments, - }); - } else { - replacement_instruction = Some(Instruction::Call { - func: apply_function_value_id, - arguments, - }); + arguments.insert(0, target_func_id); } + let func = apply_function_value_id; + replacement_instruction = Some(Instruction::Call { func, arguments }); } Value::Function(id) => { target_function_ids.insert(id); From 95109750dcb3417514c412f038bab74268f1c20b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Thu, 6 Jul 2023 17:58:44 +0200 Subject: [PATCH 18/20] Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher --- crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index 02481ea2ac7..ce2819d160e 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -150,7 +150,7 @@ impl DefunctionalizationContext { } // Change the type of all the values that are not call targets to NativeField - let value_ids: Vec = func.dfg.values_iter().map(|(id, _)| id).collect(); + let value_ids = vecmap(func.dfg.values_iter(), |(id, _)| id); for value_id in value_ids { let value = &func.dfg[value_id]; if let Type::Function = value.get_type() { From 35edb84a539134d8e2de92dcfc9c6db9fafec232 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 6 Jul 2023 17:16:24 +0000 Subject: [PATCH 19/20] refactor: extract set type of value to the dfg --- .../src/ssa_refactor/ir/dfg.rs | 15 +++++++ .../src/ssa_refactor/opt/defunctionalize.rs | 42 ++++++------------- .../src/ssa_refactor/ssa_gen/program.rs | 10 ----- 3 files changed, 27 insertions(+), 40 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 56fd918c910..922dad17c6d 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -174,6 +174,21 @@ impl DataFlowGraph { } } + /// Set the type of value_id to the target_type. + pub(crate) fn set_type_of_value(&mut self, value_id: ValueId, target_type: Type) { + let value = &mut self.values[value_id]; + match value { + Value::Instruction { typ, .. } + | Value::Param { typ, .. } + | Value::NumericConstant { typ, .. } => { + *typ = target_type; + } + _ => { + unreachable!("ICE: Cannot set type of {:?}", value); + } + } + } + /// If `original_value_id`'s underlying `Value` has been substituted for that of another /// `ValueId`, this function will return the `ValueId` from which the substitution was taken. /// If `original_value_id`'s underlying `Value` has not been substituted, the same `ValueId` diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs index ce2819d160e..c31d0c58deb 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -15,7 +15,7 @@ use crate::ssa_refactor::{ function::{Function, FunctionId, RuntimeType}, instruction::{BinaryOp, Instruction}, types::{NumericType, Type}, - value::{Value, ValueId}, + value::Value, }, ssa_builder::FunctionBuilder, ssa_gen::Ssa, @@ -104,7 +104,7 @@ impl DefunctionalizationContext { let instruction = func.dfg[instruction_id].clone(); let mut replacement_instruction = None; // Operate on call instructions - let (target_func_id, arguments) = match instruction { + let (target_func_id, mut arguments) = match instruction { Instruction::Call { func: target_func_id, arguments } => { (target_func_id, arguments) } @@ -152,40 +152,22 @@ impl DefunctionalizationContext { // Change the type of all the values that are not call targets to NativeField let value_ids = vecmap(func.dfg.values_iter(), |(id, _)| id); for value_id in value_ids { - let value = &func.dfg[value_id]; - if let Type::Function = value.get_type() { - // If the value is a static function, transform it to the function id - let mut replacement_value_id = None; - - match value { + if let Type::Function = &func.dfg[value_id].get_type() { + match &func.dfg[value_id] { + // If the value is a static function, transform it to the function id Value::Function(id) => { if !target_function_ids.contains(id) { - replacement_value_id = Some(func.dfg.make_constant( - function_id_to_field(*id), - Type::Numeric(NumericType::NativeField), - )); + let new_value = + func.dfg.make_constant(function_id_to_field(*id), Type::field()); + func.dfg.set_value_from_id(value_id, new_value); } } - Value::Instruction { instruction, position, .. } => { - replacement_value_id = Some(func.dfg.make_value(Value::Instruction { - instruction: *instruction, - position: *position, - typ: Type::Numeric(NumericType::NativeField), - })); - } - Value::Param { block, position, .. } => { - replacement_value_id = Some(func.dfg.make_value(Value::Param { - block: *block, - position: *position, - typ: Type::Numeric(NumericType::NativeField), - })); + // If the value is a function used as value, just change the type of it + Value::Instruction { .. } | Value::Param { .. } => { + func.dfg.set_type_of_value(value_id, Type::field()); } _ => {} } - - if let Some(new_value_id) = replacement_value_id { - func.dfg.set_value_from_id(value_id, new_value_id); - } } } } @@ -206,7 +188,7 @@ fn find_variants(ssa: &Ssa) -> HashMap> { } for function_id in functions_used_as_values { - let function = ssa.get_fn(function_id); + let function = &ssa.functions[&function_id]; let signature = FunctionSignature::from(function); variants.entry(signature).or_default().push(function_id); } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs index 0c1d337cc1e..aec0e4262c8 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs @@ -34,21 +34,11 @@ impl Ssa { &self.functions[&self.main_id] } - /// Returns the function with the given ID - pub(crate) fn get_fn(&self, id: FunctionId) -> &Function { - self.functions.get(&id).expect("ICE: Could not find function") - } - /// Returns the entry-point function of the program as a mutable reference pub(crate) fn main_mut(&mut self) -> &mut Function { self.functions.get_mut(&self.main_id).expect("ICE: Ssa should have a main function") } - /// Returns the function with the given ID as a mutable reference - pub(crate) fn get_fn_mut(&mut self, id: FunctionId) -> &mut Function { - self.functions.get_mut(&id).expect("ICE: Could not find function") - } - /// Adds a new function to the program pub(crate) fn add_fn( &mut self, From b31edb1a5eea204167c79be3a5c2b6eef0eefde9 Mon Sep 17 00:00:00 2001 From: jfecher Date: Fri, 7 Jul 2023 10:01:02 +0200 Subject: [PATCH 20/20] Update crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs --- crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 922dad17c6d..9104b65d16f 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -109,7 +109,7 @@ impl DataFlowGraph { self.blocks.iter() } - /// Gets a vec of the value ids of the function + /// Iterate over every Value in this DFG in no particular order, including unused Values pub(crate) fn values_iter(&self) -> impl ExactSizeIterator { self.values.iter() }