Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opt: memory: linear for [group] const values #207

Merged
merged 3 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions examples/ZoKrates/pf/const_linear_lookup.zok
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
struct T {
field v
field w
field x
field y
field z
}

const T[9] TABLE = [
T { v: 1, w: 12, x: 13, y: 14, z: 15 },
T { v: 2, w: 22, x: 23, y: 24, z: 25 },
T { v: 3, w: 32, x: 33, y: 34, z: 35 },
T { v: 4, w: 42, x: 43, y: 44, z: 45 },
T { v: 5, w: 52, x: 53, y: 54, z: 55 },
T { v: 6, w: 62, x: 63, y: 64, z: 65 },
T { v: 7, w: 72, x: 73, y: 74, z: 75 },
T { v: 8, w: 82, x: 83, y: 84, z: 85 },
T { v: 9, w: 92, x: 93, y: 94, z: 95 }
]

def main(field i) -> field:
T t = TABLE[i]
return t.v + t.w + t.x + t.y + t.z
5 changes: 5 additions & 0 deletions examples/circ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ fn main() {
"Final R1cs rounds: {}",
prover_data.precompute.stage_sizes().count() - 1
);
println!(
"Final Witext steps: {}, arguments: {}",
prover_data.precompute.num_steps(),
prover_data.precompute.num_step_args()
);
match action {
ProofAction::Count => (),
#[cfg(feature = "bellman")]
Expand Down
2 changes: 2 additions & 0 deletions examples/opa_bench.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(clippy::mutable_key_type)]

use circ::cfg::clap::{self, Parser};
use circ::ir::term::*;
use circ::target::aby::assignment::ilp;
Expand Down
1 change: 1 addition & 0 deletions scripts/zokrates_test.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function pf_test_isolate {
}

r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120
r1cs_test_count ./examples/ZoKrates/pf/const_linear_lookup.zok 20
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOrderCheck.zok
Expand Down
2 changes: 1 addition & 1 deletion src/circify/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ pub trait Embeddable {
/// * `name`: the name
/// * `visibility`: who knows it
/// * `precompute`: an optional term for pre-computing the values of this input. If a party
/// knows the inputs to the precomputation, they can use the precomputation.
/// knows the inputs to the precomputation, they can use the precomputation.
fn declare_input(
&self,
ctx: &mut CirCtx,
Expand Down
4 changes: 0 additions & 4 deletions src/ir/opt/chall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
//!
//! Each challenge term c that depends on t is replaced with a variable v.
//! Let t' denote a rewritten term.
//!
//! Rules:
//! * round(v) >=
//! round(v
use log::{debug, trace};

use std::cell::RefCell;
Expand Down
46 changes: 41 additions & 5 deletions src/ir/opt/mem/lin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,47 @@ impl RewritePass for Linearizer {
.unwrap_or_else(|| a.val.default_term()),
)
} else {
let mut fields = (0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]);
let first = fields.next().unwrap();
Some(a.key.elems_iter().take(a.size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| {
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc]
}))
let value_sort = check(tup).as_tuple()[0].clone();
if value_sort.is_group() {
// if values are a group
// then emit v0 + ite(idx == i1, v1 - v0, 0) + ... it(idx = iN, vN - v0, 0)
// where +, -, 0 are defined by the group.
//
// we do this because if the values are constant, then the above sum is
// linear, which is very nice for most backends.
let mut fields =
(0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]);
let first = fields.next().unwrap();
let zero = value_sort.group_identity();
Some(
value_sort.group_add_nary(
std::iter::once(first.clone())
.chain(
a.key
.elems_iter()
.take(a.size)
.skip(1)
.zip(fields)
.map(|(idx_c, field)| {
term![Op::Ite;
term![Op::Eq; idx.clone(), idx_c],
value_sort.group_sub(field, first.clone()),
zero.clone()
]
}),
)
.collect(),
),
)
} else {
// otherwise, ite(idx == iN, vN, ... ite(idx == i1, v1, v0) ... )
let mut fields =
(0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]);
let first = fields.next().unwrap();
Some(a.key.elems_iter().take(a.size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| {
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc]
}))
}
}
} else {
unreachable!()
Expand Down
11 changes: 9 additions & 2 deletions src/ir/opt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ pub enum Opt {
pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I) -> Computations {
for c in cs.comps.values() {
trace!("Before all opts: {}", text::serialize_computation(c));
info!("Before all opts: {} terms", c.stats().main.n_terms);
info!(
"Before all opts: {} terms",
c.stats().main.n_terms + c.stats().prec.n_terms
);
debug!("Before all opts: {:#?}", c.stats());
debug!("Before all opts: {:#?}", c.detailed_stats());
}
Expand Down Expand Up @@ -167,7 +170,11 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I)
fits_in_bits_ip::fits_in_bits_ip(c);
}
}
info!("After {:?}: {} terms", i, c.stats().main.n_terms);
info!(
"After {:?}: {} terms",
i,
c.stats().main.n_terms + c.stats().prec.n_terms
);
debug!("After {:?}: {:#?}", i, c.stats());
trace!("After {:?}: {}", i, text::serialize_computation(c));
#[cfg(debug_assertions)]
Expand Down
87 changes: 87 additions & 0 deletions src/ir/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,93 @@ impl Sort {
pub fn is_scalar(&self) -> bool {
!matches!(self, Sort::Tuple(..) | Sort::Array(..) | Sort::Map(..))
}

/// Is this sort a group?
pub fn is_group(&self) -> bool {
match self {
Sort::BitVector(_) | Sort::Int | Sort::Field(_) | Sort::Bool => true,
Sort::F32 | Sort::F64 | Sort::Array(_) | Sort::Map(_) => false,
Sort::Tuple(fields) => fields.iter().all(|f| f.is_group()),
}
}

/// The (n-ary) group operation for these terms.
pub fn group_add_nary(&self, ts: Vec<Term>) -> Term {
debug_assert!(ts.iter().all(|t| &check(t) == self));
match self {
Sort::BitVector(_) => term(BV_ADD, ts),
Sort::Bool => term(XOR, ts),
Sort::Field(_) => term(PF_ADD, ts),
Sort::Int => term(INT_ADD, ts),
Sort::Tuple(sorts) => term(
Op::Tuple,
sorts
.iter()
.enumerate()
.map(|(i, sort)| {
sort.group_add_nary(
ts.iter()
.map(|t| term(Op::Field(i), vec![t.clone()]))
.collect(),
)
})
.collect(),
),
_ => panic!("Not a group: {}", self),
}
}

/// Group inverse
pub fn group_neg(&self, t: Term) -> Term {
debug_assert_eq!(&check(&t), self);
match self {
Sort::BitVector(_) => term(BV_NEG, vec![t]),
Sort::Bool => term(NOT, vec![t]),
Sort::Field(_) => term(PF_NEG, vec![t]),
Sort::Int => term(
INT_MUL,
vec![leaf_term(Op::new_const(Value::Int(Integer::from(-1i8)))), t],
),
Sort::Tuple(sorts) => term(
Op::Tuple,
sorts
.iter()
.enumerate()
.map(|(i, sort)| sort.group_neg(term(Op::Field(i), vec![t.clone()])))
.collect(),
),
_ => panic!("Not a group: {}", self),
}
}

/// Group identity
pub fn group_identity(&self) -> Term {
match self {
Sort::BitVector(n_bits) => bv_lit(0, *n_bits),
Sort::Bool => bool_lit(false),
Sort::Field(f) => pf_lit(f.new_v(0)),
Sort::Int => leaf_term(Op::new_const(Value::Int(Integer::from(0i8)))),
Sort::Tuple(sorts) => term(
Op::Tuple,
sorts.iter().map(|sort| sort.group_identity()).collect(),
),
_ => panic!("Not a group: {}", self),
}
}

/// Group operation
pub fn group_add(&self, s: Term, t: Term) -> Term {
debug_assert_eq!(&check(&s), self);
debug_assert_eq!(&check(&t), self);
self.group_add_nary(vec![s, t])
}

/// Group elimination
pub fn group_sub(&self, s: Term, t: Term) -> Term {
debug_assert_eq!(&check(&s), self);
debug_assert_eq!(&check(&t), self);
self.group_add(s, self.group_neg(t))
}
}

mod hc {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#![warn(missing_docs)]
#![deny(warnings)]
#![allow(rustdoc::private_intra_doc_links)]
#![allow(clippy::mutable_key_type)]

#[macro_use]
pub mod ir;
Expand Down
2 changes: 0 additions & 2 deletions src/target/aby/trans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,8 +904,6 @@ pub fn to_aby(cs: Computations, path: &Path, lang: &str, cm: &str, ss: &str) {
panic!("Unsupported sharing scheme: {}", ss);
}
};
#[cfg(feature = "bench")]
println!("LOG: Assignment {}: {:?}", name, now.elapsed());
s_map.insert(name.to_string(), assignments);
}

Expand Down
2 changes: 1 addition & 1 deletion src/target/r1cs/opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ fn constantly_true((a, b, c): &(Lc, Lc, Lc)) -> bool {
/// ## Parameters
///
/// * `lc_size_thresh`: the maximum size LC (number of non-constant monomials) that will be used
/// for propagation. `None` means no size limit.
/// for propagation. `None` means no size limit.
pub fn reduce_linearities(r1cs: R1cs, cfg: &CircCfg) -> R1cs {
let mut r = LinReducer::new(r1cs, cfg.r1cs.lc_elim_thresh).run();
r.update_stats();
Expand Down
10 changes: 10 additions & 0 deletions src/target/r1cs/wit_comp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ impl StagedWitComp {
pub fn num_stage_inputs(&self, n: usize) -> usize {
self.stages[n].inputs.len()
}

/// Number of steps
pub fn num_steps(&self) -> usize {
self.steps.len()
}

/// Number of step arguments
pub fn num_step_args(&self) -> usize {
self.step_args.len()
}
}

/// Evaluator interface
Expand Down
20 changes: 8 additions & 12 deletions src/target/smt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,7 @@ pub fn check_sat(t: &Term) -> bool {
let mut solver = make_solver((), false, false);
for c in PostOrderIter::new(t.clone()) {
if let Op::Var(v) = &c.op() {
solver
.declare_const(&SmtSymDisp(&*v.name), &v.sort)
.unwrap();
solver.declare_const(SmtSymDisp(&*v.name), &v.sort).unwrap();
}
}
assert!(check(t) == Sort::Bool);
Expand All @@ -380,9 +378,7 @@ fn get_model_solver(t: &Term, inc: bool) -> rsmt2::Solver<Parser> {
//solver.path_tee("solver_com").unwrap();
for c in PostOrderIter::new(t.clone()) {
if let Op::Var(v) = &c.op() {
solver
.declare_const(&SmtSymDisp(&*v.name), &v.sort)
.unwrap();
solver.declare_const(SmtSymDisp(&*v.name), &v.sort).unwrap();
}
}
assert!(check(t) == Sort::Bool);
Expand Down Expand Up @@ -590,13 +586,13 @@ mod test {
let mut solver = make_solver((), false, false);
for (v, val) in vs {
let s = val.sort();
solver.declare_const(&SmtSymDisp(&v), &s).unwrap();
solver.declare_const(SmtSymDisp(&v), &s).unwrap();
solver
.assert(&term![Op::Eq; var(v.to_string(), s), const_(val.clone())])
.assert(term![Op::Eq; var(v.to_string(), s), const_(val.clone())])
.unwrap();
}
let val = eval(&t, vs);
solver.assert(&term![Op::Eq; t, const_(val)]).unwrap();
solver.assert(term![Op::Eq; t, const_(val)]).unwrap();
solver.check_sat().unwrap()
}

Expand All @@ -605,14 +601,14 @@ mod test {
let mut solver = make_solver((), false, false);
for (v, val) in vs {
let s = val.sort();
solver.declare_const(&SmtSymDisp(&v), &s).unwrap();
solver.declare_const(SmtSymDisp(&v), &s).unwrap();
solver
.assert(&term![Op::Eq; var(v.to_string(), s), const_(val.clone())])
.assert(term![Op::Eq; var(v.to_string(), s), const_(val.clone())])
.unwrap();
}
let val = eval(&t, vs);
solver
.assert(&term![Op::Not; term![Op::Eq; t, const_(val)]])
.assert(term![Op::Not; term![Op::Eq; t, const_(val)]])
.unwrap();
solver.check_sat().unwrap()
}
Expand Down
Loading