diff --git a/acvm/src/pwg/mod.rs b/acvm/src/pwg/mod.rs index 0d455bad8..5a8774a4b 100644 --- a/acvm/src/pwg/mod.rs +++ b/acvm/src/pwg/mod.rs @@ -310,7 +310,7 @@ mod tests { use std::collections::BTreeMap; use acir::{ - brillig_vm::{self, BinaryFieldOp, RegisterIndex, RegisterValueOrArray, Value}, + brillig_vm::{self, BinaryFieldOp, RegisterIndex, RegisterOrMemory, Value}, circuit::{ brillig::{Brillig, BrilligInputs, BrilligOutputs}, directives::Directive, @@ -483,8 +483,8 @@ mod tests { // Oracles are named 'foreign calls' in brillig brillig_vm::Opcode::ForeignCall { function: "invert".into(), - destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))], - inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))], + destinations: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(1))], + inputs: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(0))], }, ], predicate: None, @@ -535,9 +535,8 @@ mod tests { "Should be waiting for a single input" ); // As caller of VM, need to resolve foreign calls - let foreign_call_result = vec![Value::from( - foreign_call.foreign_call_wait_info.inputs[0][0].to_field().inverse(), - )]; + let foreign_call_result = + Value::from(foreign_call.foreign_call_wait_info.inputs[0][0].to_field().inverse()); // Alter Brillig oracle opcode with foreign call resolution let brillig: Brillig = foreign_call.resolve(foreign_call_result.into()); let mut next_opcodes_for_solving = vec![Opcode::Brillig(brillig)]; @@ -611,13 +610,13 @@ mod tests { // Oracles are named 'foreign calls' in brillig brillig_vm::Opcode::ForeignCall { function: "invert".into(), - destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))], - inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))], + destinations: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(1))], + inputs: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(0))], }, brillig_vm::Opcode::ForeignCall { function: "invert".into(), - destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(3))], - inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(2))], + destinations: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(3))], + inputs: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(2))], }, ], predicate: None, @@ -673,7 +672,7 @@ mod tests { let x_plus_y_inverse = foreign_call.foreign_call_wait_info.inputs[0][0].to_field().inverse(); // Alter Brillig oracle opcode - let brillig: Brillig = foreign_call.resolve(vec![Value::from(x_plus_y_inverse)].into()); + let brillig: Brillig = foreign_call.resolve(Value::from(x_plus_y_inverse).into()); let mut next_opcodes_for_solving = vec![Opcode::Brillig(brillig)]; next_opcodes_for_solving.extend_from_slice(&unsolved_opcodes[..]); @@ -699,7 +698,7 @@ mod tests { foreign_call.foreign_call_wait_info.inputs[0][0].to_field().inverse(); assert_ne!(x_plus_y_inverse, i_plus_j_inverse); // Alter Brillig oracle opcode - let brillig = foreign_call.resolve(vec![Value::from(i_plus_j_inverse)].into()); + let brillig = foreign_call.resolve(Value::from(i_plus_j_inverse).into()); let mut next_opcodes_for_solving = vec![Opcode::Brillig(brillig)]; next_opcodes_for_solving.extend_from_slice(&unsolved_opcodes[..]); @@ -759,8 +758,8 @@ mod tests { // Oracles are named 'foreign calls' in brillig brillig_vm::Opcode::ForeignCall { function: "invert".into(), - destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))], - inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))], + destinations: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(1))], + inputs: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(0))], }, ], predicate: Some(Expression::default()), diff --git a/brillig_vm/src/lib.rs b/brillig_vm/src/lib.rs index f92ebdc60..42c55113c 100644 --- a/brillig_vm/src/lib.rs +++ b/brillig_vm/src/lib.rs @@ -11,7 +11,7 @@ mod registers; mod value; pub use opcodes::Opcode; -pub use opcodes::{BinaryFieldOp, BinaryIntOp, RegisterValueOrArray}; +pub use opcodes::{BinaryFieldOp, BinaryIntOp, RegisterOrMemory}; pub use registers::{RegisterIndex, Registers}; use serde::{Deserialize, Serialize}; pub use value::Typ; @@ -39,25 +39,31 @@ pub enum VMStatus { }, } -/// Represents the output of a [foreign call][Opcode::ForeignCall]. +/// Single output of a [foreign call][Opcode::ForeignCall]. +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)] +pub enum ForeignCallOutput { + Single(Value), + Array(Vec), +} + +/// Represents the full output of a [foreign call][Opcode::ForeignCall]. /// /// See [`VMStatus::ForeignCallWait`] for more information. #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)] pub struct ForeignCallResult { /// Resolved output values of the foreign call. - /// Each output is its own list of values as an output can be either a single value or a memory pointer - pub values: Vec>, + pub values: Vec, } -impl From> for ForeignCallResult { - fn from(values: Vec) -> Self { - ForeignCallResult { values: vec![values] } +impl From for ForeignCallResult { + fn from(value: Value) -> Self { + ForeignCallResult { values: vec![ForeignCallOutput::Single(value)] } } } -impl From>> for ForeignCallResult { - fn from(values: Vec>) -> Self { - ForeignCallResult { values } +impl From> for ForeignCallResult { + fn from(values: Vec) -> Self { + ForeignCallResult { values: vec![ForeignCallOutput::Array(values)] } } } @@ -204,50 +210,72 @@ impl VM { &self.foreign_call_results[self.foreign_call_counter]; let mut invalid_foreign_call_result = false; - for (destination, values) in destinations.iter().zip(values) { + for (destination, output) in destinations.iter().zip(values) { match destination { - RegisterValueOrArray::RegisterIndex(index) => { - if values.len() != 1 { - invalid_foreign_call_result = true; - break; + RegisterOrMemory::RegisterIndex(value_index) => match output { + ForeignCallOutput::Single(value) => { + self.registers.set(*value_index, *value) } - - self.registers.set(*index, values[0]) - } - RegisterValueOrArray::HeapArray(index, size) => { - if values.len() != *size { - invalid_foreign_call_result = true; - break; - } - - // Convert the destination pointer to a usize - let destination = self.registers.get(*index).to_usize(); - // Expand memory if the array to be written - // will overtake the maximum memory pointer - if (destination + size) >= self.memory.len() { - self.memory.append(&mut vec![ - Value::from(0_usize); - (destination + size) - - self.memory.len() - ]); + _ => unreachable!( + "Function result size does not match brillig bytecode (expected 1 result)" + ), + }, + RegisterOrMemory::HeapArray(pointer_index, size) => { + match output { + ForeignCallOutput::Array(values) => { + if values.len() != *size { + invalid_foreign_call_result = true; + break; + } + // Convert the destination pointer to a usize + let destination = self.registers.get(*pointer_index).to_usize(); + // Calculate new memory size + let new_size = + std::cmp::max(self.memory.len(), destination + size); + // Expand memory to new size with default values if needed + self.memory.resize(new_size, Value::from(0_usize)); + // Write to our destination memory + for (i, value) in values.iter().enumerate() { + self.memory[destination + i] = *value; + } + } + _ => { + unreachable!("Function result size does not match brillig bytecode size") + } } - - for (i, value) in values.iter().enumerate() { - self.memory[destination + i] = *value; + } + RegisterOrMemory::HeapVector(pointer_index, size_index) => { + match output { + ForeignCallOutput::Array(values) => { + // Set our size in the size register + self.registers.set(*size_index, Value::from(values.len())); + // Convert the destination pointer to a usize + let destination = self.registers.get(*pointer_index).to_usize(); + // Calculate new memory size + let new_size = + std::cmp::max(self.memory.len(), destination + values.len()); + // Expand memory to new size with default values if needed + self.memory.resize(new_size, Value::from(0_usize)); + // Write to our destination memory + for (i, value) in values.iter().enumerate() { + self.memory[destination + i] = *value; + } + } + _ => { + unreachable!("Function result size does not match brillig bytecode size") + } } } } } // These checks must come after resolving the foreign call outputs as `fail` uses a mutable reference - if invalid_foreign_call_result { - return VMStatus::Failure { - message: "Function result size does not match brillig bytecode".to_owned(), - }; - } if destinations.len() != values.len() { self.fail(format!("{} output values were provided as a foreign call result for {} destination slots", values.len(), destinations.len())); } + if invalid_foreign_call_result { + self.fail("Function result size does not match brillig bytecode".to_owned()); + } self.foreign_call_counter += 1; self.increment_program_counter() @@ -314,13 +342,20 @@ impl VM { self.status.clone() } - fn get_register_value_or_memory_values(&self, input: RegisterValueOrArray) -> Vec { + fn get_register_value_or_memory_values(&self, input: RegisterOrMemory) -> Vec { match input { - RegisterValueOrArray::RegisterIndex(index) => vec![self.registers.get(index)], - RegisterValueOrArray::HeapArray(index, size) => { - let start = self.registers.get(index); + RegisterOrMemory::RegisterIndex(value_index) => { + vec![self.registers.get(value_index)] + } + RegisterOrMemory::HeapArray(pointer_index, size) => { + let start = self.registers.get(pointer_index); self.memory[start.to_usize()..(start.to_usize() + size)].to_vec() } + RegisterOrMemory::HeapVector(pointer_index, size_index) => { + let start = self.registers.get(pointer_index); + let size = self.registers.get(size_index); + self.memory[start.to_usize()..(start.to_usize() + size.to_usize())].to_vec() + } } } @@ -355,7 +390,6 @@ impl VM { let rhs_value = self.registers.get(rhs); let result_value = op.evaluate_int(lhs_value.to_u128(), rhs_value.to_u128(), bit_size); - self.registers.set(result, result_value.into()); } } @@ -841,8 +875,8 @@ mod tests { // Call foreign function "double" with the input register Opcode::ForeignCall { function: "double".into(), - destinations: vec![RegisterValueOrArray::RegisterIndex(r_result)], - inputs: vec![RegisterValueOrArray::RegisterIndex(r_input)], + destinations: vec![RegisterOrMemory::RegisterIndex(r_result)], + inputs: vec![RegisterOrMemory::RegisterIndex(r_input)], }, ]; @@ -858,9 +892,9 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(ForeignCallResult { - values: vec![vec![Value::from(10u128)]], // Result of doubling 5u128 - }); + vm.foreign_call_results.push( + Value::from(10u128).into(), // Result of doubling 5u128 + ); // Resume VM brillig_execute(&mut vm); @@ -896,8 +930,8 @@ mod tests { // *output = matrix_2x2_transpose(*input) Opcode::ForeignCall { function: "matrix_2x2_transpose".into(), - destinations: vec![RegisterValueOrArray::HeapArray(r_output, initial_matrix.len())], - inputs: vec![RegisterValueOrArray::HeapArray(r_input, initial_matrix.len())], + destinations: vec![RegisterOrMemory::HeapArray(r_output, initial_matrix.len())], + inputs: vec![RegisterOrMemory::HeapArray(r_input, initial_matrix.len())], }, ]; @@ -913,7 +947,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(ForeignCallResult { values: vec![expected_result.clone()] }); + vm.foreign_call_results.push(expected_result.clone().into()); // Resume VM brillig_execute(&mut vm); @@ -929,6 +963,76 @@ mod tests { assert_eq!(vm.foreign_call_counter, 1); } + /// Calling a simple foreign call function that takes any string input, concatenates it with itself, and reverses the concatenation + #[test] + fn foreign_call_opcode_vector_input_and_output() { + let r_input_pointer = RegisterIndex::from(0); + let r_input_size = RegisterIndex::from(1); + // We need to pass a location of appropriate size + let r_output_pointer = RegisterIndex::from(2); + let r_output_size = RegisterIndex::from(3); + + // Our first string to use the identity function with + let input_string = + vec![Value::from(1u128), Value::from(2u128), Value::from(3u128), Value::from(4u128)]; + // Double the string (concatenate it with itself) + let mut output_string: Vec = + input_string.iter().cloned().chain(input_string.clone()).collect(); + // Reverse the concatenated string + output_string.reverse(); + + // First call: + let string_double_program = vec![ + // input_pointer = 0 + Opcode::Const { destination: r_input_pointer, value: Value::from(0u128) }, + // input_size = input_string.len() (constant here) + Opcode::Const { destination: r_input_size, value: Value::from(input_string.len()) }, + // output_pointer = 0 + input_size = input_size + Opcode::Const { destination: r_output_pointer, value: Value::from(input_string.len()) }, + // output_size = input_size * 2 + Opcode::Const { + destination: r_output_size, + value: Value::from(input_string.len() * 2), + }, + // output_pointer[0..output_size] = string_double(input_pointer[0...input_size]) + Opcode::ForeignCall { + function: "string_double".into(), + destinations: vec![RegisterOrMemory::HeapVector(r_output_pointer, r_output_size)], + inputs: vec![RegisterOrMemory::HeapVector(r_input_pointer, r_input_size)], + }, + ]; + + let mut vm = brillig_execute_and_get_vm(input_string.clone(), string_double_program); + + // Check that VM is waiting + assert_eq!( + vm.status, + VMStatus::ForeignCallWait { + function: "string_double".into(), + inputs: vec![input_string.clone()] + } + ); + + // Push result we're waiting for + vm.foreign_call_results.push(ForeignCallResult { + values: vec![ForeignCallOutput::Array(output_string.clone())], + }); + + // Resume VM + brillig_execute(&mut vm); + + // Check that VM finished once resumed + assert_eq!(vm.status, VMStatus::Finished); + + // Check result in memory + let result_values = + vm.memory[input_string.len()..(input_string.len() + output_string.len())].to_vec(); + assert_eq!(result_values, output_string); + + // Ensure the foreign call counter has been incremented + assert_eq!(vm.foreign_call_counter, 1); + } + #[test] fn foreign_call_opcode_memory_alloc_result() { let r_input = RegisterIndex::from(0); @@ -950,8 +1054,8 @@ mod tests { // *output = matrix_2x2_transpose(*input) Opcode::ForeignCall { function: "matrix_2x2_transpose".into(), - destinations: vec![RegisterValueOrArray::HeapArray(r_output, initial_matrix.len())], - inputs: vec![RegisterValueOrArray::HeapArray(r_input, initial_matrix.len())], + destinations: vec![RegisterOrMemory::HeapArray(r_output, initial_matrix.len())], + inputs: vec![RegisterOrMemory::HeapArray(r_input, initial_matrix.len())], }, ]; @@ -967,7 +1071,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(ForeignCallResult { values: vec![expected_result.clone()] }); + vm.foreign_call_results.push(expected_result.clone().into()); // Resume VM brillig_execute(&mut vm); @@ -1022,10 +1126,10 @@ mod tests { // *output = matrix_2x2_transpose(*input) Opcode::ForeignCall { function: "matrix_2x2_transpose".into(), - destinations: vec![RegisterValueOrArray::HeapArray(r_output, matrix_a.len())], + destinations: vec![RegisterOrMemory::HeapArray(r_output, matrix_a.len())], inputs: vec![ - RegisterValueOrArray::HeapArray(r_input_a, matrix_a.len()), - RegisterValueOrArray::HeapArray(r_input_b, matrix_b.len()), + RegisterOrMemory::HeapArray(r_input_a, matrix_a.len()), + RegisterOrMemory::HeapArray(r_input_b, matrix_b.len()), ], }, ]; @@ -1043,7 +1147,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(ForeignCallResult { values: vec![expected_result.clone()] }); + vm.foreign_call_results.push(expected_result.clone().into()); // Resume VM brillig_execute(&mut vm); diff --git a/brillig_vm/src/opcodes.rs b/brillig_vm/src/opcodes.rs index 565f48beb..cd531934a 100644 --- a/brillig_vm/src/opcodes.rs +++ b/brillig_vm/src/opcodes.rs @@ -4,10 +4,27 @@ use serde::{Deserialize, Serialize}; pub type Label = usize; +/// Lays out various ways an external foreign call's input and output data may be interpreted inside Brillig. +/// This data can either be an individual register value or memory. +/// +/// While we are usually agnostic to how memory is passed within Brillig, +/// this needs to be encoded somehow when dealing with an external system. +/// For simplicity, the extra type information is given right in the ForeignCall instructions. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy)] -pub enum RegisterValueOrArray { +pub enum RegisterOrMemory { + /// A single register value passed to or from an external call + /// It is an 'immediate' value - used without dereferencing memory. + /// For a foreign call input, the value is read directly from the register. + /// For a foreign call output, the value is written directly to the register. RegisterIndex(RegisterIndex), + /// A fix-sized array passed starting from a Brillig register memory location. + /// In the case of a foreign call input, the array is read from this Brillig memory location + usize more cells. + /// In the case of a foreign call output, the array is written to this Brillig memory location with the usize being here just as a sanity check for the size write. HeapArray(RegisterIndex, usize), + /// A register-sized vector passed starting from a Brillig register memory location and with a register-held size + /// In the case of a foreign call input, the vector is read from this Brillig memory location + as many cells as the 2nd register indicates. + /// In the case of a foreign call output, the vector is written to this Brillig memory location and as 'size' cells, with size being stored in the second register. + HeapVector(RegisterIndex, RegisterIndex), } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -64,9 +81,9 @@ pub enum Opcode { /// who the caller is. function: String, /// Destination registers (may be single values or memory pointers). - destinations: Vec, + destinations: Vec, /// Input registers (may be single values or memory pointers). - inputs: Vec, + inputs: Vec, }, Mov { destination: RegisterIndex, diff --git a/cspell.json b/cspell.json index f03f326c7..8a4dc7645 100644 --- a/cspell.json +++ b/cspell.json @@ -20,6 +20,7 @@ "deflater", "endianness", "euclidian", + "funcs", "hasher", "keccak", "Merkle",