diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs index df1e8f537da..a1b8a166033 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs @@ -89,11 +89,15 @@ impl Binary { /// Try to simplify this binary instruction, returning the new value if possible. pub(super) fn simplify(&self, dfg: &mut DataFlowGraph) -> SimplifyResult { - let lhs_value = dfg.get_numeric_constant(self.lhs); - let rhs_value = dfg.get_numeric_constant(self.rhs); + let lhs = dfg.resolve(self.lhs); + let rhs = dfg.resolve(self.rhs); - let lhs_type = dfg.type_of_value(self.lhs).unwrap_numeric(); - let rhs_type = dfg.type_of_value(self.rhs).unwrap_numeric(); + let lhs_value = dfg.get_numeric_constant(lhs); + let rhs_value = dfg.get_numeric_constant(rhs); + eprintln!("{lhs} = {lhs_value:?}, {rhs} = {rhs_value:?}"); + + let lhs_type = dfg.type_of_value(lhs).unwrap_numeric(); + let rhs_type = dfg.type_of_value(rhs).unwrap_numeric(); let operator = self.operator; if operator != BinaryOp::Shl && operator != BinaryOp::Shr { @@ -124,7 +128,7 @@ impl Binary { }; // We never return `SimplifyResult::None` here because `operator` might have changed. - let simplified = Instruction::Binary(Binary { lhs: self.lhs, rhs: self.rhs, operator }); + let simplified = Instruction::Binary(Binary { lhs, rhs, operator }); if let (Some(lhs), Some(rhs)) = (lhs_value, rhs_value) { return match eval_constant_binary_op(lhs, rhs, operator, lhs_type) { @@ -145,58 +149,56 @@ impl Binary { match self.operator { BinaryOp::Add { .. } => { if lhs_is_zero { - return SimplifyResult::SimplifiedTo(self.rhs); + return SimplifyResult::SimplifiedTo(rhs); } if rhs_is_zero { - return SimplifyResult::SimplifiedTo(self.lhs); + return SimplifyResult::SimplifiedTo(lhs); } } BinaryOp::Sub { .. } => { if rhs_is_zero { - return SimplifyResult::SimplifiedTo(self.lhs); + return SimplifyResult::SimplifiedTo(lhs); } } BinaryOp::Mul { .. } => { if lhs_is_one { - return SimplifyResult::SimplifiedTo(self.rhs); + return SimplifyResult::SimplifiedTo(rhs); } if rhs_is_one { - return SimplifyResult::SimplifiedTo(self.lhs); + return SimplifyResult::SimplifiedTo(lhs); } if lhs_is_zero || rhs_is_zero { let zero = dfg.make_constant(FieldElement::zero(), lhs_type); return SimplifyResult::SimplifiedTo(zero); } - if dfg.get_value_max_num_bits(self.lhs) == 1 { + if dfg.get_value_max_num_bits(lhs) == 1 { // Squaring a boolean value is a noop. - if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) { - return SimplifyResult::SimplifiedTo(self.lhs); + if lhs == rhs { + return SimplifyResult::SimplifiedTo(lhs); } // b*(b*x) = b*x if b is boolean - if let super::Value::Instruction { instruction, .. } = &dfg[self.rhs] { - if let Instruction::Binary(Binary { lhs, rhs, operator }) = + if let super::Value::Instruction { instruction, .. } = &dfg[rhs] { + if let Instruction::Binary(Binary { lhs: b_lhs, rhs: b_rhs, operator }) = dfg[*instruction] { if matches!(operator, BinaryOp::Mul { .. }) - && (dfg.resolve(self.lhs) == dfg.resolve(lhs) - || dfg.resolve(self.lhs) == dfg.resolve(rhs)) + && (lhs == dfg.resolve(b_lhs) || lhs == dfg.resolve(b_rhs)) { - return SimplifyResult::SimplifiedTo(self.rhs); + return SimplifyResult::SimplifiedTo(rhs); } } } } // (b*x)*b = b*x if b is boolean - if dfg.get_value_max_num_bits(self.rhs) == 1 { - if let super::Value::Instruction { instruction, .. } = &dfg[self.lhs] { - if let Instruction::Binary(Binary { lhs, rhs, operator }) = + if dfg.get_value_max_num_bits(rhs) == 1 { + if let super::Value::Instruction { instruction, .. } = &dfg[lhs] { + if let Instruction::Binary(Binary { lhs: b_lhs, rhs: b_rhs, operator }) = dfg[*instruction] { if matches!(operator, BinaryOp::Mul { .. }) - && (dfg.resolve(self.rhs) == dfg.resolve(lhs) - || dfg.resolve(self.rhs) == dfg.resolve(rhs)) + && (rhs == dfg.resolve(b_lhs) || rhs == dfg.resolve(b_rhs)) { - return SimplifyResult::SimplifiedTo(self.lhs); + return SimplifyResult::SimplifiedTo(lhs); } } } @@ -204,7 +206,7 @@ impl Binary { } BinaryOp::Div => { if rhs_is_one { - return SimplifyResult::SimplifiedTo(self.lhs); + return SimplifyResult::SimplifiedTo(lhs); } } BinaryOp::Mod => { @@ -221,7 +223,7 @@ impl Binary { let bit_size = modulus.ilog2(); return SimplifyResult::SimplifiedToInstruction( Instruction::Truncate { - value: self.lhs, + value: lhs, bit_size, max_bit_size: lhs_type.bit_size(), }, @@ -231,7 +233,7 @@ impl Binary { } } BinaryOp::Eq => { - if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) { + if lhs == rhs { let one = dfg.make_constant(FieldElement::one(), NumericType::bool()); return SimplifyResult::SimplifiedTo(one); } @@ -239,22 +241,22 @@ impl Binary { if lhs_type == NumericType::bool() { // Simplify forms of `(boolean == true)` into `boolean` if lhs_is_one { - return SimplifyResult::SimplifiedTo(self.rhs); + return SimplifyResult::SimplifiedTo(rhs); } if rhs_is_one { - return SimplifyResult::SimplifiedTo(self.lhs); + return SimplifyResult::SimplifiedTo(lhs); } // Simplify forms of `(boolean == false)` into `!boolean` if lhs_is_zero { - return SimplifyResult::SimplifiedToInstruction(Instruction::Not(self.rhs)); + return SimplifyResult::SimplifiedToInstruction(Instruction::Not(rhs)); } if rhs_is_zero { - return SimplifyResult::SimplifiedToInstruction(Instruction::Not(self.lhs)); + return SimplifyResult::SimplifiedToInstruction(Instruction::Not(lhs)); } } } BinaryOp::Lt => { - if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) { + if lhs == rhs { let zero = dfg.make_constant(FieldElement::zero(), NumericType::bool()); return SimplifyResult::SimplifiedTo(zero); } @@ -267,7 +269,7 @@ impl Binary { let zero = dfg.make_constant(FieldElement::zero(), lhs_type); return SimplifyResult::SimplifiedToInstruction(Instruction::binary( BinaryOp::Eq, - self.lhs, + lhs, zero, )); } @@ -278,14 +280,14 @@ impl Binary { let zero = dfg.make_constant(FieldElement::zero(), lhs_type); return SimplifyResult::SimplifiedTo(zero); } - if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) { - return SimplifyResult::SimplifiedTo(self.lhs); + if lhs == rhs { + return SimplifyResult::SimplifiedTo(lhs); } if lhs_type == NumericType::bool() { // Boolean AND is equivalent to multiplication, which is a cheaper operation. // (mul unchecked because these are bools so it doesn't matter really) let instruction = - Instruction::binary(BinaryOp::Mul { unchecked: true }, self.lhs, self.rhs); + Instruction::binary(BinaryOp::Mul { unchecked: true }, lhs, rhs); return SimplifyResult::SimplifiedToInstruction(instruction); } if lhs_type.is_unsigned() { @@ -299,7 +301,7 @@ impl Binary { // The bitmask must then be one less than a power of 2. let bitmask_plus_one = bitmask.to_u128() + 1; if bitmask_plus_one.is_power_of_two() { - let value = if lhs_value.is_some() { self.rhs } else { self.lhs }; + let value = if lhs_value.is_some() { rhs } else { lhs }; let num_bits = bitmask_plus_one.ilog2(); return SimplifyResult::SimplifiedToInstruction( Instruction::Truncate { @@ -317,27 +319,27 @@ impl Binary { } BinaryOp::Or => { if lhs_is_zero { - return SimplifyResult::SimplifiedTo(self.rhs); + return SimplifyResult::SimplifiedTo(rhs); } if rhs_is_zero { - return SimplifyResult::SimplifiedTo(self.lhs); + return SimplifyResult::SimplifiedTo(lhs); } if lhs_type == NumericType::bool() && (lhs_is_one || rhs_is_one) { let one = dfg.make_constant(FieldElement::one(), lhs_type); return SimplifyResult::SimplifiedTo(one); } - if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) { - return SimplifyResult::SimplifiedTo(self.lhs); + if lhs == rhs { + return SimplifyResult::SimplifiedTo(lhs); } } BinaryOp::Xor => { if lhs_is_zero { - return SimplifyResult::SimplifiedTo(self.rhs); + return SimplifyResult::SimplifiedTo(rhs); } if rhs_is_zero { - return SimplifyResult::SimplifiedTo(self.lhs); + return SimplifyResult::SimplifiedTo(lhs); } - if dfg.resolve(self.lhs) == dfg.resolve(self.rhs) { + if lhs == rhs { let zero = dfg.make_constant(FieldElement::zero(), lhs_type); return SimplifyResult::SimplifiedTo(zero); } diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index 368dba675ca..e83f315602b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -83,18 +83,35 @@ impl DefunctionalizationContext { let mut call_target_values = HashSet::new(); for block_id in func.reachable_blocks() { - let block = &func.dfg[block_id]; - let instructions = block.instructions().to_vec(); + let block = &mut func.dfg[block_id]; - for instruction_id in instructions { - let instruction = func.dfg[instruction_id].clone(); + // Temporarily take the parameters here just to avoid cloning them + let parameters = block.take_parameters(); + for parameter in ¶meters { + if func.dfg.type_of_value(*parameter) == Type::Function { + func.dfg.set_type_of_value(*parameter, Type::field()); + } + } + + let block = &mut func.dfg[block_id]; + block.set_parameters(parameters); + + for instruction_id in block.instructions().to_vec() { + let mut instruction = func.dfg[instruction_id].clone(); let mut replacement_instruction = None; + + if remove_first_class_functions_in_instruction(func, &mut instruction) { + func.dfg[instruction_id] = instruction.clone(); + } + // Operate on call instructions let (target_func_id, arguments) = match &instruction { Instruction::Call { func: target_func_id, arguments } => { (*target_func_id, arguments) } - _ => continue, + _ => { + continue; + } }; match func.dfg[target_func_id] { @@ -130,29 +147,6 @@ impl DefunctionalizationContext { } } } - - // Change the type of all the values that are not call targets to NativeField - let value_ids = vecmap(func.dfg.values_iter(), |(id, _)| id); - for value_id in value_ids { - if let Type::Function = func.dfg[value_id].get_type().as_ref() { - match &func.dfg[value_id] { - // If the value is a static function, transform it to the function id - Value::Function(id) => { - if !call_target_values.contains(&value_id) { - let field = NumericType::NativeField; - let new_value = - func.dfg.make_constant(function_id_to_field(*id), field); - func.dfg.set_value_from_id(value_id, new_value); - } - } - // If the value is a function used as value, just change the type of it - Value::Instruction { .. } | Value::Param { .. } => { - func.dfg.set_type_of_value(value_id, Type::field()); - } - _ => {} - } - } - } } /// Returns the apply function for the given signature @@ -161,6 +155,44 @@ impl DefunctionalizationContext { } } +/// Replace any first class functions used in an instruction with a field value. +/// This applies to any function used anywhere else other than the function position +/// of a call instruction. Returns true if the instruction was modified +fn remove_first_class_functions_in_instruction( + func: &mut Function, + instruction: &mut Instruction, +) -> bool { + let mut modified = false; + let mut map_value = |value: ValueId| { + if let Type::Function = func.dfg[value].get_type().as_ref() { + match &func.dfg[value] { + // If the value is a static function, transform it to the function id + Value::Function(id) => { + let new_value = function_id_to_field(*id); + modified = true; + return func.dfg.make_constant(new_value, NumericType::NativeField); + } + // If the value is a function used as value, just change the type of it + Value::Instruction { .. } | Value::Param { .. } => { + func.dfg.set_type_of_value(value, Type::field()); + } + _ => (), + } + } + value + }; + + if let Instruction::Call { func: _, arguments } = instruction { + for arg in arguments { + *arg = map_value(*arg); + } + } else { + instruction.map_values_mut(map_value); + } + + modified +} + /// Collects all functions used as values that can be called by their signatures fn find_variants(ssa: &Ssa) -> Variants { let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new();