Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add CMOV instruction to brillig and brillig gen #5308

Merged
merged 5 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,24 @@ pub fn brillig_to_avm(brillig: &Brillig) -> Vec<u8> {
} => {
avm_instrs.push(generate_mov_instruction(Some(ALL_DIRECT), source.to_usize() as u32, destination.to_usize() as u32));
}
BrilligOpcode::ConditionalMov {
source_a,
source_b,
condition,
destination,
} => {
avm_instrs.push(AvmInstruction {
opcode: AvmOpcode::CMOV,
indirect: Some(ALL_DIRECT),
operands: vec![
AvmOperand::U32 { value: source_a.to_usize() as u32 },
AvmOperand::U32 { value: source_b.to_usize() as u32 },
AvmOperand::U32 { value: condition.to_usize() as u32 },
AvmOperand::U32 { value: destination.to_usize() as u32 },
],
..Default::default()
});
}
BrilligOpcode::Load {
destination,
source_pointer,
Expand Down
74 changes: 74 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,17 @@ struct BrilligOpcode {
static Mov bincodeDeserialize(std::vector<uint8_t>);
};

struct ConditionalMov {
Program::MemoryAddress destination;
Program::MemoryAddress source_a;
Program::MemoryAddress source_b;
Program::MemoryAddress condition;

friend bool operator==(const ConditionalMov&, const ConditionalMov&);
std::vector<uint8_t> bincodeSerialize() const;
static ConditionalMov bincodeDeserialize(std::vector<uint8_t>);
};

struct Load {
Program::MemoryAddress destination;
Program::MemoryAddress source_pointer;
Expand Down Expand Up @@ -644,6 +655,7 @@ struct BrilligOpcode {
Return,
ForeignCall,
Mov,
ConditionalMov,
Load,
Store,
BlackBox,
Expand Down Expand Up @@ -5832,6 +5844,68 @@ Program::BrilligOpcode::Mov serde::Deserializable<Program::BrilligOpcode::Mov>::

namespace Program {

inline bool operator==(const BrilligOpcode::ConditionalMov& lhs, const BrilligOpcode::ConditionalMov& rhs)
{
if (!(lhs.destination == rhs.destination)) {
return false;
}
if (!(lhs.source_a == rhs.source_a)) {
return false;
}
if (!(lhs.source_b == rhs.source_b)) {
return false;
}
if (!(lhs.condition == rhs.condition)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BrilligOpcode::ConditionalMov::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligOpcode::ConditionalMov>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligOpcode::ConditionalMov BrilligOpcode::ConditionalMov::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligOpcode::ConditionalMov>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BrilligOpcode::ConditionalMov>::serialize(
const Program::BrilligOpcode::ConditionalMov& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.destination)>::serialize(obj.destination, serializer);
serde::Serializable<decltype(obj.source_a)>::serialize(obj.source_a, serializer);
serde::Serializable<decltype(obj.source_b)>::serialize(obj.source_b, serializer);
serde::Serializable<decltype(obj.condition)>::serialize(obj.condition, serializer);
}

template <>
template <typename Deserializer>
Program::BrilligOpcode::ConditionalMov serde::Deserializable<Program::BrilligOpcode::ConditionalMov>::deserialize(
Deserializer& deserializer)
{
Program::BrilligOpcode::ConditionalMov obj;
obj.destination = serde::Deserializable<decltype(obj.destination)>::deserialize(deserializer);
obj.source_a = serde::Deserializable<decltype(obj.source_a)>::deserialize(deserializer);
obj.source_b = serde::Deserializable<decltype(obj.source_b)>::deserialize(deserializer);
obj.condition = serde::Deserializable<decltype(obj.condition)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BrilligOpcode::Load& lhs, const BrilligOpcode::Load& rhs)
{
if (!(lhs.destination == rhs.destination)) {
Expand Down
60 changes: 59 additions & 1 deletion noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,17 @@ namespace Program {
static Mov bincodeDeserialize(std::vector<uint8_t>);
};

struct ConditionalMov {
Program::MemoryAddress destination;
Program::MemoryAddress source_a;
Program::MemoryAddress source_b;
Program::MemoryAddress condition;

friend bool operator==(const ConditionalMov&, const ConditionalMov&);
std::vector<uint8_t> bincodeSerialize() const;
static ConditionalMov bincodeDeserialize(std::vector<uint8_t>);
};

struct Load {
Program::MemoryAddress destination;
Program::MemoryAddress source_pointer;
Expand Down Expand Up @@ -612,7 +623,7 @@ namespace Program {
static Stop bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<BinaryFieldOp, BinaryIntOp, Cast, JumpIfNot, JumpIf, Jump, CalldataCopy, Call, Const, Return, ForeignCall, Mov, Load, Store, BlackBox, Trap, Stop> value;
std::variant<BinaryFieldOp, BinaryIntOp, Cast, JumpIfNot, JumpIf, Jump, CalldataCopy, Call, Const, Return, ForeignCall, Mov, ConditionalMov, Load, Store, BlackBox, Trap, Stop> value;

friend bool operator==(const BrilligOpcode&, const BrilligOpcode&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4826,6 +4837,53 @@ Program::BrilligOpcode::Mov serde::Deserializable<Program::BrilligOpcode::Mov>::
return obj;
}

namespace Program {

inline bool operator==(const BrilligOpcode::ConditionalMov &lhs, const BrilligOpcode::ConditionalMov &rhs) {
if (!(lhs.destination == rhs.destination)) { return false; }
if (!(lhs.source_a == rhs.source_a)) { return false; }
if (!(lhs.source_b == rhs.source_b)) { return false; }
if (!(lhs.condition == rhs.condition)) { return false; }
return true;
}

inline std::vector<uint8_t> BrilligOpcode::ConditionalMov::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligOpcode::ConditionalMov>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligOpcode::ConditionalMov BrilligOpcode::ConditionalMov::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligOpcode::ConditionalMov>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BrilligOpcode::ConditionalMov>::serialize(const Program::BrilligOpcode::ConditionalMov &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.destination)>::serialize(obj.destination, serializer);
serde::Serializable<decltype(obj.source_a)>::serialize(obj.source_a, serializer);
serde::Serializable<decltype(obj.source_b)>::serialize(obj.source_b, serializer);
serde::Serializable<decltype(obj.condition)>::serialize(obj.condition, serializer);
}

template <>
template <typename Deserializer>
Program::BrilligOpcode::ConditionalMov serde::Deserializable<Program::BrilligOpcode::ConditionalMov>::deserialize(Deserializer &deserializer) {
Program::BrilligOpcode::ConditionalMov obj;
obj.destination = serde::Deserializable<decltype(obj.destination)>::deserialize(deserializer);
obj.source_a = serde::Deserializable<decltype(obj.source_a)>::deserialize(deserializer);
obj.source_b = serde::Deserializable<decltype(obj.source_b)>::deserialize(deserializer);
obj.condition = serde::Deserializable<decltype(obj.condition)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BrilligOpcode::Load &lhs, const BrilligOpcode::Load &rhs) {
Expand Down
8 changes: 4 additions & 4 deletions noir/noir-repo/acvm-repo/acir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ mod reflection {
generator.output(&mut source, &registry).unwrap();

// Comment this out to write updated C++ code to file.
// if let Some(old_hash) = old_hash {
// let new_hash = fxhash::hash64(&source);
// assert_eq!(new_hash, old_hash, "Serialization format has changed");
// }
if let Some(old_hash) = old_hash {
let new_hash = fxhash::hash64(&source);
assert_eq!(new_hash, old_hash, "Serialization format has changed");
}

write_to_file(&source, &path);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,11 @@ fn simple_brillig_foreign_call() {
let bytes = Program::serialize_program(&program);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 77, 10, 165, 244, 212,
167, 216, 31, 244, 51, 61, 120, 241, 32, 226, 251, 85, 140, 176, 136, 122, 209, 129, 144,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 77, 10, 165, 244, 214,
159, 216, 31, 244, 51, 61, 120, 241, 32, 226, 251, 85, 140, 176, 136, 122, 209, 129, 144,
176, 9, 97, 151, 84, 225, 74, 69, 50, 31, 48, 35, 85, 251, 164, 235, 53, 94, 218, 247, 75,
163, 95, 150, 12, 153, 179, 227, 191, 114, 195, 222, 216, 240, 59, 63, 75, 221, 251, 208,
106, 207, 232, 150, 65, 100, 53, 33, 2, 9, 69, 91, 82, 144, 1, 0, 0,
106, 207, 232, 150, 65, 100, 53, 33, 2, 22, 232, 178, 27, 144, 1, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down Expand Up @@ -311,15 +311,15 @@ fn complex_brillig_foreign_call() {
let bytes = Program::serialize_program(&program);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 65, 14, 132, 32, 12, 108, 101, 117, 205, 158,
246, 9, 38, 187, 15, 96, 247, 5, 254, 197, 120, 211, 232, 209, 231, 139, 113, 136, 181, 65,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 65, 14, 132, 32, 12, 108, 101, 117, 205, 222,
246, 7, 38, 187, 15, 96, 247, 5, 254, 197, 120, 211, 232, 209, 231, 139, 113, 136, 181, 65,
47, 98, 162, 147, 52, 20, 24, 202, 164, 45, 48, 205, 200, 157, 49, 124, 227, 44, 129, 207,
152, 75, 120, 94, 137, 209, 30, 195, 143, 227, 197, 178, 103, 105, 76, 110, 160, 209, 156,
160, 209, 247, 195, 69, 235, 29, 179, 46, 81, 243, 103, 2, 239, 231, 225, 44, 117, 150, 97,
254, 196, 152, 99, 157, 176, 87, 168, 188, 147, 224, 121, 20, 209, 180, 254, 109, 70, 75,
47, 178, 186, 251, 37, 116, 86, 93, 219, 55, 245, 96, 20, 85, 75, 253, 8, 255, 171, 246,
121, 231, 220, 4, 249, 237, 132, 56, 28, 224, 109, 113, 223, 180, 164, 50, 165, 0, 137, 17,
72, 139, 88, 97, 4, 198, 90, 226, 196, 33, 5, 0, 0,
72, 139, 88, 97, 4, 173, 98, 132, 157, 33, 5, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `complex_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 65, 14, 132, 32, 12, 108, 101, 117, 205, 158,
246, 9, 38, 187, 15, 96, 247, 5, 254, 197, 120, 211, 232, 209, 231, 139, 113, 136, 181, 65,
47, 98, 162, 147, 52, 20, 24, 202, 164, 45, 48, 205, 200, 157, 49, 124, 227, 44, 129, 207,
152, 75, 120, 94, 137, 209, 30, 195, 143, 227, 197, 178, 103, 105, 76, 110, 160, 209, 156,
160, 209, 247, 195, 69, 235, 29, 179, 46, 81, 243, 103, 2, 239, 231, 225, 44, 117, 150, 97,
254, 196, 152, 99, 157, 176, 87, 168, 188, 147, 224, 121, 20, 209, 180, 254, 109, 70, 75,
47, 178, 186, 251, 37, 116, 86, 93, 219, 55, 245, 96, 20, 85, 75, 253, 8, 255, 171, 246,
121, 231, 220, 4, 249, 237, 132, 56, 28, 224, 109, 113, 223, 180, 164, 50, 165, 0, 137, 17,
72, 139, 88, 97, 4, 198, 90, 226, 196, 33, 5, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 65, 14, 132, 32, 12, 108, 101, 117, 205, 222, 246, 7, 38, 187, 15, 96,
247, 5, 254, 197, 120, 211, 232, 209, 231, 139, 113, 136, 181, 65, 47, 98, 162, 147, 52, 20, 24, 202, 164, 45, 48,
205, 200, 157, 49, 124, 227, 44, 129, 207, 152, 75, 120, 94, 137, 209, 30, 195, 143, 227, 197, 178, 103, 105, 76, 110,
160, 209, 156, 160, 209, 247, 195, 69, 235, 29, 179, 46, 81, 243, 103, 2, 239, 231, 225, 44, 117, 150, 97, 254, 196,
152, 99, 157, 176, 87, 168, 188, 147, 224, 121, 20, 209, 180, 254, 109, 70, 75, 47, 178, 186, 251, 37, 116, 86, 93,
219, 55, 245, 96, 20, 85, 75, 253, 8, 255, 171, 246, 121, 231, 220, 4, 249, 237, 132, 56, 28, 224, 109, 113, 223, 180,
164, 50, 165, 0, 137, 17, 72, 139, 88, 97, 4, 173, 98, 132, 157, 33, 5, 0, 0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000001'],
Expand Down
9 changes: 4 additions & 5 deletions noir/noir-repo/acvm-repo/acvm_js/test/shared/foreign_call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `simple_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 77, 10, 165, 244, 212,
167, 216, 31, 244, 51, 61, 120, 241, 32, 226, 251, 85, 140, 176, 136, 122, 209, 129, 144,
176, 9, 97, 151, 84, 225, 74, 69, 50, 31, 48, 35, 85, 251, 164, 235, 53, 94, 218, 247, 75,
163, 95, 150, 12, 153, 179, 227, 191, 114, 195, 222, 216, 240, 59, 63, 75, 221, 251, 208,
106, 207, 232, 150, 65, 100, 53, 33, 2, 9, 69, 91, 82, 144, 1, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 77, 10, 165, 244, 214, 159, 216, 31, 244, 51, 61,
120, 241, 32, 226, 251, 85, 140, 176, 136, 122, 209, 129, 144, 176, 9, 97, 151, 84, 225, 74, 69, 50, 31, 48, 35, 85,
251, 164, 235, 53, 94, 218, 247, 75, 163, 95, 150, 12, 153, 179, 227, 191, 114, 195, 222, 216, 240, 59, 63, 75, 221,
251, 208, 106, 207, 232, 150, 65, 100, 53, 33, 2, 22, 232, 178, 27, 144, 1, 0, 0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000005'],
Expand Down
7 changes: 7 additions & 0 deletions noir/noir-repo/acvm-repo/brillig/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ pub enum BrilligOpcode {
destination: MemoryAddress,
source: MemoryAddress,
},
/// destination = condition > 0 ? source_a : source_b
ConditionalMov {
destination: MemoryAddress,
source_a: MemoryAddress,
source_b: MemoryAddress,
condition: MemoryAddress,
},
Load {
destination: MemoryAddress,
source_pointer: MemoryAddress,
Expand Down
55 changes: 55 additions & 0 deletions noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,15 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
self.memory.write(*destination_address, source_value);
self.increment_program_counter()
}
Opcode::ConditionalMov { destination, source_a, source_b, condition } => {
let condition_value = self.memory.read(*condition);
if condition_value.is_zero() {
self.memory.write(*destination, self.memory.read(*source_b));
} else {
self.memory.write(*destination, self.memory.read(*source_a));
}
self.increment_program_counter()
}
Opcode::Trap => self.fail("explicit trap hit in brillig".to_string()),
Opcode::Stop { return_data_offset, return_data_size } => {
self.finish(*return_data_offset, *return_data_size)
Expand Down Expand Up @@ -793,6 +802,52 @@ mod tests {
assert_eq!(source_value, Value::from(1u128));
}

#[test]
fn cmov_opcode() {
let calldata =
vec![Value::from(0u128), Value::from(1u128), Value::from(2u128), Value::from(3u128)];

let calldata_copy = Opcode::CalldataCopy {
destination_address: MemoryAddress::from(0),
size: 4,
offset: 0,
};

let opcodes = &[
calldata_copy,
Opcode::ConditionalMov {
destination: MemoryAddress(4), // Sets 3_u128 to memory address 4
source_a: MemoryAddress(2),
source_b: MemoryAddress(3),
condition: MemoryAddress(0),
},
Opcode::ConditionalMov {
destination: MemoryAddress(5), // Sets 2_u128 to memory address 5
source_a: MemoryAddress(2),
source_b: MemoryAddress(3),
condition: MemoryAddress(1),
},
];
let mut vm = VM::new(calldata, opcodes, vec![], &DummyBlackBoxSolver);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::Finished { return_data_offset: 0, return_data_size: 0 });

let VM { memory, .. } = vm;

let destination_value = memory.read(MemoryAddress::from(4));
assert_eq!(destination_value, Value::from(3_u128));

let source_value = memory.read(MemoryAddress::from(5));
assert_eq!(source_value, Value::from(2_u128));
}

#[test]
fn cmp_binary_ops() {
let bit_size = 32;
Expand Down
Loading
Loading