Skip to content

Commit 9ac4c26

Browse files
authored
further fix OOB indices in lin (#193)
test them in lin and tr
1 parent 2bf4f82 commit 9ac4c26

15 files changed

+237
-35
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
2+
(let (
3+
(x #f6)
4+
) false ; ignored
5+
))
6+
7+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
2+
(let (
3+
(x #f6)
4+
(return #f6)
5+
) false ; ignored
6+
))
7+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
def main(field x) -> field:
2+
field[25] A = [0; 25]
3+
for field counter in 0..5 do
4+
cond_store(A, counter - 1, x, counter > 1)
5+
endfor
6+
7+
return A[x]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
2+
(let (
3+
(x #f6)
4+
) false ; ignored
5+
))
6+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
2+
(let (
3+
(x #f6)
4+
(return #f6)
5+
) false ; ignored
6+
))
7+
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
def main(field x) -> field:
22
transcript field[25] A = [0; 25]
33
for field counter in 0..30 do
4-
cond_store(A, counter, x, counter < x)
4+
bool oob = counter < x
5+
cond_store(A, if oob then counter else 0 fi, x, oob)
56
endfor
67

78
return A[x]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
2+
(let (
3+
(x #f6)
4+
) false ; ignored
5+
))
6+
7+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
2+
(let (
3+
(x #f6)
4+
(return #f6)
5+
) false ; ignored
6+
))
7+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
def main(field x) -> field:
2+
transcript field[25] A = [0; 25]
3+
for field counter in 0..5 do
4+
cond_store(A, counter - 1, x, counter > 1)
5+
endfor
6+
7+
return A[x]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
2+
(let (
3+
(x #f6)
4+
) false ; ignored
5+
))
6+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
2+
(let (
3+
(x #f6)
4+
(return #f6)
5+
) false ; ignored
6+
))
7+

examples/ZoKrates/pf/mem/sparse4.zok

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// Examples of different SHA-esque operations being performed using sparse form
2+
// and lookup arguments
3+
4+
5+
// python -c "b=4;dtos=lambda d: sum(4**i*int(b) for i, b in enumerate(bin(d)[2:][::-1]));print(f'const transcript field[{2**b}] D_TO_S_{b} = [', ', '.join(str(dtos(i)) for i in range(2**b)), ']', sep='')"
6+
const transcript field[16] D_TO_S_4 = [0, 1, 4, 5, 16, 17, 20, 21, 64, 65, 68, 69, 80, 81, 84, 85]
7+
8+
const transcript field[8] D_TO_S_3 = [0, 1, 4, 5, 16, 17, 20, 21]
9+
10+
const transcript field[8] D_3 = [0, 1, 2, 3, 4, 5, 6, 7]
11+
12+
// python -c "b=4;dtos=lambda d: sum(4**i*int(b) for i, b in enumerate(bin(d)[2:][::-1]));print(f'const field S_ONES_{b} = {dtos(2**b-1)}');print(f'const field D_ONES_{b} = {2**b-1}')"
13+
const field S_ONES_4 = 85
14+
const field D_ONES_4 = 15
15+
16+
from "EMBED" import unpack, value_in_array, reverse_lookup, fits_in_bits
17+
18+
// split a number into (unchecked) high and low bits
19+
def unsafe_split<LOW_BITS,HIGH_BITS>(field x) -> field[2]:
20+
bool[LOW_BITS+HIGH_BITS] bits = unpack(x)
21+
field low = 0
22+
field high = 0
23+
for u32 i in 0..LOW_BITS do
24+
low = low + 2 ** i * (if bits[LOW_BITS+HIGH_BITS-1-i] then 1 else 0 fi)
25+
endfor
26+
for u32 i in LOW_BITS..HIGH_BITS do
27+
high = high + 2 ** i * (if bits[LOW_BITS+HIGH_BITS-1-i] then 1 else 0 fi)
28+
endfor
29+
return [low, high]
30+
31+
// split a 2N bit number into (unchecked) even and odd bits (in sparse form)
32+
def unsafe_separate_sparse<N>(field x) -> field[2]:
33+
bool[2*N] bits = unpack(x)
34+
field even = 0
35+
field odd = 0
36+
for u32 i in 0..N do
37+
even = even + 4 ** i * (if bits[2*N-1-(2*i)] then 1 else 0 fi)
38+
odd = odd + 4 ** i * (if bits[2*N-1-(2*i+1)] then 1 else 0 fi)
39+
endfor
40+
return [even, odd]
41+
42+
struct Dual {
43+
field s
44+
field d
45+
}
46+
47+
// convert a dense 8-bit value to dual form; ensures the value fits in 8 bits.
48+
def dense_to_dual_4(field x) -> Dual:
49+
field s = D_TO_S_4[x]
50+
return Dual {s: s, d: x}
51+
52+
// get the even bits of a 16-bit value in dual form; ensures the value fits in 16 bits.
53+
def split_even_dual_4(field x) -> Dual:
54+
unsafe witness field[2] split = unsafe_separate_sparse::<8>(x)
55+
field even = split[0]
56+
field odd = split[1]
57+
assert(x == 2 * odd + even)
58+
field even_d = reverse_lookup(D_TO_S_4, even)
59+
assert(value_in_array(odd, D_TO_S_4))
60+
return Dual { s: even, d: even_d }
61+
62+
// get the odd bits of a 16-bit value in dual form; ensures the value fits in 16 bits.
63+
def split_odd_dual_4(field x) -> Dual:
64+
unsafe witness field[2] split = unsafe_separate_sparse::<8>(x)
65+
// field even = split[0]
66+
field odd = split[1]
67+
field even = x - 2 * odd
68+
// assert(x == 2 * odd + even)
69+
field odd_d = reverse_lookup(D_TO_S_4, odd)
70+
assert(value_in_array(even, D_TO_S_4))
71+
return Dual { s: odd, d: odd_d }
72+
73+
// get the even and odd bits of a 16-bit value in dual form; ensures the value fits in 16 bits.
74+
def split_both_dual_4(field x) -> Dual[2]:
75+
unsafe witness field[2] split = unsafe_separate_sparse::<8>(x)
76+
field even = split[0]
77+
field odd = split[1]
78+
field odd_d = reverse_lookup(D_TO_S_4, odd)
79+
field even_d = reverse_lookup(D_TO_S_4, even)
80+
return [Dual { s: even, d: even_d }, Dual { s: odd, d: odd_d }]
81+
82+
// expected cost: 3 observed: 5
83+
def and_4(Dual x, Dual y) -> Dual:
84+
return split_odd_dual_4(x.s + y.s)
85+
86+
def maj_4(Dual x, Dual y, Dual z) -> Dual:
87+
return split_odd_dual_4(x.s + y.s + z.s)
88+
89+
def xor_4(Dual x, Dual y, Dual z) -> Dual:
90+
return split_even_dual_4(x.s + y.s + z.s)
91+
92+
def not_4(Dual x) -> Dual:
93+
return Dual { s: S_ONES_4 - x.s, d: D_ONES_4 - x.d }
94+
95+
def or_4(Dual x, Dual y) -> Dual:
96+
return not_4(and_4(not_4(x), not_4(y)))
97+
98+
// split s into 8 low bits and 3 high bits, and return the low bits
99+
// in dual form.
100+
def normalize_sum_4(field s) -> Dual:
101+
unsafe witness field[2] split = unsafe_split::<8, 3>(s)
102+
field low = split[0]
103+
field high = split[1]
104+
assert(value_in_array(high, D_3))
105+
return dense_to_dual_4(low)
106+
107+
// table costs:
108+
// 16 + 16 + 8 = 40
109+
//do a bitwise AND.
110+
def main(private field dense_x, private field dense_y) -> field:
111+
Dual z = dense_to_dual_4(0)
112+
Dual x = dense_to_dual_4(dense_x) // 1010 (10)
113+
Dual y = dense_to_dual_4(dense_y) // 1001 (9)
114+
Dual a = and_4(x, y) // 1000 (8)
115+
for field i in 0..10 do
116+
a = and_4(a, y) // idempotent
117+
endfor
118+
Dual b = or_4(x, y) // 1011 (11)
119+
Dual s = normalize_sum_4(b.d + a.d) // 0011 (3)
120+
return s.d
121+
// return reverse_lookup(D_TO_S_4, dense_x) * dense_y
122+
// return reverse_lookup(D_TO_S_4, dense_x) * reverse_lookup(D_TO_S_4, dense_y)
123+
// return dense_x * dense_y
124+
125+
126+
127+

scripts/ram_test.zsh

+2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ transcript_type_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok "covering ROM"
7171
# A=400; N=20; L=2; expected cost ~= N + A(L+1) = 1220
7272
cs_count_test ./examples/ZoKrates/pf/mem/rom.zok 1230
7373

74+
ram_test ./examples/ZoKrates/pf/mem/2024_05_31_benny_bug_tr.zok mirage ""
75+
ram_test ./examples/ZoKrates/pf/mem/2024_05_24_benny_bug_tr.zok mirage ""
7476
ram_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
7577
ram_test ./examples/ZoKrates/pf/mem/volatile.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
7678
# waksman is broken for non-scalar array values

scripts/zokrates_test.zsh

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ function pf_test_isolate {
7272
done
7373
}
7474

75+
pf_test 2024_05_24_benny_bug
76+
pf_test 2024_05_31_benny_bug
7577
r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120
7678
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok
7779
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok
@@ -86,8 +88,6 @@ r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/utils/casts/bool_128_to_
8688
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsScalarMult.zok
8789
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R20.zok
8890
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/hashes/pedersen/512bit.zok
89-
r1cs_test ./examples/ZoKrates/pf/2024_05_24_benny_bug.zok
90-
r1cs_test ./examples/ZoKrates/pf/2024_05_24_benny_bug_tr.zok
9191

9292
pf_test_only_pf sha_temp1
9393
pf_test_only_pf sha_rot

src/ir/opt/mem/lin.rs

+36-32
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,19 @@ impl RewritePass for Linearizer {
5252
let cs = rewritten_children();
5353
let idx = &cs[1];
5454
let tup = &cs[0];
55-
if let Sort::Array(key_sort, val_sort, size) = check(&orig.cs()[0]) {
56-
assert!(size > 0);
55+
if let Sort::Array(key_sort, val_sort, sz) = check(&orig.cs()[0]) {
56+
assert!(sz > 0);
5757
if idx.is_const() {
58-
let idx_usize = extras::as_uint_constant(idx).unwrap().to_usize().unwrap();
59-
if idx_usize < size {
60-
Some(term![Op::Field(idx_usize); tup.clone()])
61-
} else {
62-
Some(val_sort.default_term())
63-
}
58+
Some(
59+
extras::as_uint_constant(idx)
60+
.and_then(|cidx| cidx.to_usize())
61+
.and_then(|u| (u < sz).then_some(term![Op::Field(u); tup.clone()]))
62+
.unwrap_or_else(|| val_sort.default_term()),
63+
)
6464
} else {
65-
let mut fields = (0..size).map(|idx| term![Op::Field(idx); tup.clone()]);
65+
let mut fields = (0..sz).map(|idx| term![Op::Field(idx); tup.clone()]);
6666
let first = fields.next().unwrap();
67-
Some(key_sort.elems_iter().take(size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| {
67+
Some(key_sort.elems_iter().take(sz).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| {
6868
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc]
6969
}))
7070
}
@@ -77,20 +77,23 @@ impl RewritePass for Linearizer {
7777
let tup = &cs[0];
7878
let idx = &cs[1];
7979
let val = &cs[2];
80-
if let Sort::Array(key_sort, _, size) = check(&orig.cs()[0]) {
81-
assert!(size > 0);
80+
if let Sort::Array(key_sort, _, sz) = check(&orig.cs()[0]) {
81+
assert!(sz > 0);
8282
if idx.is_const() {
83-
let idx_usize = extras::as_uint_constant(idx).unwrap().to_usize().unwrap();
84-
if idx_usize < size {
85-
Some(term![Op::Update(idx_usize); tup.clone(), val.clone()])
86-
} else {
87-
Some(tup.clone())
88-
}
83+
Some(
84+
extras::as_uint_constant(idx)
85+
.and_then(|cidx| cidx.to_usize())
86+
.and_then(|u| {
87+
(u < sz)
88+
.then_some(term![Op::Update(u); tup.clone(), val.clone()])
89+
})
90+
.unwrap_or_else(|| tup.clone()),
91+
)
8992
} else {
9093
let mut updates =
91-
(0..size).map(|idx| term![Op::Update(idx); tup.clone(), val.clone()]);
94+
(0..sz).map(|idx| term![Op::Update(idx); tup.clone(), val.clone()]);
9295
let first = updates.next().unwrap();
93-
Some(key_sort.elems_iter().take(size).skip(1).zip(updates).fold(first, |acc, (idx_c, update)| {
96+
Some(key_sort.elems_iter().take(sz).skip(1).zip(updates).fold(first, |acc, (idx_c, update)| {
9497
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], update, acc]
9598
}))
9699
}
@@ -104,22 +107,23 @@ impl RewritePass for Linearizer {
104107
let idx = &cs[1];
105108
let val = &cs[2];
106109
let cond = &cs[3];
107-
if let Sort::Array(key_sort, _, size) = check(&orig.cs()[0]) {
108-
assert!(size > 0);
110+
if let Sort::Array(key_sort, _, sz) = check(&orig.cs()[0]) {
111+
assert!(sz > 0);
109112
if idx.is_const() {
110-
let idx_usize = extras::as_uint_constant(idx).unwrap().to_usize().unwrap();
111-
if idx_usize < size {
112-
Some(
113-
term![Op::Ite; cond.clone(), term![Op::Update(idx_usize); tup.clone(), val.clone()], tup.clone()],
114-
)
115-
} else {
116-
Some(tup.clone())
117-
}
113+
Some(
114+
extras::as_uint_constant(idx)
115+
.and_then(|cidx| cidx.to_usize())
116+
.and_then(|u| {
117+
(u < sz)
118+
.then_some(term![Op::Ite; cond.clone(), term![Op::Update(u); tup.clone(), val.clone()], tup.clone()])
119+
})
120+
.unwrap_or_else(|| tup.clone()),
121+
)
118122
} else {
119123
let mut updates =
120-
(0..size).map(|idx| term![Op::Update(idx); tup.clone(), val.clone()]);
124+
(0..sz).map(|idx| term![Op::Update(idx); tup.clone(), val.clone()]);
121125
let first = updates.next().unwrap();
122-
Some(key_sort.elems_iter().take(size).skip(1).zip(updates).fold(first, |acc, (idx_c, update)| {
126+
Some(key_sort.elems_iter().take(sz).skip(1).zip(updates).fold(first, |acc, (idx_c, update)| {
123127
term![Op::Ite; term![AND; term![Op::Eq; idx.clone(), idx_c], cond.clone()], update, acc]
124128
}))
125129
}

0 commit comments

Comments
 (0)