diff --git a/compiler/noirc_evaluator/src/lib.rs b/compiler/noirc_evaluator/src/lib.rs index 3a5825699b2..193cf7bb387 100644 --- a/compiler/noirc_evaluator/src/lib.rs +++ b/compiler/noirc_evaluator/src/lib.rs @@ -28,3 +28,22 @@ pub(crate) fn trim_leading_whitespace_from_lines(src: &str) -> String { } result } + +/// Trim comments from the lines, ie. content starting with `//`. +#[cfg(test)] +pub(crate) fn trim_comments_from_lines(src: &str) -> String { + let mut result = String::new(); + let mut first = true; + for line in src.lines() { + if !first { + result.push('\n'); + } + if let Some(comment) = line.find("//") { + result.push_str(line[..comment].trim_end()); + } else { + result.push_str(line); + } + first = false; + } + result +} diff --git a/compiler/noirc_evaluator/src/ssa/ir/function.rs b/compiler/noirc_evaluator/src/ssa/ir/function.rs index e8245ff6036..b1233e3063e 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -46,6 +46,14 @@ impl RuntimeType { | RuntimeType::Brillig(InlineType::NoPredicates) ) } + + pub(crate) fn is_brillig(&self) -> bool { + matches!(self, RuntimeType::Brillig(_)) + } + + pub(crate) fn is_acir(&self) -> bool { + matches!(self, RuntimeType::Acir(_)) + } } /// A function holds a list of instructions. diff --git a/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs b/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs index 708b02b9102..5e133072067 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs @@ -71,6 +71,7 @@ impl<'f> FunctionInserter<'f> { } } + /// Get an instruction and make sure all the values in it are freshly resolved. pub(crate) fn map_instruction(&mut self, id: InstructionId) -> (Instruction, CallStack) { ( self.function.dfg[id].clone().map_values(|id| self.resolve(id)), diff --git a/compiler/noirc_evaluator/src/ssa/ir/post_order.rs b/compiler/noirc_evaluator/src/ssa/ir/post_order.rs index 94ff96ba1d7..398ce887b96 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/post_order.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/post_order.rs @@ -22,7 +22,7 @@ impl PostOrder { } impl PostOrder { - /// Allocate and compute a function's block post-order. Pos + /// Allocate and compute a function's block post-order. pub(crate) fn with_function(func: &Function) -> Self { PostOrder(Self::compute_post_order(func)) } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index 098f62bceba..10e86c6601a 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -35,13 +35,14 @@ pub(crate) fn assert_normalized_ssa_equals(mut ssa: super::Ssa, expected: &str) panic!("`expected` argument of `assert_ssa_equals` is not valid SSA:\n{:?}", err); } - use crate::{ssa::Ssa, trim_leading_whitespace_from_lines}; + use crate::{ssa::Ssa, trim_comments_from_lines, trim_leading_whitespace_from_lines}; ssa.normalize_ids(); let ssa = ssa.to_string(); let ssa = trim_leading_whitespace_from_lines(&ssa); let expected = trim_leading_whitespace_from_lines(expected); + let expected = trim_comments_from_lines(&expected); if ssa != expected { println!("Expected:\n~~~\n{expected}\n~~~\nGot:\n~~~\n{ssa}\n~~~"); diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 3ff0e630a69..89f1b2b2d7d 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -7,16 +7,20 @@ //! b. If we have previously modified any of the blocks in the loop, //! restart from step 1 to refresh the context. //! c. If not, try to unroll the loop. If successful, remember the modified -//! blocks. If unsuccessfully either error if the abort_on_error flag is set, +//! blocks. If unsuccessful either error if the abort_on_error flag is set, //! or otherwise remember that the loop failed to unroll and leave it unmodified. //! //! Note that this pass also often creates superfluous jmp instructions in the -//! program that will need to be removed by a later simplify cfg pass. -//! Note also that unrolling is skipped for Brillig runtime and as a result -//! we remove reference count instructions because they are only used by Brillig bytecode +//! program that will need to be removed by a later simplify CFG pass. +//! +//! Note also that unrolling is skipped for Brillig runtime, unless the loops are deemed +//! sufficiently small that inlining can be done without increasing the bytecode. +//! +//! When unrolling ACIR code, we remove reference count instructions because they are +//! only used by Brillig bytecode. use std::collections::HashSet; -use acvm::acir::AcirField; +use acvm::{acir::AcirField, FieldElement}; use crate::{ errors::RuntimeError, @@ -26,9 +30,9 @@ use crate::{ cfg::ControlFlowGraph, dfg::{CallStack, DataFlowGraph}, dom::DominatorTree, - function::{Function, RuntimeType}, + function::Function, function_inserter::{ArrayCache, FunctionInserter}, - instruction::{Instruction, TerminatorInstruction}, + instruction::{Binary, BinaryOp, Instruction, InstructionId, TerminatorInstruction}, post_order::PostOrder, value::ValueId, }, @@ -42,16 +46,9 @@ impl Ssa { /// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found. #[tracing::instrument(level = "trace", skip(ssa))] pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result { - let acir_functions = ssa.functions.iter_mut().filter(|(_, func)| { - // Loop unrolling in brillig can lead to a code explosion currently. This can - // also be true for ACIR, but we have no alternative to unrolling in ACIR. - // Brillig also generally prefers smaller code rather than faster code. - !matches!(func.runtime(), RuntimeType::Brillig(_)) - }); - - for (_, function) in acir_functions { + for (_, function) in ssa.functions.iter_mut() { // Try to unroll loops first: - let mut unroll_errors = function.try_to_unroll_loops(); + let mut unroll_errors = function.try_unroll_loops(); // Keep unrolling until no more errors are found while !unroll_errors.is_empty() { @@ -66,21 +63,24 @@ impl Ssa { function.mem2reg(); // Unroll again - unroll_errors = function.try_to_unroll_loops(); + unroll_errors = function.try_unroll_loops(); // If we didn't manage to unroll any more loops, exit if unroll_errors.len() >= prev_unroll_err_count { return Err(unroll_errors.swap_remove(0)); } } } - Ok(ssa) } } impl Function { - fn try_to_unroll_loops(&mut self) -> Vec { - find_all_loops(self).unroll_each_loop(self) + // Loop unrolling in brillig can lead to a code explosion currently. + // This can also be true for ACIR, but we have no alternative to unrolling in ACIR. + // Brillig also generally prefers smaller code rather than faster code, + // so we only attempt to unroll small loops, which we decide on a case-by-case basis. + fn try_unroll_loops(&mut self) -> Vec { + Loops::find_all(self).unroll_each(self) } } @@ -94,7 +94,7 @@ struct Loop { back_edge_start: BasicBlockId, /// All the blocks contained within the loop, including `header` and `back_edge_start`. - pub(crate) blocks: HashSet, + blocks: HashSet, } struct Loops { @@ -107,60 +107,88 @@ struct Loops { cfg: ControlFlowGraph, } -/// Find a loop in the program by finding a node that dominates any predecessor node. -/// The edge where this happens will be the back-edge of the loop. -fn find_all_loops(function: &Function) -> Loops { - let cfg = ControlFlowGraph::with_function(function); - let post_order = PostOrder::with_function(function); - let mut dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order); - - let mut loops = vec![]; - - for (block, _) in function.dfg.basic_blocks_iter() { - // These reachable checks wouldn't be needed if we only iterated over reachable blocks - if dom_tree.is_reachable(block) { - for predecessor in cfg.predecessors(block) { - if dom_tree.is_reachable(predecessor) && dom_tree.dominates(block, predecessor) { - // predecessor -> block is the back-edge of a loop - loops.push(find_blocks_in_loop(block, predecessor, &cfg)); +impl Loops { + /// Find a loop in the program by finding a node that dominates any predecessor node. + /// The edge where this happens will be the back-edge of the loop. + /// + /// For example consider the following SSA of a basic loop: + /// ```text + /// main(): + /// v0 = ... start ... + /// v1 = ... end ... + /// jmp loop_entry(v0) + /// loop_entry(i: Field): + /// v2 = lt i v1 + /// jmpif v2, then: loop_body, else: loop_end + /// loop_body(): + /// v3 = ... body ... + /// v4 = add 1, i + /// jmp loop_entry(v4) + /// loop_end(): + /// ``` + /// + /// The CFG will look something like this: + /// ```text + /// main + /// ↓ + /// loop_entry ←---↰ + /// ↓ ↘ | + /// loop_end loop_body + /// ``` + /// `loop_entry` has two predecessors: `main` and `loop_body`, and it dominates `loop_body`. + fn find_all(function: &Function) -> Self { + let cfg = ControlFlowGraph::with_function(function); + let post_order = PostOrder::with_function(function); + let mut dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order); + + let mut loops = vec![]; + + for (block, _) in function.dfg.basic_blocks_iter() { + // These reachable checks wouldn't be needed if we only iterated over reachable blocks + if dom_tree.is_reachable(block) { + for predecessor in cfg.predecessors(block) { + // In the above example, we're looking for when `block` is `loop_entry` and `predecessor` is `loop_body`. + if dom_tree.is_reachable(predecessor) && dom_tree.dominates(block, predecessor) + { + // predecessor -> block is the back-edge of a loop + loops.push(Loop::find_blocks_in_loop(block, predecessor, &cfg)); + } } } } - } - // Sort loops by block size so that we unroll the larger, outer loops of nested loops first. - // This is needed because inner loops may use the induction variable from their outer loops in - // their loop range. - loops.sort_by_key(|loop_| loop_.blocks.len()); + // Sort loops by block size so that we unroll the larger, outer loops of nested loops first. + // This is needed because inner loops may use the induction variable from their outer loops in + // their loop range. We will start popping loops from the back. + loops.sort_by_key(|loop_| loop_.blocks.len()); - Loops { - failed_to_unroll: HashSet::new(), - yet_to_unroll: loops, - modified_blocks: HashSet::new(), - cfg, + Self { + failed_to_unroll: HashSet::new(), + yet_to_unroll: loops, + modified_blocks: HashSet::new(), + cfg, + } } -} -impl Loops { /// Unroll all loops within a given function. /// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified. - fn unroll_each_loop(mut self, function: &mut Function) -> Vec { + fn unroll_each(mut self, function: &mut Function) -> Vec { let mut unroll_errors = vec![]; while let Some(next_loop) = self.yet_to_unroll.pop() { + if function.runtime().is_brillig() && !next_loop.is_small_loop(function, &self.cfg) { + continue; + } // If we've previously modified a block in this loop we need to refresh the context. // This happens any time we have nested loops. if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) { - let mut new_context = find_all_loops(function); - new_context.failed_to_unroll = self.failed_to_unroll; - return unroll_errors - .into_iter() - .chain(new_context.unroll_each_loop(function)) - .collect(); + let mut new_loops = Self::find_all(function); + new_loops.failed_to_unroll = self.failed_to_unroll; + return unroll_errors.into_iter().chain(new_loops.unroll_each(function)).collect(); } // Don't try to unroll the loop again if it is known to fail if !self.failed_to_unroll.contains(&next_loop.header) { - match unroll_loop(function, &self.cfg, &next_loop) { + match next_loop.unroll(function, &self.cfg) { Ok(_) => self.modified_blocks.extend(next_loop.blocks), Err(call_stack) => { self.failed_to_unroll.insert(next_loop.header); @@ -173,83 +201,522 @@ impl Loops { } } -/// Return each block that is in a loop starting in the given header block. -/// Expects back_edge_start -> header to be the back edge of the loop. -fn find_blocks_in_loop( - header: BasicBlockId, - back_edge_start: BasicBlockId, - cfg: &ControlFlowGraph, -) -> Loop { - let mut blocks = HashSet::new(); - blocks.insert(header); - - let mut insert = |block, stack: &mut Vec| { - if !blocks.contains(&block) { - blocks.insert(block); - stack.push(block); +impl Loop { + /// Return each block that is in a loop starting in the given header block. + /// Expects back_edge_start -> header to be the back edge of the loop. + fn find_blocks_in_loop( + header: BasicBlockId, + back_edge_start: BasicBlockId, + cfg: &ControlFlowGraph, + ) -> Self { + let mut blocks = HashSet::new(); + blocks.insert(header); + + let mut insert = |block, stack: &mut Vec| { + if !blocks.contains(&block) { + blocks.insert(block); + stack.push(block); + } + }; + + // Starting from the back edge of the loop, each predecessor of this block until + // the header is within the loop. + let mut stack = vec![]; + insert(back_edge_start, &mut stack); + + while let Some(block) = stack.pop() { + for predecessor in cfg.predecessors(block) { + insert(predecessor, &mut stack); + } + } + + Self { header, back_edge_start, blocks } + } + + /// Find the lower bound of the loop in the pre-header and return it + /// if it's a numeric constant, which it will be if the previous SSA + /// steps managed to inline it. + /// + /// Consider the following example of a `for i in 0..4` loop: + /// ```text + /// brillig(inline) fn main f0 { + /// b0(v0: u32): // Pre-header + /// ... + /// jmp b1(u32 0) // Lower-bound + /// b1(v1: u32): // Induction variable + /// v5 = lt v1, u32 4 + /// jmpif v5 then: b3, else: b2 + /// ``` + fn get_const_lower_bound( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Result, CallStack> { + let pre_header = self.get_pre_header(function, cfg)?; + let jump_value = get_induction_variable(function, pre_header)?; + Ok(function.dfg.get_numeric_constant(jump_value)) + } + + /// Find the upper bound of the loop in the loop header and return it + /// if it's a numeric constant, which it will be if the previous SSA + /// steps managed to inline it. + /// + /// Consider the following example of a `for i in 0..4` loop: + /// ```text + /// brillig(inline) fn main f0 { + /// b0(v0: u32): + /// ... + /// jmp b1(u32 0) + /// b1(v1: u32): // Loop header + /// v5 = lt v1, u32 4 // Upper bound + /// jmpif v5 then: b3, else: b2 + /// ``` + fn get_const_upper_bound(&self, function: &Function) -> Option { + let block = &function.dfg[self.header]; + let instructions = block.instructions(); + assert_eq!( + instructions.len(), + 1, + "The header should just compare the induction variable and jump" + ); + match &function.dfg[instructions[0]] { + Instruction::Binary(Binary { lhs: _, operator: BinaryOp::Lt, rhs }) => { + function.dfg.get_numeric_constant(*rhs) + } + Instruction::Binary(Binary { lhs: _, operator: BinaryOp::Eq, rhs }) => { + // `for i in 0..1` is turned into: + // b1(v0: u32): + // v12 = eq v0, u32 0 + // jmpif v12 then: b3, else: b2 + function.dfg.get_numeric_constant(*rhs).map(|c| c + FieldElement::one()) + } + other => panic!("Unexpected instruction in header: {other:?}"), } - }; + } - // Starting from the back edge of the loop, each predecessor of this block until - // the header is within the loop. - let mut stack = vec![]; - insert(back_edge_start, &mut stack); + /// Get the lower and upper bounds of the loop if both are constant numeric values. + fn get_const_bounds( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Result, CallStack> { + let Some(lower) = self.get_const_lower_bound(function, cfg)? else { + return Ok(None); + }; + let Some(upper) = self.get_const_upper_bound(function) else { + return Ok(None); + }; + Ok(Some((lower, upper))) + } - while let Some(block) = stack.pop() { - for predecessor in cfg.predecessors(block) { - insert(predecessor, &mut stack); + /// Unroll a single loop in the function. + /// Returns Ok(()) if it succeeded, Err(callstack) if it failed, + /// where the callstack indicates the location of the instruction + /// that could not be processed, or empty if such information was + /// not available. + /// + /// Consider this example: + /// ```text + /// main(): + /// v0 = 0 + /// v1 = 2 + /// jmp loop_entry(v0) + /// loop_entry(i: Field): + /// v2 = lt i v1 + /// jmpif v2, then: loop_body, else: loop_end + /// ``` + /// + /// The first step is to unroll the header by recognizing that jump condition + /// is a constant, which means it will go to `loop_body`: + /// ```text + /// main(): + /// v0 = 0 + /// v1 = 2 + /// v2 = lt v0 v1 + /// // jmpif v2, then: loop_body, else: loop_end + /// jmp dest: loop_body + /// ``` + /// + /// Following that we unroll the loop body, which is the next source, replace + /// the induction variable with the new value created in the body, and have + /// another go at the header. + /// ```text + /// main(): + /// v0 = 0 + /// v1 = 2 + /// v2 = lt v0 v1 + /// v3 = ... body ... + /// v4 = add 1, 0 + /// jmp loop_entry(v4) + /// ``` + /// + /// At the end we reach a point where the condition evaluates to 0 and we jump to the end. + /// ```text + /// main(): + /// v0 = 0 + /// v1 = 2 + /// v2 = lt 0 + /// v3 = ... body ... + /// v4 = add 1, v0 + /// v5 = lt v4 v1 + /// v6 = ... body ... + /// v7 = add v4, 1 + /// v8 = lt v5 v1 + /// jmp loop_end + /// ``` + /// + /// When e.g. `v8 = lt v5 v1` cannot be evaluated to a constant, the loop signals by returning `Err` + /// that a few SSA passes are required to evaluate and simplify these values. + fn unroll(&self, function: &mut Function, cfg: &ControlFlowGraph) -> Result<(), CallStack> { + let mut unroll_into = self.get_pre_header(function, cfg)?; + let mut jump_value = get_induction_variable(function, unroll_into)?; + let mut array_cache = Some(ArrayCache::default()); + + while let Some(mut context) = self.unroll_header(function, unroll_into, jump_value)? { + // The inserter's array cache must be explicitly enabled. This is to + // confirm that we're inserting in insertion order. This is true here since: + // 1. We have a fresh inserter for each loop + // 2. Each loop is unrolled in iteration order + // + // Within a loop we do not insert in insertion order. This is fine however since the + // array cache is buffered with a separate fresh_array_cache which collects arrays + // but does not deduplicate. When we later call `into_array_cache`, that will merge + // the fresh cache in with the old one so that each iteration of the loop can cache + // from previous iterations but not the current iteration. + context.inserter.set_array_cache(array_cache, unroll_into); + (unroll_into, jump_value, array_cache) = context.unroll_loop_iteration(); } + + Ok(()) + } + + /// The loop pre-header is the block that comes before the loop begins. Generally a header block + /// is expected to have 2 predecessors: the pre-header and the final block of the loop which jumps + /// back to the beginning. Other predecessors can come from `break` or `continue`. + fn get_pre_header( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Result { + let mut pre_header = cfg + .predecessors(self.header) + .filter(|predecessor| *predecessor != self.back_edge_start) + .collect::>(); + + if function.runtime().is_acir() { + assert_eq!(pre_header.len(), 1); + Ok(pre_header.remove(0)) + } else if pre_header.len() == 1 { + Ok(pre_header.remove(0)) + } else { + // We can come back into the header from multiple blocks, so we can't unroll this. + Err(CallStack::new()) + } + } + + /// Unrolls the header block of the loop. This is the block that dominates all other blocks in the + /// loop and contains the jmpif instruction that lets us know if we should continue looping. + /// Returns Some(iteration context) if we should perform another iteration. + fn unroll_header<'a>( + &'a self, + function: &'a mut Function, + unroll_into: BasicBlockId, + induction_value: ValueId, + ) -> Result>, CallStack> { + // We insert into a fresh block first and move instructions into the unroll_into block later + // only once we verify the jmpif instruction has a constant condition. If it does not, we can + // just discard this fresh block and leave the loop unmodified. + let fresh_block = function.dfg.make_block(); + + let mut context = LoopIteration::new(function, self, fresh_block, self.header); + let source_block = &context.dfg()[context.source_block]; + assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header"); + + // Insert the current value of the loop induction variable into our context. + let first_param = source_block.parameters()[0]; + context.inserter.try_map_value(first_param, induction_value); + // Copy over all instructions and a fresh terminator. + context.inline_instructions_from_block(); + // Mutate the terminator if possible so that it points at the iteration block. + match context.dfg()[fresh_block].unwrap_terminator() { + TerminatorInstruction::JmpIf { condition, then_destination, else_destination, call_stack } => { + let condition = *condition; + let next_blocks = context.handle_jmpif(condition, *then_destination, *else_destination, call_stack.clone()); + + // If there is only 1 next block the jmpif evaluated to a single known block. + // This is the expected case and lets us know if we should loop again or not. + if next_blocks.len() == 1 { + context.dfg_mut().inline_block(fresh_block, unroll_into); + + // The fresh block is gone now so we're committing to insert into the original + // unroll_into block from now on. + context.insert_block = unroll_into; + + // In the last iteration, `handle_jmpif` will have replaced `context.source_block` + // with the `else_destination`, that is, the `loop_end`, which signals that we + // have no more loops to unroll, because that block was not part of the loop itself, + // ie. it wasn't between `loop_header` and `loop_body`. Otherwise we have the `loop_body` + // in `source_block` and can unroll that into the destination. + Ok(self.blocks.contains(&context.source_block).then_some(context)) + } else { + // If this case is reached the loop either uses non-constant indices or we need + // another pass, such as mem2reg to resolve them to constants. + Err(context.inserter.function.dfg.get_value_call_stack(condition)) + } + } + other => unreachable!("Expected loop header to terminate in a JmpIf to the loop body, but found {other:?} instead"), + } + } + + /// Find all reference values which were allocated before the pre-header. + /// + /// These are accessible inside the loop body, and they can be involved + /// in load/store operations that could be eliminated if we unrolled the + /// body into the pre-header. + /// + /// Consider this loop: + /// ```text + /// let mut sum = 0; + /// let mut arr = &[]; + /// for i in 0..3 { + /// sum = sum + i; + /// arr.push_back(sum) + /// } + /// sum + /// ``` + /// + /// The SSA has a load+store for the `sum` and a load+push for the `arr`: + /// ```text + /// b0(v0: u32): + /// v2 = allocate -> &mut u32 // reference allocated for `sum` + /// store u32 0 at v2 // initial value for `sum` + /// v4 = allocate -> &mut u32 // reference allocated for the length of `arr` + /// store u32 0 at v4 // initial length of `arr` + /// inc_rc [] of u32 // storage for `arr` + /// v6 = allocate -> &mut [u32] // reference allocated to point at the storage of `arr` + /// store [] of u32 at v6 // initial value for the storage of `arr` + /// jmp b1(u32 0) // start looping from 0 + /// b1(v1: u32): // `i` induction variable + /// v8 = lt v1, u32 3 // loop until 3 + /// jmpif v8 then: b3, else: b2 + /// b3(): + /// v11 = load v2 -> u32 // load `sum` + /// v12 = add v11, v1 // add `i` to `sum` + /// store v12 at v2 // store updated `sum` + /// v13 = load v4 -> u32 // load length of `arr` + /// v14 = load v6 -> [u32] // load storage of `arr` + /// v16, v17 = call slice_push_back(v13, v14, v12) -> (u32, [u32]) // builtin to push, will store to storage and length references + /// v19 = add v1, u32 1 // increase `arr` + /// jmp b1(v19) // back-edge of the loop + /// b2(): // after the loop + /// v9 = load v2 -> u32 // read final value of `sum` + /// ``` + /// + /// We won't always find load _and_ store ops (e.g. the push above doesn't come with a store), + /// but it's likely that mem2reg could eliminate a lot of the loads we can find, so we can + /// use this as an approximation of the gains we would see. + fn find_pre_header_reference_values( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Result, CallStack> { + // We need to traverse blocks from the pre-header up to the block entry point. + let pre_header = self.get_pre_header(function, cfg)?; + let function_entry = function.entry_block(); + + // The algorithm in `find_blocks_in_loop` expects to collect the blocks between the header and the back-edge of the loop, + // but technically works the same if we go from the pre-header up to the function entry as well. + let blocks = Self::find_blocks_in_loop(function_entry, pre_header, cfg).blocks; + + // Collect allocations in all blocks above the header. + let allocations = blocks.iter().flat_map(|b| { + function.dfg[*b] + .instructions() + .iter() + .filter(|i| matches!(&function.dfg[**i], Instruction::Allocate)) + .map(|i| { + // Get the value into which the allocation was stored. + function.dfg.instruction_results(*i)[0] + }) + }); + + // Collect reference parameters of the function itself. + let params = + function.parameters().iter().filter(|p| function.dfg.value_is_reference(**p)).copied(); + + Ok(params.chain(allocations).collect()) + } + + /// Count the number of load and store instructions of specific variables in the loop. + /// + /// Returns `(loads, stores)` in case we want to differentiate in the estimates. + fn count_loads_and_stores( + &self, + function: &Function, + refs: &HashSet, + ) -> (usize, usize) { + let mut loads = 0; + let mut stores = 0; + for block in &self.blocks { + for instruction in function.dfg[*block].instructions() { + match &function.dfg[*instruction] { + Instruction::Load { address } if refs.contains(address) => { + loads += 1; + } + Instruction::Store { address, .. } if refs.contains(address) => { + stores += 1; + } + _ => {} + } + } + } + (loads, stores) + } + + /// Count the number of instructions in the loop, including the terminating jumps. + fn count_all_instructions(&self, function: &Function) -> usize { + self.blocks + .iter() + .map(|block| { + let block = &function.dfg[*block]; + block.instructions().len() + block.terminator().map(|_| 1).unwrap_or_default() + }) + .sum() + } + + /// Count the number of increments to the induction variable. + /// It should be one, but it can be duplicated. + /// The increment should be in the block where the back-edge was found. + fn count_induction_increments(&self, function: &Function) -> usize { + let back = &function.dfg[self.back_edge_start]; + let header = &function.dfg[self.header]; + let induction_var = header.parameters()[0]; + + back.instructions().iter().filter(|instruction| { + let instruction = &function.dfg[**instruction]; + matches!(instruction, Instruction::Binary(Binary { lhs, operator: BinaryOp::Add, rhs: _ }) if *lhs == induction_var) + }).count() + } + + /// Decide if this loop is small enough that it can be inlined in a way that the number + /// of unrolled instructions times the number of iterations would result in smaller bytecode + /// than if we keep the loops with their overheads. + fn is_small_loop(&self, function: &Function, cfg: &ControlFlowGraph) -> bool { + self.boilerplate_stats(function, cfg).map(|s| s.is_small()).unwrap_or_default() } - Loop { header, back_edge_start, blocks } + /// Collect boilerplate stats if we can figure out the upper and lower bounds of the loop, + /// and the loop doesn't have multiple back-edges from breaks and continues. + fn boilerplate_stats( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Option { + let Ok(Some((lower, upper))) = self.get_const_bounds(function, cfg) else { + return None; + }; + let Some(lower) = lower.try_to_u64() else { + return None; + }; + let Some(upper) = upper.try_to_u64() else { + return None; + }; + let Ok(refs) = self.find_pre_header_reference_values(function, cfg) else { + return None; + }; + let (loads, stores) = self.count_loads_and_stores(function, &refs); + let increments = self.count_induction_increments(function); + let all_instructions = self.count_all_instructions(function); + + Some(BoilerplateStats { + iterations: (upper - lower) as usize, + loads, + stores, + increments, + all_instructions, + }) + } } -/// Unroll a single loop in the function. -/// Returns Err(()) if it failed to unroll and Ok(()) otherwise. -fn unroll_loop( - function: &mut Function, - cfg: &ControlFlowGraph, - loop_: &Loop, -) -> Result<(), CallStack> { - let mut unroll_into = get_pre_header(cfg, loop_); - let mut jump_value = get_induction_variable(function, unroll_into)?; - let mut array_cache = Some(ArrayCache::default()); - - while let Some(mut context) = unroll_loop_header(function, loop_, unroll_into, jump_value)? { - // The inserter's array cache must be explicitly enabled. This is to - // confirm that we're inserting in insertion order. This is true here since: - // 1. We have a fresh inserter for each loop - // 2. Each loop is unrolled in iteration order - // - // Within a loop we do not insert in insertion order. This is fine however since the - // array cache is buffered with a separate fresh_array_cache which collects arrays - // but does not deduplicate. When we later call `into_array_cache`, that will merge - // the fresh cache in with the old one so that each iteration of the loop can cache - // from previous iterations but not the current iteration. - context.inserter.set_array_cache(array_cache, unroll_into); - (unroll_into, jump_value, array_cache) = context.unroll_loop_iteration(); - } - - Ok(()) +/// All the instructions in the following example are boilerplate: +/// ```text +/// brillig(inline) fn main f0 { +/// b0(v0: u32): +/// ... +/// jmp b1(u32 0) +/// b1(v1: u32): +/// v5 = lt v1, u32 4 +/// jmpif v5 then: b3, else: b2 +/// b3(): +/// ... +/// v11 = add v1, u32 1 +/// jmp b1(v11) +/// b2(): +/// ... +/// } +/// ``` +#[derive(Debug)] +struct BoilerplateStats { + /// Number of iterations in the loop. + iterations: usize, + /// Number of loads pre-header references in the loop. + loads: usize, + /// Number of stores into pre-header references in the loop. + stores: usize, + /// Number of increments to the induction variable (might be duplicated). + increments: usize, + /// Number of instructions in the loop, including boilerplate, + /// but excluding the boilerplate which is outside the loop. + all_instructions: usize, } -/// The loop pre-header is the block that comes before the loop begins. Generally a header block -/// is expected to have 2 predecessors: the pre-header and the final block of the loop which jumps -/// back to the beginning. -fn get_pre_header(cfg: &ControlFlowGraph, loop_: &Loop) -> BasicBlockId { - let mut pre_header = cfg - .predecessors(loop_.header) - .filter(|predecessor| *predecessor != loop_.back_edge_start) - .collect::>(); - - assert_eq!(pre_header.len(), 1); - pre_header.remove(0) +impl BoilerplateStats { + /// Instruction count if we leave the loop as-is. + /// It's the instructions in the loop, plus the one to kick it off in the pre-header. + fn baseline_instructions(&self) -> usize { + self.all_instructions + 1 + } + + /// Estimated number of _useful_ instructions, which is the ones in the loop + /// minus all in-loop boilerplate. + fn useful_instructions(&self) -> usize { + // Two jumps + plus the comparison with the upper bound + let boilerplate = 3; + // Be conservative and only assume that mem2reg gets rid of load followed by store. + // NB we have not checked that these are actual pairs. + let load_and_store = self.loads.min(self.stores) * 2; + self.all_instructions - self.increments - load_and_store - boilerplate + } + + /// Estimated number of instructions if we unroll the loop. + fn unrolled_instructions(&self) -> usize { + self.useful_instructions() * self.iterations + } + + /// A small loop is where if we unroll it into the pre-header then considering the + /// number of iterations we still end up with a smaller bytecode than if we leave + /// the blocks in tact with all the boilerplate involved in jumping, and the extra + /// reference access instructions. + fn is_small(&self) -> bool { + self.unrolled_instructions() < self.baseline_instructions() + } } /// Return the induction value of the current iteration of the loop, from the given block's jmp arguments. /// /// Expects the current block to terminate in `jmp h(N)` where h is the loop header and N is -/// a Field value. +/// a Field value. Returns an `Err` if this isn't the case. +/// +/// Consider the following example: +/// ```text +/// main(): +/// v0 = ... start ... +/// v1 = ... end ... +/// jmp loop_entry(v0) +/// loop_entry(i: Field): +/// ... +/// ``` +/// We're looking for the terminating jump of the `main` predecessor of `loop_entry`. fn get_induction_variable(function: &Function, block: BasicBlockId) -> Result { match function.dfg[block].terminator() { Some(TerminatorInstruction::Jmp { arguments, call_stack: location, .. }) => { @@ -270,54 +737,6 @@ fn get_induction_variable(function: &Function, block: BasicBlockId) -> Result( - function: &'a mut Function, - loop_: &'a Loop, - unroll_into: BasicBlockId, - induction_value: ValueId, -) -> Result>, CallStack> { - // We insert into a fresh block first and move instructions into the unroll_into block later - // only once we verify the jmpif instruction has a constant condition. If it does not, we can - // just discard this fresh block and leave the loop unmodified. - let fresh_block = function.dfg.make_block(); - - let mut context = LoopIteration::new(function, loop_, fresh_block, loop_.header); - let source_block = &context.dfg()[context.source_block]; - assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header"); - - // Insert the current value of the loop induction variable into our context. - let first_param = source_block.parameters()[0]; - context.inserter.try_map_value(first_param, induction_value); - context.inline_instructions_from_block(); - - match context.dfg()[fresh_block].unwrap_terminator() { - TerminatorInstruction::JmpIf { condition, then_destination, else_destination, call_stack } => { - let condition = *condition; - let next_blocks = context.handle_jmpif(condition, *then_destination, *else_destination, call_stack.clone()); - - // If there is only 1 next block the jmpif evaluated to a single known block. - // This is the expected case and lets us know if we should loop again or not. - if next_blocks.len() == 1 { - context.dfg_mut().inline_block(fresh_block, unroll_into); - - // The fresh block is gone now so we're committing to insert into the original - // unroll_into block from now on. - context.insert_block = unroll_into; - - Ok(loop_.blocks.contains(&context.source_block).then_some(context)) - } else { - // If this case is reached the loop either uses non-constant indices or we need - // another pass, such as mem2reg to resolve them to constants. - Err(context.inserter.function.dfg.get_value_call_stack(condition)) - } - } - other => unreachable!("Expected loop header to terminate in a JmpIf to the loop body, but found {other:?} instead"), - } -} - /// The context object for each loop iteration. /// Notably each loop iteration maps each loop block to a fresh, unrolled block. struct LoopIteration<'f> { @@ -379,7 +798,8 @@ impl<'f> LoopIteration<'f> { next_blocks.append(&mut blocks); } } - + // After having unrolled all blocks in the loop body, we must know how to get back to the header; + // this is also the block into which we have to unroll into next. let (end_block, induction_value) = self .induction_value .expect("Expected to find the induction variable by end of loop iteration"); @@ -390,6 +810,8 @@ impl<'f> LoopIteration<'f> { /// Unroll a single block in the current iteration of the loop fn unroll_loop_block(&mut self) -> Vec { let mut next_blocks = self.unroll_loop_block_helper(); + // Guarantee that the next blocks we set up to be unrolled, are actually part of the loop, + // which we recorded while inlining the instructions of the blocks already processed. next_blocks.retain(|block| { let b = self.get_original_block(*block); self.loop_.blocks.contains(&b) @@ -399,6 +821,7 @@ impl<'f> LoopIteration<'f> { /// Unroll a single block in the current iteration of the loop fn unroll_loop_block_helper(&mut self) -> Vec { + // Copy instructions from the loop body to the unroll destination, replacing the terminator. self.inline_instructions_from_block(); self.visited_blocks.insert(self.source_block); @@ -416,6 +839,7 @@ impl<'f> LoopIteration<'f> { ), TerminatorInstruction::Jmp { destination, arguments, call_stack: _ } => { if self.get_original_block(*destination) == self.loop_.header { + // We found the back-edge of the loop. assert_eq!(arguments.len(), 1); self.induction_value = Some((self.insert_block, arguments[0])); } @@ -427,7 +851,10 @@ impl<'f> LoopIteration<'f> { /// Find the next branch(es) to take from a jmpif terminator and return them. /// If only one block is returned, it means the jmpif condition evaluated to a known - /// constant and we can safely take only the given branch. + /// constant and we can safely take only the given branch. In this case the method + /// also replaces the terminator of the insert block (a.k.a fresh block) to be a `Jmp`, + /// and changes the source block in the context for the next iteration to be the + /// destination indicated by the constant condition (ie. the `then` or the `else`). fn handle_jmpif( &mut self, condition: ValueId, @@ -473,10 +900,13 @@ impl<'f> LoopIteration<'f> { } } + /// Find the original ID of a block that replaced it. fn get_original_block(&self, block: BasicBlockId) -> BasicBlockId { self.original_blocks.get(&block).copied().unwrap_or(block) } + /// Copy over instructions from the source into the insert block, + /// while simplifying instructions and keeping track of original block IDs. fn inline_instructions_from_block(&mut self) { let source_block = &self.dfg()[self.source_block]; let instructions = source_block.instructions().to_vec(); @@ -485,23 +915,31 @@ impl<'f> LoopIteration<'f> { // instances of the induction variable or any values that were changed as a result // of the new induction variable value. for instruction in instructions { - // Skip reference count instructions since they are only used for brillig, and brillig code is not unrolled - if !matches!( - self.dfg()[instruction], - Instruction::IncrementRc { .. } | Instruction::DecrementRc { .. } - ) { - self.inserter.push_instruction(instruction, self.insert_block); + // Reference counting is only used by Brillig, ACIR doesn't need them. + if self.inserter.function.runtime().is_acir() && self.is_refcount(instruction) { + continue; } + self.inserter.push_instruction(instruction, self.insert_block); } let mut terminator = self.dfg()[self.source_block] .unwrap_terminator() .clone() .map_values(|value| self.inserter.resolve(value)); + // Replace the blocks in the terminator with fresh one with the same parameters, + // while remembering which were the original block IDs. terminator.mutate_blocks(|block| self.get_or_insert_block(block)); self.inserter.function.dfg.set_block_terminator(self.insert_block, terminator); } + /// Is the instruction an `Rc`? + fn is_refcount(&self, instruction: InstructionId) -> bool { + matches!( + self.dfg()[instruction], + Instruction::IncrementRc { .. } | Instruction::DecrementRc { .. } + ) + } + fn dfg(&self) -> &DataFlowGraph { &self.inserter.function.dfg } @@ -513,22 +951,19 @@ impl<'f> LoopIteration<'f> { #[cfg(test)] mod tests { - use crate::{ - errors::RuntimeError, - ssa::{ - function_builder::FunctionBuilder, - ir::{instruction::BinaryOp, map::Id, types::Type}, - }, - }; + use acvm::FieldElement; - use super::Ssa; + use crate::errors::RuntimeError; + use crate::ssa::{ir::value::ValueId, opt::assert_normalized_ssa_equals, Ssa}; + + use super::{BoilerplateStats, Loops}; /// Tries to unroll all loops in each SSA function. /// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state. - fn try_to_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec) { + fn try_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec) { let mut errors = vec![]; for function in ssa.functions.values_mut() { - errors.extend(function.try_to_unroll_loops()); + errors.extend(function.try_unroll_loops()); } (ssa, errors) } @@ -542,166 +977,406 @@ mod tests { // } // } // } - // - // fn main f0 { - // b0(): - // jmp b1(Field 0) - // b1(v0: Field): // header of outer loop - // v1 = lt v0, Field 3 - // jmpif v1, then: b2, else: b3 - // b2(): - // jmp b4(Field 0) - // b4(v2: Field): // header of inner loop - // v3 = lt v2, Field 4 - // jmpif v3, then: b5, else: b6 - // b5(): - // v4 = add v0, v2 - // v5 = lt Field 10, v4 - // constrain v5 - // v6 = add v2, Field 1 - // jmp b4(v6) - // b6(): // end of inner loop - // v7 = add v0, Field 1 - // jmp b1(v7) - // b3(): // end of outer loop - // return Field 0 - // } - let main_id = Id::test_new(0); - - // Compiling main - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - let b3 = builder.insert_block(); - let b4 = builder.insert_block(); - let b5 = builder.insert_block(); - let b6 = builder.insert_block(); - - let v0 = builder.add_block_parameter(b1, Type::field()); - let v2 = builder.add_block_parameter(b4, Type::field()); - - let zero = builder.field_constant(0u128); - let one = builder.field_constant(1u128); - let three = builder.field_constant(3u128); - let four = builder.field_constant(4u128); - let ten = builder.field_constant(10u128); - - builder.terminate_with_jmp(b1, vec![zero]); - - // b1 - builder.switch_to_block(b1); - let v1 = builder.insert_binary(v0, BinaryOp::Lt, three); - builder.terminate_with_jmpif(v1, b2, b3); - - // b2 - builder.switch_to_block(b2); - builder.terminate_with_jmp(b4, vec![zero]); - - // b3 - builder.switch_to_block(b3); - builder.terminate_with_return(vec![zero]); - - // b4 - builder.switch_to_block(b4); - let v3 = builder.insert_binary(v2, BinaryOp::Lt, four); - builder.terminate_with_jmpif(v3, b5, b6); - - // b5 - builder.switch_to_block(b5); - let v4 = builder.insert_binary(v0, BinaryOp::Add, v2); - let v5 = builder.insert_binary(ten, BinaryOp::Lt, v4); - builder.insert_constrain(v5, one, None); - let v6 = builder.insert_binary(v2, BinaryOp::Add, one); - builder.terminate_with_jmp(b4, vec![v6]); - - // b6 - builder.switch_to_block(b6); - let v7 = builder.insert_binary(v0, BinaryOp::Add, one); - builder.terminate_with_jmp(b1, vec![v7]); - - let ssa = builder.finish(); - assert_eq!(ssa.main().reachable_blocks().len(), 7); - - // Expected output: - // - // fn main f0 { - // b0(): - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // jmp b23() - // b23(): - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // jmp b27() - // b27(): - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // jmp b31() - // b31(): - // jmp b3() - // b3(): - // return Field 0 - // } + let src = " + acir(inline) fn main f0 { + b0(): + jmp b1(Field 0) + b1(v0: Field): // header of outer loop + v1 = lt v0, Field 3 + jmpif v1 then: b2, else: b3 + b2(): + jmp b4(Field 0) + b4(v2: Field): // header of inner loop + v3 = lt v2, Field 4 + jmpif v3 then: b5, else: b6 + b5(): + v4 = add v0, v2 + v5 = lt Field 10, v4 + constrain v5 == Field 1 + v6 = add v2, Field 1 + jmp b4(v6) + b6(): // end of inner loop + v7 = add v0, Field 1 + jmp b1(v7) + b3(): // end of outer loop + return Field 0 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(): + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + jmp b1() + b1(): + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + jmp b2() + b2(): + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + jmp b3() + b3(): + jmp b4() + b4(): + return Field 0 + } + "; + // The final block count is not 1 because unrolling creates some unnecessary jmps. // If a simplify cfg pass is ran afterward, the expected block count will be 1. - let (ssa, errors) = try_to_unroll_loops(ssa); + let (ssa, errors) = try_unroll_loops(ssa); assert_eq!(errors.len(), 0, "All loops should be unrolled"); assert_eq!(ssa.main().reachable_blocks().len(), 5); + + assert_normalized_ssa_equals(ssa, expected); } // Test that the pass can still be run on loops which fail to unroll properly #[test] fn fail_to_unroll_loop() { - // fn main f0 { - // b0(v0: Field): - // jmp b1(v0) - // b1(v1: Field): - // v2 = lt v1, 5 - // jmpif v2, then: b2, else: b3 - // b2(): - // v3 = add v1, Field 1 - // jmp b1(v3) - // b3(): - // return Field 0 - // } - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); + let src = " + acir(inline) fn main f0 { + b0(v0: Field): + jmp b1(v0) + b1(v1: Field): + v2 = lt v1, Field 5 + jmpif v2 then: b2, else: b3 + b2(): + v3 = add v1, Field 1 + jmp b1(v3) + b3(): + return Field 0 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + // Sanity check + assert_eq!(ssa.main().reachable_blocks().len(), 4); + + // Expected that we failed to unroll the loop + let (_, errors) = try_unroll_loops(ssa); + assert_eq!(errors.len(), 1, "Expected to fail to unroll loop"); + } - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - let b3 = builder.insert_block(); + #[test] + fn test_get_const_bounds() { + let ssa = brillig_unroll_test_case(); + let function = ssa.main(); + let loops = Loops::find_all(function); + assert_eq!(loops.yet_to_unroll.len(), 1); + + let (lower, upper) = loops.yet_to_unroll[0] + .get_const_bounds(function, &loops.cfg) + .expect("should find bounds") + .expect("bounds are numeric const"); + + assert_eq!(lower, FieldElement::from(0u32)); + assert_eq!(upper, FieldElement::from(4u32)); + } - let v0 = builder.add_parameter(Type::field()); - let v1 = builder.add_block_parameter(b1, Type::field()); + #[test] + fn test_find_pre_header_reference_values() { + let ssa = brillig_unroll_test_case(); + let function = ssa.main(); + let mut loops = Loops::find_all(function); + let loop0 = loops.yet_to_unroll.pop().unwrap(); + + let refs = loop0.find_pre_header_reference_values(function, &loops.cfg).unwrap(); + assert_eq!(refs.len(), 1); + assert!(refs.contains(&ValueId::new(2))); + + let (loads, stores) = loop0.count_loads_and_stores(function, &refs); + assert_eq!(loads, 1); + assert_eq!(stores, 1); + + let all = loop0.count_all_instructions(function); + assert_eq!(all, 7); + } - builder.terminate_with_jmp(b1, vec![v0]); + #[test] + fn test_boilerplate_stats() { + let ssa = brillig_unroll_test_case(); + let stats = loop0_stats(&ssa); + assert_eq!(stats.iterations, 4); + assert_eq!(stats.all_instructions, 2 + 5); // Instructions in b1 and b3 + assert_eq!(stats.increments, 1); + assert_eq!(stats.loads, 1); + assert_eq!(stats.stores, 1); + assert_eq!(stats.useful_instructions(), 1); // Adding to sum + assert_eq!(stats.baseline_instructions(), 8); + assert!(stats.is_small()); + } - builder.switch_to_block(b1); - let five = builder.field_constant(5u128); - let v2 = builder.insert_binary(v1, BinaryOp::Lt, five); - builder.terminate_with_jmpif(v2, b2, b3); + #[test] + fn test_boilerplate_stats_6470() { + let ssa = brillig_unroll_test_case_6470(3); + let stats = loop0_stats(&ssa); + assert_eq!(stats.iterations, 3); + assert_eq!(stats.all_instructions, 2 + 8); // Instructions in b1 and b3 + assert_eq!(stats.increments, 2); + assert_eq!(stats.loads, 1); + assert_eq!(stats.stores, 1); + assert_eq!(stats.useful_instructions(), 3); // array get, add, array set + assert_eq!(stats.baseline_instructions(), 11); + assert!(stats.is_small()); + } - builder.switch_to_block(b2); - let one = builder.field_constant(1u128); - let v3 = builder.insert_binary(v1, BinaryOp::Add, one); - builder.terminate_with_jmp(b1, vec![v3]); + /// Test that we can unroll a small loop. + #[test] + fn test_brillig_unroll_small_loop() { + let ssa = brillig_unroll_test_case(); + + // Expectation taken by compiling the Noir program as ACIR, + // ie. by removing the `unconstrained` from `main`. + let expected = " + brillig(inline) fn main f0 { + b0(v0: u32): + v1 = allocate -> &mut u32 + store u32 0 at v1 + v3 = load v1 -> u32 + store v3 at v1 + v4 = load v1 -> u32 + v6 = add v4, u32 1 + store v6 at v1 + v7 = load v1 -> u32 + v9 = add v7, u32 2 + store v9 at v1 + v10 = load v1 -> u32 + v12 = add v10, u32 3 + store v12 at v1 + jmp b1() + b1(): + v13 = load v1 -> u32 + v14 = eq v13, v0 + constrain v13 == v0 + return + } + "; - builder.switch_to_block(b3); - let zero = builder.field_constant(0u128); - builder.terminate_with_return(vec![zero]); + let (ssa, errors) = try_unroll_loops(ssa); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled"); - let ssa = builder.finish(); - assert_eq!(ssa.main().reachable_blocks().len(), 4); + assert_normalized_ssa_equals(ssa, expected); + } - // Expected that we failed to unroll the loop - let (_, errors) = try_to_unroll_loops(ssa); - assert_eq!(errors.len(), 1, "Expected to fail to unroll loop"); + /// Test that we can unroll the loop in the ticket if we don't have too many iterations. + #[test] + fn test_brillig_unroll_6470_small() { + // Few enough iterations so that we can perform the unroll. + let ssa = brillig_unroll_test_case_6470(3); + let (ssa, errors) = try_unroll_loops(ssa); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled"); + + // The IDs are shifted by one compared to what the ACIR version printed. + let expected = " + brillig(inline) fn main f0 { + b0(v0: [u64; 6]): + inc_rc v0 + v2 = make_array [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] : [u64; 6] + inc_rc v2 + v3 = allocate -> &mut [u64; 6] + store v2 at v3 + v4 = load v3 -> [u64; 6] + v6 = array_get v0, index u32 0 -> u64 + v8 = add v6, u64 1 + v9 = array_set v4, index u32 0, value v8 + store v9 at v3 + v10 = load v3 -> [u64; 6] + v12 = array_get v0, index u32 1 -> u64 + v13 = add v12, u64 1 + v14 = array_set v10, index u32 1, value v13 + store v14 at v3 + v15 = load v3 -> [u64; 6] + v17 = array_get v0, index u32 2 -> u64 + v18 = add v17, u64 1 + v19 = array_set v15, index u32 2, value v18 + store v19 at v3 + jmp b1() + b1(): + v20 = load v3 -> [u64; 6] + dec_rc v0 + return v20 + } + "; + assert_normalized_ssa_equals(ssa, expected); + } + + /// Test that with more iterations it's not unrolled. + #[test] + fn test_brillig_unroll_6470_large() { + // More iterations than it can unroll + let parse_ssa = || brillig_unroll_test_case_6470(6); + let ssa = parse_ssa(); + let stats = loop0_stats(&ssa); + assert!(!stats.is_small(), "the loop should be considered large"); + + let (ssa, errors) = try_unroll_loops(ssa); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_normalized_ssa_equals(ssa, parse_ssa().to_string().as_str()); + } + + /// Test that `break` and `continue` stop unrolling without any panic. + #[test] + fn test_brillig_unroll_break_and_continue() { + // unconstrained fn main() { + // let mut count = 0; + // for i in 0..10 { + // if i == 2 { + // continue; + // } + // if i == 5 { + // break; + // } + // count += 1; + // } + // assert(count == 4); + // } + let src = " + brillig(inline) fn main f0 { + b0(): + v1 = allocate -> &mut Field + store Field 0 at v1 + jmp b1(u32 0) + b1(v0: u32): + v5 = lt v0, u32 10 + jmpif v5 then: b2, else: b6 + b2(): + v7 = eq v0, u32 2 + jmpif v7 then: b7, else: b3 + b7(): + v18 = add v0, u32 1 + jmp b1(v18) + b3(): + v9 = eq v0, u32 5 + jmpif v9 then: b5, else: b4 + b5(): + jmp b6() + b6(): + v15 = load v1 -> Field + v17 = eq v15, Field 4 + constrain v15 == Field 4 + return + b4(): + v10 = load v1 -> Field + v12 = add v10, Field 1 + store v12 at v1 + v14 = add v0, u32 1 + jmp b1(v14) + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let (ssa, errors) = try_unroll_loops(ssa); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_normalized_ssa_equals(ssa, src); + } + + /// Simple test loop: + /// ```text + /// unconstrained fn main(sum: u32) { + /// assert(loop(0, 4) == sum); + /// } + /// + /// fn loop(from: u32, to: u32) -> u32 { + /// let mut sum = 0; + /// for i in from..to { + /// sum = sum + i; + /// } + /// sum + /// } + /// ``` + /// We can check what the ACIR unrolling behavior would be by + /// removing the `unconstrained` from the `main` function and + /// compiling the program with `nargo --test-program . compile --show-ssa`. + fn brillig_unroll_test_case() -> Ssa { + let src = " + // After `static_assert` and `assert_constant`: + brillig(inline) fn main f0 { + b0(v0: u32): + v2 = allocate -> &mut u32 + store u32 0 at v2 + jmp b1(u32 0) + b1(v1: u32): + v5 = lt v1, u32 4 + jmpif v5 then: b3, else: b2 + b3(): + v8 = load v2 -> u32 + v9 = add v8, v1 + store v9 at v2 + v11 = add v1, u32 1 + jmp b1(v11) + b2(): + v6 = load v2 -> u32 + v7 = eq v6, v0 + constrain v6 == v0 + return + } + "; + Ssa::from_str(src).unwrap() + } + + /// Test case from #6470: + /// ```text + /// unconstrained fn __validate_gt_remainder(a_u60: [u64; 6]) -> [u64; 6] { + /// let mut result_u60: [u64; 6] = [0; 6]; + /// + /// for i in 0..6 { + /// result_u60[i] = a_u60[i] + 1; + /// } + /// + /// result_u60 + /// } + /// ``` + /// The `num_iterations` parameter can be used to make it more costly to inline. + fn brillig_unroll_test_case_6470(num_iterations: usize) -> Ssa { + let src = format!( + " + // After `static_assert` and `assert_constant`: + brillig(inline) fn main f0 {{ + b0(v0: [u64; 6]): + inc_rc v0 + v3 = make_array [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] : [u64; 6] + inc_rc v3 + v4 = allocate -> &mut [u64; 6] + store v3 at v4 + jmp b1(u32 0) + b1(v1: u32): + v7 = lt v1, u32 {num_iterations} + jmpif v7 then: b3, else: b2 + b3(): + v9 = load v4 -> [u64; 6] + v10 = array_get v0, index v1 -> u64 + v12 = add v10, u64 1 + v13 = array_set v9, index v1, value v12 + v15 = add v1, u32 1 + store v13 at v4 + v16 = add v1, u32 1 // duplicate + jmp b1(v16) + b2(): + v8 = load v4 -> [u64; 6] + dec_rc v0 + return v8 + }} + " + ); + Ssa::from_str(&src).unwrap() + } + + // Boilerplate stats of the first loop in the SSA. + fn loop0_stats(ssa: &Ssa) -> BoilerplateStats { + let function = ssa.main(); + let mut loops = Loops::find_all(function); + let loop0 = loops.yet_to_unroll.pop().expect("there should be a loop"); + loop0.boilerplate_stats(function, &loops.cfg).expect("there should be stats") } } diff --git a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index 9ca4f52cb14..552ac0781c7 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -25,7 +25,11 @@ struct Translator { /// Maps block names to their IDs blocks: HashMap>, - /// Maps variable names to their IDs + /// Maps variable names to their IDs. + /// + /// This is necessary because the SSA we parse might have undergone some + /// passes already which replaced some of the original IDs. The translator + /// will recreate the SSA step by step, which can result in a new ID layout. variables: HashMap>, } @@ -307,7 +311,13 @@ impl Translator { } fn finish(self) -> Ssa { - self.builder.finish() + let mut ssa = self.builder.finish(); + // Normalize the IDs so we have a better chance of matching the SSA we parsed + // after the step-by-step reconstruction done during translation. This assumes + // that the SSA we parsed was printed by the `SsaBuilder`, which normalizes + // before each print. + ssa.normalize_ids(); + ssa } fn current_function_id(&self) -> FunctionId { diff --git a/compiler/noirc_evaluator/src/ssa/parser/tests.rs b/compiler/noirc_evaluator/src/ssa/parser/tests.rs index f318e317473..60d398bf9d5 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/tests.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/tests.rs @@ -106,8 +106,8 @@ fn test_multiple_blocks_and_jmp() { acir(inline) fn main f0 { b0(): jmp b1(Field 1) - b1(v1: Field): - return v1 + b1(v0: Field): + return v0 } "; assert_ssa_roundtrip(src); @@ -118,11 +118,11 @@ fn test_jmpif() { let src = " acir(inline) fn main f0 { b0(v0: Field): - jmpif v0 then: b1, else: b2 - b1(): - return + jmpif v0 then: b2, else: b1 b2(): return + b1(): + return } "; assert_ssa_roundtrip(src); diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index aac6e33bf5b..38a0565866f 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -466,6 +466,7 @@ impl<'a> FunctionContext<'a> { /// /// For example, the loop `for i in start .. end { body }` is codegen'd as: /// + /// ```text /// v0 = ... codegen start ... /// v1 = ... codegen end ... /// br loop_entry(v0) @@ -478,6 +479,7 @@ impl<'a> FunctionContext<'a> { /// br loop_entry(v4) /// loop_end(): /// ... This is the current insert point after codegen_for finishes ... + /// ``` fn codegen_for(&mut self, for_expr: &ast::For) -> Result { let loop_entry = self.builder.insert_block(); let loop_body = self.builder.insert_block(); @@ -529,6 +531,7 @@ impl<'a> FunctionContext<'a> { /// /// For example, the expression `if cond { a } else { b }` is codegen'd as: /// + /// ```text /// v0 = ... codegen cond ... /// brif v0, then: then_block, else: else_block /// then_block(): @@ -539,16 +542,19 @@ impl<'a> FunctionContext<'a> { /// br end_if(v2) /// end_if(v3: ?): // Type of v3 matches the type of a and b /// ... This is the current insert point after codegen_if finishes ... + /// ``` /// /// As another example, the expression `if cond { a }` is codegen'd as: /// + /// ```text /// v0 = ... codegen cond ... - /// brif v0, then: then_block, else: end_block + /// brif v0, then: then_block, else: end_if /// then_block: /// v1 = ... codegen a ... /// br end_if() /// end_if: // No block parameter is needed. Without an else, the unit value is always returned. /// ... This is the current insert point after codegen_if finishes ... + /// ``` fn codegen_if(&mut self, if_expr: &ast::If) -> Result { let condition = self.codegen_non_tuple_expression(&if_expr.condition)?;