Skip to content

Commit

Permalink
feat: reorganize acir composer (AztecProtocol#3957)
Browse files Browse the repository at this point in the history
Previously, the proof generation flows defined in bb/main.cpp performed
1 or more additional circuit constructions than were necessary. This
work removes the unnecessary computation and generally makes circuit
construction more explicit to avoid similar confusion in the future. For
example, `init_proving_key()` no longer makes an internal call to
`create_circuit()` - instead, the circuit must be constructed
explicitly. The code has been otherwise simplified/clarified where
possible.

Note: removal of the redundant circuit construction should also avoid
issues like the one recently encountered by @guipublic where circuit
generation without a witness was inconsistent with generation _with_ a
witness, leading to srs size issues.
  • Loading branch information
ledwards2225 authored Jan 12, 2024
1 parent 70b2ffd commit e6232e8
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 109 deletions.
12 changes: 0 additions & 12 deletions barretenberg/cpp/src/barretenberg/bb/get_witness.hpp

This file was deleted.

77 changes: 40 additions & 37 deletions barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#include "barretenberg/dsl/acir_format/acir_format.hpp"
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp"
#include "config.hpp"
#include "get_bn254_crs.hpp"
#include "get_bytecode.hpp"
#include "get_grumpkin_crs.hpp"
#include "get_witness.hpp"
#include "log.hpp"
#include <barretenberg/common/benchmark.hpp>
#include <barretenberg/common/container.hpp>
Expand All @@ -32,34 +30,22 @@ bool verbose = false;
const std::filesystem::path current_path = std::filesystem::current_path();
const auto current_dir = current_path.filename().string();

acir_proofs::AcirComposer init(acir_format::acir_format& constraint_system)
/**
* @brief Initialize the global crs_factory for bn254 based on a known dyadic circuit size
*
* @param dyadic_circuit_size power-of-2 circuit size
*/
void init_bn254_crs(size_t dyadic_circuit_size)
{
acir_proofs::AcirComposer acir_composer(0, verbose);
acir_composer.create_circuit(constraint_system);
auto subgroup_size = acir_composer.get_circuit_subgroup_size();

// Must +1!
auto bn254_g1_data = get_bn254_g1_data(CRS_PATH, subgroup_size + 1);
// Must +1 for Plonk only!
auto bn254_g1_data = get_bn254_g1_data(CRS_PATH, dyadic_circuit_size + 1);
auto bn254_g2_data = get_bn254_g2_data(CRS_PATH);
srs::init_crs_factory(bn254_g1_data, bn254_g2_data);

return acir_composer;
}

void init_reference_strings()
void init_grumpkin_crs(size_t eccvm_dyadic_circuit_size)
{
// TODO(https://github.com/AztecProtocol/barretenberg/issues/811): Don't hardcode subgroup size. Currently set to
// max circuit size present in acir tests suite.
size_t hardcoded_subgroup_size_hack = 262144;

// TODO(https://github.com/AztecProtocol/barretenberg/issues/811) reduce duplication with above
// Must +1!
auto g1_data = get_bn254_g1_data(CRS_PATH, hardcoded_subgroup_size_hack + 1);
auto g2_data = get_bn254_g2_data(CRS_PATH);
srs::init_crs_factory(g1_data, g2_data);

// Must +1!
auto grumpkin_g1_data = get_grumpkin_g1_data(CRS_PATH, hardcoded_subgroup_size_hack + 1);
auto grumpkin_g1_data = get_grumpkin_g1_data(CRS_PATH, eccvm_dyadic_circuit_size);
srs::init_grumpkin_crs_factory(grumpkin_g1_data);
}

Expand All @@ -75,7 +61,7 @@ acir_proofs::AcirComposer verifier_init()

acir_format::WitnessVector get_witness(std::string const& witness_path)
{
auto witness_data = get_witness_data(witness_path);
auto witness_data = get_bytecode(witness_path);
return acir_format::witness_buf_to_witness_data(witness_data);
}

Expand Down Expand Up @@ -103,16 +89,20 @@ bool proveAndVerify(const std::string& bytecodePath, const std::string& witnessP
auto constraint_system = get_constraint_system(bytecodePath);
auto witness = get_witness(witnessPath);

auto acir_composer = init(constraint_system);
acir_proofs::AcirComposer acir_composer{ 0, verbose };
acir_composer.create_circuit(constraint_system, witness);

init_bn254_crs(acir_composer.get_dyadic_circuit_size());

Timer pk_timer;
acir_composer.init_proving_key(constraint_system);
acir_composer.init_proving_key();
write_benchmark("pk_construction_time", pk_timer.milliseconds(), "acir_test", current_dir);

write_benchmark("gate_count", acir_composer.get_total_circuit_size(), "acir_test", current_dir);
write_benchmark("subgroup_size", acir_composer.get_circuit_subgroup_size(), "acir_test", current_dir);
write_benchmark("subgroup_size", acir_composer.get_dyadic_circuit_size(), "acir_test", current_dir);

Timer proof_timer;
auto proof = acir_composer.create_proof(constraint_system, witness, recursive);
auto proof = acir_composer.create_proof(recursive);
write_benchmark("proof_construction_time", proof_timer.milliseconds(), "acir_test", current_dir);

Timer vk_timer;
Expand Down Expand Up @@ -145,11 +135,16 @@ bool proveAndVerifyGoblin(const std::string& bytecodePath,
auto constraint_system = get_constraint_system(bytecodePath);
auto witness = get_witness(witnessPath);

init_reference_strings();

acir_proofs::AcirComposer acir_composer;
acir_composer.create_goblin_circuit(constraint_system, witness);

// TODO(https://github.com/AztecProtocol/barretenberg/issues/811): Don't hardcode dyadic circuit size. Currently set
// to max circuit size present in acir tests suite.
size_t hardcoded_bn254_dyadic_size_hack = 1 << 18;
init_bn254_crs(hardcoded_bn254_dyadic_size_hack);
size_t hardcoded_grumpkin_dyadic_size_hack = 1 << 10; // For eccvm only
init_grumpkin_crs(hardcoded_grumpkin_dyadic_size_hack);

auto proof = acir_composer.create_goblin_proof();

auto verified = acir_composer.verify_goblin_proof(proof);
Expand All @@ -176,8 +171,12 @@ void prove(const std::string& bytecodePath,
{
auto constraint_system = get_constraint_system(bytecodePath);
auto witness = get_witness(witnessPath);
auto acir_composer = init(constraint_system);
auto proof = acir_composer.create_proof(constraint_system, witness, recursive);

acir_proofs::AcirComposer acir_composer{ 0, verbose };
acir_composer.create_circuit(constraint_system, witness);
init_bn254_crs(acir_composer.get_dyadic_circuit_size());
acir_composer.init_proving_key();
auto proof = acir_composer.create_proof(recursive);

if (outputPath == "-") {
writeRawBytesToStdout(proof);
Expand Down Expand Up @@ -247,8 +246,10 @@ bool verify(const std::string& proof_path, bool recursive, const std::string& vk
void write_vk(const std::string& bytecodePath, const std::string& outputPath)
{
auto constraint_system = get_constraint_system(bytecodePath);
auto acir_composer = init(constraint_system);
acir_composer.init_proving_key(constraint_system);
acir_proofs::AcirComposer acir_composer{ 0, verbose };
acir_composer.create_circuit(constraint_system);
init_bn254_crs(acir_composer.get_dyadic_circuit_size());
acir_composer.init_proving_key();
auto vk = acir_composer.init_verification_key();
auto serialized_vk = to_buffer(*vk);
if (outputPath == "-") {
Expand All @@ -263,8 +264,10 @@ void write_vk(const std::string& bytecodePath, const std::string& outputPath)
void write_pk(const std::string& bytecodePath, const std::string& outputPath)
{
auto constraint_system = get_constraint_system(bytecodePath);
auto acir_composer = init(constraint_system);
auto pk = acir_composer.init_proving_key(constraint_system);
acir_proofs::AcirComposer acir_composer{ 0, verbose };
acir_composer.create_circuit(constraint_system);
init_bn254_crs(acir_composer.get_dyadic_circuit_size());
auto pk = acir_composer.init_proving_key();
auto serialized_pk = to_buffer(*pk);

if (outputPath == "-") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "acir_format.hpp"
#include "barretenberg/common/log.hpp"
#include "barretenberg/dsl/acir_format/pedersen.hpp"
#include "barretenberg/dsl/acir_format/recursion_constraint.hpp"
#include "barretenberg/proof_system/circuit_builder/ultra_circuit_builder.hpp"
#include <cstddef>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once
#include "barretenberg/common/slab_allocator.hpp"
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/serialize/msgpack.hpp"
#include "blake2s_constraint.hpp"
#include "blake3_constraint.hpp"
Expand Down
57 changes: 19 additions & 38 deletions barretenberg/cpp/src/barretenberg/dsl/acir_proofs/acir_composer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
#include "barretenberg/common/serialize.hpp"
#include "barretenberg/common/throw_or_abort.hpp"
#include "barretenberg/dsl/acir_format/acir_format.hpp"
#include "barretenberg/dsl/acir_format/recursion_constraint.hpp"
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/goblin/mock_circuits.hpp"
#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp"
#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp"
#include "barretenberg/plonk/proof_system/verification_key/sol_gen.hpp"
#include "barretenberg/plonk/proof_system/verification_key/verification_key.hpp"
#include "barretenberg/srs/factories/crs_factory.hpp"
#include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders_fwd.hpp"
#include "contract.hpp"

Expand All @@ -20,55 +17,36 @@ AcirComposer::AcirComposer(size_t size_hint, bool verbose)
, verbose_(verbose)
{}

template <typename Builder> void AcirComposer::create_circuit(acir_format::acir_format& constraint_system)
/**
* @brief Populate acir_composer-owned builder with circuit generated from constraint system and an optional witness
*
* @tparam Builder
* @param constraint_system
* @param witness
*/
template <typename Builder>
void AcirComposer::create_circuit(acir_format::acir_format& constraint_system, WitnessVector const& witness)
{
// this seems to have made sense for plonk but no longer makes sense for Honk? if we return early then the
// sizes below never get set and that eventually causes too few srs points to be extracted
if (builder_.get_num_gates() > 1) {
return;
}
vinfo("building circuit...");
builder_ = acir_format::create_circuit<Builder>(constraint_system, size_hint_);
exact_circuit_size_ = builder_.get_num_gates();
total_circuit_size_ = builder_.get_total_circuit_size();
circuit_subgroup_size_ = builder_.get_circuit_subgroup_size(total_circuit_size_);
size_hint_ = circuit_subgroup_size_;
builder_ = acir_format::create_circuit<Builder>(constraint_system, size_hint_, witness);
vinfo("gates: ", builder_.get_total_circuit_size());
}

template void AcirComposer::create_circuit<proof_system::UltraCircuitBuilder>(
acir_format::acir_format& constraint_system);

std::shared_ptr<proof_system::plonk::proving_key> AcirComposer::init_proving_key(
acir_format::acir_format& constraint_system)
std::shared_ptr<proof_system::plonk::proving_key> AcirComposer::init_proving_key()
{
create_circuit(constraint_system);
acir_format::Composer composer;
vinfo("computing proving key...");
proving_key_ = composer.compute_proving_key(builder_);
return proving_key_;
}

std::vector<uint8_t> AcirComposer::create_proof(acir_format::acir_format& constraint_system,
acir_format::WitnessVector& witness,
bool is_recursive)
std::vector<uint8_t> AcirComposer::create_proof(bool is_recursive)
{
vinfo("building circuit with witness...");
builder_ = acir_format::create_circuit(constraint_system, size_hint_, witness);

vinfo("gates: ", builder_.get_total_circuit_size());

auto composer = [&]() {
if (proving_key_) {
return acir_format::Composer(proving_key_, nullptr);
}
if (!proving_key_) {
throw_or_abort("Must compute proving key before constructing proof.");
}

acir_format::Composer composer;
vinfo("computing proving key...");
proving_key_ = composer.compute_proving_key(builder_);
vinfo("done.");
return composer;
}();
acir_format::Composer composer(proving_key_, nullptr);

vinfo("creating proof...");
std::vector<uint8_t> proof;
Expand Down Expand Up @@ -208,4 +186,7 @@ std::vector<barretenberg::fr> AcirComposer::serialize_verification_key_into_fiel
return acir_format::export_key_in_recursion_format(verification_key_);
}

template void AcirComposer::create_circuit<UltraCircuitBuilder>(acir_format::acir_format& constraint_system,
WitnessVector const& witness);

} // namespace acir_proofs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once
#include <barretenberg/dsl/acir_format/acir_format.hpp>
#include <barretenberg/goblin/goblin.hpp>
#include <barretenberg/proof_system/op_queue/ecc_op_queue.hpp>

namespace acir_proofs {

Expand All @@ -12,16 +11,18 @@ namespace acir_proofs {
* structure of the newer code since there's much more of that code now?
*/
class AcirComposer {

using WitnessVector = std::vector<fr, ContainerSlabAllocator<fr>>;

public:
AcirComposer(size_t size_hint = 0, bool verbose = true);

template <typename Builder = UltraCircuitBuilder> void create_circuit(acir_format::acir_format& constraint_system);
template <typename Builder = UltraCircuitBuilder>
void create_circuit(acir_format::acir_format& constraint_system, WitnessVector const& witness = {});

std::shared_ptr<proof_system::plonk::proving_key> init_proving_key(acir_format::acir_format& constraint_system);
std::shared_ptr<proof_system::plonk::proving_key> init_proving_key();

std::vector<uint8_t> create_proof(acir_format::acir_format& constraint_system,
acir_format::WitnessVector& witness,
bool is_recursive);
std::vector<uint8_t> create_proof(bool is_recursive);

void load_verification_key(proof_system::plonk::verification_key_data&& data);

Expand All @@ -30,9 +31,8 @@ class AcirComposer {
bool verify_proof(std::vector<uint8_t> const& proof, bool is_recursive);

std::string get_solidity_verifier();
size_t get_exact_circuit_size() { return exact_circuit_size_; };
size_t get_total_circuit_size() { return total_circuit_size_; };
size_t get_circuit_subgroup_size() { return circuit_subgroup_size_; };
size_t get_total_circuit_size() { return builder_.get_total_circuit_size(); };
size_t get_dyadic_circuit_size() { return builder_.get_circuit_subgroup_size(builder_.get_total_circuit_size()); };

std::vector<barretenberg::fr> serialize_proof_into_fields(std::vector<uint8_t> const& proof,
size_t num_inner_public_inputs);
Expand All @@ -49,9 +49,6 @@ class AcirComposer {
acir_format::GoblinBuilder goblin_builder_;
Goblin goblin;
size_t size_hint_;
size_t exact_circuit_size_;
size_t total_circuit_size_;
size_t circuit_subgroup_size_;
std::shared_ptr<proof_system::plonk::proving_key> proving_key_;
std::shared_ptr<proof_system::plonk::verification_key> verification_key_;
bool verbose_ = true;
Expand Down
19 changes: 12 additions & 7 deletions barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
WASM_EXPORT void acir_get_circuit_sizes(uint8_t const* acir_vec, uint32_t* exact, uint32_t* total, uint32_t* subgroup)
{
auto constraint_system = acir_format::circuit_buf_to_acir_format(from_buffer<std::vector<uint8_t>>(acir_vec));
auto composer = acir_format::create_circuit(constraint_system, 1 << 19);
*exact = htonl((uint32_t)composer.get_num_gates());
*total = htonl((uint32_t)composer.get_total_circuit_size());
*subgroup = htonl((uint32_t)composer.get_circuit_subgroup_size(composer.get_total_circuit_size()));
auto builder = acir_format::create_circuit(constraint_system, 1 << 19);
*exact = htonl((uint32_t)builder.get_num_gates());
*total = htonl((uint32_t)builder.get_total_circuit_size());
*subgroup = htonl((uint32_t)builder.get_circuit_subgroup_size(builder.get_total_circuit_size()));
}

WASM_EXPORT void acir_new_acir_composer(uint32_t const* size_hint, out_ptr out)
Expand All @@ -35,8 +35,9 @@ WASM_EXPORT void acir_init_proving_key(in_ptr acir_composer_ptr, uint8_t const*
{
auto acir_composer = reinterpret_cast<acir_proofs::AcirComposer*>(*acir_composer_ptr);
auto constraint_system = acir_format::circuit_buf_to_acir_format(from_buffer<std::vector<uint8_t>>(acir_vec));
acir_composer->create_circuit(constraint_system);

acir_composer->init_proving_key(constraint_system);
acir_composer->init_proving_key();
}

WASM_EXPORT void acir_create_proof(in_ptr acir_composer_ptr,
Expand All @@ -49,7 +50,10 @@ WASM_EXPORT void acir_create_proof(in_ptr acir_composer_ptr,
auto constraint_system = acir_format::circuit_buf_to_acir_format(from_buffer<std::vector<uint8_t>>(acir_vec));
auto witness = acir_format::witness_buf_to_witness_data(from_buffer<std::vector<uint8_t>>(witness_vec));

auto proof_data = acir_composer->create_proof(constraint_system, witness, *is_recursive);
acir_composer->create_circuit(constraint_system, witness);

acir_composer->init_proving_key();
auto proof_data = acir_composer->create_proof(*is_recursive);
*out = to_heap_buffer(proof_data);
}

Expand Down Expand Up @@ -92,7 +96,8 @@ WASM_EXPORT void acir_get_proving_key(in_ptr acir_composer_ptr, uint8_t const* a
{
auto acir_composer = reinterpret_cast<acir_proofs::AcirComposer*>(*acir_composer_ptr);
auto constraint_system = acir_format::circuit_buf_to_acir_format(from_buffer<std::vector<uint8_t>>(acir_vec));
auto pk = acir_composer->init_proving_key(constraint_system);
acir_composer->create_circuit(constraint_system);
auto pk = acir_composer->init_proving_key();
// We flatten to a vector<uint8_t> first, as that's how we treat it on the calling side.
*out = to_heap_buffer(to_buffer(*pk));
}
Expand Down

0 comments on commit e6232e8

Please sign in to comment.