Skip to content

Commit

Permalink
chore: remove unnecessary dereferencing within brillig vm (#7375)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench authored Feb 14, 2025
1 parent 8f20392 commit 3878037
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 56 deletions.
4 changes: 2 additions & 2 deletions acvm-repo/brillig_vm/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ pub(crate) fn evaluate_binary_field_op<F: AcirField>(
lhs: MemoryValue<F>,
rhs: MemoryValue<F>,
) -> Result<MemoryValue<F>, BrilligArithmeticError> {
let a = *lhs.expect_field().map_err(|err| {
let a = lhs.expect_field().map_err(|err| {
let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err;
BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
}
})?;
let b = *rhs.expect_field().map_err(|err| {
let b = rhs.expect_field().map_err(|err| {
let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err;
BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: value_bit_size,
Expand Down
58 changes: 29 additions & 29 deletions acvm-repo/brillig_vm/src/black_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn read_heap_array<'a, F: AcirField>(
fn to_u8_vec<F: AcirField>(inputs: &[MemoryValue<F>]) -> Vec<u8> {
let mut result = Vec::with_capacity(inputs.len());
for &input in inputs {
result.push(input.try_into().unwrap());
result.push(input.expect_u8().unwrap());
}
result
}
Expand Down Expand Up @@ -81,7 +81,7 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
BlackBoxOp::Keccakf1600 { input, output } => {
let state_vec: Vec<u64> = read_heap_array(memory, input)
.iter()
.map(|&memory_value| memory_value.try_into().unwrap())
.map(|&memory_value| memory_value.expect_u64().unwrap())
.collect();
let state: [u64; 25] = state_vec.try_into().unwrap();

Expand Down Expand Up @@ -145,18 +145,18 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
let points: Vec<F> = read_heap_vector(memory, points)
.iter()
.enumerate()
.map(|(i, &x)| {
.map(|(i, x)| {
if i % 3 == 2 {
let is_infinite: bool = x.try_into().unwrap();
F::from(is_infinite as u128)
let is_infinite: bool = x.expect_u1().unwrap();
F::from(is_infinite)
} else {
*x.extract_field().unwrap()
x.expect_field().unwrap()
}
})
.collect();
let scalars: Vec<F> = read_heap_vector(memory, scalars)
.iter()
.map(|x| *x.extract_field().unwrap())
.map(|x| x.expect_field().unwrap())
.collect();
let mut scalars_lo = Vec::with_capacity(scalars.len() / 2);
let mut scalars_hi = Vec::with_capacity(scalars.len() / 2);
Expand Down Expand Up @@ -187,12 +187,12 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
input1_infinite,
input2_infinite,
} => {
let input1_x = *memory.read(*input1_x).extract_field().unwrap();
let input1_y = *memory.read(*input1_y).extract_field().unwrap();
let input1_infinite: bool = memory.read(*input1_infinite).try_into().unwrap();
let input2_x = *memory.read(*input2_x).extract_field().unwrap();
let input2_y = *memory.read(*input2_y).extract_field().unwrap();
let input2_infinite: bool = memory.read(*input2_infinite).try_into().unwrap();
let input1_x = memory.read(*input1_x).expect_field().unwrap();
let input1_y = memory.read(*input1_y).expect_field().unwrap();
let input1_infinite: bool = memory.read(*input1_infinite).expect_u1().unwrap();
let input2_x = memory.read(*input2_x).expect_field().unwrap();
let input2_y = memory.read(*input2_y).expect_field().unwrap();
let input2_infinite: bool = memory.read(*input2_infinite).expect_u1().unwrap();
let (x, y, infinite) = solver.ec_add(
&input1_x,
&input1_y,
Expand All @@ -212,50 +212,50 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
Ok(())
}
BlackBoxOp::BigIntAdd { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let lhs = memory.read(*lhs).expect_u32().unwrap();
let rhs = memory.read(*rhs).expect_u32().unwrap();

let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntAdd)?;
memory.write(*output, new_id.into());
Ok(())
}
BlackBoxOp::BigIntSub { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let lhs = memory.read(*lhs).expect_u32().unwrap();
let rhs = memory.read(*rhs).expect_u32().unwrap();

let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntSub)?;
memory.write(*output, new_id.into());
Ok(())
}
BlackBoxOp::BigIntMul { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let lhs = memory.read(*lhs).expect_u32().unwrap();
let rhs = memory.read(*rhs).expect_u32().unwrap();

let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntMul)?;
memory.write(*output, new_id.into());
Ok(())
}
BlackBoxOp::BigIntDiv { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let lhs = memory.read(*lhs).expect_u32().unwrap();
let rhs = memory.read(*rhs).expect_u32().unwrap();

let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntDiv)?;
memory.write(*output, new_id.into());
Ok(())
}
BlackBoxOp::BigIntFromLeBytes { inputs, modulus, output } => {
let input = read_heap_vector(memory, inputs);
let input: Vec<u8> = input.iter().map(|&x| x.try_into().unwrap()).collect();
let input: Vec<u8> = input.iter().map(|x| x.expect_u8().unwrap()).collect();
let modulus = read_heap_vector(memory, modulus);
let modulus: Vec<u8> = modulus.iter().map(|&x| x.try_into().unwrap()).collect();
let modulus: Vec<u8> = modulus.iter().map(|x| x.expect_u8().unwrap()).collect();

let new_id = bigint_solver.bigint_from_bytes(&input, &modulus)?;
memory.write(*output, new_id.into());

Ok(())
}
BlackBoxOp::BigIntToLeBytes { input, output } => {
let input: u32 = memory.read(*input).try_into().unwrap();
let input: u32 = memory.read(*input).expect_u32().unwrap();
let bytes = bigint_solver.bigint_to_bytes(input)?;
let mut values = Vec::new();
for i in 0..32 {
Expand All @@ -270,8 +270,8 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
}
BlackBoxOp::Poseidon2Permutation { message, output, len } => {
let input = read_heap_vector(memory, message);
let input: Vec<F> = input.iter().map(|x| *x.extract_field().unwrap()).collect();
let len = memory.read(*len).try_into().unwrap();
let input: Vec<F> = input.iter().map(|x| x.expect_field().unwrap()).collect();
let len = memory.read(*len).expect_u32().unwrap();
let result = solver.poseidon2_permutation(&input, len)?;
let mut values = Vec::new();
for i in result {
Expand All @@ -290,7 +290,7 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
));
}
for (i, &input) in inputs.iter().enumerate() {
message[i] = input.try_into().unwrap();
message[i] = input.expect_u32().unwrap();
}
let mut state = [0; 8];
let values = read_heap_array(memory, hash_values);
Expand All @@ -301,7 +301,7 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
));
}
for (i, &value) in values.iter().enumerate() {
state[i] = value.try_into().unwrap();
state[i] = value.expect_u32().unwrap();
}

sha256_compression(&mut state, &message);
Expand All @@ -311,7 +311,7 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
Ok(())
}
BlackBoxOp::ToRadix { input, radix, output_pointer, num_limbs, output_bits } => {
let input: F = *memory.read(*input).extract_field().expect("ToRadix input not a field");
let input: F = memory.read(*input).expect_field().expect("ToRadix input not a field");
let MemoryValue::U32(radix) = memory.read(*radix) else {
panic!("ToRadix opcode's radix bit size does not match expected bit size 32")
};
Expand Down
6 changes: 3 additions & 3 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,14 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {
// Check if condition is true
// We use 0 to mean false and any other value to mean true
let condition_value = self.memory.read(*condition);
if condition_value.try_into().expect("condition value is not a boolean") {
if condition_value.expect_u1().expect("condition value is not a boolean") {
return self.set_program_counter(*destination);
}
self.increment_program_counter()
}
Opcode::JumpIfNot { condition, location: destination } => {
let condition_value = self.memory.read(*condition);
if condition_value.try_into().expect("condition value is not a boolean") {
if condition_value.expect_u1().expect("condition value is not a boolean") {
return self.increment_program_counter();
}
self.set_program_counter(*destination)
Expand Down Expand Up @@ -340,7 +340,7 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {
}
Opcode::ConditionalMov { destination, source_a, source_b, condition } => {
let condition_value = self.memory.read(*condition);
if condition_value.try_into().expect("condition value is not a boolean") {
if condition_value.expect_u1().expect("condition value is not a boolean") {
self.memory.write(*destination, self.memory.read(*source_a));
} else {
self.memory.write(*destination, self.memory.read(*source_b));
Expand Down
36 changes: 14 additions & 22 deletions acvm-repo/brillig_vm/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@ impl<F> MemoryValue<F> {
}
}

/// Extracts the field element from the memory value, if it is typed as field element.
pub fn extract_field(&self) -> Option<&F> {
match self {
MemoryValue::Field(value) => Some(value),
_ => None,
}
}

pub fn bit_size(&self) -> BitSize {
match self {
MemoryValue::Field(_) => BitSize::Field,
Expand Down Expand Up @@ -102,7 +94,8 @@ impl<F: AcirField> MemoryValue<F> {
}
}

pub fn expect_field(&self) -> Result<&F, MemoryTypeError> {
/// Extracts the field element from the memory value, if it is typed as field element.
pub fn expect_field(self) -> Result<F, MemoryTypeError> {
if let MemoryValue::Field(field) = self {
Ok(field)
} else {
Expand All @@ -112,10 +105,9 @@ impl<F: AcirField> MemoryValue<F> {
})
}
}

pub(crate) fn expect_u1(&self) -> Result<bool, MemoryTypeError> {
pub(crate) fn expect_u1(self) -> Result<bool, MemoryTypeError> {
if let MemoryValue::U1(value) = self {
Ok(*value)
Ok(value)
} else {
Err(MemoryTypeError::MismatchedBitSize {
value_bit_size: self.bit_size().to_u32::<F>(),
Expand All @@ -124,9 +116,9 @@ impl<F: AcirField> MemoryValue<F> {
}
}

pub(crate) fn expect_u8(&self) -> Result<u8, MemoryTypeError> {
pub(crate) fn expect_u8(self) -> Result<u8, MemoryTypeError> {
if let MemoryValue::U8(value) = self {
Ok(*value)
Ok(value)
} else {
Err(MemoryTypeError::MismatchedBitSize {
value_bit_size: self.bit_size().to_u32::<F>(),
Expand All @@ -135,9 +127,9 @@ impl<F: AcirField> MemoryValue<F> {
}
}

pub(crate) fn expect_u16(&self) -> Result<u16, MemoryTypeError> {
pub(crate) fn expect_u16(self) -> Result<u16, MemoryTypeError> {
if let MemoryValue::U16(value) = self {
Ok(*value)
Ok(value)
} else {
Err(MemoryTypeError::MismatchedBitSize {
value_bit_size: self.bit_size().to_u32::<F>(),
Expand All @@ -146,9 +138,9 @@ impl<F: AcirField> MemoryValue<F> {
}
}

pub(crate) fn expect_u32(&self) -> Result<u32, MemoryTypeError> {
pub(crate) fn expect_u32(self) -> Result<u32, MemoryTypeError> {
if let MemoryValue::U32(value) = self {
Ok(*value)
Ok(value)
} else {
Err(MemoryTypeError::MismatchedBitSize {
value_bit_size: self.bit_size().to_u32::<F>(),
Expand All @@ -157,9 +149,9 @@ impl<F: AcirField> MemoryValue<F> {
}
}

pub(crate) fn expect_u64(&self) -> Result<u64, MemoryTypeError> {
pub(crate) fn expect_u64(self) -> Result<u64, MemoryTypeError> {
if let MemoryValue::U64(value) = self {
Ok(*value)
Ok(value)
} else {
Err(MemoryTypeError::MismatchedBitSize {
value_bit_size: self.bit_size().to_u32::<F>(),
Expand All @@ -168,9 +160,9 @@ impl<F: AcirField> MemoryValue<F> {
}
}

pub(crate) fn expect_u128(&self) -> Result<u128, MemoryTypeError> {
pub(crate) fn expect_u128(self) -> Result<u128, MemoryTypeError> {
if let MemoryValue::U128(value) = self {
Ok(*value)
Ok(value)
} else {
Err(MemoryTypeError::MismatchedBitSize {
value_bit_size: self.bit_size().to_u32::<F>(),
Expand Down

0 comments on commit 3878037

Please sign in to comment.