Skip to content

Commit

Permalink
Keep inc_rc during preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
aakoshh committed Jan 23, 2025
1 parent 53abd0e commit f1342c9
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 19 deletions.
99 changes: 82 additions & 17 deletions compiler/noirc_evaluator/src/ssa/opt/die.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@ impl Ssa {
/// This step should come after the flattening of the CFG and mem2reg.
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn dead_instruction_elimination(self) -> Ssa {
self.dead_instruction_elimination_inner(true)
self.dead_instruction_elimination_inner(true, false)
}

fn dead_instruction_elimination_inner(mut self, flattened: bool) -> Ssa {
fn dead_instruction_elimination_inner(
mut self,
flattened: bool,
keep_rcs_of_parameters: bool,
) -> Ssa {
let mut used_global_values: HashSet<_> = self
.functions
.par_iter_mut()
.flat_map(|(_, func)| func.dead_instruction_elimination(true, flattened))
.flat_map(|(_, func)| {
func.dead_instruction_elimination(true, flattened, keep_rcs_of_parameters)
})
.collect();

let globals = &self.functions[&self.main_id].dfg.globals;
Expand Down Expand Up @@ -69,6 +75,7 @@ impl Function {
&mut self,
insert_out_of_bounds_checks: bool,
flattened: bool,
keep_rcs_of_parameters: bool,
) -> HashSet<ValueId> {
let mut context = Context { flattened, ..Default::default() };

Expand All @@ -84,14 +91,15 @@ impl Function {
self,
*block,
insert_out_of_bounds_checks,
keep_rcs_of_parameters,
);
}

// If we inserted out of bounds check, let's run the pass again with those new
// instructions (we don't want to remove those checks, or instructions that are
// dependencies of those checks)
if inserted_out_of_bounds_checks {
return self.dead_instruction_elimination(false, flattened);
return self.dead_instruction_elimination(false, flattened, keep_rcs_of_parameters);
}

context.remove_rc_instructions(&mut self.dfg);
Expand Down Expand Up @@ -140,6 +148,7 @@ impl Context {
function: &mut Function,
block_id: BasicBlockId,
insert_out_of_bounds_checks: bool,
keep_rcs_of_parameters: bool,
) -> bool {
let block = &function.dfg[block_id];
self.mark_terminator_values_as_used(function, block);
Expand All @@ -148,10 +157,17 @@ impl Context {

let mut rc_tracker = RcTracker::default();

// During the preprocessing of functions in isolation we don't want to
// get rid of IncRCs arrays that can potentially be mutated outside.
if keep_rcs_of_parameters {
rc_tracker.track_function_parameters(function);
}

// Indexes of instructions that might be out of bounds.
// We'll remove those, but before that we'll insert bounds checks for them.
let mut possible_index_out_of_bounds_indexes = Vec::new();

// Going in reverse so we know if a result of an instruction was used.
for (instruction_index, instruction_id) in block.instructions().iter().rev().enumerate() {
let instruction = &function.dfg[*instruction_id];

Expand Down Expand Up @@ -241,6 +257,8 @@ impl Context {
}
}

/// Go through the RC instructions collected when we figured out which values were unused;
/// for each RC that refers to an unused value, remove the RC as well.
fn remove_rc_instructions(&self, dfg: &mut DataFlowGraph) {
let unused_rc_values_by_block: HashMap<BasicBlockId, HashSet<InstructionId>> =
self.rc_instructions.iter().fold(HashMap::default(), |mut acc, (rc, block)| {
Expand Down Expand Up @@ -580,10 +598,12 @@ struct RcTracker {
// with the same value but no array set in between.
// If we see an inc/dec RC pair within a block we can safely remove both instructions.
rcs_with_possible_pairs: HashMap<Type, Vec<RcInstruction>>,
// Tracks repeated RC instructions: if there are two `inc_rc` for the same value in a row, the 2nd one is redundant.
rc_pairs_to_remove: HashSet<InstructionId>,
// We also separately track all IncrementRc instructions and all array types which have been mutably borrowed.
// If an array is the same type as one of those non-mutated array types, we can safely remove all IncrementRc instructions on that array.
inc_rcs: HashMap<ValueId, HashSet<InstructionId>>,
// When tracking mutations we consider arrays with the same type as all being possibly mutated.
mutated_array_types: HashSet<Type>,
// The SSA often creates patterns where after simplifications we end up with repeat
// IncrementRc instructions on the same value. We track whether the previous instruction was an IncrementRc,
Expand All @@ -593,9 +613,23 @@ struct RcTracker {
}

impl RcTracker {
/// Mark any array parameters to the function itself as possibly mutated,
/// so we don't get rid of RC instructions just because we don't mutate
/// them in this function, which could potentially cause them to be
/// mutated outside the function without our consent.
fn track_function_parameters(&mut self, function: &Function) {
for parameter in function.parameters() {
let typ = function.dfg.type_of_value(*parameter);
if typ.contains_an_array() {
self.mutated_array_types.insert(typ);
}
}
}

fn track_inc_rcs_to_remove(&mut self, instruction_id: InstructionId, function: &Function) {
let instruction = &function.dfg[instruction_id];

// Deduplicate IncRC instructions.
if let Instruction::IncrementRc { value } = instruction {
if let Some(previous_value) = self.previous_inc_rc {
if previous_value == *value {
Expand All @@ -604,13 +638,16 @@ impl RcTracker {
}
self.previous_inc_rc = Some(*value);
} else {
// Reset the deduplication.
self.previous_inc_rc = None;
}

// DIE loops over a block in reverse order, so we insert an RC instruction for possible removal
// when we see a DecrementRc and check whether it was possibly mutated when we see an IncrementRc.
match instruction {
Instruction::IncrementRc { value } => {
// Get any RC instruction recorded further down the block for this array;
// if it exists and not marked as mutated, then both RCs can be removed.
if let Some(inc_rc) =
pop_rc_for(*value, function, &mut self.rcs_with_possible_pairs)
{
Expand All @@ -619,7 +656,7 @@ impl RcTracker {
self.rc_pairs_to_remove.insert(instruction_id);
}
}

// Remember that this array was RC'd by this instruction.
self.inc_rcs.entry(*value).or_default().insert(instruction_id);
}
Instruction::DecrementRc { value } => {
Expand All @@ -632,12 +669,12 @@ impl RcTracker {
}
Instruction::ArraySet { array, .. } => {
let typ = function.dfg.type_of_value(*array);
// We mark all RCs that refer to arrays with a matching type as the one being set, as possibly mutated.
if let Some(dec_rcs) = self.rcs_with_possible_pairs.get_mut(&typ) {
for dec_rc in dec_rcs {
dec_rc.possibly_mutated = true;
}
}

self.mutated_array_types.insert(typ);
}
Instruction::Store { value, .. } => {
Expand All @@ -648,6 +685,7 @@ impl RcTracker {
}
}
Instruction::Call { arguments, .. } => {
// Treat any array-type arguments to calls as possible sources of mutation.
for arg in arguments {
let typ = function.dfg.type_of_value(*arg);
if matches!(&typ, Type::Array(..) | Type::Slice(..)) {
Expand All @@ -659,6 +697,7 @@ impl RcTracker {
}
}

/// Get all RC instructions which work on arrays whose type has not been marked as mutated.
fn get_non_mutated_arrays(&self, dfg: &DataFlowGraph) -> HashSet<InstructionId> {
self.inc_rcs
.keys()
Expand Down Expand Up @@ -857,16 +896,6 @@ mod test {

#[test]
fn keep_inc_rc_on_borrowed_array_set() {
// brillig(inline) fn main f0 {
// b0(v0: [u32; 2]):
// inc_rc v0
// v3 = array_set v0, index u32 0, value u32 1
// inc_rc v0
// inc_rc v0
// inc_rc v0
// v4 = array_get v3, index u32 1
// return v4
// }
let src = "
brillig(inline) fn main f0 {
b0(v0: [u32; 2]):
Expand Down Expand Up @@ -919,6 +948,42 @@ mod test {
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn not_remove_inc_rcs_for_input_parameters() {
let src = "
brillig(inline) fn main f0 {
b0(v0: [Field; 2]):
inc_rc v0
inc_rc v0
inc_rc v0
v2 = array_get v0, index u32 0 -> Field
inc_rc v0
return v2
}
";

let ssa = Ssa::from_str(src).unwrap();
let main = ssa.main();

// The instruction count never includes the terminator instruction
assert_eq!(main.dfg[main.entry_block()].instructions().len(), 5);

let expected = "
brillig(inline) fn main f0 {
b0(v0: [Field; 2]):
inc_rc v0
v2 = array_get v0, index u32 0 -> Field
inc_rc v0
return v2
}
";

// We want to be able to switch this on during preprocessing.
let keep_rcs_of_parameters = true;
let ssa = ssa.dead_instruction_elimination_inner(true, keep_rcs_of_parameters);
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn remove_inc_rcs_that_are_never_mutably_borrowed() {
let src = "
Expand Down Expand Up @@ -1010,7 +1075,7 @@ mod test {
let ssa = Ssa::from_str(src).unwrap();

// Even though these ACIR functions only have 1 block, we have not inlined and flattened anything yet.
let ssa = ssa.dead_instruction_elimination_inner(false);
let ssa = ssa.dead_instruction_elimination_inner(false, false);

let expected = "
acir(inline) fn main f0 {
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl Ssa {
// Try to reduce the number of blocks.
function.simplify_function();
// Remove leftover instructions.
function.dead_instruction_elimination(true, false);
function.dead_instruction_elimination(true, false, true);

// Put it back into the SSA, so the next functions can pick it up.
self.functions.insert(id, function);
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ fn brillig_bytecode_size(
simplify_between_unrolls(&mut temp);

// This is to try to prevent hitting ICE.
temp.dead_instruction_elimination(false, true);
temp.dead_instruction_elimination(false, true, false);

convert_ssa_function(&temp, false, globals).byte_code.len()
}
Expand Down

0 comments on commit f1342c9

Please sign in to comment.