Skip to content

Commit

Permalink
Rework defunctionalize pass
Browse files Browse the repository at this point in the history
  • Loading branch information
jfecher committed Jan 28, 2025
1 parent ea10cf3 commit d204f44
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 72 deletions.
90 changes: 46 additions & 44 deletions compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,15 @@ impl Binary {

/// Try to simplify this binary instruction, returning the new value if possible.
pub(super) fn simplify(&self, dfg: &mut DataFlowGraph) -> SimplifyResult {
let lhs_value = dfg.get_numeric_constant(self.lhs);
let rhs_value = dfg.get_numeric_constant(self.rhs);
let lhs = dfg.resolve(self.lhs);
let rhs = dfg.resolve(self.rhs);

let lhs_type = dfg.type_of_value(self.lhs).unwrap_numeric();
let rhs_type = dfg.type_of_value(self.rhs).unwrap_numeric();
let lhs_value = dfg.get_numeric_constant(lhs);
let rhs_value = dfg.get_numeric_constant(rhs);
eprintln!("{lhs} = {lhs_value:?}, {rhs} = {rhs_value:?}");

let lhs_type = dfg.type_of_value(lhs).unwrap_numeric();
let rhs_type = dfg.type_of_value(rhs).unwrap_numeric();

let operator = self.operator;
if operator != BinaryOp::Shl && operator != BinaryOp::Shr {
Expand Down Expand Up @@ -124,7 +128,7 @@ impl Binary {
};

// We never return `SimplifyResult::None` here because `operator` might have changed.
let simplified = Instruction::Binary(Binary { lhs: self.lhs, rhs: self.rhs, operator });
let simplified = Instruction::Binary(Binary { lhs, rhs, operator });

if let (Some(lhs), Some(rhs)) = (lhs_value, rhs_value) {
return match eval_constant_binary_op(lhs, rhs, operator, lhs_type) {
Expand All @@ -145,66 +149,64 @@ impl Binary {
match self.operator {
BinaryOp::Add { .. } => {
if lhs_is_zero {
return SimplifyResult::SimplifiedTo(self.rhs);
return SimplifyResult::SimplifiedTo(rhs);
}
if rhs_is_zero {
return SimplifyResult::SimplifiedTo(self.lhs);
return SimplifyResult::SimplifiedTo(lhs);
}
}
BinaryOp::Sub { .. } => {
if rhs_is_zero {
return SimplifyResult::SimplifiedTo(self.lhs);
return SimplifyResult::SimplifiedTo(lhs);
}
}
BinaryOp::Mul { .. } => {
if lhs_is_one {
return SimplifyResult::SimplifiedTo(self.rhs);
return SimplifyResult::SimplifiedTo(rhs);
}
if rhs_is_one {
return SimplifyResult::SimplifiedTo(self.lhs);
return SimplifyResult::SimplifiedTo(lhs);
}
if lhs_is_zero || rhs_is_zero {
let zero = dfg.make_constant(FieldElement::zero(), lhs_type);
return SimplifyResult::SimplifiedTo(zero);
}
if dfg.get_value_max_num_bits(self.lhs) == 1 {
if dfg.get_value_max_num_bits(lhs) == 1 {
// Squaring a boolean value is a noop.
if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) {
return SimplifyResult::SimplifiedTo(self.lhs);
if lhs == rhs {
return SimplifyResult::SimplifiedTo(lhs);
}
// b*(b*x) = b*x if b is boolean
if let super::Value::Instruction { instruction, .. } = &dfg[self.rhs] {
if let Instruction::Binary(Binary { lhs, rhs, operator }) =
if let super::Value::Instruction { instruction, .. } = &dfg[rhs] {
if let Instruction::Binary(Binary { lhs: b_lhs, rhs: b_rhs, operator }) =
dfg[*instruction]
{
if matches!(operator, BinaryOp::Mul { .. })
&& (dfg.resolve(self.lhs) == dfg.resolve(lhs)
|| dfg.resolve(self.lhs) == dfg.resolve(rhs))
&& (lhs == dfg.resolve(b_lhs) || lhs == dfg.resolve(b_rhs))
{
return SimplifyResult::SimplifiedTo(self.rhs);
return SimplifyResult::SimplifiedTo(rhs);
}
}
}
}
// (b*x)*b = b*x if b is boolean
if dfg.get_value_max_num_bits(self.rhs) == 1 {
if let super::Value::Instruction { instruction, .. } = &dfg[self.lhs] {
if let Instruction::Binary(Binary { lhs, rhs, operator }) =
if dfg.get_value_max_num_bits(rhs) == 1 {
if let super::Value::Instruction { instruction, .. } = &dfg[lhs] {
if let Instruction::Binary(Binary { lhs: b_lhs, rhs: b_rhs, operator }) =
dfg[*instruction]
{
if matches!(operator, BinaryOp::Mul { .. })
&& (dfg.resolve(self.rhs) == dfg.resolve(lhs)
|| dfg.resolve(self.rhs) == dfg.resolve(rhs))
&& (rhs == dfg.resolve(b_lhs) || rhs == dfg.resolve(b_rhs))
{
return SimplifyResult::SimplifiedTo(self.lhs);
return SimplifyResult::SimplifiedTo(lhs);
}
}
}
}
}
BinaryOp::Div => {
if rhs_is_one {
return SimplifyResult::SimplifiedTo(self.lhs);
return SimplifyResult::SimplifiedTo(lhs);
}
}
BinaryOp::Mod => {
Expand All @@ -221,7 +223,7 @@ impl Binary {
let bit_size = modulus.ilog2();
return SimplifyResult::SimplifiedToInstruction(
Instruction::Truncate {
value: self.lhs,
value: lhs,
bit_size,
max_bit_size: lhs_type.bit_size(),
},
Expand All @@ -231,30 +233,30 @@ impl Binary {
}
}
BinaryOp::Eq => {
if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) {
if lhs == rhs {
let one = dfg.make_constant(FieldElement::one(), NumericType::bool());
return SimplifyResult::SimplifiedTo(one);
}

if lhs_type == NumericType::bool() {
// Simplify forms of `(boolean == true)` into `boolean`
if lhs_is_one {
return SimplifyResult::SimplifiedTo(self.rhs);
return SimplifyResult::SimplifiedTo(rhs);
}
if rhs_is_one {
return SimplifyResult::SimplifiedTo(self.lhs);
return SimplifyResult::SimplifiedTo(lhs);
}
// Simplify forms of `(boolean == false)` into `!boolean`
if lhs_is_zero {
return SimplifyResult::SimplifiedToInstruction(Instruction::Not(self.rhs));
return SimplifyResult::SimplifiedToInstruction(Instruction::Not(rhs));
}
if rhs_is_zero {
return SimplifyResult::SimplifiedToInstruction(Instruction::Not(self.lhs));
return SimplifyResult::SimplifiedToInstruction(Instruction::Not(lhs));
}
}
}
BinaryOp::Lt => {
if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) {
if lhs == rhs {
let zero = dfg.make_constant(FieldElement::zero(), NumericType::bool());
return SimplifyResult::SimplifiedTo(zero);
}
Expand All @@ -267,7 +269,7 @@ impl Binary {
let zero = dfg.make_constant(FieldElement::zero(), lhs_type);
return SimplifyResult::SimplifiedToInstruction(Instruction::binary(
BinaryOp::Eq,
self.lhs,
lhs,
zero,
));
}
Expand All @@ -278,14 +280,14 @@ impl Binary {
let zero = dfg.make_constant(FieldElement::zero(), lhs_type);
return SimplifyResult::SimplifiedTo(zero);
}
if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) {
return SimplifyResult::SimplifiedTo(self.lhs);
if lhs == rhs {
return SimplifyResult::SimplifiedTo(lhs);
}
if lhs_type == NumericType::bool() {
// Boolean AND is equivalent to multiplication, which is a cheaper operation.
// (mul unchecked because these are bools so it doesn't matter really)

Check warning on line 288 in compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (bools)
let instruction =
Instruction::binary(BinaryOp::Mul { unchecked: true }, self.lhs, self.rhs);
Instruction::binary(BinaryOp::Mul { unchecked: true }, lhs, rhs);
return SimplifyResult::SimplifiedToInstruction(instruction);
}
if lhs_type.is_unsigned() {
Expand All @@ -299,7 +301,7 @@ impl Binary {
// The bitmask must then be one less than a power of 2.
let bitmask_plus_one = bitmask.to_u128() + 1;
if bitmask_plus_one.is_power_of_two() {
let value = if lhs_value.is_some() { self.rhs } else { self.lhs };
let value = if lhs_value.is_some() { rhs } else { lhs };
let num_bits = bitmask_plus_one.ilog2();
return SimplifyResult::SimplifiedToInstruction(
Instruction::Truncate {
Expand All @@ -317,27 +319,27 @@ impl Binary {
}
BinaryOp::Or => {
if lhs_is_zero {
return SimplifyResult::SimplifiedTo(self.rhs);
return SimplifyResult::SimplifiedTo(rhs);
}
if rhs_is_zero {
return SimplifyResult::SimplifiedTo(self.lhs);
return SimplifyResult::SimplifiedTo(lhs);
}
if lhs_type == NumericType::bool() && (lhs_is_one || rhs_is_one) {
let one = dfg.make_constant(FieldElement::one(), lhs_type);
return SimplifyResult::SimplifiedTo(one);
}
if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) {
return SimplifyResult::SimplifiedTo(self.lhs);
if lhs == rhs {
return SimplifyResult::SimplifiedTo(lhs);
}
}
BinaryOp::Xor => {
if lhs_is_zero {
return SimplifyResult::SimplifiedTo(self.rhs);
return SimplifyResult::SimplifiedTo(rhs);
}
if rhs_is_zero {
return SimplifyResult::SimplifiedTo(self.lhs);
return SimplifyResult::SimplifiedTo(lhs);
}
if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) {
if lhs == rhs {
let zero = dfg.make_constant(FieldElement::zero(), lhs_type);
return SimplifyResult::SimplifiedTo(zero);
}
Expand Down
88 changes: 60 additions & 28 deletions compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,35 @@ impl DefunctionalizationContext {
let mut call_target_values = HashSet::new();

for block_id in func.reachable_blocks() {
let block = &func.dfg[block_id];
let instructions = block.instructions().to_vec();
let block = &mut func.dfg[block_id];

for instruction_id in instructions {
let instruction = func.dfg[instruction_id].clone();
// Temporarily take the parameters here just to avoid cloning them
let parameters = block.take_parameters();
for parameter in &parameters {
if func.dfg.type_of_value(*parameter) == Type::Function {
func.dfg.set_type_of_value(*parameter, Type::field());
}
}

let block = &mut func.dfg[block_id];
block.set_parameters(parameters);

for instruction_id in block.instructions().to_vec() {
let mut instruction = func.dfg[instruction_id].clone();
let mut replacement_instruction = None;

if remove_first_class_functions_in_instruction(func, &mut instruction) {
func.dfg[instruction_id] = instruction.clone();
}

// Operate on call instructions
let (target_func_id, arguments) = match &instruction {
Instruction::Call { func: target_func_id, arguments } => {
(*target_func_id, arguments)
}
_ => continue,
_ => {
continue;
}
};

match func.dfg[target_func_id] {
Expand Down Expand Up @@ -130,29 +147,6 @@ impl DefunctionalizationContext {
}
}
}

// Change the type of all the values that are not call targets to NativeField
let value_ids = vecmap(func.dfg.values_iter(), |(id, _)| id);
for value_id in value_ids {
if let Type::Function = func.dfg[value_id].get_type().as_ref() {
match &func.dfg[value_id] {
// If the value is a static function, transform it to the function id
Value::Function(id) => {
if !call_target_values.contains(&value_id) {
let field = NumericType::NativeField;
let new_value =
func.dfg.make_constant(function_id_to_field(*id), field);
func.dfg.set_value_from_id(value_id, new_value);
}
}
// If the value is a function used as value, just change the type of it
Value::Instruction { .. } | Value::Param { .. } => {
func.dfg.set_type_of_value(value_id, Type::field());
}
_ => {}
}
}
}
}

/// Returns the apply function for the given signature
Expand All @@ -161,6 +155,44 @@ impl DefunctionalizationContext {
}
}

/// Replace any first class functions used in an instruction with a field value.
/// This applies to any function used anywhere else other than the function position
/// of a call instruction. Returns true if the instruction was modified
fn remove_first_class_functions_in_instruction(
func: &mut Function,
instruction: &mut Instruction,
) -> bool {
let mut modified = false;
let mut map_value = |value: ValueId| {
if let Type::Function = func.dfg[value].get_type().as_ref() {
match &func.dfg[value] {
// If the value is a static function, transform it to the function id
Value::Function(id) => {
let new_value = function_id_to_field(*id);
modified = true;
return func.dfg.make_constant(new_value, NumericType::NativeField);
}
// If the value is a function used as value, just change the type of it
Value::Instruction { .. } | Value::Param { .. } => {
func.dfg.set_type_of_value(value, Type::field());
}
_ => (),
}
}
value
};

if let Instruction::Call { func: _, arguments } = instruction {
for arg in arguments {
*arg = map_value(*arg);
}
} else {
instruction.map_values_mut(map_value);
}

modified
}

/// Collects all functions used as values that can be called by their signatures
fn find_variants(ssa: &Ssa) -> Variants {
let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new();
Expand Down

0 comments on commit d204f44

Please sign in to comment.