Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

feat(brillig): Allow dynamic-size foreign calls #370

Merged
merged 23 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ mod tests {
use std::collections::BTreeMap;

use acir::{
brillig_vm::{self, BinaryFieldOp, RegisterIndex, RegisterValueOrArray, Value},
brillig_vm::{self, BinaryFieldOp, RegisterIndex, RegisterOrMemory, Value},
circuit::{
brillig::{Brillig, BrilligInputs, BrilligOutputs},
directives::Directive,
Expand Down Expand Up @@ -483,8 +483,8 @@ mod tests {
// Oracles are named 'foreign calls' in brillig
brillig_vm::Opcode::ForeignCall {
function: "invert".into(),
destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))],
inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))],
destinations: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(1))],
inputs: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(0))],
},
],
predicate: None,
Expand Down Expand Up @@ -611,13 +611,13 @@ mod tests {
// Oracles are named 'foreign calls' in brillig
brillig_vm::Opcode::ForeignCall {
function: "invert".into(),
destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))],
inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))],
destinations: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(1))],
inputs: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(0))],
},
brillig_vm::Opcode::ForeignCall {
function: "invert".into(),
destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(3))],
inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(2))],
destinations: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(3))],
inputs: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(2))],
},
],
predicate: None,
Expand Down Expand Up @@ -759,8 +759,8 @@ mod tests {
// Oracles are named 'foreign calls' in brillig
brillig_vm::Opcode::ForeignCall {
function: "invert".into(),
destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))],
inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))],
destinations: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(1))],
inputs: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(0))],
},
],
predicate: Some(Expression::default()),
Expand Down
146 changes: 116 additions & 30 deletions brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod registers;
mod value;

pub use opcodes::Opcode;
pub use opcodes::{BinaryFieldOp, BinaryIntOp, RegisterValueOrArray};
pub use opcodes::{BinaryFieldOp, BinaryIntOp, RegisterOrMemory};
pub use registers::{RegisterIndex, Registers};
use serde::{Deserialize, Serialize};
pub use value::Typ;
Expand Down Expand Up @@ -205,32 +205,46 @@ impl VM {

for (destination, values) in destinations.iter().zip(values) {
match destination {
RegisterValueOrArray::RegisterIndex(index) => {
RegisterOrMemory::RegisterIndex(value_index) => {
assert_eq!(
values.len(),
1,
"Function result size does not match brillig bytecode"
"Function result size does not match brillig (expected 1 result)"
);
self.registers.set(*index, values[0])
self.registers.set(*value_index, values[0])
}
RegisterValueOrArray::HeapArray(index, size) => {
RegisterOrMemory::HeapArray(pointer_index, size) => {
assert_eq!(
values.len(),
*size,
"Function result size does not match brillig bytecode"
"Function result size does not match brillig size"
);
// Convert the destination pointer to a usize
let destination = self.registers.get(*index).to_usize();
// Expand memory if the array to be written
// will overtake the maximum memory pointer
if (destination + size) >= self.memory.len() {
self.memory.append(&mut vec![
Value::from(0_usize);
(destination + size)
- self.memory.len()
]);
let destination = self.registers.get(*pointer_index).to_usize();
// Calculate new memory size
let new_size = std::cmp::max(self.memory.len(), destination + size);
// Expand memory to new size with default values if needed
self.memory.resize(new_size, Value::from(0_usize));
// Write to our destination memory
for (i, value) in values.iter().enumerate() {
self.memory[destination + i] = *value;
}

}
RegisterOrMemory::HeapVector(pointer_index, size_index) => {
// Convert the size pointer to a usize
let size = self.registers.get(*size_index).to_usize();
assert_eq!(
values.len(),
size,
"Function result size does not match brillig size register"
);
// Convert the destination pointer to a usize
let destination = self.registers.get(*pointer_index).to_usize();
// Calculate new memory size
let new_size = std::cmp::max(self.memory.len(), destination + size);
// Expand memory to new size with default values if needed
self.memory.resize(new_size, Value::from(0_usize));
// Write to our destination memory
for (i, value) in values.iter().enumerate() {
self.memory[destination + i] = *value;
}
Expand Down Expand Up @@ -308,13 +322,20 @@ impl VM {
self.status.clone()
}

fn get_register_value_or_memory_values(&self, input: RegisterValueOrArray) -> Vec<Value> {
fn get_register_value_or_memory_values(&self, input: RegisterOrMemory) -> Vec<Value> {
match input {
RegisterValueOrArray::RegisterIndex(index) => vec![self.registers.get(index)],
RegisterValueOrArray::HeapArray(index, size) => {
let start = self.registers.get(index);
RegisterOrMemory::RegisterIndex(value_index) => {
vec![self.registers.get(value_index)]
}
RegisterOrMemory::HeapArray(pointer_index, size) => {
let start = self.registers.get(pointer_index);
self.memory[start.to_usize()..(start.to_usize() + size)].to_vec()
}
RegisterOrMemory::HeapVector(pointer_index, size_index) => {
let start = self.registers.get(pointer_index);
let size = self.registers.get(size_index);
self.memory[start.to_usize()..(start.to_usize() + size.to_usize())].to_vec()
}
}
}

Expand Down Expand Up @@ -349,7 +370,6 @@ impl VM {
let rhs_value = self.registers.get(rhs);

let result_value = op.evaluate_int(lhs_value.to_u128(), rhs_value.to_u128(), bit_size);

self.registers.set(result, result_value.into());
}
}
Expand Down Expand Up @@ -835,8 +855,8 @@ mod tests {
// Call foreign function "double" with the input register
Opcode::ForeignCall {
function: "double".into(),
destinations: vec![RegisterValueOrArray::RegisterIndex(r_result)],
inputs: vec![RegisterValueOrArray::RegisterIndex(r_input)],
destinations: vec![RegisterOrMemory::RegisterIndex(r_result)],
inputs: vec![RegisterOrMemory::RegisterIndex(r_input)],
},
];

Expand Down Expand Up @@ -890,8 +910,8 @@ mod tests {
// *output = matrix_2x2_transpose(*input)
Opcode::ForeignCall {
function: "matrix_2x2_transpose".into(),
destinations: vec![RegisterValueOrArray::HeapArray(r_output, initial_matrix.len())],
inputs: vec![RegisterValueOrArray::HeapArray(r_input, initial_matrix.len())],
destinations: vec![RegisterOrMemory::HeapArray(r_output, initial_matrix.len())],
inputs: vec![RegisterOrMemory::HeapArray(r_input, initial_matrix.len())],
},
];

Expand Down Expand Up @@ -923,6 +943,72 @@ mod tests {
assert_eq!(vm.foreign_call_counter, 1);
}

/// Calling a foreign call identity function that takes any string input and concatenates it with itself
#[test]
fn foreign_call_opcode_vector_input_and_output() {
let r_input_pointer = RegisterIndex::from(0);
let r_input_size = RegisterIndex::from(1);
// We need to pass a location of appropriate size
let r_output_pointer = RegisterIndex::from(2);
let r_output_size = RegisterIndex::from(3);

// Our first string to use the identity function with
let input_string =
vec![Value::from(1u128), Value::from(2u128), Value::from(3u128), Value::from(4u128)];
// Double the string (concate it with itself)
let output_string: Vec<Value> =
input_string.iter().cloned().chain(input_string.clone()).collect();

// Expected results are the same as the initial results
// First call:
let string_double_program = vec![
// input_pointer = 0
Opcode::Const { destination: r_input_pointer, value: Value::from(0u128) },
// input_size = input_string.len() (constant here)
Opcode::Const { destination: r_input_size, value: Value::from(input_string.len()) },
// output_pointer = 0 + input_size = input_size
Opcode::Const { destination: r_output_pointer, value: Value::from(input_string.len()) },
// output_size = input_size * 2
Opcode::Const {
destination: r_output_size,
value: Value::from(input_string.len() * 2),
},
// output_pointer[0..output_size] = string_double(input_pointer[0...input_size])
Opcode::ForeignCall {
function: "string_double".into(),
destinations: vec![RegisterOrMemory::HeapVector(r_output_pointer, r_output_size)],
inputs: vec![RegisterOrMemory::HeapVector(r_input_pointer, r_input_size)],
},
];

let mut vm = brillig_execute_and_get_vm(input_string.clone(), string_double_program);

// Check that VM is waiting
assert_eq!(
vm.status,
VMStatus::ForeignCallWait {
function: "string_double".into(),
inputs: vec![input_string]
}
);

// Push result we're waiting for
vm.foreign_call_results.push(ForeignCallResult { values: vec![output_string.clone()] });

// Resume VM
brillig_execute(&mut vm);

// Check that VM finished once resumed
assert_eq!(vm.status, VMStatus::Finished);

// Check result in memory
let result_values = vm.memory[0..output_string.len()].to_vec();
assert_eq!(result_values, output_string);

// Ensure the foreign call counter has been incremented
assert_eq!(vm.foreign_call_counter, 1);
}

#[test]
fn foreign_call_opcode_memory_alloc_result() {
let r_input = RegisterIndex::from(0);
Expand All @@ -944,8 +1030,8 @@ mod tests {
// *output = matrix_2x2_transpose(*input)
Opcode::ForeignCall {
function: "matrix_2x2_transpose".into(),
destinations: vec![RegisterValueOrArray::HeapArray(r_output, initial_matrix.len())],
inputs: vec![RegisterValueOrArray::HeapArray(r_input, initial_matrix.len())],
destinations: vec![RegisterOrMemory::HeapArray(r_output, initial_matrix.len())],
inputs: vec![RegisterOrMemory::HeapArray(r_input, initial_matrix.len())],
},
];

Expand Down Expand Up @@ -1016,10 +1102,10 @@ mod tests {
// *output = matrix_2x2_transpose(*input)
Opcode::ForeignCall {
function: "matrix_2x2_transpose".into(),
destinations: vec![RegisterValueOrArray::HeapArray(r_output, matrix_a.len())],
destinations: vec![RegisterOrMemory::HeapArray(r_output, matrix_a.len())],
inputs: vec![
RegisterValueOrArray::HeapArray(r_input_a, matrix_a.len()),
RegisterValueOrArray::HeapArray(r_input_b, matrix_b.len()),
RegisterOrMemory::HeapArray(r_input_a, matrix_a.len()),
RegisterOrMemory::HeapArray(r_input_b, matrix_b.len()),
],
},
];
Expand Down
21 changes: 18 additions & 3 deletions brillig_vm/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,25 @@ use serde::{Deserialize, Serialize};

pub type Label = usize;

/// Lays out various ways we might interpret memory when passing to a foreign call external from Brillig.
///
/// While we are usually are agnostic to how memory is passed within Brillig,
/// this needs to be encoded somehow when dealing with an external system.
/// For simplicity, the extra type information is given right in the ForeignCall instructions.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy)]
pub enum RegisterValueOrArray {
pub enum RegisterOrMemory {
/// An immediate value passed to or from an external call
/// For a foreign call input, this is read directly from the register.
/// For a foreign call output, this is written directly to the register.
RegisterIndex(RegisterIndex),
/// A fix-sized array passed starting from a Brillig register memory location.
/// In the case of a foreign call input, this is read from this Brillig memory location + usize more cells.
/// In the case of a foreign call output, this is written to this Brillig memory location with the usize being here just as a sanity check for the size write.
HeapArray(RegisterIndex, usize),
/// A register-sized vector passed starting from a Brillig register memory location and with a register-held size
/// In the case of a foreign call input, this is read from this Brillig memory location + as many cells as the 2nd register indicates.
/// In the case of a foreign call output, this is written to this Brillig memory location with the usize being here just as a sanity check for the size write.
HeapVector(RegisterIndex, RegisterIndex),
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
Expand Down Expand Up @@ -64,9 +79,9 @@ pub enum Opcode {
/// who the caller is.
function: String,
/// Destination registers (may be single values or memory pointers).
destinations: Vec<RegisterValueOrArray>,
destinations: Vec<RegisterOrMemory>,
/// Input registers (may be single values or memory pointers).
inputs: Vec<RegisterValueOrArray>,
inputs: Vec<RegisterOrMemory>,
},
Mov {
destination: RegisterIndex,
Expand Down