Skip to content

Commit

Permalink
builds w UCB inheritance w new arith
Browse files Browse the repository at this point in the history
  • Loading branch information
ledwards2225 committed Nov 1, 2023
1 parent 0370b13 commit 4362af5
Show file tree
Hide file tree
Showing 15 changed files with 368 additions and 113 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include <cstddef>
#include <cstdint>
#include <gtest/gtest.h>

#include "barretenberg/common/log.hpp"
#include "barretenberg/honk/composer/ultra_composer.hpp"
#include "barretenberg/honk/proof_system/ultra_prover.hpp"
#include "barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.hpp"
#include "barretenberg/proof_system/circuit_builder/ultra_circuit_builder.hpp"

using namespace proof_system::honk;

namespace test_ultra_honk_composer {

namespace {
auto& engine = numeric::random::get_debug_engine();
}

class DataBusComposerTests : public ::testing::Test {
protected:
static void SetUpTestSuite() { barretenberg::srs::init_crs_factory("../srs_db/ignition"); }

using Curve = curve::BN254;
using FF = Curve::ScalarField;
using Point = Curve::AffineElement;
using CommitmentKey = pcs::CommitmentKey<Curve>;

/**
* @brief Generate a simple test circuit with some ECC op gates and conventional arithmetic gates
*
* @param builder
*/
void generate_test_circuit(auto& builder)
{
// Add some ecc op gates
for (size_t i = 0; i < 3; ++i) {
auto point = Point::one() * FF::random_element();
auto scalar = FF::random_element();
builder.queue_ecc_mul_accum(point, scalar);
}
builder.queue_ecc_eq();

// Add some conventional gates that utilize public inputs
for (size_t i = 0; i < 10; ++i) {
FF a = FF::random_element();
FF b = FF::random_element();
FF c = FF::random_element();
FF d = a + b + c;
uint32_t a_idx = builder.add_public_variable(a);
uint32_t b_idx = builder.add_variable(b);
uint32_t c_idx = builder.add_variable(c);
uint32_t d_idx = builder.add_variable(d);

builder.create_big_add_gate({ a_idx, b_idx, c_idx, d_idx, FF(1), FF(1), FF(1), FF(-1), FF(0) });
}
}

/**
* @brief Construct and a verify a Honk proof
*
*/
bool construct_and_verify_honk_proof(auto& composer, auto& builder)
{
auto instance = composer.create_instance(builder);
auto prover = composer.create_prover(instance);
auto verifier = composer.create_verifier(instance);
auto proof = prover.construct_proof();
bool verified = verifier.verify_proof(proof);

return verified;
}
};

/**
* @brief Test proof construction/verification for a circuit with ECC op gates, public inputs, and basic arithmetic
* gates
* @note We simulate op queue interactions with a previous circuit so the actual circuit under test utilizes an op queue
* with non-empty 'previous' data. This avoid complications with zero-commitments etc.
*
*/
TEST_F(DataBusComposerTests, SingleCircuit)
{
auto op_queue = std::make_shared<proof_system::ECCOpQueue>();

// Add mock data to op queue to simulate interaction with a previous circuit
op_queue->populate_with_mock_initital_data();

auto builder = proof_system::GoblinUltraCircuitBuilder(op_queue);

generate_test_circuit(builder);

auto composer = GoblinUltraComposer();

// Construct and verify Honk proof
auto honk_verified = construct_and_verify_honk_proof(composer, builder);
EXPECT_TRUE(honk_verified);
}

} // namespace test_ultra_honk_composer
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ TEST_F(UltraHonkComposerTests, non_native_field_multiplication)
const auto q_indices = get_limb_witness_indices(split_into_limbs(uint256_t(q)));
const auto r_indices = get_limb_witness_indices(split_into_limbs(uint256_t(r)));

proof_system::UltraCircuitBuilder::non_native_field_witnesses inputs{
proof_system::non_native_field_witnesses<fr> inputs{
a_indices, b_indices, q_indices, r_indices, modulus_limbs, fr(uint256_t(modulus)),
};
const auto [lo_1_idx, hi_1_idx] = circuit_builder.evaluate_non_native_field_multiplication(inputs);
Expand Down
183 changes: 108 additions & 75 deletions barretenberg/cpp/src/barretenberg/honk/flavor/goblin_ultra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ class GoblinUltra {
// The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often
// need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`.
// Note: this number does not include the individual sorted list polynomials.
static constexpr size_t NUM_ALL_ENTITIES = 48; // 43 (UH) + 4 op wires + 1 op wire "selector"
// NUM = 43 (UH) + 4 op wires + 1 op wire "selector" + 3 (calldata + calldata_read_counts + q_busread)
static constexpr size_t NUM_ALL_ENTITIES = 51;
// The number of polynomials precomputed to describe a circuit and to aid a prover in constructing a satisfying
// assignment of witnesses. We again choose a neutral name.
static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 26; // 25 (UH) + 1 op wire "selector"
static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 27; // 25 (UH) + 1 op wire "selector" + q_busread
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 15; // 11 (UH) + 4 op wires
static constexpr size_t NUM_WITNESS_ENTITIES = 17; // 11 (UH) + 4 op wires + (calldata + calldata_read_counts)

using GrandProductRelations =
std::tuple<proof_system::UltraPermutationRelation<FF>, proof_system::LookupRelation<FF>>;
Expand All @@ -50,6 +51,7 @@ class GoblinUltra {
proof_system::EllipticRelation<FF>,
proof_system::AuxiliaryRelation<FF>,
proof_system::EccOpQueueRelation<FF>>;
// WORKTODO: add bus lookup relation!

static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = compute_max_partial_relation_length<Relations>();
static constexpr size_t MAX_TOTAL_RELATION_LENGTH = compute_max_total_relation_length<Relations>();
Expand Down Expand Up @@ -89,27 +91,28 @@ class GoblinUltra {
DataType& q_elliptic = std::get<8>(this->_data);
DataType& q_aux = std::get<9>(this->_data);
DataType& q_lookup = std::get<10>(this->_data);
DataType& sigma_1 = std::get<11>(this->_data);
DataType& sigma_2 = std::get<12>(this->_data);
DataType& sigma_3 = std::get<13>(this->_data);
DataType& sigma_4 = std::get<14>(this->_data);
DataType& id_1 = std::get<15>(this->_data);
DataType& id_2 = std::get<16>(this->_data);
DataType& id_3 = std::get<17>(this->_data);
DataType& id_4 = std::get<18>(this->_data);
DataType& table_1 = std::get<19>(this->_data);
DataType& table_2 = std::get<20>(this->_data);
DataType& table_3 = std::get<21>(this->_data);
DataType& table_4 = std::get<22>(this->_data);
DataType& lagrange_first = std::get<23>(this->_data);
DataType& lagrange_last = std::get<24>(this->_data);
DataType& lagrange_ecc_op = std::get<25>(this->_data); // indicator poly for ecc op gates
DataType& q_busread = std::get<11>(this->_data);
DataType& sigma_1 = std::get<12>(this->_data);
DataType& sigma_2 = std::get<13>(this->_data);
DataType& sigma_3 = std::get<14>(this->_data);
DataType& sigma_4 = std::get<15>(this->_data);
DataType& id_1 = std::get<16>(this->_data);
DataType& id_2 = std::get<17>(this->_data);
DataType& id_3 = std::get<18>(this->_data);
DataType& id_4 = std::get<19>(this->_data);
DataType& table_1 = std::get<20>(this->_data);
DataType& table_2 = std::get<21>(this->_data);
DataType& table_3 = std::get<22>(this->_data);
DataType& table_4 = std::get<23>(this->_data);
DataType& lagrange_first = std::get<24>(this->_data);
DataType& lagrange_last = std::get<25>(this->_data);
DataType& lagrange_ecc_op = std::get<26>(this->_data); // indicator poly for ecc op gates

static constexpr CircuitType CIRCUIT_TYPE = CircuitBuilder::CIRCUIT_TYPE;

std::vector<HandleType> get_selectors() override
{
return { q_m, q_c, q_l, q_r, q_o, q_4, q_arith, q_sort, q_elliptic, q_aux, q_lookup };
return { q_m, q_c, q_l, q_r, q_o, q_4, q_arith, q_sort, q_elliptic, q_aux, q_lookup, q_busread };
};
std::vector<HandleType> get_sigma_polynomials() override { return { sigma_1, sigma_2, sigma_3, sigma_4 }; };
std::vector<HandleType> get_id_polynomials() override { return { id_1, id_2, id_3, id_4 }; };
Expand Down Expand Up @@ -139,6 +142,8 @@ class GoblinUltra {
DataType& ecc_op_wire_2 = std::get<12>(this->_data);
DataType& ecc_op_wire_3 = std::get<13>(this->_data);
DataType& ecc_op_wire_4 = std::get<14>(this->_data);
DataType& calldata = std::get<15>(this->_data);
DataType& calldata_read_counts = std::get<16>(this->_data);

std::vector<HandleType> get_wires() override { return { w_l, w_r, w_o, w_4 }; };
std::vector<HandleType> get_ecc_op_wires()
Expand Down Expand Up @@ -172,43 +177,46 @@ class GoblinUltra {
DataType& q_elliptic = std::get<8>(this->_data);
DataType& q_aux = std::get<9>(this->_data);
DataType& q_lookup = std::get<10>(this->_data);
DataType& sigma_1 = std::get<11>(this->_data);
DataType& sigma_2 = std::get<12>(this->_data);
DataType& sigma_3 = std::get<13>(this->_data);
DataType& sigma_4 = std::get<14>(this->_data);
DataType& id_1 = std::get<15>(this->_data);
DataType& id_2 = std::get<16>(this->_data);
DataType& id_3 = std::get<17>(this->_data);
DataType& id_4 = std::get<18>(this->_data);
DataType& table_1 = std::get<19>(this->_data);
DataType& table_2 = std::get<20>(this->_data);
DataType& table_3 = std::get<21>(this->_data);
DataType& table_4 = std::get<22>(this->_data);
DataType& lagrange_first = std::get<23>(this->_data);
DataType& lagrange_last = std::get<24>(this->_data);
DataType& lagrange_ecc_op = std::get<25>(this->_data);
DataType& w_l = std::get<26>(this->_data);
DataType& w_r = std::get<27>(this->_data);
DataType& w_o = std::get<28>(this->_data);
DataType& w_4 = std::get<29>(this->_data);
DataType& sorted_accum = std::get<30>(this->_data);
DataType& z_perm = std::get<31>(this->_data);
DataType& z_lookup = std::get<32>(this->_data);
DataType& ecc_op_wire_1 = std::get<33>(this->_data);
DataType& ecc_op_wire_2 = std::get<34>(this->_data);
DataType& ecc_op_wire_3 = std::get<35>(this->_data);
DataType& ecc_op_wire_4 = std::get<36>(this->_data);
DataType& table_1_shift = std::get<37>(this->_data);
DataType& table_2_shift = std::get<38>(this->_data);
DataType& table_3_shift = std::get<39>(this->_data);
DataType& table_4_shift = std::get<40>(this->_data);
DataType& w_l_shift = std::get<41>(this->_data);
DataType& w_r_shift = std::get<42>(this->_data);
DataType& w_o_shift = std::get<43>(this->_data);
DataType& w_4_shift = std::get<44>(this->_data);
DataType& sorted_accum_shift = std::get<45>(this->_data);
DataType& z_perm_shift = std::get<46>(this->_data);
DataType& z_lookup_shift = std::get<47>(this->_data);
DataType& q_busread = std::get<11>(this->_data);
DataType& sigma_1 = std::get<12>(this->_data);
DataType& sigma_2 = std::get<13>(this->_data);
DataType& sigma_3 = std::get<14>(this->_data);
DataType& sigma_4 = std::get<15>(this->_data);
DataType& id_1 = std::get<16>(this->_data);
DataType& id_2 = std::get<17>(this->_data);
DataType& id_3 = std::get<18>(this->_data);
DataType& id_4 = std::get<19>(this->_data);
DataType& table_1 = std::get<20>(this->_data);
DataType& table_2 = std::get<21>(this->_data);
DataType& table_3 = std::get<22>(this->_data);
DataType& table_4 = std::get<23>(this->_data);
DataType& lagrange_first = std::get<24>(this->_data);
DataType& lagrange_last = std::get<25>(this->_data);
DataType& lagrange_ecc_op = std::get<26>(this->_data);
DataType& w_l = std::get<27>(this->_data);
DataType& w_r = std::get<28>(this->_data);
DataType& w_o = std::get<29>(this->_data);
DataType& w_4 = std::get<30>(this->_data);
DataType& sorted_accum = std::get<31>(this->_data);
DataType& z_perm = std::get<32>(this->_data);
DataType& z_lookup = std::get<33>(this->_data);
DataType& ecc_op_wire_1 = std::get<34>(this->_data);
DataType& ecc_op_wire_2 = std::get<35>(this->_data);
DataType& ecc_op_wire_3 = std::get<36>(this->_data);
DataType& ecc_op_wire_4 = std::get<37>(this->_data);
DataType& calldata = std::get<38>(this->_data);
DataType& calldata_read_counts = std::get<39>(this->_data);
DataType& table_1_shift = std::get<40>(this->_data);
DataType& table_2_shift = std::get<41>(this->_data);
DataType& table_3_shift = std::get<42>(this->_data);
DataType& table_4_shift = std::get<43>(this->_data);
DataType& w_l_shift = std::get<44>(this->_data);
DataType& w_r_shift = std::get<45>(this->_data);
DataType& w_o_shift = std::get<46>(this->_data);
DataType& w_4_shift = std::get<47>(this->_data);
DataType& sorted_accum_shift = std::get<48>(this->_data);
DataType& z_perm_shift = std::get<49>(this->_data);
DataType& z_lookup_shift = std::get<50>(this->_data);

std::vector<HandleType> get_wires() override { return { w_l, w_r, w_o, w_4 }; };
std::vector<HandleType> get_ecc_op_wires()
Expand All @@ -218,25 +226,46 @@ class GoblinUltra {
// Gemini-specific getters.
std::vector<HandleType> get_unshifted() override
{
return { q_c, q_l,
q_r, q_o,
q_4, q_m,
q_arith, q_sort,
q_elliptic, q_aux,
q_lookup, sigma_1,
sigma_2, sigma_3,
sigma_4, id_1,
id_2, id_3,
id_4, table_1,
table_2, table_3,
table_4, lagrange_first,
lagrange_last, lagrange_ecc_op,
w_l, w_r,
w_o, w_4,
sorted_accum, z_perm,
z_lookup, ecc_op_wire_1,
ecc_op_wire_2, ecc_op_wire_3,
ecc_op_wire_4 };
return { q_c,
q_l,
q_r,
q_o,
q_4,
q_m,
q_arith,
q_sort,
q_elliptic,
q_aux,
q_lookup,
q_busread,
sigma_1,
sigma_2,
sigma_3,
sigma_4,
id_1,
id_2,
id_3,
id_4,
table_1,
table_2,
table_3,
table_4,
lagrange_first,
lagrange_last,
lagrange_ecc_op,
w_l,
w_r,
w_o,
w_4,
sorted_accum,
z_perm,
z_lookup,
ecc_op_wire_1,
ecc_op_wire_2,
ecc_op_wire_3,
ecc_op_wire_4,
calldata,
calldata_read_counts };
};
std::vector<HandleType> get_to_be_shifted() override
{
Expand Down Expand Up @@ -384,6 +413,8 @@ class GoblinUltra {
ecc_op_wire_2 = "ECC_OP_WIRE_2";
ecc_op_wire_3 = "ECC_OP_WIRE_3";
ecc_op_wire_4 = "ECC_OP_WIRE_4";
calldata = "CALLDATA";
calldata_read_counts = "CALLDATA_READ_COUNTS";

// The ones beginning with "__" are only used for debugging
q_c = "__Q_C";
Expand All @@ -397,6 +428,7 @@ class GoblinUltra {
q_elliptic = "__Q_ELLIPTIC";
q_aux = "__Q_AUX";
q_lookup = "__Q_LOOKUP";
q_busread = "__Q_BUSREAD";
sigma_1 = "__SIGMA_1";
sigma_2 = "__SIGMA_2";
sigma_3 = "__SIGMA_3";
Expand Down Expand Up @@ -432,6 +464,7 @@ class GoblinUltra {
q_elliptic = verification_key->q_elliptic;
q_aux = verification_key->q_aux;
q_lookup = verification_key->q_lookup;
q_busread = verification_key->q_busread;
sigma_1 = verification_key->sigma_1;
sigma_2 = verification_key->sigma_2;
sigma_3 = verification_key->sigma_3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ template <class Flavor> void ProverInstance_<Flavor>::initialise_prover_polynomi
prover_polynomials.ecc_op_wire_3 = proving_key->ecc_op_wire_3;
prover_polynomials.ecc_op_wire_4 = proving_key->ecc_op_wire_4;
prover_polynomials.lagrange_ecc_op = proving_key->lagrange_ecc_op;
// DataBus polynomials
prover_polynomials.calldata = proving_key->calldata;
prover_polynomials.calldata_read_counts = proving_key->calldata_read_counts;
prover_polynomials.q_busread = proving_key->q_busread;
}

std::span<FF> public_wires_source = prover_polynomials.w_r;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ TYPED_TEST(ultra_plonk_composer, non_native_field_multiplication)
const auto q_indices = get_limb_witness_indices(split_into_limbs(uint256_t(q)));
const auto r_indices = get_limb_witness_indices(split_into_limbs(uint256_t(r)));

proof_system::UltraCircuitBuilder::non_native_field_witnesses inputs{
proof_system::non_native_field_witnesses<fr> inputs{
a_indices, b_indices, q_indices, r_indices, modulus_limbs, fr(uint256_t(modulus)),
};
const auto [lo_1_idx, hi_1_idx] = builder.evaluate_non_native_field_multiplication(inputs);
Expand Down
Loading

0 comments on commit 4362af5

Please sign in to comment.