From bef54b879cdba15a3a70f2a6678a9399fb529944 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Mon, 1 May 2023 13:25:48 -0700 Subject: [PATCH 1/7] feat!: Introduce WitnessMap data structure to avoid leaking internal structure --- acir/src/native_types/mod.rs | 1 + acir/src/native_types/witness.rs | 70 ++++++++++++++++++++++++++++---- acvm/src/lib.rs | 20 ++++----- acvm/src/pwg.rs | 9 ++-- acvm/src/pwg/arithmetic.rs | 23 ++++------- acvm/src/pwg/block.rs | 16 ++++---- acvm/src/pwg/directives.rs | 8 ++-- acvm/src/pwg/hash.rs | 13 +++--- acvm/src/pwg/logic.rs | 15 ++++--- acvm/src/pwg/oracle.rs | 6 +-- acvm/src/pwg/range.rs | 5 +-- acvm/src/pwg/signature/ecdsa.rs | 5 +-- 12 files changed, 116 insertions(+), 75 deletions(-) diff --git a/acir/src/native_types/mod.rs b/acir/src/native_types/mod.rs index 4b54d9388..449c04b3e 100644 --- a/acir/src/native_types/mod.rs +++ b/acir/src/native_types/mod.rs @@ -3,3 +3,4 @@ mod witness; pub use expression::Expression; pub use witness::Witness; +pub use witness::WitnessMap; diff --git a/acir/src/native_types/witness.rs b/acir/src/native_types/witness.rs index 5ffd47391..0cafb58c8 100644 --- a/acir/src/native_types/witness.rs +++ b/acir/src/native_types/witness.rs @@ -1,4 +1,4 @@ -use std::io::Read; +use std::{collections::BTreeMap, io::Read, ops::Index}; use flate2::{ bufread::{DeflateDecoder, DeflateEncoder}, @@ -27,23 +27,75 @@ impl Witness { pub const fn can_defer_constraint(&self) -> bool { true } +} + +impl From for Witness { + fn from(value: u32) -> Self { + Self(value) + } +} + +/// A map from the witnesses in a constraint system to the field element values +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)] +pub struct WitnessMap(BTreeMap); + +impl WitnessMap { + pub fn new() -> Self { + Self(BTreeMap::new()) + } + pub fn get(&self, witness: &Witness) -> Option<&acir_field::FieldElement> { + self.0.get(witness) + } + pub fn get_index(&self, index: u32) -> Option<&acir_field::FieldElement> { + self.0.get(&index.into()) + } + pub fn contains_key(&self, key: &Witness) -> bool { + self.0.contains_key(key) + } + pub fn insert( + &mut self, + key: Witness, + value: acir_field::FieldElement, + ) -> Option { + self.0.insert(key, value) + } +} - pub fn to_bytes( - witnesses: &std::collections::BTreeMap, - ) -> Vec { - let buf = rmp_serde::to_vec(witnesses).unwrap(); +impl Index<&Witness> for WitnessMap { + type Output = acir_field::FieldElement; + + fn index(&self, index: &Witness) -> &Self::Output { + &self.0[index] + } +} + +impl From> for WitnessMap { + fn from(value: BTreeMap) -> Self { + Self(value) + } +} + +impl From for Vec { + fn from(val: WitnessMap) -> Self { + let buf = rmp_serde::to_vec(&val).unwrap(); let mut deflater = DeflateEncoder::new(buf.as_slice(), Compression::best()); let mut buf_c = Vec::new(); deflater.read_to_end(&mut buf_c).unwrap(); buf_c } +} - pub fn from_bytes( - bytes: &[u8], - ) -> std::collections::BTreeMap { +impl From<&[u8]> for WitnessMap { + fn from(bytes: &[u8]) -> Self { let mut deflater = DeflateDecoder::new(bytes); let mut buf_d = Vec::new(); deflater.read_to_end(&mut buf_d).unwrap(); - rmp_serde::from_slice(buf_d.as_slice()).unwrap() + Self(rmp_serde::from_slice(buf_d.as_slice()).unwrap()) + } +} + +impl From for Vec { + fn from(val: WitnessMap) -> Self { + val.0.into_values().collect() } } diff --git a/acvm/src/lib.rs b/acvm/src/lib.rs index 11791df6a..f8cd3aba4 100644 --- a/acvm/src/lib.rs +++ b/acvm/src/lib.rs @@ -13,11 +13,10 @@ use acir::{ opcodes::{BlackBoxFuncCall, OracleData}, Circuit, Opcode, }, - native_types::{Expression, Witness}, + native_types::{Expression, Witness, WitnessMap}, BlackBoxFunc, }; use pwg::{block::Blocks, directives::solve_directives}; -use std::collections::BTreeMap; use thiserror::Error; // re-export acir @@ -80,7 +79,7 @@ pub enum PartialWitnessGeneratorStatus { /// /// Returns the first missing assignment if any are missing fn first_missing_assignment( - witness_assignments: &BTreeMap, + witness_assignments: &WitnessMap, func_call: &BlackBoxFuncCall, ) -> Option { func_call.inputs.iter().find_map(|input| { @@ -100,7 +99,7 @@ pub trait Backend: SmartContract + ProofSystemCompiler + PartialWitnessGenerator pub trait PartialWitnessGenerator { fn solve( &self, - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, blocks: &mut Blocks, mut opcode_to_solve: Vec, ) -> Result { @@ -189,7 +188,7 @@ pub trait PartialWitnessGenerator { fn solve_black_box_function_call( &self, - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, func_call: &BlackBoxFuncCall, ) -> Result; } @@ -230,7 +229,7 @@ pub trait ProofSystemCompiler { fn prove_with_pk( &self, circuit: &Circuit, - witness_values: BTreeMap, + witness_values: WitnessMap, proving_key: &[u8], ) -> Result, Self::Error>; @@ -238,7 +237,7 @@ pub trait ProofSystemCompiler { fn verify_with_vk( &self, proof: &[u8], - public_inputs: BTreeMap, + public_inputs: WitnessMap, circuit: &Circuit, verification_key: &[u8], ) -> Result; @@ -319,7 +318,7 @@ mod test { opcodes::{BlackBoxFuncCall, OracleData}, Opcode, }, - native_types::{Expression, Witness}, + native_types::{Expression, Witness, WitnessMap}, FieldElement, }; @@ -333,7 +332,7 @@ mod test { impl PartialWitnessGenerator for StubbedPwg { fn solve_black_box_function_call( &self, - _initial_witness: &mut BTreeMap, + _initial_witness: &mut WitnessMap, _func_call: &BlackBoxFuncCall, ) -> Result { panic!("Path not trodden by this test") @@ -389,7 +388,8 @@ mod test { let mut witness_assignments = BTreeMap::from([ (Witness(1), FieldElement::from(2u128)), (Witness(2), FieldElement::from(3u128)), - ]); + ]) + .into(); let mut blocks = Blocks::default(); let solver_status = pwg .solve(&mut witness_assignments, &mut blocks, opcodes) diff --git a/acvm/src/pwg.rs b/acvm/src/pwg.rs index bf7bbdab5..ee8104a8a 100644 --- a/acvm/src/pwg.rs +++ b/acvm/src/pwg.rs @@ -2,10 +2,9 @@ use crate::{OpcodeNotSolvable, OpcodeResolutionError}; use acir::{ - native_types::{Expression, Witness}, + native_types::{Expression, Witness, WitnessMap}, FieldElement, }; -use std::collections::BTreeMap; use self::arithmetic::ArithmeticSolver; @@ -26,7 +25,7 @@ pub mod sorting; // If the witness has no assignment, then // an error is returned pub fn witness_to_value( - initial_witness: &BTreeMap, + initial_witness: &WitnessMap, witness: Witness, ) -> Result<&FieldElement, OpcodeResolutionError> { match initial_witness.get(&witness) { @@ -39,7 +38,7 @@ pub fn witness_to_value( // TODO versus just getting values from Witness pub fn get_value( expr: &Expression, - initial_witness: &BTreeMap, + initial_witness: &WitnessMap, ) -> Result { let expr = ArithmeticSolver::evaluate(expr, initial_witness); match expr.to_const() { @@ -59,7 +58,7 @@ pub fn get_value( fn insert_value( witness: &Witness, value_to_insert: FieldElement, - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, ) -> Result<(), OpcodeResolutionError> { let optional_old_value = initial_witness.insert(*witness, value_to_insert); diff --git a/acvm/src/pwg/arithmetic.rs b/acvm/src/pwg/arithmetic.rs index 6827f3ac1..c9cd85ff5 100644 --- a/acvm/src/pwg/arithmetic.rs +++ b/acvm/src/pwg/arithmetic.rs @@ -1,8 +1,7 @@ use acir::{ - native_types::{Expression, Witness}, + native_types::{Expression, Witness, WitnessMap}, FieldElement, }; -use std::collections::BTreeMap; use crate::{OpcodeNotSolvable, OpcodeResolution, OpcodeResolutionError}; @@ -26,7 +25,7 @@ enum MulTerm { impl ArithmeticSolver { /// Derives the rest of the witness based on the initial low level variables pub fn solve( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, gate: &Expression, ) -> Result { let gate = &ArithmeticSolver::evaluate(gate, initial_witness); @@ -120,10 +119,7 @@ impl ArithmeticSolver { /// If the witness values are not known, then the function returns a None /// XXX: Do we need to account for the case where 5xy + 6x = 0 ? We do not know y, but it can be solved given x . But I believe x can be solved with another gate /// XXX: What about making a mul gate = a constant 5xy + 7 = 0 ? This is the same as the above. - fn solve_mul_term( - arith_gate: &Expression, - witness_assignments: &BTreeMap, - ) -> MulTerm { + fn solve_mul_term(arith_gate: &Expression, witness_assignments: &WitnessMap) -> MulTerm { // First note that the mul term can only contain one/zero term // We are assuming it has been optimized. match arith_gate.mul_terms.len() { @@ -138,7 +134,7 @@ impl ArithmeticSolver { fn solve_mul_term_helper( term: &(FieldElement, Witness, Witness), - witness_assignments: &BTreeMap, + witness_assignments: &WitnessMap, ) -> MulTerm { let (q_m, w_l, w_r) = term; // Check if these values are in the witness assignments @@ -155,7 +151,7 @@ impl ArithmeticSolver { fn solve_fan_in_term_helper( term: &(FieldElement, Witness), - witness_assignments: &BTreeMap, + witness_assignments: &WitnessMap, ) -> Option { let (q_l, w_l) = term; // Check if we have w_l @@ -168,7 +164,7 @@ impl ArithmeticSolver { /// We cannot assign pub fn solve_fan_in_term( arith_gate: &Expression, - witness_assignments: &BTreeMap, + witness_assignments: &WitnessMap, ) -> GateStatus { // This is assuming that the fan-in is more than 0 @@ -202,10 +198,7 @@ impl ArithmeticSolver { } // Partially evaluate the gate using the known witnesses - pub fn evaluate( - expr: &Expression, - initial_witness: &BTreeMap, - ) -> Expression { + pub fn evaluate(expr: &Expression, initial_witness: &WitnessMap) -> Expression { let mut result = Expression::default(); for &(c, w1, w2) in &expr.mul_terms { let mul_result = ArithmeticSolver::solve_mul_term_helper(&(c, w1, w2), initial_witness); @@ -280,7 +273,7 @@ fn arithmetic_smoke_test() { q_c: FieldElement::zero(), }; - let mut values: BTreeMap = BTreeMap::new(); + let mut values = WitnessMap::new(); values.insert(b, FieldElement::from(2_i128)); values.insert(c, FieldElement::from(1_i128)); values.insert(d, FieldElement::from(1_i128)); diff --git a/acvm/src/pwg/block.rs b/acvm/src/pwg/block.rs index 091d9c231..ebe931f99 100644 --- a/acvm/src/pwg/block.rs +++ b/acvm/src/pwg/block.rs @@ -1,8 +1,8 @@ -use std::collections::{BTreeMap, HashMap}; +use std::collections::HashMap; use acir::{ circuit::opcodes::{BlockId, MemOp}, - native_types::Witness, + native_types::{Witness, WitnessMap}, FieldElement, }; @@ -24,7 +24,7 @@ impl Blocks { &mut self, id: BlockId, trace: &[MemOp], - solved_witness: &mut BTreeMap, + solved_witness: &mut WitnessMap, ) -> Result { let solver = self.blocks.entry(id).or_default(); solver.solve(solved_witness, trace) @@ -54,7 +54,7 @@ impl BlockSolver { // We stop when an operation cannot be resolved fn solve_helper( &mut self, - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, trace: &[MemOp], ) -> Result<(), OpcodeResolutionError> { let missing_assignment = |witness: Option| { @@ -102,7 +102,7 @@ impl BlockSolver { // and converts its result into GateResolution pub(crate) fn solve( &mut self, - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, trace: &[MemOp], ) -> Result { let initial_solved_operations = self.solved_operations; @@ -123,11 +123,9 @@ impl BlockSolver { #[cfg(test)] mod test { - use std::collections::BTreeMap; - use acir::{ circuit::opcodes::{BlockId, MemOp}, - native_types::{Expression, Witness}, + native_types::{Expression, Witness, WitnessMap}, FieldElement, }; @@ -161,7 +159,7 @@ mod test { value: Expression::from(Witness(4)), }); let id = BlockId::default(); - let mut initial_witness = BTreeMap::new(); + let mut initial_witness = WitnessMap::new(); let mut value = FieldElement::zero(); insert_value(&Witness(1), value, &mut initial_witness).unwrap(); value = FieldElement::one(); diff --git a/acvm/src/pwg/directives.rs b/acvm/src/pwg/directives.rs index fb45fd8d6..bafdc3243 100644 --- a/acvm/src/pwg/directives.rs +++ b/acvm/src/pwg/directives.rs @@ -1,8 +1,8 @@ -use std::{cmp::Ordering, collections::BTreeMap}; +use std::cmp::Ordering; use acir::{ circuit::directives::{Directive, LogInfo}, - native_types::Witness, + native_types::{Witness, WitnessMap}, FieldElement, }; use num_bigint::BigUint; @@ -19,7 +19,7 @@ use super::{get_value, insert_value, sorting::route, witness_to_value}; /// /// Returns `Err(OpcodeResolutionError)` if a circuit constraint is unsatisfied. pub fn solve_directives( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, directive: &Directive, ) -> Result { match solve_directives_internal(initial_witness, directive) { @@ -32,7 +32,7 @@ pub fn solve_directives( } fn solve_directives_internal( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, directive: &Directive, ) -> Result<(), OpcodeResolutionError> { match directive { diff --git a/acvm/src/pwg/hash.rs b/acvm/src/pwg/hash.rs index 8a6f0cd4b..c2d8ceaed 100644 --- a/acvm/src/pwg/hash.rs +++ b/acvm/src/pwg/hash.rs @@ -1,15 +1,14 @@ -use acir::{circuit::opcodes::BlackBoxFuncCall, native_types::Witness, FieldElement}; +use acir::{circuit::opcodes::BlackBoxFuncCall, native_types::WitnessMap, FieldElement}; use blake2::{Blake2s, Digest}; use sha2::Sha256; use sha3::Keccak256; -use std::collections::BTreeMap; use crate::{OpcodeResolution, OpcodeResolutionError}; use super::{insert_value, witness_to_value}; pub fn blake2s( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, func_call: &BlackBoxFuncCall, ) -> Result { let hash = generic_hash_256::(initial_witness, func_call)?; @@ -26,7 +25,7 @@ pub fn blake2s( } pub fn sha256( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, func_call: &BlackBoxFuncCall, ) -> Result { let hash = generic_hash_256::(initial_witness, func_call)?; @@ -43,7 +42,7 @@ pub fn sha256( } pub fn keccak256( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, func_call: &BlackBoxFuncCall, ) -> Result { let hash = generic_hash_256::(initial_witness, func_call)?; @@ -60,7 +59,7 @@ pub fn keccak256( } pub fn hash_to_field_128_security( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, func_call: &BlackBoxFuncCall, ) -> Result { let hash = generic_hash_256::(initial_witness, func_call)?; @@ -72,7 +71,7 @@ pub fn hash_to_field_128_security( } fn generic_hash_256( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, func_call: &BlackBoxFuncCall, ) -> Result<[u8; 32], OpcodeResolutionError> { let mut hasher = D::new(); diff --git a/acvm/src/pwg/logic.rs b/acvm/src/pwg/logic.rs index eca868ec9..3500194b7 100644 --- a/acvm/src/pwg/logic.rs +++ b/acvm/src/pwg/logic.rs @@ -1,10 +1,13 @@ use super::{insert_value, witness_to_value}; use crate::{OpcodeResolution, OpcodeResolutionError}; -use acir::{circuit::opcodes::BlackBoxFuncCall, native_types::Witness, BlackBoxFunc, FieldElement}; -use std::collections::BTreeMap; +use acir::{ + circuit::opcodes::BlackBoxFuncCall, + native_types::{Witness, WitnessMap}, + BlackBoxFunc, +}; pub fn solve_logic_opcode( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, func_call: &BlackBoxFuncCall, ) -> Result { match func_call.name { @@ -19,7 +22,7 @@ pub struct LogicSolver; impl LogicSolver { /// Derives the rest of the witness based on the initial low level variables fn solve_logic_gate( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, a: &Witness, b: &Witness, result: Witness, @@ -39,14 +42,14 @@ impl LogicSolver { } pub fn solve_and_gate( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, gate: &BlackBoxFuncCall, ) -> Result { let (a, b, result, num_bits) = extract_input_output(gate); LogicSolver::solve_logic_gate(initial_witness, &a, &b, result, num_bits, false) } pub fn solve_xor_gate( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, gate: &BlackBoxFuncCall, ) -> Result { let (a, b, result, num_bits) = extract_input_output(gate); diff --git a/acvm/src/pwg/oracle.rs b/acvm/src/pwg/oracle.rs index 09339f7e2..3a4fd7176 100644 --- a/acvm/src/pwg/oracle.rs +++ b/acvm/src/pwg/oracle.rs @@ -1,6 +1,4 @@ -use std::collections::BTreeMap; - -use acir::{circuit::opcodes::OracleData, native_types::Witness, FieldElement}; +use acir::{circuit::opcodes::OracleData, native_types::WitnessMap}; use crate::{OpcodeNotSolvable, OpcodeResolution, OpcodeResolutionError}; @@ -11,7 +9,7 @@ pub struct OracleSolver; impl OracleSolver { /// Derives the rest of the witness based on the initial low level variables pub fn solve( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, data: &mut OracleData, ) -> Result { // Set input values diff --git a/acvm/src/pwg/range.rs b/acvm/src/pwg/range.rs index 65757188f..99121c24b 100644 --- a/acvm/src/pwg/range.rs +++ b/acvm/src/pwg/range.rs @@ -1,9 +1,8 @@ use crate::{pwg::witness_to_value, OpcodeResolution, OpcodeResolutionError}; -use acir::{circuit::opcodes::BlackBoxFuncCall, native_types::Witness, BlackBoxFunc, FieldElement}; -use std::collections::BTreeMap; +use acir::{circuit::opcodes::BlackBoxFuncCall, native_types::WitnessMap, BlackBoxFunc}; pub fn solve_range_opcode( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, func_call: &BlackBoxFuncCall, ) -> Result { // TODO: this consistency check can be moved to a general function diff --git a/acvm/src/pwg/signature/ecdsa.rs b/acvm/src/pwg/signature/ecdsa.rs index 2cb0b9a57..bd36d3e17 100644 --- a/acvm/src/pwg/signature/ecdsa.rs +++ b/acvm/src/pwg/signature/ecdsa.rs @@ -1,10 +1,9 @@ -use acir::{circuit::opcodes::BlackBoxFuncCall, native_types::Witness, FieldElement}; -use std::collections::BTreeMap; +use acir::{circuit::opcodes::BlackBoxFuncCall, native_types::WitnessMap, FieldElement}; use crate::{pwg::witness_to_value, OpcodeResolution, OpcodeResolutionError}; pub fn secp256k1_prehashed( - initial_witness: &mut BTreeMap, + initial_witness: &mut WitnessMap, gadget_call: &BlackBoxFuncCall, ) -> Result { let mut inputs_iter = gadget_call.inputs.iter(); From 00356235ef583f3eeae8719f1c0beb925e00ea8f Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Mon, 1 May 2023 16:13:26 -0700 Subject: [PATCH 2/7] avoid fully qualified name --- acir/src/native_types/witness.rs | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/acir/src/native_types/witness.rs b/acir/src/native_types/witness.rs index 0cafb58c8..1bd576bb0 100644 --- a/acir/src/native_types/witness.rs +++ b/acir/src/native_types/witness.rs @@ -1,5 +1,6 @@ use std::{collections::BTreeMap, io::Read, ops::Index}; +use acir_field::FieldElement; use flate2::{ bufread::{DeflateDecoder, DeflateEncoder}, Compression, @@ -37,40 +38,36 @@ impl From for Witness { /// A map from the witnesses in a constraint system to the field element values #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)] -pub struct WitnessMap(BTreeMap); +pub struct WitnessMap(BTreeMap); impl WitnessMap { pub fn new() -> Self { Self(BTreeMap::new()) } - pub fn get(&self, witness: &Witness) -> Option<&acir_field::FieldElement> { + pub fn get(&self, witness: &Witness) -> Option<&FieldElement> { self.0.get(witness) } - pub fn get_index(&self, index: u32) -> Option<&acir_field::FieldElement> { + pub fn get_index(&self, index: u32) -> Option<&FieldElement> { self.0.get(&index.into()) } pub fn contains_key(&self, key: &Witness) -> bool { self.0.contains_key(key) } - pub fn insert( - &mut self, - key: Witness, - value: acir_field::FieldElement, - ) -> Option { + pub fn insert(&mut self, key: Witness, value: FieldElement) -> Option { self.0.insert(key, value) } } impl Index<&Witness> for WitnessMap { - type Output = acir_field::FieldElement; + type Output = FieldElement; fn index(&self, index: &Witness) -> &Self::Output { &self.0[index] } } -impl From> for WitnessMap { - fn from(value: BTreeMap) -> Self { +impl From> for WitnessMap { + fn from(value: BTreeMap) -> Self { Self(value) } } @@ -94,7 +91,7 @@ impl From<&[u8]> for WitnessMap { } } -impl From for Vec { +impl From for Vec { fn from(val: WitnessMap) -> Self { val.0.into_values().collect() } From 392b19e994b2214bc79569247872d45f461802da Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Tue, 2 May 2023 13:39:23 -0700 Subject: [PATCH 3/7] Implement IntoIterator for WitnessMap --- acir/src/native_types/witness.rs | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/acir/src/native_types/witness.rs b/acir/src/native_types/witness.rs index 1bd576bb0..8098bc513 100644 --- a/acir/src/native_types/witness.rs +++ b/acir/src/native_types/witness.rs @@ -1,4 +1,8 @@ -use std::{collections::BTreeMap, io::Read, ops::Index}; +use std::{ + collections::{btree_map, BTreeMap}, + io::Read, + ops::Index, +}; use acir_field::FieldElement; use flate2::{ @@ -66,6 +70,25 @@ impl Index<&Witness> for WitnessMap { } } +pub struct IntoIter(btree_map::IntoIter); + +impl Iterator for IntoIter { + type Item = (Witness, FieldElement); + + fn next(&mut self) -> Option { + self.0.next() + } +} + +impl IntoIterator for WitnessMap { + type Item = (Witness, FieldElement); + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter(self.0.into_iter()) + } +} + impl From> for WitnessMap { fn from(value: BTreeMap) -> Self { Self(value) @@ -90,9 +113,3 @@ impl From<&[u8]> for WitnessMap { Self(rmp_serde::from_slice(buf_d.as_slice()).unwrap()) } } - -impl From for Vec { - fn from(val: WitnessMap) -> Self { - val.0.into_values().collect() - } -} From 01bf81334ad4fc813179dcf97196d23b5d51d565 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Tue, 2 May 2023 13:42:03 -0700 Subject: [PATCH 4/7] Make WitnessMap its own module --- acir/src/native_types/mod.rs | 3 +- acir/src/native_types/witness.rs | 85 --------------------------- acir/src/native_types/witness_map.rs | 88 ++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 86 deletions(-) create mode 100644 acir/src/native_types/witness_map.rs diff --git a/acir/src/native_types/mod.rs b/acir/src/native_types/mod.rs index 449c04b3e..c54e63095 100644 --- a/acir/src/native_types/mod.rs +++ b/acir/src/native_types/mod.rs @@ -1,6 +1,7 @@ mod expression; mod witness; +mod witness_map; pub use expression::Expression; pub use witness::Witness; -pub use witness::WitnessMap; +pub use witness_map::WitnessMap; diff --git a/acir/src/native_types/witness.rs b/acir/src/native_types/witness.rs index 8098bc513..3e9beb510 100644 --- a/acir/src/native_types/witness.rs +++ b/acir/src/native_types/witness.rs @@ -1,14 +1,3 @@ -use std::{ - collections::{btree_map, BTreeMap}, - io::Read, - ops::Index, -}; - -use acir_field::FieldElement; -use flate2::{ - bufread::{DeflateDecoder, DeflateEncoder}, - Compression, -}; use serde::{Deserialize, Serialize}; // Witness might be a misnomer. This is an index that represents the position a witness will take @@ -39,77 +28,3 @@ impl From for Witness { Self(value) } } - -/// A map from the witnesses in a constraint system to the field element values -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)] -pub struct WitnessMap(BTreeMap); - -impl WitnessMap { - pub fn new() -> Self { - Self(BTreeMap::new()) - } - pub fn get(&self, witness: &Witness) -> Option<&FieldElement> { - self.0.get(witness) - } - pub fn get_index(&self, index: u32) -> Option<&FieldElement> { - self.0.get(&index.into()) - } - pub fn contains_key(&self, key: &Witness) -> bool { - self.0.contains_key(key) - } - pub fn insert(&mut self, key: Witness, value: FieldElement) -> Option { - self.0.insert(key, value) - } -} - -impl Index<&Witness> for WitnessMap { - type Output = FieldElement; - - fn index(&self, index: &Witness) -> &Self::Output { - &self.0[index] - } -} - -pub struct IntoIter(btree_map::IntoIter); - -impl Iterator for IntoIter { - type Item = (Witness, FieldElement); - - fn next(&mut self) -> Option { - self.0.next() - } -} - -impl IntoIterator for WitnessMap { - type Item = (Witness, FieldElement); - type IntoIter = IntoIter; - - fn into_iter(self) -> Self::IntoIter { - IntoIter(self.0.into_iter()) - } -} - -impl From> for WitnessMap { - fn from(value: BTreeMap) -> Self { - Self(value) - } -} - -impl From for Vec { - fn from(val: WitnessMap) -> Self { - let buf = rmp_serde::to_vec(&val).unwrap(); - let mut deflater = DeflateEncoder::new(buf.as_slice(), Compression::best()); - let mut buf_c = Vec::new(); - deflater.read_to_end(&mut buf_c).unwrap(); - buf_c - } -} - -impl From<&[u8]> for WitnessMap { - fn from(bytes: &[u8]) -> Self { - let mut deflater = DeflateDecoder::new(bytes); - let mut buf_d = Vec::new(); - deflater.read_to_end(&mut buf_d).unwrap(); - Self(rmp_serde::from_slice(buf_d.as_slice()).unwrap()) - } -} diff --git a/acir/src/native_types/witness_map.rs b/acir/src/native_types/witness_map.rs new file mode 100644 index 000000000..c43a9de41 --- /dev/null +++ b/acir/src/native_types/witness_map.rs @@ -0,0 +1,88 @@ +use std::{ + collections::{btree_map, BTreeMap}, + io::Read, + ops::Index, +}; + +use acir_field::FieldElement; +use flate2::{ + bufread::{DeflateDecoder, DeflateEncoder}, + Compression, +}; +use serde::{Deserialize, Serialize}; + +use crate::native_types::Witness; + +/// A map from the witnesses in a constraint system to the field element values +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)] +pub struct WitnessMap(BTreeMap); + +impl WitnessMap { + pub fn new() -> Self { + Self(BTreeMap::new()) + } + pub fn get(&self, witness: &Witness) -> Option<&FieldElement> { + self.0.get(witness) + } + pub fn get_index(&self, index: u32) -> Option<&FieldElement> { + self.0.get(&index.into()) + } + pub fn contains_key(&self, key: &Witness) -> bool { + self.0.contains_key(key) + } + pub fn insert(&mut self, key: Witness, value: FieldElement) -> Option { + self.0.insert(key, value) + } +} + +impl Index<&Witness> for WitnessMap { + type Output = FieldElement; + + fn index(&self, index: &Witness) -> &Self::Output { + &self.0[index] + } +} + +pub struct IntoIter(btree_map::IntoIter); + +impl Iterator for IntoIter { + type Item = (Witness, FieldElement); + + fn next(&mut self) -> Option { + self.0.next() + } +} + +impl IntoIterator for WitnessMap { + type Item = (Witness, FieldElement); + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter(self.0.into_iter()) + } +} + +impl From> for WitnessMap { + fn from(value: BTreeMap) -> Self { + Self(value) + } +} + +impl From for Vec { + fn from(val: WitnessMap) -> Self { + let buf = rmp_serde::to_vec(&val).unwrap(); + let mut deflater = DeflateEncoder::new(buf.as_slice(), Compression::best()); + let mut buf_c = Vec::new(); + deflater.read_to_end(&mut buf_c).unwrap(); + buf_c + } +} + +impl From<&[u8]> for WitnessMap { + fn from(bytes: &[u8]) -> Self { + let mut deflater = DeflateDecoder::new(bytes); + let mut buf_d = Vec::new(); + deflater.read_to_end(&mut buf_d).unwrap(); + Self(rmp_serde::from_slice(buf_d.as_slice()).unwrap()) + } +} From 13f2a35795656c5f6c6c7dda76da71009e295fa4 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Thu, 4 May 2023 12:48:35 -0700 Subject: [PATCH 5/7] move thiserror to the workspace --- Cargo.toml | 2 ++ acir/Cargo.toml | 1 + acvm/Cargo.toml | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a9b9deca2..adac33497 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,4 +16,6 @@ stdlib = { package = "acvm_stdlib", version = "0.10.3", path = "stdlib", default num-bigint = "0.4" num-traits = "0.2" +thiserror = "1.0.21" + serde = { version = "1.0.136", features = ["derive"] } diff --git a/acir/Cargo.toml b/acir/Cargo.toml index 0c861d2ad..8abaed71c 100644 --- a/acir/Cargo.toml +++ b/acir/Cargo.toml @@ -12,6 +12,7 @@ rust-version.workspace = true [dependencies] acir_field.workspace = true serde.workspace = true +thiserror.workspace = true rmp-serde = "1.1.0" flate2 = "1.0.24" diff --git a/acvm/Cargo.toml b/acvm/Cargo.toml index f2cc02db2..19ac41bf8 100644 --- a/acvm/Cargo.toml +++ b/acvm/Cargo.toml @@ -12,6 +12,7 @@ rust-version.workspace = true [dependencies] num-bigint.workspace = true num-traits.workspace = true +thiserror.workspace = true acir.workspace = true stdlib.workspace = true @@ -28,7 +29,6 @@ k256 = { version = "0.7.2", features = [ "arithmetic", ] } indexmap = "1.7.0" -thiserror = "1.0.21" [features] default = ["bn254"] From 1baaf3106808d1974d84280d09fbe805943818ce Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Thu, 4 May 2023 12:48:52 -0700 Subject: [PATCH 6/7] Implement try_from instead of from where there was unwrapping --- acir/src/native_types/witness_map.rs | 35 +++++++++++++++++++++------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/acir/src/native_types/witness_map.rs b/acir/src/native_types/witness_map.rs index c43a9de41..fbe84080f 100644 --- a/acir/src/native_types/witness_map.rs +++ b/acir/src/native_types/witness_map.rs @@ -10,9 +10,22 @@ use flate2::{ Compression, }; use serde::{Deserialize, Serialize}; +use thiserror::Error; use crate::native_types::Witness; +#[derive(Debug, Error)] +pub enum WitnessMapError { + #[error(transparent)] + MsgpackEncodeError(#[from] rmp_serde::encode::Error), + + #[error(transparent)] + MsgpackDecodeError(#[from] rmp_serde::decode::Error), + + #[error(transparent)] + DeflateError(#[from] std::io::Error), +} + /// A map from the witnesses in a constraint system to the field element values #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)] pub struct WitnessMap(BTreeMap); @@ -68,21 +81,25 @@ impl From> for WitnessMap { } } -impl From for Vec { - fn from(val: WitnessMap) -> Self { - let buf = rmp_serde::to_vec(&val).unwrap(); +impl TryFrom for Vec { + type Error = WitnessMapError; + + fn try_from(val: WitnessMap) -> Result { + let buf = rmp_serde::to_vec(&val)?; let mut deflater = DeflateEncoder::new(buf.as_slice(), Compression::best()); let mut buf_c = Vec::new(); - deflater.read_to_end(&mut buf_c).unwrap(); - buf_c + deflater.read_to_end(&mut buf_c)?; + Ok(buf_c) } } -impl From<&[u8]> for WitnessMap { - fn from(bytes: &[u8]) -> Self { +impl TryFrom<&[u8]> for WitnessMap { + type Error = WitnessMapError; + + fn try_from(bytes: &[u8]) -> Result { let mut deflater = DeflateDecoder::new(bytes); let mut buf_d = Vec::new(); - deflater.read_to_end(&mut buf_d).unwrap(); - Self(rmp_serde::from_slice(buf_d.as_slice()).unwrap()) + deflater.read_to_end(&mut buf_d)?; + Ok(Self(rmp_serde::from_slice(buf_d.as_slice())?)) } } From f42f0e5e9e2dfb2116edd170f61d68a519ec886f Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Thu, 4 May 2023 12:53:35 -0700 Subject: [PATCH 7/7] clippy --- acir/src/native_types/witness_map.rs | 6 +++--- acvm/src/pwg/directives.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/acir/src/native_types/witness_map.rs b/acir/src/native_types/witness_map.rs index fbe84080f..9f814eab3 100644 --- a/acir/src/native_types/witness_map.rs +++ b/acir/src/native_types/witness_map.rs @@ -17,13 +17,13 @@ use crate::native_types::Witness; #[derive(Debug, Error)] pub enum WitnessMapError { #[error(transparent)] - MsgpackEncodeError(#[from] rmp_serde::encode::Error), + MsgpackEncode(#[from] rmp_serde::encode::Error), #[error(transparent)] - MsgpackDecodeError(#[from] rmp_serde::decode::Error), + MsgpackDecode(#[from] rmp_serde::decode::Error), #[error(transparent)] - DeflateError(#[from] std::io::Error), + Deflate(#[from] std::io::Error), } /// A map from the witnesses in a constraint system to the field element values diff --git a/acvm/src/pwg/directives.rs b/acvm/src/pwg/directives.rs index bafdc3243..0a87ae5d2 100644 --- a/acvm/src/pwg/directives.rs +++ b/acvm/src/pwg/directives.rs @@ -2,7 +2,7 @@ use std::cmp::Ordering; use acir::{ circuit::directives::{Directive, LogInfo}, - native_types::{Witness, WitnessMap}, + native_types::WitnessMap, FieldElement, }; use num_bigint::BigUint;