diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs index 5b7963b5f06..ea5ef3ea686 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs @@ -95,19 +95,49 @@ impl GeneratedAcir { } impl GeneratedAcir { - /// Computes lhs mod 2^rhs + /// Computes lhs = 2^{rhs_bit_size} * q + r /// - /// `max_bits` is the upper-bound on the bit_size of the object that `lhs` is representing. - - /// An example; max_bits would be 32, if lhs was representing a u32 at a higher level. + /// For example, if we had a u32: + /// - `rhs` would be `32` + /// - `max_bits` would be the size of `lhs` + /// + /// Take the following code: + /// `` + /// fn main(x : u32) -> u32 { + /// let a = x + x; (L1) + /// let b = a * a; (L2) + /// b + b (L3) + /// } + /// `` + /// + /// Call truncate only on L1: + /// - `rhs` would be `32` + /// - `max_bits` would be `33` due to the addition of two u32s + /// Call truncate only on L2: + /// - `rhs` would be `32` + /// - `max_bits` would be `66` due to the multiplication of two u33s `a` + /// Call truncate only on L3: + /// - `rhs` would be `32` + /// - `max_bits` would be `67` due to the addition of two u66s `b` + /// + /// Truncation is done via the euclidean division formula: + /// + /// a = b * q + r + /// + /// where: + /// - a = `lhs` + /// - b = 2^{max_bits} + /// The prover will supply the quotient and the remainder, where the remainder + /// is the truncated value that we will return since it is enforced to be + /// in the range: 0 <= r < 2^{rhs_bit_size} pub(crate) fn truncate( &mut self, lhs: &Expression, - rhs: u32, + rhs_bit_size: u32, max_bits: u32, ) -> Result { - assert!(max_bits > rhs, "max_bits = {max_bits}, rhs = {rhs}"); - let exp_big = BigUint::from(2_u32).pow(rhs); + assert!(max_bits > rhs_bit_size, "max_bits = {max_bits}, rhs = {rhs_bit_size} -- The caller should ensure that truncation is only called when the value needs to be truncated"); + let exp_big = BigUint::from(2_u32).pow(rhs_bit_size); // 0. Check for constant expression. if let Some(a_c) = lhs.to_const() { @@ -115,37 +145,47 @@ impl GeneratedAcir { a_big %= exp_big; return Ok(Expression::from(FieldElement::from_be_bytes_reduce(&a_big.to_bytes_be()))); } + // Note: This is doing a reduction. However, since the compiler will call + // `max_bits` before it overflows the modulus, this line should never do a reduction. + // + // For example, if the modulus is a 254 bit number. + // `max_bits` will never be 255 since `exp` will be 2^255, which will cause a reduction in the following line. + // TODO: We should change this from `from_be_bytes_reduce` to `from_be_bytes` + // TODO: the latter will return an option that we can unwrap in the compiler let exp = FieldElement::from_be_bytes_reduce(&exp_big.to_bytes_be()); // 1. Generate witnesses a,b,c - let b_witness = self.next_witness_index(); - let c_witness = self.next_witness_index(); + let remainder_witness = self.next_witness_index(); + let quotient_witness = self.next_witness_index(); self.push_opcode(AcirOpcode::Directive(Directive::Quotient(QuotientDirective { a: lhs.clone(), b: Expression::from_field(exp), - q: c_witness, - r: b_witness, + q: quotient_witness, + r: remainder_witness, predicate: None, }))); - self.range_constraint(b_witness, rhs)?; - self.range_constraint(c_witness, max_bits - rhs)?; + // According to the division theorem, the remainder needs to be 0 <= r < 2^{rhs_bit_size} + self.range_constraint(remainder_witness, rhs_bit_size)?; + + // According to the formula above, the quotient should be within the range 0 <= q < 2^{max_bits - rhs} + self.range_constraint(quotient_witness, max_bits - rhs_bit_size)?; - // 2. Add the constraint a = b + 2^{rhs} * c + // 2. Add the constraint a == r + (q * 2^{rhs}) // // 2^{rhs} let mut two_pow_rhs_bits = FieldElement::from(2_i128); - two_pow_rhs_bits = two_pow_rhs_bits.pow(&FieldElement::from(rhs as i128)); + two_pow_rhs_bits = two_pow_rhs_bits.pow(&FieldElement::from(rhs_bit_size as i128)); - let b_arith = Expression::from(b_witness); - let c_arith = Expression::from(c_witness); + let remainder_expr = Expression::from(remainder_witness); + let quotient_expr = Expression::from(quotient_witness); - let res = &b_arith + &(two_pow_rhs_bits * &c_arith); - let my_constraint = &res - lhs; + let res = &remainder_expr + &(two_pow_rhs_bits * "ient_expr); + let euclidean_division = &res - lhs; - self.push_opcode(AcirOpcode::Arithmetic(my_constraint)); + self.push_opcode(AcirOpcode::Arithmetic(euclidean_division)); - Ok(Expression::from(b_witness)) + Ok(Expression::from(remainder_witness)) } /// Calls a black box function and returns the output @@ -264,13 +304,15 @@ impl GeneratedAcir { &mut self, lhs: &Expression, rhs: &Expression, - bit_size: u32, + max_bit_size: u32, predicate: &Expression, ) -> Result<(Witness, Witness), AcirGenError> { let q_witness = self.next_witness_index(); let r_witness = self.next_witness_index(); - let pa = lhs * predicate; + // lhs = rhs * q + r + // + // If predicate is zero, `q_witness` and `r_witness` will be 0 self.push_opcode(AcirOpcode::Directive(Directive::Quotient(QuotientDirective { a: lhs.clone(), b: rhs.clone(), @@ -279,17 +321,25 @@ impl GeneratedAcir { predicate: Some(predicate.clone()), }))); - //r predicate * ( a - b * q - r) == 0 + // When the predicate is 0, the equation always passes. + // When the predicate is 1, the euclidean division needs to be + // true. + let mut rhs_constraint = rhs * &Expression::from(q_witness); + rhs_constraint = &rhs_constraint + r_witness; + rhs_constraint = &rhs_constraint * predicate; + let lhs_constraint = lhs * predicate; + let div_euclidean = &lhs_constraint - &rhs_constraint; self.push_opcode(AcirOpcode::Arithmetic(div_euclidean)); @@ -576,7 +626,7 @@ impl GeneratedAcir { // TODO: perhaps this should be a user error, instead of an assert assert!(max_bits + 1 < FieldElement::max_num_bits()); - // Compute : 2^max_bits + a - b + // Compute : 2^{max_bits} + a - b let mut comparison_evaluation = a - b; let two = FieldElement::from(2_i128); let two_max_bits = two.pow(&FieldElement::from(max_bits as i128)); @@ -586,6 +636,25 @@ impl GeneratedAcir { let r_witness = self.next_witness_index(); // Add constraint : 2^{max_bits} + a - b = q * 2^{max_bits} + r + // + // case: a == b + // + // let k = 0; + // - 2^{max_bits} == q * 2^{max_bits} + r + // - This is only the case when q == 1 and r == 0 (assuming r is bounded to be less than 2^{max_bits}) + // + // case: a > b + // + // let k = a - b; + // - k + 2^{max_bits} == q * 2^{max_bits} + r + // - This is the case when q == 1 and r = k + // + // case: a < b + // + // let k = b - a + // - 2^{max_bits} - k == q * 2^{max_bits} + r + // - This is only the case when q == 0 and r == 2^{max_bits} - k + // let mut expr = Expression::default(); expr.push_addition_term(two_max_bits, q_witness); expr.push_addition_term(FieldElement::one(), r_witness);