diff --git a/acvm-repo/brillig_vm/src/arithmetic.rs b/acvm-repo/brillig_vm/src/arithmetic.rs index 527b0a7849e..a87635bd542 100644 --- a/acvm-repo/brillig_vm/src/arithmetic.rs +++ b/acvm-repo/brillig_vm/src/arithmetic.rs @@ -1,7 +1,9 @@ -use acir::brillig::{BinaryFieldOp, BinaryIntOp, IntegerBitSize}; +use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr}; + +use acir::brillig::{BinaryFieldOp, BinaryIntOp, BitSize, IntegerBitSize}; use acir::AcirField; use num_bigint::BigUint; -use num_traits::{AsPrimitive, PrimInt, WrappingAdd, WrappingMul, WrappingSub}; +use num_traits::{CheckedDiv, WrappingAdd, WrappingMul, WrappingSub, Zero}; use crate::memory::{MemoryTypeError, MemoryValue}; @@ -21,24 +23,20 @@ pub(crate) fn evaluate_binary_field_op( lhs: MemoryValue, rhs: MemoryValue, ) -> Result, BrilligArithmeticError> { - let a = match lhs { - MemoryValue::Field(a) => a, - MemoryValue::Integer(_, bit_size) => { - return Err(BrilligArithmeticError::MismatchedLhsBitSize { - lhs_bit_size: bit_size.into(), - op_bit_size: F::max_num_bits(), - }); + let a = *lhs.expect_field().map_err(|err| { + let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err; + BrilligArithmeticError::MismatchedLhsBitSize { + lhs_bit_size: value_bit_size, + op_bit_size: expected_bit_size, } - }; - let b = match rhs { - MemoryValue::Field(b) => b, - MemoryValue::Integer(_, bit_size) => { - return Err(BrilligArithmeticError::MismatchedRhsBitSize { - rhs_bit_size: bit_size.into(), - op_bit_size: F::max_num_bits(), - }); + })?; + let b = *rhs.expect_field().map_err(|err| { + let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err; + BrilligArithmeticError::MismatchedRhsBitSize { + rhs_bit_size: value_bit_size, + op_bit_size: expected_bit_size, } - }; + })?; Ok(match op { // Perform addition, subtraction, multiplication, and division based on the BinaryOp variant. @@ -70,46 +68,120 @@ pub(crate) fn evaluate_binary_int_op( rhs: MemoryValue, bit_size: IntegerBitSize, ) -> Result, BrilligArithmeticError> { - let lhs = lhs.expect_integer_with_bit_size(bit_size).map_err(|err| match err { - MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => { - BrilligArithmeticError::MismatchedLhsBitSize { - lhs_bit_size: value_bit_size, - op_bit_size: expected_bit_size, + match op { + BinaryIntOp::Add + | BinaryIntOp::Sub + | BinaryIntOp::Mul + | BinaryIntOp::Div + | BinaryIntOp::And + | BinaryIntOp::Or + | BinaryIntOp::Xor => match (lhs, rhs, bit_size) { + (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => { + evaluate_binary_int_op_u1(op, lhs, rhs).map(MemoryValue::U1) } - } - })?; - - let rhs_bit_size = if op == &BinaryIntOp::Shl || op == &BinaryIntOp::Shr { - IntegerBitSize::U8 - } else { - bit_size - }; + (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => { + evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U8) + } + (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => { + evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U16) + } + (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => { + evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U32) + } + (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => { + evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U64) + } + (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => { + evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U128) + } + (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => { + Err(BrilligArithmeticError::MismatchedLhsBitSize { + lhs_bit_size: lhs.bit_size().to_u32::(), + op_bit_size: bit_size.into(), + }) + } + (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => { + Err(BrilligArithmeticError::MismatchedRhsBitSize { + rhs_bit_size: rhs.bit_size().to_u32::(), + op_bit_size: bit_size.into(), + }) + } + _ => unreachable!("Invalid arguments are covered by the two arms above."), + }, - let rhs = rhs.expect_integer_with_bit_size(rhs_bit_size).map_err(|err| match err { - MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => { - BrilligArithmeticError::MismatchedRhsBitSize { - rhs_bit_size: value_bit_size, - op_bit_size: expected_bit_size, + BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => { + match (lhs, rhs, bit_size) { + (MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => { + Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs))) + } + (MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => { + Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs))) + } + (MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => { + Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs))) + } + (MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => { + Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs))) + } + (MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => { + Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs))) + } + (MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => { + Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs))) + } + (lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => { + Err(BrilligArithmeticError::MismatchedLhsBitSize { + lhs_bit_size: lhs.bit_size().to_u32::(), + op_bit_size: bit_size.into(), + }) + } + (_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => { + Err(BrilligArithmeticError::MismatchedRhsBitSize { + rhs_bit_size: rhs.bit_size().to_u32::(), + op_bit_size: bit_size.into(), + }) + } + _ => unreachable!("Invalid arguments are covered by the two arms above."), } } - })?; - - // `lhs` and `rhs` are asserted to fit within their given types when being read from memory so this is safe. - let result = match bit_size { - IntegerBitSize::U1 => evaluate_binary_int_op_u1(op, lhs != 0, rhs != 0)?.into(), - IntegerBitSize::U8 => evaluate_binary_int_op_num(op, lhs as u8, rhs as u8, 8)?.into(), - IntegerBitSize::U16 => evaluate_binary_int_op_num(op, lhs as u16, rhs as u16, 16)?.into(), - IntegerBitSize::U32 => evaluate_binary_int_op_num(op, lhs as u32, rhs as u32, 32)?.into(), - IntegerBitSize::U64 => evaluate_binary_int_op_num(op, lhs as u64, rhs as u64, 64)?.into(), - IntegerBitSize::U128 => evaluate_binary_int_op_num(op, lhs, rhs, 128)?, - }; - Ok(match op { - BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => { - MemoryValue::new_integer(result, IntegerBitSize::U1) + BinaryIntOp::Shl | BinaryIntOp::Shr => { + let rhs = rhs.expect_u8().map_err( + |MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size }| { + BrilligArithmeticError::MismatchedRhsBitSize { + rhs_bit_size: value_bit_size, + op_bit_size: expected_bit_size, + } + }, + )?; + + match (lhs, bit_size) { + (MemoryValue::U1(lhs), IntegerBitSize::U1) => { + let result = if rhs == 0 { lhs } else { false }; + Ok(MemoryValue::U1(result)) + } + (MemoryValue::U8(lhs), IntegerBitSize::U8) => { + Ok(MemoryValue::U8(evaluate_binary_int_op_shifts(op, lhs, rhs))) + } + (MemoryValue::U16(lhs), IntegerBitSize::U16) => { + Ok(MemoryValue::U16(evaluate_binary_int_op_shifts(op, lhs, rhs))) + } + (MemoryValue::U32(lhs), IntegerBitSize::U32) => { + Ok(MemoryValue::U32(evaluate_binary_int_op_shifts(op, lhs, rhs))) + } + (MemoryValue::U64(lhs), IntegerBitSize::U64) => { + Ok(MemoryValue::U64(evaluate_binary_int_op_shifts(op, lhs, rhs))) + } + (MemoryValue::U128(lhs), IntegerBitSize::U128) => { + Ok(MemoryValue::U128(evaluate_binary_int_op_shifts(op, lhs, rhs))) + } + _ => Err(BrilligArithmeticError::MismatchedLhsBitSize { + lhs_bit_size: lhs.bit_size().to_u32::(), + op_bit_size: bit_size.into(), + }), + } } - _ => MemoryValue::new_integer(result, bit_size), - }) + } } fn evaluate_binary_int_op_u1( @@ -118,8 +190,12 @@ fn evaluate_binary_int_op_u1( rhs: bool, ) -> Result { let result = match op { - BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs, - BinaryIntOp::Mul => lhs & rhs, + BinaryIntOp::Equals => lhs == rhs, + BinaryIntOp::LessThan => !lhs & rhs, + BinaryIntOp::LessThanEquals => lhs <= rhs, + BinaryIntOp::And | BinaryIntOp::Mul => lhs & rhs, + BinaryIntOp::Or => lhs | rhs, + BinaryIntOp::Xor | BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs, BinaryIntOp::Div => { if !rhs { return Err(BrilligArithmeticError::DivisionByZero); @@ -127,58 +203,70 @@ fn evaluate_binary_int_op_u1( lhs } } + _ => unreachable!("Operator not handled by this function: {op:?}"), + }; + Ok(result) +} + +fn evaluate_binary_int_op_cmp(op: &BinaryIntOp, lhs: T, rhs: T) -> bool { + match op { BinaryIntOp::Equals => lhs == rhs, - BinaryIntOp::LessThan => !lhs & rhs, + BinaryIntOp::LessThan => lhs < rhs, BinaryIntOp::LessThanEquals => lhs <= rhs, - BinaryIntOp::And => lhs & rhs, - BinaryIntOp::Or => lhs | rhs, - BinaryIntOp::Xor => lhs ^ rhs, - BinaryIntOp::Shl | BinaryIntOp::Shr => { - if rhs { - false + _ => unreachable!("Operator not handled by this function: {op:?}"), + } +} + +fn evaluate_binary_int_op_shifts + Zero + Shl + Shr>( + op: &BinaryIntOp, + lhs: T, + rhs: u8, +) -> T { + match op { + BinaryIntOp::Shl => { + let rhs_usize: usize = rhs as usize; + #[allow(unused_qualifications)] + if rhs_usize >= 8 * std::mem::size_of::() { + T::zero() } else { - lhs + lhs << rhs.into() } } - }; - Ok(result) + BinaryIntOp::Shr => { + let rhs_usize: usize = rhs as usize; + #[allow(unused_qualifications)] + if rhs_usize >= 8 * std::mem::size_of::() { + T::zero() + } else { + lhs >> rhs.into() + } + } + _ => unreachable!("Operator not handled by this function: {op:?}"), + } } -fn evaluate_binary_int_op_num< - T: PrimInt + AsPrimitive + From + WrappingAdd + WrappingSub + WrappingMul, +fn evaluate_binary_int_op_arith< + T: WrappingAdd + + WrappingSub + + WrappingMul + + CheckedDiv + + BitAnd + + BitOr + + BitXor, >( op: &BinaryIntOp, lhs: T, rhs: T, - num_bits: usize, ) -> Result { let result = match op { BinaryIntOp::Add => lhs.wrapping_add(&rhs), BinaryIntOp::Sub => lhs.wrapping_sub(&rhs), BinaryIntOp::Mul => lhs.wrapping_mul(&rhs), BinaryIntOp::Div => lhs.checked_div(&rhs).ok_or(BrilligArithmeticError::DivisionByZero)?, - BinaryIntOp::Equals => (lhs == rhs).into(), - BinaryIntOp::LessThan => (lhs < rhs).into(), - BinaryIntOp::LessThanEquals => (lhs <= rhs).into(), BinaryIntOp::And => lhs & rhs, BinaryIntOp::Or => lhs | rhs, BinaryIntOp::Xor => lhs ^ rhs, - BinaryIntOp::Shl => { - let rhs_usize = rhs.as_(); - if rhs_usize >= num_bits { - T::zero() - } else { - lhs << rhs_usize - } - } - BinaryIntOp::Shr => { - let rhs_usize = rhs.as_(); - if rhs_usize >= num_bits { - T::zero() - } else { - lhs >> rhs_usize - } - } + _ => unreachable!("Operator not handled by this function: {op:?}"), }; Ok(result) } diff --git a/acvm-repo/brillig_vm/src/black_box.rs b/acvm-repo/brillig_vm/src/black_box.rs index bc8b1f3c230..ab0584d0d80 100644 --- a/acvm-repo/brillig_vm/src/black_box.rs +++ b/acvm-repo/brillig_vm/src/black_box.rs @@ -1,4 +1,4 @@ -use acir::brillig::{BlackBoxOp, HeapArray, HeapVector, IntegerBitSize}; +use acir::brillig::{BlackBoxOp, HeapArray, HeapVector}; use acir::{AcirField, BlackBoxFunc}; use acvm_blackbox_solver::{ aes128_encrypt, blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccakf1600, @@ -312,16 +312,13 @@ pub(crate) fn evaluate_black_box } BlackBoxOp::ToRadix { input, radix, output_pointer, num_limbs, output_bits } => { let input: F = *memory.read(*input).extract_field().expect("ToRadix input not a field"); - let radix = memory - .read(*radix) - .expect_integer_with_bit_size(IntegerBitSize::U32) - .expect("ToRadix opcode's radix bit size does not match expected bit size 32"); + let MemoryValue::U32(radix) = memory.read(*radix) else { + panic!("ToRadix opcode's radix bit size does not match expected bit size 32") + }; let num_limbs = memory.read(*num_limbs).to_usize(); - let output_bits = !memory - .read(*output_bits) - .expect_integer_with_bit_size(IntegerBitSize::U1) - .expect("ToRadix opcode's output_bits size does not match expected bit size 1") - .is_zero(); + let MemoryValue::U1(output_bits) = memory.read(*output_bits) else { + panic!("ToRadix opcode's output_bits size does not match expected bit size 1") + }; let mut input = BigUint::from_bytes_be(&input.to_be_bytes()); let radix = BigUint::from_bytes_be(&radix.to_be_bytes()); @@ -349,13 +346,10 @@ pub(crate) fn evaluate_black_box for i in (0..num_limbs).rev() { let limb = &input % &radix; if output_bits { - limbs[i] = MemoryValue::new_integer( - if limb.is_zero() { 0 } else { 1 }, - IntegerBitSize::U1, - ); + limbs[i] = MemoryValue::U1(!limb.is_zero()); } else { let limb: u8 = limb.try_into().unwrap(); - limbs[i] = MemoryValue::new_integer(limb as u128, IntegerBitSize::U8); + limbs[i] = MemoryValue::U8(limb); }; input /= &radix; } diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index 9a40022a3ba..f3eb3211e8e 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -21,6 +21,7 @@ use black_box::{evaluate_black_box, BrilligBigIntSolver}; // Re-export `brillig`. pub use acir::brillig; +use memory::MemoryTypeError; pub use memory::{Memory, MemoryValue, MEMORY_ADDRESSING_BIT_SIZE}; mod arithmetic; @@ -248,7 +249,7 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { } Opcode::Not { destination, source, bit_size } => { if let Err(error) = self.process_not(*source, *destination, *bit_size) { - self.fail(error) + self.fail(error.to_string()) } else { self.increment_program_counter() } @@ -775,63 +776,91 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { source: MemoryAddress, destination: MemoryAddress, op_bit_size: IntegerBitSize, - ) -> Result<(), String> { - let (value, bit_size) = self - .memory - .read(source) - .extract_integer() - .ok_or("Not opcode source is not an integer")?; - - if bit_size != op_bit_size { - return Err(format!( - "Not opcode source bit size {} does not match expected bit size {}", - bit_size, op_bit_size - )); - } - - let negated_value = if let IntegerBitSize::U128 = bit_size { - !value - } else { - let bit_size: u32 = bit_size.into(); - let mask = (1_u128 << bit_size as u128) - 1; - (!value) & mask + ) -> Result<(), MemoryTypeError> { + let value = self.memory.read(source); + + let negated_value = match op_bit_size { + IntegerBitSize::U1 => MemoryValue::U1(!value.expect_u1()?), + IntegerBitSize::U8 => MemoryValue::U8(!value.expect_u8()?), + IntegerBitSize::U16 => MemoryValue::U16(!value.expect_u16()?), + IntegerBitSize::U32 => MemoryValue::U32(!value.expect_u32()?), + IntegerBitSize::U64 => MemoryValue::U64(!value.expect_u64()?), + IntegerBitSize::U128 => MemoryValue::U128(!value.expect_u128()?), }; - self.memory.write(destination, MemoryValue::new_integer(negated_value, bit_size)); + self.memory.write(destination, negated_value); Ok(()) } /// Casts a value to a different bit size. fn cast(&self, target_bit_size: BitSize, source_value: MemoryValue) -> MemoryValue { + use MemoryValue::*; + match (source_value, target_bit_size) { - // Field to field, no op - (MemoryValue::Field(_), BitSize::Field) => source_value, // Field downcast to u128 - (MemoryValue::Field(field), BitSize::Integer(IntegerBitSize::U128)) => { - MemoryValue::Integer(field.to_u128(), IntegerBitSize::U128) - } + (Field(field), BitSize::Integer(IntegerBitSize::U128)) => U128(field.to_u128()), // Field downcast to arbitrary bit size - (MemoryValue::Field(field), BitSize::Integer(target_bit_size)) => { + (Field(field), BitSize::Integer(target_bit_size)) => { let as_u128 = field.to_u128(); - let target_bit_size_u32: u32 = target_bit_size.into(); - let mask = (1_u128 << target_bit_size_u32) - 1; - MemoryValue::Integer(as_u128 & mask, target_bit_size) - } - // Integer upcast to field - (MemoryValue::Integer(integer, _), BitSize::Field) => { - MemoryValue::new_field(integer.into()) - } - // Integer upcast to integer - (MemoryValue::Integer(integer, source_bit_size), BitSize::Integer(target_bit_size)) - if source_bit_size <= target_bit_size => - { - MemoryValue::Integer(integer, target_bit_size) - } - // Integer downcast - (MemoryValue::Integer(integer, _), BitSize::Integer(target_bit_size)) => { - let target_bit_size_u32: u32 = target_bit_size.into(); - let mask = (1_u128 << target_bit_size_u32) - 1; - MemoryValue::Integer(integer & mask, target_bit_size) + match target_bit_size { + IntegerBitSize::U1 => U1(as_u128 & 0x01 == 1), + IntegerBitSize::U8 => U8(as_u128 as u8), + IntegerBitSize::U16 => U16(as_u128 as u16), + IntegerBitSize::U32 => U32(as_u128 as u32), + IntegerBitSize::U64 => U64(as_u128 as u64), + IntegerBitSize::U128 => unreachable!(), + } } + + (U1(value), BitSize::Integer(IntegerBitSize::U8)) => U8(value.into()), + (U1(value), BitSize::Integer(IntegerBitSize::U16)) => U16(value.into()), + (U1(value), BitSize::Integer(IntegerBitSize::U32)) => U32(value.into()), + (U1(value), BitSize::Integer(IntegerBitSize::U64)) => U64(value.into()), + (U1(value), BitSize::Integer(IntegerBitSize::U128)) => U128(value.into()), + (U1(value), BitSize::Field) => Field(value.into()), + + (U8(value), BitSize::Integer(IntegerBitSize::U1)) => U1(value & 0x01 == 1), + (U8(value), BitSize::Integer(IntegerBitSize::U16)) => U16(value.into()), + (U8(value), BitSize::Integer(IntegerBitSize::U32)) => U32(value.into()), + (U8(value), BitSize::Integer(IntegerBitSize::U64)) => U64(value.into()), + (U8(value), BitSize::Integer(IntegerBitSize::U128)) => U128(value.into()), + (U8(value), BitSize::Field) => Field((value as u128).into()), + + (U16(value), BitSize::Integer(IntegerBitSize::U1)) => U1(value & 0x01 == 1), + (U16(value), BitSize::Integer(IntegerBitSize::U8)) => U8(value as u8), + (U16(value), BitSize::Integer(IntegerBitSize::U32)) => U32(value.into()), + (U16(value), BitSize::Integer(IntegerBitSize::U64)) => U64(value.into()), + (U16(value), BitSize::Integer(IntegerBitSize::U128)) => U128(value.into()), + (U16(value), BitSize::Field) => Field((value as u128).into()), + + (U32(value), BitSize::Integer(IntegerBitSize::U1)) => U1(value & 0x01 == 1), + (U32(value), BitSize::Integer(IntegerBitSize::U8)) => U8(value as u8), + (U32(value), BitSize::Integer(IntegerBitSize::U16)) => U16(value as u16), + (U32(value), BitSize::Integer(IntegerBitSize::U64)) => U64(value.into()), + (U32(value), BitSize::Integer(IntegerBitSize::U128)) => U128(value.into()), + (U32(value), BitSize::Field) => Field((value as u128).into()), + + (U64(value), BitSize::Integer(IntegerBitSize::U1)) => U1(value & 0x01 == 1), + (U64(value), BitSize::Integer(IntegerBitSize::U8)) => U8(value as u8), + (U64(value), BitSize::Integer(IntegerBitSize::U16)) => U16(value as u16), + (U64(value), BitSize::Integer(IntegerBitSize::U32)) => U32(value as u32), + (U64(value), BitSize::Integer(IntegerBitSize::U128)) => U128(value.into()), + (U64(value), BitSize::Field) => Field((value as u128).into()), + + (U128(value), BitSize::Integer(IntegerBitSize::U1)) => U1(value & 0x01 == 1), + (U128(value), BitSize::Integer(IntegerBitSize::U8)) => U8(value as u8), + (U128(value), BitSize::Integer(IntegerBitSize::U16)) => U16(value as u16), + (U128(value), BitSize::Integer(IntegerBitSize::U32)) => U32(value as u32), + (U128(value), BitSize::Integer(IntegerBitSize::U64)) => U64(value as u64), + (U128(value), BitSize::Field) => Field(value.into()), + + // no ops + (Field(_), BitSize::Field) => source_value, + (U1(_), BitSize::Integer(IntegerBitSize::U1)) => source_value, + (U8(_), BitSize::Integer(IntegerBitSize::U8)) => source_value, + (U16(_), BitSize::Integer(IntegerBitSize::U16)) => source_value, + (U32(_), BitSize::Integer(IntegerBitSize::U32)) => source_value, + (U64(_), BitSize::Integer(IntegerBitSize::U64)) => source_value, + (U128(_), BitSize::Integer(IntegerBitSize::U128)) => source_value, } } } @@ -1127,10 +1156,9 @@ mod tests { let VM { memory, .. } = vm; - let (negated_value, _) = memory - .read(MemoryAddress::direct(1)) - .extract_integer() - .expect("Expected integer as the output of Not"); + let MemoryValue::U128(negated_value) = memory.read(MemoryAddress::direct(1)) else { + panic!("Expected integer as the output of Not"); + }; assert_eq!(negated_value, !1_u128); } diff --git a/acvm-repo/brillig_vm/src/memory.rs b/acvm-repo/brillig_vm/src/memory.rs index 2bf45f77b54..7a942339a3a 100644 --- a/acvm-repo/brillig_vm/src/memory.rs +++ b/acvm-repo/brillig_vm/src/memory.rs @@ -2,14 +2,18 @@ use acir::{ brillig::{BitSize, IntegerBitSize, MemoryAddress}, AcirField, }; -use num_traits::{One, Zero}; pub const MEMORY_ADDRESSING_BIT_SIZE: IntegerBitSize = IntegerBitSize::U32; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum MemoryValue { Field(F), - Integer(u128, IntegerBitSize), + U1(bool), + U8(u8), + U16(u16), + U32(u32), + U64(u64), + U128(u128), } #[derive(Debug, thiserror::Error)] @@ -26,7 +30,14 @@ impl MemoryValue { /// Builds an integer-typed memory value. pub fn new_integer(value: u128, bit_size: IntegerBitSize) -> Self { - MemoryValue::Integer(value, bit_size) + match bit_size { + IntegerBitSize::U1 => MemoryValue::U1(value != 0), + IntegerBitSize::U8 => MemoryValue::U8(value as u8), + IntegerBitSize::U16 => MemoryValue::U16(value as u16), + IntegerBitSize::U32 => MemoryValue::U32(value as u32), + IntegerBitSize::U64 => MemoryValue::U64(value as u64), + IntegerBitSize::U128 => MemoryValue::U128(value), + } } /// Extracts the field element from the memory value, if it is typed as field element. @@ -37,26 +48,21 @@ impl MemoryValue { } } - /// Extracts the integer from the memory value, if it is typed as integer. - pub fn extract_integer(&self) -> Option<(u128, IntegerBitSize)> { - match self { - MemoryValue::Integer(value, bit_size) => Some((*value, *bit_size)), - _ => None, - } - } - pub fn bit_size(&self) -> BitSize { match self { MemoryValue::Field(_) => BitSize::Field, - MemoryValue::Integer(_, bit_size) => BitSize::Integer(*bit_size), + MemoryValue::U1(_) => BitSize::Integer(IntegerBitSize::U1), + MemoryValue::U8(_) => BitSize::Integer(IntegerBitSize::U8), + MemoryValue::U16(_) => BitSize::Integer(IntegerBitSize::U16), + MemoryValue::U32(_) => BitSize::Integer(IntegerBitSize::U32), + MemoryValue::U64(_) => BitSize::Integer(IntegerBitSize::U64), + MemoryValue::U128(_) => BitSize::Integer(IntegerBitSize::U128), } } pub fn to_usize(&self) -> usize { match self { - MemoryValue::Integer(_, bit_size) if *bit_size == MEMORY_ADDRESSING_BIT_SIZE => { - self.extract_integer().unwrap().0.try_into().unwrap() - } + MemoryValue::U32(value) => (*value).try_into().unwrap(), _ => panic!("value is not typed as brillig usize"), } } @@ -87,38 +93,89 @@ impl MemoryValue { pub fn to_field(&self) -> F { match self { MemoryValue::Field(value) => *value, - MemoryValue::Integer(value, _) => F::from(*value), + MemoryValue::U1(value) => F::from(*value), + MemoryValue::U8(value) => F::from(*value as u128), + MemoryValue::U16(value) => F::from(*value as u128), + MemoryValue::U32(value) => F::from(*value as u128), + MemoryValue::U64(value) => F::from(*value as u128), + MemoryValue::U128(value) => F::from(*value), } } pub fn expect_field(&self) -> Result<&F, MemoryTypeError> { - match self { - MemoryValue::Integer(_, bit_size) => Err(MemoryTypeError::MismatchedBitSize { - value_bit_size: (*bit_size).into(), + if let MemoryValue::Field(field) = self { + Ok(field) + } else { + Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: self.bit_size().to_u32::(), expected_bit_size: F::max_num_bits(), - }), - MemoryValue::Field(field) => Ok(field), + }) } } - pub fn expect_integer_with_bit_size( - &self, - expected_bit_size: IntegerBitSize, - ) -> Result { - match self { - MemoryValue::Integer(value, bit_size) => { - if *bit_size != expected_bit_size { - return Err(MemoryTypeError::MismatchedBitSize { - value_bit_size: (*bit_size).into(), - expected_bit_size: expected_bit_size.into(), - }); - } - Ok(*value) - } - MemoryValue::Field(_) => Err(MemoryTypeError::MismatchedBitSize { - value_bit_size: F::max_num_bits(), - expected_bit_size: expected_bit_size.into(), - }), + pub(crate) fn expect_u1(&self) -> Result { + if let MemoryValue::U1(value) = self { + Ok(*value) + } else { + Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: self.bit_size().to_u32::(), + expected_bit_size: 1, + }) + } + } + + pub(crate) fn expect_u8(&self) -> Result { + if let MemoryValue::U8(value) = self { + Ok(*value) + } else { + Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: self.bit_size().to_u32::(), + expected_bit_size: 8, + }) + } + } + + pub(crate) fn expect_u16(&self) -> Result { + if let MemoryValue::U16(value) = self { + Ok(*value) + } else { + Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: self.bit_size().to_u32::(), + expected_bit_size: 16, + }) + } + } + + pub(crate) fn expect_u32(&self) -> Result { + if let MemoryValue::U32(value) = self { + Ok(*value) + } else { + Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: self.bit_size().to_u32::(), + expected_bit_size: 32, + }) + } + } + + pub(crate) fn expect_u64(&self) -> Result { + if let MemoryValue::U64(value) = self { + Ok(*value) + } else { + Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: self.bit_size().to_u32::(), + expected_bit_size: 64, + }) + } + } + + pub(crate) fn expect_u128(&self) -> Result { + if let MemoryValue::U128(value) = self { + Ok(*value) + } else { + Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: self.bit_size().to_u32::(), + expected_bit_size: 128, + }) } } } @@ -127,9 +184,12 @@ impl std::fmt::Display for MemoryValue { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> { match self { MemoryValue::Field(value) => write!(f, "{}: field", value), - MemoryValue::Integer(value, bit_size) => { - write!(f, "{}: {}", value, bit_size) - } + MemoryValue::U1(value) => write!(f, "{}: u1", value), + MemoryValue::U8(value) => write!(f, "{}: u8", value), + MemoryValue::U16(value) => write!(f, "{}: u16", value), + MemoryValue::U32(value) => write!(f, "{}: u32", value), + MemoryValue::U64(value) => write!(f, "{}: u64", value), + MemoryValue::U128(value) => write!(f, "{}: u128", value), } } } @@ -142,38 +202,37 @@ impl Default for MemoryValue { impl From for MemoryValue { fn from(value: bool) -> Self { - let value = if value { 1 } else { 0 }; - MemoryValue::new_integer(value, IntegerBitSize::U1) + MemoryValue::U1(value) } } impl From for MemoryValue { fn from(value: u8) -> Self { - MemoryValue::new_integer(value.into(), IntegerBitSize::U8) + MemoryValue::U8(value) } } impl From for MemoryValue { fn from(value: usize) -> Self { - MemoryValue::new_integer(value as u128, MEMORY_ADDRESSING_BIT_SIZE) + MemoryValue::U32(value as u32) } } impl From for MemoryValue { fn from(value: u32) -> Self { - MemoryValue::new_integer(value.into(), IntegerBitSize::U32) + MemoryValue::U32(value) } } impl From for MemoryValue { fn from(value: u64) -> Self { - MemoryValue::new_integer(value.into(), IntegerBitSize::U64) + MemoryValue::U64(value) } } impl From for MemoryValue { fn from(value: u128) -> Self { - MemoryValue::new_integer(value, IntegerBitSize::U128) + MemoryValue::U128(value) } } @@ -181,15 +240,7 @@ impl TryFrom> for bool { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - let as_integer = memory_value.expect_integer_with_bit_size(IntegerBitSize::U1)?; - - if as_integer.is_zero() { - Ok(false) - } else if as_integer.is_one() { - Ok(true) - } else { - unreachable!("value typed as bool is greater than one") - } + memory_value.expect_u1() } } @@ -197,7 +248,7 @@ impl TryFrom> for u8 { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_integer_with_bit_size(IntegerBitSize::U8).map(|value| value as u8) + memory_value.expect_u8() } } @@ -205,7 +256,7 @@ impl TryFrom> for u32 { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_integer_with_bit_size(IntegerBitSize::U32).map(|value| value as u32) + memory_value.expect_u32() } } @@ -213,7 +264,7 @@ impl TryFrom> for u64 { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_integer_with_bit_size(IntegerBitSize::U64).map(|value| value as u64) + memory_value.expect_u64() } } @@ -221,7 +272,7 @@ impl TryFrom> for u128 { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_integer_with_bit_size(IntegerBitSize::U128) + memory_value.expect_u128() } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 51df653d1ec..fe2c0d35f73 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -632,10 +632,7 @@ impl<'brillig> Context<'brillig> { let memory = memory_values[*memory_index]; *memory_index += 1; - let field_value = match memory { - MemoryValue::Field(field_value) => field_value, - MemoryValue::Integer(u128_value, _) => u128_value.into(), - }; + let field_value = memory.to_field(); dfg.make_constant(field_value, typ) } Type::Array(types, length) => {