Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: replace arithmetic equalities with assert equal #8386

Merged
merged 8 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ void build_constraints(Builder& builder,
gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.bigint_to_le_bytes_constraints.at(i));
}
// assert equals
for (size_t i = 0; i < constraint_system.assert_equalities.size(); ++i) {
const auto& constraint = constraint_system.assert_equalities.at(i);

builder.assert_equal(constraint.a, constraint.b);
gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.assert_equalities.at(i));
}

// RecursionConstraints
// TODO(https://github.com/AztecProtocol/barretenberg/issues/817): disable these for MegaHonk for now since we're
Expand All @@ -227,10 +235,11 @@ void build_constraints(Builder& builder,
process_plonk_recursion_constraints(builder, constraint_system, has_valid_witness_assignments, gate_counter);
process_honk_recursion_constraints(builder, constraint_system, has_valid_witness_assignments, gate_counter);

// If the circuit does not itself contain honk recursion constraints but is going to be proven with honk then
// recursively verified, add a default aggregation object
// If the circuit does not itself contain honk recursion constraints but is going to be
// proven with honk then recursively verified, add a default aggregation object
if (constraint_system.honk_recursion_constraints.empty() && honk_recursion &&
builder.is_recursive_circuit) { // Set a default aggregation object if we don't have one.
builder.is_recursive_circuit) { // Set a default aggregation object if we don't have
// one.
AggregationObjectIndices current_aggregation_object =
stdlib::recursion::init_default_agg_obj_indices<Builder>(builder);
// Make sure the verification key records the public input indices of the
Expand Down Expand Up @@ -265,31 +274,34 @@ void process_plonk_recursion_constraints(Builder& builder,
for (size_t constraint_idx = 0; constraint_idx < constraint_system.recursion_constraints.size(); ++constraint_idx) {
auto constraint = constraint_system.recursion_constraints[constraint_idx];

// A proof passed into the constraint should be stripped of its public inputs, except in the case where a
// proof contains an aggregation object itself. We refer to this as the `nested_aggregation_object`. The
// verifier circuit requires that the indices to a nested proof aggregation state are a circuit constant.
// The user tells us they how they want these constants set by keeping the nested aggregation object
// attached to the proof as public inputs. As this is the only object that can prepended to the proof if the
// proof is above the expected size (with public inputs stripped)
// A proof passed into the constraint should be stripped of its public inputs, except in
// the case where a proof contains an aggregation object itself. We refer to this as the
// `nested_aggregation_object`. The verifier circuit requires that the indices to a
// nested proof aggregation state are a circuit constant. The user tells us they how
// they want these constants set by keeping the nested aggregation object attached to
// the proof as public inputs. As this is the only object that can prepended to the
// proof if the proof is above the expected size (with public inputs stripped)
AggregationObjectPubInputIndices nested_aggregation_object = {};
// If the proof has public inputs attached to it, we should handle setting the nested aggregation object
// If the proof has public inputs attached to it, we should handle setting the nested
// aggregation object
if (constraint.proof.size() > proof_size_no_pub_inputs) {
// The public inputs attached to a proof should match the aggregation object in size
if (constraint.proof.size() - proof_size_no_pub_inputs != bb::AGGREGATION_OBJECT_SIZE) {
auto error_string = format(
"Public inputs are always stripped from proofs unless we have a recursive proof.\n"
"Thus, public inputs attached to a proof must match the recursive aggregation object in size "
"which is ",
bb::AGGREGATION_OBJECT_SIZE);
auto error_string = format("Public inputs are always stripped from proofs "
"unless we have a recursive proof.\n"
"Thus, public inputs attached to a proof must match "
"the recursive aggregation object in size "
"which is ",
bb::AGGREGATION_OBJECT_SIZE);
throw_or_abort(error_string);
}
for (size_t i = 0; i < bb::AGGREGATION_OBJECT_SIZE; ++i) {
// Set the nested aggregation object indices to the current size of the public inputs
// This way we know that the nested aggregation object indices will always be the last
// indices of the public inputs
// Set the nested aggregation object indices to the current size of the public
// inputs This way we know that the nested aggregation object indices will
// always be the last indices of the public inputs
nested_aggregation_object[i] = static_cast<uint32_t>(constraint.public_inputs.size());
// Attach the nested aggregation object to the end of the public inputs to fill in
// the slot where the nested aggregation object index will point into
// Attach the nested aggregation object to the end of the public inputs to fill
// in the slot where the nested aggregation object index will point into
constraint.public_inputs.emplace_back(constraint.proof[i]);
}
// Remove the aggregation object so that they can be handled as normal public inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct AcirFormatOriginalOpcodeIndices {
std::vector<size_t> bigint_from_le_bytes_constraints;
std::vector<size_t> bigint_to_le_bytes_constraints;
std::vector<size_t> bigint_operations;
std::vector<size_t> assert_equalities;
std::vector<size_t> poly_triple_constraints;
std::vector<size_t> quad_constraints;
// Multiple opcode indices per block:
Expand Down Expand Up @@ -98,6 +99,7 @@ struct AcirFormat {
std::vector<BigIntFromLeBytes> bigint_from_le_bytes_constraints;
std::vector<BigIntToLeBytes> bigint_to_le_bytes_constraints;
std::vector<BigIntOperation> bigint_operations;
std::vector<bb::poly_triple_<bb::curve::BN254::ScalarField>> assert_equalities;

// A standard plonk arithmetic constraint, as defined in the poly_triple struct, consists of selector values
// for q_M,q_L,q_R,q_O,q_C and indices of three variables taking the role of left, right and output wire
Expand All @@ -110,6 +112,9 @@ struct AcirFormat {
// Has length equal to num_acir_opcodes.
std::vector<size_t> gates_per_opcode = {};

// Set of constrained witnesses
std::set<uint32_t> constrained_witness = {};

// Indices of the original opcode that originated each constraint in AcirFormat.
AcirFormatOriginalOpcodeIndices original_opcode_indices;

Expand Down Expand Up @@ -139,7 +144,8 @@ struct AcirFormat {
block_constraints,
bigint_from_le_bytes_constraints,
bigint_to_le_bytes_constraints,
bigint_operations);
bigint_operations,
assert_equalities);

friend bool operator==(AcirFormat const& lhs, AcirFormat const& rhs) = default;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ TEST_F(AcirFormatTests, TestASingleConstraintNoPubInputs)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { constraint },
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -185,6 +186,7 @@ TEST_F(AcirFormatTests, TestLogicGateFromNoirCircuit)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { expr_a, expr_b, expr_c, expr_d },
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -264,6 +266,7 @@ TEST_F(AcirFormatTests, TestSchnorrVerifyPass)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { poly_triple{
.a = schnorr_constraint.result,
.b = schnorr_constraint.result,
Expand Down Expand Up @@ -370,6 +373,7 @@ TEST_F(AcirFormatTests, TestSchnorrVerifySmallRange)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { poly_triple{
.a = schnorr_constraint.result,
.b = schnorr_constraint.result,
Expand Down Expand Up @@ -489,6 +493,7 @@ TEST_F(AcirFormatTests, TestVarKeccak)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { dummy },
.quad_constraints = {},
.block_constraints = {},
Expand All @@ -510,7 +515,7 @@ TEST_F(AcirFormatTests, TestKeccakPermutation)
{
Keccakf1600
keccak_permutation{
.state = {
.state = {
WitnessOrConstant<bb::fr>::from_index(1),
WitnessOrConstant<bb::fr>::from_index(2),
WitnessOrConstant<bb::fr>::from_index(3),
Expand Down Expand Up @@ -568,6 +573,7 @@ TEST_F(AcirFormatTests, TestKeccakPermutation)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -644,6 +650,7 @@ TEST_F(AcirFormatTests, TestCollectsGateCounts)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { first_gate, second_gate },
.quad_constraints = {},
.block_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ acir_format::AcirFormatOriginalOpcodeIndices create_empty_original_opcode_indice
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -100,6 +101,9 @@ void mock_opcode_indices(acir_format::AcirFormat& constraint_system)
for (size_t i = 0; i < constraint_system.bigint_operations.size(); i++) {
constraint_system.original_opcode_indices.bigint_operations.push_back(current_opcode++);
}
for (size_t i = 0; i < constraint_system.assert_equalities.size(); i++) {
constraint_system.original_opcode_indices.assert_equalities.push_back(current_opcode++);
}
for (size_t i = 0; i < constraint_system.poly_triple_constraints.size(); i++) {
constraint_system.original_opcode_indices.poly_triple_constraints.push_back(current_opcode++);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "acir_to_constraint_buf.hpp"
#include "barretenberg/common/container.hpp"
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/plonk_honk_shared/arithmetization/gate_data.hpp"
#include <cstddef>
#include <cstdint>
#include <tuple>
#include <utility>
#ifndef __wasm__
Expand Down Expand Up @@ -167,10 +170,50 @@ mul_quad_<fr> serialize_mul_quad_gate(Program::Expression const& arg)
return quad;
}

void constrain_witnesses(Program::Opcode::AssertZero const& arg, AcirFormat& af)
{
for (const auto& linear_term : arg.value.linear_combinations) {
uint32_t witness_idx = std::get<1>(linear_term).value;
af.constrained_witness.insert(witness_idx);
}
for (const auto& linear_term : arg.value.mul_terms) {
uint32_t witness_idx = std::get<1>(linear_term).value;
af.constrained_witness.insert(witness_idx);
witness_idx = std::get<2>(linear_term).value;
af.constrained_witness.insert(witness_idx);
}
}

std::pair<uint32_t, uint32_t> is_assert_equal(Program::Opcode::AssertZero const& arg,
poly_triple const& pt,
AcirFormat const& af)
{
if (!arg.value.mul_terms.empty() || arg.value.linear_combinations.size() != 2) {
return { 0, 0 };
}
if (pt.q_l == -pt.q_r && pt.q_l != bb::fr::zero() && pt.q_c == bb::fr::zero()) {
if (af.constrained_witness.contains(pt.a) && af.constrained_witness.contains(pt.b)) {
return { pt.a, pt.b };
}
}
return { 0, 0 };
}

void handle_arithmetic(Program::Opcode::AssertZero const& arg, AcirFormat& af, size_t opcode_index)
{
if (arg.value.linear_combinations.size() <= 3) {
poly_triple pt = serialize_arithmetic_gate(arg.value);

auto assert_equal = is_assert_equal(arg, pt, af);
uint32_t w1 = std::get<0>(assert_equal);
uint32_t w2 = std::get<1>(assert_equal);
if (w1 != 0) {
if (w1 != w2) {
af.assert_equalities.push_back(pt);
af.original_opcode_indices.assert_equalities.push_back(opcode_index);
}
return;
}
// Even if the number of linear terms is less than 3, we might not be able to fit it into a width-3 arithmetic
// gate. This is the case if the linear terms are all disctinct witness from the multiplication term. In that
// case, the serialize_arithmetic_gate() function will return a poly_triple with all 0's, and we use a width-4
Expand All @@ -187,6 +230,7 @@ void handle_arithmetic(Program::Opcode::AssertZero const& arg, AcirFormat& af, s
af.quad_constraints.push_back(serialize_mul_quad_gate(arg.value));
af.original_opcode_indices.quad_constraints.push_back(opcode_index);
}
constrain_witnesses(arg, af);
}

uint32_t get_witness_from_function_input(Program::FunctionInput input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ TEST_F(BigIntTests, TestBigIntConstraintMultiple)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -270,6 +271,7 @@ TEST_F(BigIntTests, TestBigIntConstraintSimple)
.bigint_from_le_bytes_constraints = { from_le_bytes_constraint_bigint1 },
.bigint_to_le_bytes_constraints = { result2_to_le_bytes },
.bigint_operations = { add_constraint },
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -327,6 +329,7 @@ TEST_F(BigIntTests, TestBigIntConstraintReuse)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -389,6 +392,7 @@ TEST_F(BigIntTests, TestBigIntConstraintReuse2)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -472,6 +476,7 @@ TEST_F(BigIntTests, TestBigIntDIV)
.bigint_from_le_bytes_constraints = { from_le_bytes_constraint_bigint1, from_le_bytes_constraint_bigint2 },
.bigint_to_le_bytes_constraints = { result3_to_le_bytes },
.bigint_operations = { div_constraint },
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ TEST_F(UltraPlonkRAM, TestBlockConstraint)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = { block },
Expand Down Expand Up @@ -216,6 +217,7 @@ TEST_F(MegaHonk, Databus)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = { block },
Expand Down Expand Up @@ -322,6 +324,7 @@ TEST_F(MegaHonk, DatabusReturn)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { assert_equal },
.quad_constraints = {},
.block_constraints = { block },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ TEST_F(EcOperations, TestECOperations)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -223,6 +224,7 @@ TEST_F(EcOperations, TestECMultiScalarMul)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { assert_equal },
.quad_constraints = {},
.block_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ TEST_F(ECDSASecp256k1, TestECDSAConstraintSucceed)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -173,6 +174,7 @@ TEST_F(ECDSASecp256k1, TestECDSACompilesForVerifier)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -222,6 +224,7 @@ TEST_F(ECDSASecp256k1, TestECDSAConstraintFail)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down
Loading
Loading