Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: inline simple functions #7160

Merged
merged 15 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,16 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result<Ss
Ok(builder
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (1st)")
.run_pass(Ssa::defunctionalize, "Defunctionalization")
.run_pass(Ssa::inline_simple_functions, "Inlining simple functions")
.run_pass(Ssa::mem2reg, "Mem2Reg (1st)")
.run_pass(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs")
.run_pass(
|ssa| ssa.preprocess_functions(options.inliner_aggressiveness),
"Preprocessing Functions",
)
.run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "Inlining (1st)")
// Run mem2reg with the CFG separated into blocks
.run_pass(Ssa::mem2reg, "Mem2Reg (1st)")
.run_pass(Ssa::mem2reg, "Mem2Reg (2nd)")
.run_pass(Ssa::simplify_cfg, "Simplifying (1st)")
.run_pass(Ssa::as_slice_optimization, "`as_slice` optimization")
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (2nd)")
Expand All @@ -173,11 +175,11 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result<Ss
"Unrolling",
)?
.run_pass(Ssa::simplify_cfg, "Simplifying (2nd)")
.run_pass(Ssa::mem2reg, "Mem2Reg (2nd)")
.run_pass(Ssa::mem2reg, "Mem2Reg (3rd)")
.run_pass(Ssa::flatten_cfg, "Flattening")
.run_pass(Ssa::remove_bit_shifts, "Removing Bit Shifts")
// Run mem2reg once more with the flattened CFG to catch any remaining loads/stores
.run_pass(Ssa::mem2reg, "Mem2Reg (3rd)")
.run_pass(Ssa::mem2reg, "Mem2Reg (4th)")
// Run the inlining pass again to handle functions with `InlineType::NoPredicates`.
// Before flattening is run, we treat functions marked with the `InlineType::NoPredicates` as an entry point.
// This pass must come immediately following `mem2reg` as the succeeding passes
Expand Down
103 changes: 102 additions & 1 deletion compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,34 @@ impl Ssa {
});
self
}

pub(crate) fn inline_simple_functions(mut self: Ssa) -> Ssa {
let should_inline_call = |callee: &Function| {
if let RuntimeType::Acir(_) = callee.runtime() {
// Functions marked to not have predicates should be preserved.
if callee.is_no_predicates() {
return false;
}
}

let entry_block_id = callee.entry_block();
let entry_block = &callee.dfg[entry_block_id];

// Only inline functions with a single block
if entry_block.successors().next().is_some() {
return false;
}

// Only inline functions with 0 or 1 instructions
entry_block.instructions().len() <= 1
};

self.functions = btree_map(self.functions.iter(), |(id, function)| {
(*id, function.inlined(&self, &should_inline_call))
});

self
}
}

impl Function {
Expand Down Expand Up @@ -185,7 +213,7 @@ pub(super) struct InlineInfo {
is_brillig_entry_point: bool,
is_acir_entry_point: bool,
is_recursive: bool,
should_inline: bool,
pub(super) should_inline: bool,
weight: i64,
cost: i64,
}
Expand Down Expand Up @@ -1123,6 +1151,7 @@ mod test {
map::Id,
types::{NumericType, Type},
},
opt::assert_normalized_ssa_equals,
Ssa,
};

Expand Down Expand Up @@ -1597,4 +1626,76 @@ mod test {
);
assert!(tws[3] > max(tws[1], tws[2]), "ideally 'main' has the most weight");
}

#[test]
fn inline_simple_functions_with_zero_instructions() {
let src = "
acir(inline) fn main f0 {
b0(v0: Field):
v2 = call f1(v0) -> Field
v3 = call f1(v0) -> Field
v4 = add v2, v3
return v4
}

acir(inline) fn foo f1 {
b0(v0: Field):
return v0
}
";
let ssa = Ssa::from_str(src).unwrap();

let expected = "
acir(inline) fn main f0 {
b0(v0: Field):
v1 = add v0, v0
return v1
}
acir(inline) fn foo f1 {
b0(v0: Field):
return v0
}
";

let ssa = ssa.inline_simple_functions();
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn inline_simple_functions_with_one_instruction() {
let src = "
acir(inline) fn main f0 {
b0(v0: Field):
v2 = call f1(v0) -> Field
v3 = call f1(v0) -> Field
v4 = add v2, v3
return v4
}

acir(inline) fn foo f1 {
b0(v0: Field):
v2 = add v0, Field 1
return v2
}
";
let ssa = Ssa::from_str(src).unwrap();

let expected = "
acir(inline) fn main f0 {
b0(v0: Field):
v2 = add v0, Field 1
v3 = add v0, Field 1
v4 = add v2, v3
return v4
}
acir(inline) fn foo f1 {
b0(v0: Field):
v2 = add v0, Field 1
return v2
}
";

let ssa = ssa.inline_simple_functions();
assert_normalized_ssa_equals(ssa, expected);
}
}
Loading