diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs index a6d7ae157d3..b45fb58dbaa 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs @@ -512,6 +512,7 @@ impl AcirContext { lhs: AcirVar, rhs: AcirVar, bit_size: u32, + predicate: Option, ) -> Result { let lhs_data = &self.vars[lhs]; let rhs_data = &self.vars[rhs]; @@ -521,8 +522,13 @@ impl AcirContext { // TODO: check what happens when we do (a as u8) >= (b as u32) // TODO: The frontend should shout in this case + + let predicate = predicate.map(|acir_var| { + let predicate_data = &self.vars[acir_var]; + predicate_data.to_expression().into_owned() + }); let is_greater_than_eq = - self.acir_ir.more_than_eq_comparison(&lhs_expr, &rhs_expr, bit_size)?; + self.acir_ir.more_than_eq_comparison(&lhs_expr, &rhs_expr, bit_size, predicate)?; Ok(self.add_data(AcirVarData::Witness(is_greater_than_eq))) } @@ -534,10 +540,11 @@ impl AcirContext { lhs: AcirVar, rhs: AcirVar, bit_size: u32, + predicate: Option, ) -> Result { // Flip the result of calling more than equal method to // compute less than. - let comparison = self.more_than_eq_var(lhs, rhs, bit_size)?; + let comparison = self.more_than_eq_var(lhs, rhs, bit_size, predicate)?; let one = self.add_constant(FieldElement::one()); self.sub_var(one, comparison) // comparison_negated diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs index 5b7963b5f06..ab3ad8abd8b 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs @@ -570,6 +570,7 @@ impl GeneratedAcir { a: &Expression, b: &Expression, max_bits: u32, + predicate: Option, ) -> Result { // Ensure that 2^{max_bits + 1} is less than the field size // @@ -596,7 +597,7 @@ impl GeneratedAcir { b: Expression::from_field(two_max_bits), q: q_witness, r: r_witness, - predicate: None, + predicate, }))); // Add constraint to ensure `r` is correctly bounded diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs index 537b08ddd93..c34fc67b69d 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs @@ -39,6 +39,10 @@ struct Context { /// already exists for this Value, we return the `AcirVar`. ssa_values: HashMap, AcirValue>, + /// The `AcirVar` that describes the condition belonging to the most recently invoked + /// `SideEffectsEnabled` instruction. + current_side_effects_enabled_var: Option, + /// Manages and builds the `AcirVar`s to which the converted SSA values refer. acir_context: AcirContext, } @@ -218,6 +222,10 @@ impl Context { self.define_result_var(dfg, instruction_id, result_acir_var); } + Instruction::EnableSideEffects { condition } => { + let acir_var = self.convert_numeric_value(*condition, dfg); + self.current_side_effects_enabled_var = Some(acir_var); + } Instruction::ArrayGet { array, index } => { self.handle_array_operation(instruction_id, *array, *index, None, dfg); } @@ -405,7 +413,12 @@ impl Context { // Note: that this produces unnecessary constraints when // this Eq instruction is being used for a constrain statement BinaryOp::Eq => self.acir_context.eq_var(lhs, rhs), - BinaryOp::Lt => self.acir_context.less_than_var(lhs, rhs, bit_count), + BinaryOp::Lt => self.acir_context.less_than_var( + lhs, + rhs, + bit_count, + self.current_side_effects_enabled_var, + ), BinaryOp::Shl => self.acir_context.shift_left_var(lhs, rhs, binary_type), BinaryOp::Shr => self.acir_context.shift_right_var(lhs, rhs, binary_type), BinaryOp::Xor => self.acir_context.xor_var(lhs, rhs, binary_type), diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index a9767bc3777..38726c14619 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -103,6 +103,15 @@ pub(crate) enum Instruction { /// Writes a value to memory. Store { address: ValueId, value: ValueId }, + /// Provides a context for all instructions that follow up until the next + /// `EnableSideEffects` is encountered, for stating a condition that determines whether + /// such instructions are allowed to have side-effects. + /// + /// This instruction is only emitted after the cfg flattening pass, and is used to annotate + /// instruction regions with an condition that corresponds to their position in the CFG's + /// if-branching structure. + EnableSideEffects { condition: ValueId }, + /// Retrieve a value from an array at the given index ArrayGet { array: ValueId, index: ValueId }, @@ -127,7 +136,9 @@ impl Instruction { InstructionResultType::Operand(*value) } Instruction::ArraySet { array, .. } => InstructionResultType::Operand(*array), - Instruction::Constrain(_) | Instruction::Store { .. } => InstructionResultType::None, + Instruction::Constrain(_) + | Instruction::Store { .. } + | Instruction::EnableSideEffects { .. } => InstructionResultType::None, Instruction::Load { .. } | Instruction::ArrayGet { .. } | Instruction::Call { .. } => { InstructionResultType::Unknown } @@ -167,6 +178,9 @@ impl Instruction { Instruction::Store { address, value } => { Instruction::Store { address: f(*address), value: f(*value) } } + Instruction::EnableSideEffects { condition } => { + Instruction::EnableSideEffects { condition: f(*condition) } + } Instruction::ArrayGet { array, index } => { Instruction::ArrayGet { array: f(*array), index: f(*index) } } @@ -248,6 +262,7 @@ impl Instruction { Instruction::Allocate { .. } => None, Instruction::Load { .. } => None, Instruction::Store { .. } => None, + Instruction::EnableSideEffects { .. } => None, } } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs index 2829c6768b8..9b45a0b272c 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs @@ -153,6 +153,9 @@ pub(crate) fn display_instruction( Instruction::Store { address, value } => { writeln!(f, "store {} at {}", show(*value), show(*address)) } + Instruction::EnableSideEffects { condition } => { + writeln!(f, "enable_side_effects {}", show(*condition)) + } Instruction::ArrayGet { array, index } => { writeln!(f, "array_get {}, index {}", show(*array), show(*index)) } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs index 51ce5601e12..14eb99e5570 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs @@ -10,6 +10,26 @@ //! while merging branches. These extra instructions can be cleaned up by a later dead instruction //! elimination (DIE) pass. //! +//! Though CFG information is lost during this pass, some key information is retained in the form +//! of `EnableSideEffect` instructions. Each time the flattening pass enters and exits a branch of +//! a jmpif, an instruction is inserted to capture a condition that is analogous to the activeness +//! of the program point. For example: +//! +//! b0(v0: u1): +//! jmpif v0, then: b1, else: b2 +//! b1(): +//! v1 = call f0 +//! jmp b3(v1) +//! ... blocks b2 & b3 ... +//! +//! Would brace the call instruction as such: +//! enable_side_effects v0 +//! v1 = call f0 +//! enable_side_effects u1 1 +//! +//! (Note: we restore to "true" to indicate that this program point is not nested within any +//! other branches.) +//! //! When we are flattening a block that was reached via a jmpif with a non-constant condition c, //! the following transformations of certain instructions within the block are expected: //! @@ -288,6 +308,8 @@ impl<'f> Context<'f> { let else_branch = self.inline_branch(block, else_block, old_condition, else_condition, zero); + self.insert_current_side_effects_enabled(); + // While there is a condition on the stack we don't compile outside the condition // until it is popped. This ensures we inline the full then and else branches // before continuing from the end of the conditional here where they can be merged properly. @@ -359,6 +381,20 @@ impl<'f> Context<'f> { self.function.dfg.insert_instruction_and_results(instruction, block, ctrl_typevars) } + /// Checks the branch condition on the top of the stack and uses it to build and insert an + /// `EnableSideEffects` instruction into the entry block. + /// + /// If the stack is empty, a "true" u1 constant is taken to be the active condition. This is + /// necessary for re-enabling side-effects when re-emerging to a branch depth of 0. + fn insert_current_side_effects_enabled(&mut self) { + let condition = match self.conditions.last() { + Some((_, cond)) => *cond, + None => self.function.dfg.make_constant(FieldElement::one(), Type::unsigned(1)), + }; + let enable_side_effects = Instruction::EnableSideEffects { condition }; + self.insert_instruction_with_typevars(enable_side_effects, None); + } + /// Merge two values a and b from separate basic blocks to a single value. This /// function would return the result of `if c { a } else { b }` as `c*a + (!c)*b`. fn merge_values( @@ -397,6 +433,7 @@ impl<'f> Context<'f> { condition_value: FieldElement, ) -> Branch { self.push_condition(jmpif_block, new_condition); + self.insert_current_side_effects_enabled(); let old_stores = std::mem::take(&mut self.store_values); // Remember the old condition value is now known to be true/false within this branch @@ -643,11 +680,13 @@ mod test { // Expected output: // fn main f0 { // b0(v0: u1): - // v4 = not v0 - // v5 = mul v0, Field 3 - // v7 = not v0 - // v8 = mul v7, Field 4 - // v9 = add v5, v8 + // enable_side_effects v0 + // v5 = not v0 + // enable_side_effects v5 + // enable_side_effects u1 1 + // v7 = mul v0, Field 3 + // v8 = mul v5, Field 4 + // v9 = add v7, v8 // return v9 // } let ssa = ssa.flatten_cfg(); @@ -686,13 +725,17 @@ mod test { let ssa = builder.finish(); assert_eq!(ssa.main().reachable_blocks().len(), 3); - // Expected output (sans useless extra 'not' instruction): + // Expected output: // fn main f0 { // b0(v0: u1, v1: u1): - // v2 = mul v1, v0 - // v3 = eq v2, v0 - // constrain v3 - // return v1 + // enable_side_effects v0 + // v3 = mul v1, v0 + // v4 = eq v3, v0 + // constrain v4 + // v5 = not v0 + // enable_side_effects v5 + // enable_side_effects u1 1 + // return // } let ssa = ssa.flatten_cfg(); assert_eq!(ssa.main().reachable_blocks().len(), 1); @@ -733,14 +776,16 @@ mod test { // Expected output: // fn main f0 { // b0(v0: u1, v1: reference): + // enable_side_effects v0 // v4 = load v1 // store Field 5 at v1 // v5 = not v0 + // enable_side_effects v5 + // enable_side_effects u1 1 // v7 = mul v0, Field 5 - // v8 = not v0 - // v9 = mul v8, v4 - // v10 = add v7, v9 - // store v10 at v1 + // v8 = mul v5, v4 + // v9 = add v7, v8 + // store v9 at v1 // return // } let ssa = ssa.flatten_cfg(); @@ -807,21 +852,24 @@ mod test { // Expected output: // fn main f0 { // b0(v0: u1, v1: reference): - // v8 = add v1, Field 1 - // v9 = load v8 - // store Field 5 at v8 - // v10 = not v0 - // v12 = add v1, Field 1 - // v13 = load v12 - // store Field 6 at v12 - // v14 = mul v0, Field 5 - // v15 = mul v10, v9 - // v16 = add v14, v15 - // store v16 at v8 - // v17 = mul v0, v13 - // v18 = mul v10, Field 6 - // v19 = add v17, v18 - // store v19 at v12 + // enable_side_effects v0 + // v7 = add v1, Field 1 + // v8 = load v7 + // store Field 5 at v7 + // v9 = not v0 + // enable_side_effects v9 + // v11 = add v1, Field 1 + // v12 = load v11 + // store Field 6 at v11 + // enable_side_effects Field 1 + // v13 = mul v0, Field 5 + // v14 = mul v9, v8 + // v15 = add v13, v14 + // store v15 at v7 + // v16 = mul v0, v12 + // v17 = mul v9, Field 6 + // v18 = add v16, v17 + // store v18 at v11 // return // } let ssa = ssa.flatten_cfg(); @@ -1013,31 +1061,38 @@ mod test { // b0(v0: u1, v1: u1): // call println(Field 0, Field 0) // call println(Field 1, Field 1) + // enable_side_effects v0 // call println(Field 2, Field 2) - // call println(Field 4, Field 2) ; block 4 does not store a value - // v45 = and v0, v1 + // call println(Field 4, Field 2) + // v29 = and v0, v1 + // enable_side_effects v29 // call println(Field 5, Field 5) - // v49 = not v1 - // v50 = and v0, v49 + // v32 = not v1 + // v33 = and v0, v32 + // enable_side_effects v33 // call println(Field 6, Field 6) - // v54 = mul v1, Field 5 - // v55 = mul v49, Field 2 - // v56 = add v54, v55 - // v57 = mul v1, Field 5 - // v58 = mul v49, Field 6 - // v59 = add v57, v58 - // call println(Field 7, v59) ; v59 = 5 and 6 merged - // v61 = not v0 + // enable_side_effects v0 + // v36 = mul v1, Field 5 + // v37 = mul v32, Field 2 + // v38 = add v36, v37 + // v39 = mul v1, Field 5 + // v40 = mul v32, Field 6 + // v41 = add v39, v40 + // call println(Field 7, v42) + // v43 = not v0 + // enable_side_effects v43 + // store Field 3 at v2 // call println(Field 3, Field 3) - // call println(Field 8, Field 3) ; block 8 does not store a value - // v66 = mul v0, v59 - // v67 = mul v61, Field 1 - // v68 = add v66, v67 ; This was from an unused store. - // v69 = mul v0, v59 - // v70 = mul v61, Field 3 - // v71 = add v69, v70 - // call println(Field 9, v71) ; v71 = 3, 5, and 6 merged - // return v71 + // call println(Field 8, Field 3) + // enable_side_effects Field 1 + // v47 = mul v0, v41 + // v48 = mul v43, Field 1 + // v49 = add v47, v48 + // v50 = mul v0, v44 + // v51 = mul v43, Field 3 + // v52 = add v50, v51 + // call println(Field 9, v53) + // return v54 // } let main = ssa.main();