Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: move witness computation into class plus some other cleanup #11140

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "barretenberg/benchmark/mega_memory_bench/memory_estimator.hpp"
#include "barretenberg/stdlib/primitives/field/field.hpp"
#include "barretenberg/stdlib/primitives/plookup/plookup.hpp"
#include "barretenberg/stdlib_circuit_builders/plookup_tables/fixed_base/fixed_base.hpp"
Expand Down Expand Up @@ -312,10 +313,10 @@ void fill_trace(State& state, TraceSettings settings)
}

builder.finalize_circuit(/* ensure_nonzero */ true);
uint64_t builder_estimate = builder.estimate_memory();
uint64_t builder_estimate = MegaMemoryEstimator::estimate_builder_memory(builder);
for (auto _ : state) {
DeciderProvingKey proving_key(builder, settings);
uint64_t memory_estimate = proving_key.proving_key.estimate_memory();
uint64_t memory_estimate = MegaMemoryEstimator::estimate_proving_key_memory(proving_key.proving_key);
state.counters["poly_mem_est"] = static_cast<double>(memory_estimate);
state.counters["builder_mem_est"] = static_cast<double>(builder_estimate);
benchmark::DoNotOptimize(proving_key);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#pragma once

#include "barretenberg/stdlib_circuit_builders/mega_flavor.hpp"
#include <cstdint>

namespace bb {

/**
* @brief Methods for estimating memory in key components of MegaHonk
*
*/
class MegaMemoryEstimator {
using FF = MegaFlavor::FF;

public:
static uint64_t estimate_proving_key_memory(MegaFlavor::ProvingKey& proving_key)
{
vinfo("++Estimating proving key memory++");

auto& polynomials = proving_key.polynomials;

for (auto [polynomial, label] : zip_view(polynomials.get_all(), polynomials.get_labels())) {
uint64_t size = polynomial.size();
vinfo(label, " num: ", size, " size: ", (size * sizeof(FF)) >> 10, " KiB");
}

uint64_t result(0);
for (auto& polynomial : polynomials.get_unshifted()) {
result += polynomial.size() * sizeof(FF);
}

result += proving_key.public_inputs.capacity() * sizeof(FF);

return result;
}

static uint64_t estimate_builder_memory(MegaFlavor::CircuitBuilder& builder)
{
vinfo("++Estimating builder memory++");
uint64_t result{ 0 };

// gates:
for (auto [block, label] : zip_view(builder.blocks.get(), builder.blocks.get_labels())) {
uint64_t size{ 0 };
for (const auto& wire : block.wires) {
size += wire.capacity() * sizeof(uint32_t);
}
for (const auto& selector : block.selectors) {
size += selector.capacity() * sizeof(FF);
}
vinfo(label, " size ", size >> 10, " KiB");
result += size;
}

// variables
size_t to_add{ builder.variables.capacity() * sizeof(FF) };
result += to_add;
vinfo("variables: ", to_add);

// public inputs
to_add = builder.public_inputs.capacity() * sizeof(uint32_t);
result += to_add;
vinfo("public inputs: ", to_add);

// other variable indices
to_add = builder.next_var_index.capacity() * sizeof(uint32_t);
to_add += builder.prev_var_index.capacity() * sizeof(uint32_t);
to_add += builder.real_variable_index.capacity() * sizeof(uint32_t);
to_add += builder.real_variable_tags.capacity() * sizeof(uint32_t);
result += to_add;
vinfo("variable indices: ", to_add);

return result;
}
};

} // namespace bb
7 changes: 0 additions & 7 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,6 @@ class ECCVMFlavor {
// this getter is necessary for more uniform zk verifiers
auto get_shifted_witnesses() { return ShiftedEntities<DataType>::get_all(); };
auto get_precomputed() { return PrecomputedEntities<DataType>::get_all(); };
// the getter for all witnesses including derived and shifted ones
auto get_all_witnesses()
{
return concatenate(WitnessEntities<DataType>::get_all(), ShiftedEntities<DataType>::get_all());
};
// this getter is necessary for a universal ZK Sumcheck
auto get_non_witnesses() { return PrecomputedEntities<DataType>::get_all(); };
};

public:
Expand Down
4 changes: 0 additions & 4 deletions barretenberg/cpp/src/barretenberg/polynomials/polynomial.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,6 @@ template <typename Fr> class Polynomial {
// safety check for in place operations
bool in_place_operation_viable(size_t domain_size) { return (size() >= domain_size); }

// When a polynomial is instantiated from a size alone, the memory allocated corresponds to
// input size + MAXIMUM_COEFFICIENT_SHIFT to support 'shifted' coefficients efficiently.
const static size_t MAXIMUM_COEFFICIENT_SHIFT = 1;

// The underlying memory, with a bespoke (but minimal) shared array struct that fits our needs.
// Namely, it supports polynomial shifts and 'virtual' zeroes past a size up until a 'virtual' size.
SharedShiftedVirtualZeroesArray<Fr> coefficients_;
Expand Down
14 changes: 2 additions & 12 deletions barretenberg/cpp/src/barretenberg/protogalaxy/protogalaxy.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "barretenberg/stdlib_circuit_builders/mock_circuits.hpp"
#include "barretenberg/ultra_honk/decider_prover.hpp"
#include "barretenberg/ultra_honk/decider_verifier.hpp"
#include "barretenberg/ultra_honk/witness_computation.hpp"

#include <gtest/gtest.h>

Expand Down Expand Up @@ -116,18 +117,7 @@ template <typename Flavor> class ProtogalaxyTests : public testing::Test {

auto decider_pk = std::make_shared<DeciderProvingKey>(builder);

decider_pk->relation_parameters.eta = FF::random_element();
decider_pk->relation_parameters.eta_two = FF::random_element();
decider_pk->relation_parameters.eta_three = FF::random_element();
decider_pk->relation_parameters.beta = FF::random_element();
decider_pk->relation_parameters.gamma = FF::random_element();

decider_pk->proving_key.add_ram_rom_memory_records_to_wire_4(decider_pk->relation_parameters.eta,
decider_pk->relation_parameters.eta_two,
decider_pk->relation_parameters.eta_three);
decider_pk->proving_key.compute_logderivative_inverses(decider_pk->relation_parameters);
decider_pk->proving_key.compute_grand_product_polynomial(decider_pk->relation_parameters,
decider_pk->final_active_wire_idx + 1);
WitnessComputation<Flavor>::complete_proving_key_for_test(decider_pk);

for (auto& alpha : decider_pk->alphas) {
alpha = FF::random_element();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,44 +237,6 @@ template <typename FF> class MegaCircuitBuilder_ : public UltraCircuitBuilder_<M
const BusVector& get_calldata() const { return databus[static_cast<size_t>(BusId::CALLDATA)]; }
const BusVector& get_secondary_calldata() const { return databus[static_cast<size_t>(BusId::SECONDARY_CALLDATA)]; }
const BusVector& get_return_data() const { return databus[static_cast<size_t>(BusId::RETURNDATA)]; }
uint64_t estimate_memory() const
{
vinfo("++Estimating builder memory++");
uint64_t result{ 0 };

// gates:
for (auto [block, label] : zip_view(this->blocks.get(), this->blocks.get_labels())) {
uint64_t size{ 0 };
for (const auto& wire : block.wires) {
size += wire.capacity() * sizeof(uint32_t);
}
for (const auto& selector : block.selectors) {
size += selector.capacity() * sizeof(FF);
}
vinfo(label, " size ", size >> 10, " KiB");
result += size;
}

// variables
size_t to_add{ this->variables.capacity() * sizeof(FF) };
result += to_add;
vinfo("variables: ", to_add);

// public inputs
to_add = this->public_inputs.capacity() * sizeof(uint32_t);
result += to_add;
vinfo("public inputs: ", to_add);

// other variable indices
to_add = this->next_var_index.capacity() * sizeof(uint32_t);
to_add += this->prev_var_index.capacity() * sizeof(uint32_t);
to_add += this->real_variable_index.capacity() * sizeof(uint32_t);
to_add += this->real_variable_tags.capacity() * sizeof(uint32_t);
result += to_add;
vinfo("variable indices: ", to_add);

return result;
}
};
using MegaCircuitBuilder = MegaCircuitBuilder_<bb::fr>;
} // namespace bb
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
#include "barretenberg/flavor/flavor_macros.hpp"
#include "barretenberg/flavor/relation_definitions.hpp"
#include "barretenberg/flavor/repeated_commitments_data.hpp"
#include "barretenberg/honk/proof_system/types/proof.hpp"
#include "barretenberg/plonk_honk_shared/library/grand_product_delta.hpp"
#include "barretenberg/plonk_honk_shared/library/grand_product_library.hpp"
#include "barretenberg/polynomials/univariate.hpp"
#include "barretenberg/relations/auxiliary_relation.hpp"
#include "barretenberg/relations/databus_lookup_relation.hpp"
Expand All @@ -18,7 +15,6 @@
#include "barretenberg/relations/permutation_relation.hpp"
#include "barretenberg/relations/poseidon2_external_relation.hpp"
#include "barretenberg/relations/poseidon2_internal_relation.hpp"
#include "barretenberg/relations/relation_parameters.hpp"
#include "barretenberg/relations/ultra_arithmetic_relation.hpp"
#include "barretenberg/stdlib_circuit_builders/mega_circuit_builder.hpp"
#include "barretenberg/transcript/transcript.hpp"
Expand Down Expand Up @@ -86,8 +82,8 @@ class MegaFlavor {
static constexpr size_t NUM_ALL_WITNESS_ENTITIES = NUM_WITNESS_ENTITIES + NUM_SHIFTED_WITNESSES;

// For instances of this flavour, used in folding, we need a unique sumcheck batching challenges for each
// subrelation. This
// is because using powers of alpha would increase the degree of Protogalaxy polynomial $G$ (the combiner) too much.
// subrelation. This is because using powers of alpha would increase the degree of Protogalaxy polynomial $G$ (the
// combiner) too much.
static constexpr size_t NUM_SUBRELATIONS = compute_number_of_subrelations<Relations>();
using RelationSeparator = std::array<FF, NUM_SUBRELATIONS - 1>;

Expand Down Expand Up @@ -172,7 +168,7 @@ class MegaFlavor {

// Mega needs to expose more public classes than most flavors due to MegaRecursive reuse, but these
// are internal:
public:

// WireEntities for basic witness entities
template <typename DataType> class WireEntities {
public:
Expand Down Expand Up @@ -284,8 +280,6 @@ class MegaFlavor {
w_o_shift, // column 2
w_4_shift, // column 3
z_perm_shift) // column 4

auto get_shifted() { return RefArray{ w_l_shift, w_r_shift, w_o_shift, w_4_shift, z_perm_shift }; };
};

public:
Expand Down Expand Up @@ -320,13 +314,6 @@ class MegaFlavor {
auto get_witness() { return WitnessEntities<DataType>::get_all(); };
auto get_to_be_shifted() { return WitnessEntities<DataType>::get_to_be_shifted(); };
auto get_shifted() { return ShiftedEntities<DataType>::get_all(); };
// this getter is used in ZK Sumcheck, where all witness evaluations (including shifts) have to be masked
auto get_all_witnesses()
{
return concatenate(WitnessEntities<DataType>::get_all(), ShiftedEntities<DataType>::get_all());
};
// getter for the complement of all witnesses inside all entities
auto get_non_witnesses() { return PrecomputedEntities<DataType>::get_all(); };
};

/**
Expand Down Expand Up @@ -433,115 +420,17 @@ class MegaFlavor {

// Data pertaining to transfer of databus return data via public inputs
DatabusPropagationData databus_propagation_data;

/**
* @brief Add plookup memory records to the fourth wire polynomial
*
* @details This operation must be performed after the first three wires have been committed to, hence the
* dependence on the `eta` challenge.
*
* @tparam Flavor
* @param eta challenge produced after commitment to first three wire polynomials
*/
void add_ram_rom_memory_records_to_wire_4(const FF& eta, const FF& eta_two, const FF& eta_three)
{
// The plookup memory record values are computed at the indicated indices as
// w4 = w3 * eta^3 + w2 * eta^2 + w1 * eta + read_write_flag;
// (See plookup_auxiliary_widget.hpp for details)
auto wires = polynomials.get_wires();

// Compute read record values
for (const auto& gate_idx : memory_read_records) {
wires[3].at(gate_idx) += wires[2][gate_idx] * eta_three;
wires[3].at(gate_idx) += wires[1][gate_idx] * eta_two;
wires[3].at(gate_idx) += wires[0][gate_idx] * eta;
}

// Compute write record values
for (const auto& gate_idx : memory_write_records) {
wires[3].at(gate_idx) += wires[2][gate_idx] * eta_three;
wires[3].at(gate_idx) += wires[1][gate_idx] * eta_two;
wires[3].at(gate_idx) += wires[0][gate_idx] * eta;
wires[3].at(gate_idx) += 1;
}
}

/**
* @brief Compute the inverse polynomials used in the log derivative lookup relations
*
* @tparam Flavor
* @param beta
* @param gamma
*/
void compute_logderivative_inverses(const RelationParameters<FF>& relation_parameters)
{
PROFILE_THIS_NAME("compute_logderivative_inverses");

// Compute inverses for conventional lookups
LogDerivLookupRelation<FF>::compute_logderivative_inverse(
this->polynomials, relation_parameters, this->circuit_size);

// Compute inverses for calldata reads
DatabusLookupRelation<FF>::compute_logderivative_inverse</*bus_idx=*/0>(
this->polynomials, relation_parameters, this->circuit_size);

// Compute inverses for secondary_calldata reads
DatabusLookupRelation<FF>::compute_logderivative_inverse</*bus_idx=*/1>(
this->polynomials, relation_parameters, this->circuit_size);

// Compute inverses for return data reads
DatabusLookupRelation<FF>::compute_logderivative_inverse</*bus_idx=*/2>(
this->polynomials, relation_parameters, this->circuit_size);
}

/**
* @brief Computes public_input_delta and the permutation grand product polynomial
*
* @param relation_parameters
* @param size_override override the size of the domain over which to compute the grand product
*/
void compute_grand_product_polynomial(RelationParameters<FF>& relation_parameters, size_t size_override = 0)
{
relation_parameters.public_input_delta = compute_public_input_delta<MegaFlavor>(this->public_inputs,
relation_parameters.beta,
relation_parameters.gamma,
this->circuit_size,
this->pub_inputs_offset);

// Compute permutation grand product polynomial
compute_grand_product<MegaFlavor, UltraPermutationRelation<FF>>(
this->polynomials, relation_parameters, size_override, this->active_region_data);
}

uint64_t estimate_memory()
{
vinfo("++Estimating proving key memory++");
for (auto [polynomial, label] : zip_view(polynomials.get_all(), polynomials.get_labels())) {
uint64_t size = polynomial.size();
vinfo(label, " num: ", size, " size: ", (size * sizeof(FF)) >> 10, " KiB");
}

uint64_t result(0);
for (auto& polynomial : polynomials.get_unshifted()) {
result += polynomial.size() * sizeof(FF);
}

result += public_inputs.capacity() * sizeof(FF);

return result;
}
};

/**
* @brief The verification key is responsible for storing the commitments to the precomputed (non-witnessk)
* @brief The verification key is responsible for storing the commitments to the precomputed (non-witness)
* polynomials used by the verifier.
*
* @note Note the discrepancy with what sort of data is stored here vs in the proving key. We may want to resolve
* that, and split out separate PrecomputedPolynomials/Commitments data for clarity but also for portability of our
* circuits.
* @todo TODO(https://github.com/AztecProtocol/barretenberg/issues/876)
*/
// using VerificationKey = VerificationKey_<PrecomputedEntities<Commitment>, VerifierCommitmentKey>;
class VerificationKey : public VerificationKey_<PrecomputedEntities<Commitment>, VerifierCommitmentKey> {
public:
// Data pertaining to transfer of databus return data via public inputs of the proof being recursively verified
Expand Down Expand Up @@ -614,6 +503,7 @@ class MegaFlavor {
}

// TODO(https://github.com/AztecProtocol/barretenberg/issues/964): Clean the boilerplate up.
// Explicit constructor for msgpack serialization
VerificationKey(const size_t circuit_size,
const size_t num_public_inputs,
const size_t pub_inputs_offset,
Expand Down
Loading
Loading