From 55a55dd218093cd170e7927585b92e31d8208f43 Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Wed, 24 Jan 2024 16:44:37 -0800 Subject: [PATCH] Optimizations to KZG commitment scheme (#300) * simplify kzg_verify_batch closure credit: storojs72 * parallel computation of polynomials credit: storojs72 * eliminate computation of the last commitment credit: storojs72 * update tests to account for the new behavior of the Prove method * simplify tests --- src/provider/hyperkzg.rs | 155 ++++++++++++++++++++++----------------- 1 file changed, 86 insertions(+), 69 deletions(-) diff --git a/src/provider/hyperkzg.rs b/src/provider/hyperkzg.rs index 1698f991..ccf3e031 100644 --- a/src/provider/hyperkzg.rs +++ b/src/provider/hyperkzg.rs @@ -18,6 +18,7 @@ use crate::{ evaluation::EvaluationEngineTrait, AbsorbInROTrait, Engine, ROTrait, TranscriptEngineTrait, TranscriptReprTrait, }, + zip_with, }; use core::{ marker::PhantomData, @@ -25,6 +26,7 @@ use core::{ }; use ff::Field; use halo2curves::bn256::{Fq as Bn256Fq, Fr as Bn256Fr, G1 as Bn256G1}; +use itertools::Itertools; use rand_core::OsRng; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -366,10 +368,10 @@ where ck: &CommitmentKey, _pk: &Self::ProverKey, transcript: &mut ::TE, - C: &Commitment, + _C: &Commitment, hat_P: &[E::Scalar], point: &[E::Scalar], - eval: &E::Scalar, + _eval: &E::Scalar, ) -> Result { let x: Vec = point.to_vec(); @@ -406,8 +408,7 @@ where E::CE::commit(ck, &h).comm.preprocessed() }; - let kzg_open_batch = |C: &[G1], - f: &[Vec], + let kzg_open_batch = |f: &[Vec], u: &[E::Scalar], transcript: &mut ::TE| -> (Vec>, Vec>) { @@ -447,18 +448,18 @@ where let k = f.len(); let t = u.len(); - assert!(C.len() == k); // The verifier needs f_i(u_j), so we compute them here // (V will compute B(u_j) itself) let mut v = vec![vec!(E::Scalar::ZERO; k); t]; - for i in 0..t { + v.par_iter_mut().enumerate().for_each(|(i, v_i)| { // for each point u - for (j, f_j) in f.iter().enumerate().take(k) { + v_i.par_iter_mut().zip_eq(f).for_each(|(v_ij, f)| { // for each poly f - v[i][j] = poly_eval(f_j, u[i]); // = f_j(u_i) - } - } + // for each poly f (except the last one - since it is constant) + *v_ij = poly_eval(f, u[i]); + }); + }); let q = Self::get_batch_challenge(&v, transcript); let B = kzg_compute_batch_polynomial(f, q); @@ -484,21 +485,18 @@ where assert_eq!(n, 1 << ell); // Below we assume that n is a power of two // Phase 1 -- create commitments com_1, ..., com_\ell + // We do not compute final Pi (and its commitment) as it is constant and equals to 'eval' + // also known to verifier, so can be derived on its side as well let mut polys: Vec> = Vec::new(); polys.push(hat_P.to_vec()); - for i in 0..ell { + for i in 0..ell - 1 { let Pi_len = polys[i].len() / 2; let mut Pi = vec![E::Scalar::ZERO; Pi_len]; #[allow(clippy::needless_range_loop)] - for j in 0..Pi_len { - Pi[j] = x[ell-i-1] * polys[i][2*j + 1] // Odd part of P^(i-1) - + (E::Scalar::ONE - x[ell-i-1]) * polys[i][2*j]; // Even part of P^(i-1) - } - - if i == ell - 1 && *eval != Pi[0] { - return Err(NovaError::UnSat); - } + Pi.par_iter_mut().enumerate().for_each(|(j, Pi_j)| { + *Pi_j = x[ell - i - 1] * (polys[i][2 * j + 1] - polys[i][2 * j]) + polys[i][2 * j]; + }); polys.push(Pi); } @@ -517,9 +515,7 @@ where let u = vec![r, -r, r * r]; // Phase 3 -- create response - let mut com_all = com.clone(); - com_all.insert(0, C.comm.preprocessed()); - let (w, v) = kzg_open_batch(&com_all, &polys, &u, transcript); + let (w, v) = kzg_open_batch(&polys, &u, transcript); Ok(EvaluationArgument { com, w, v }) } @@ -551,52 +547,70 @@ where let q = Self::get_batch_challenge(v, transcript); let q_powers = Self::batch_challenge_powers(q, k); // 1, q, q^2, ..., q^(k-1) - // Compute the commitment to the batched polynomial B(X) - let C_B = (::group(&C[0]) - + E::GE::vartime_multiscalar_mul(&q_powers[1..k], &C[1..k])) - .preprocessed(); - - // Compute the batched openings - // compute B(u_i) = v[i][0] + q*v[i][1] + ... + q^(t-1) * v[i][t-1] - let B_u = (0..t) - .map(|i| { - assert_eq!(q_powers.len(), v[i].len()); - q_powers.iter().zip(v[i].iter()).map(|(a, b)| *a * *b).sum() - }) - .collect::>(); - let d_0 = Self::verifier_second_challenge(W, transcript); - let d = [d_0, d_0 * d_0]; + let d_1 = d_0 * d_0; // Shorthand to convert from preprocessed G1 elements to non-preprocessed let from_ppG1 = |P: &G1| ::group(P); // Shorthand to convert from preprocessed G2 elements to non-preprocessed let from_ppG2 = |P: &G2| <::G2 as DlogGroup>::group(P); - assert!(t == 3); + assert_eq!(t, 3); + assert_eq!(W.len(), 3); // We write a special case for t=3, since this what is required for - // mlkzg. Following the paper directly, we must compute: + // hyperkzg. Following the paper directly, we must compute: // let L0 = C_B - vk.G * B_u[0] + W[0] * u[0]; // let L1 = C_B - vk.G * B_u[1] + W[1] * u[1]; // let L2 = C_B - vk.G * B_u[2] + W[2] * u[2]; // let R0 = -W[0]; // let R1 = -W[1]; // let R2 = -W[2]; - // let L = L0 + L1*d[0] + L2*d[1]; - // let R = R0 + R1*d[0] + R2*d[1]; + // let L = L0 + L1*d_0 + L2*d_1; + // let R = R0 + R1*d_0 + R2*d_1; // // We group terms to reduce the number of scalar mults (to seven): // In Rust, we could use MSMs for these, and speed up verification. - let L = from_ppG1(&C_B) * (E::Scalar::ONE + d[0] + d[1]) - - from_ppG1(&vk.G) * (B_u[0] + d[0] * B_u[1] + d[1] * B_u[2]) - + from_ppG1(&W[0]) * u[0] - + from_ppG1(&W[1]) * (u[1] * d[0]) - + from_ppG1(&W[2]) * (u[2] * d[1]); + // + // Note, that while computing L, the intermediate computation of C_B together with computing + // L0, L1, L2 can be replaced by single MSM of C with the powers of q multiplied by (1 + d_0 + d_1) + // with additionally concatenated inputs for scalars/bases. + + let q_power_multiplier = E::Scalar::ONE + d_0 + d_1; + + let q_powers_multiplied: Vec = q_powers + .par_iter() + .map(|q_power| *q_power * q_power_multiplier) + .collect(); + + // Compute the batched openings + // compute B(u_i) = v[i][0] + q*v[i][1] + ... + q^(t-1) * v[i][t-1] + let B_u = v + .into_par_iter() + .map(|v_i| zip_with!(iter, (q_powers, v_i), |a, b| *a * *b).sum()) + .collect::>(); + + let L = E::GE::vartime_multiscalar_mul( + &[ + &q_powers_multiplied[..k], + &[ + u[0], + (u[1] * d_0), + (u[2] * d_1), + -(B_u[0] + d_0 * B_u[1] + d_1 * B_u[2]), + ], + ] + .concat(), + &[ + &C[..k], + &[W[0].clone(), W[1].clone(), W[2].clone(), vk.G.clone()], + ] + .concat(), + ); let R0 = from_ppG1(&W[0]); let R1 = from_ppG1(&W[1]); let R2 = from_ppG1(&W[2]); - let R = R0 + R1 * d[0] + R2 * d[1]; + let R = R0 + R1 * d_0 + R2 * d_1; // Check that e(L, vk.H) == e(R, vk.tau_H) (::pairing(&L, &from_ppG2(&vk.H))) @@ -624,18 +638,15 @@ where if v.len() != 3 { return Err(NovaError::ProofVerifyError); } - if v[0].len() != ell + 1 || v[1].len() != ell + 1 || v[2].len() != ell + 1 { + if v[0].len() != ell || v[1].len() != ell || v[2].len() != ell { return Err(NovaError::ProofVerifyError); } let ypos = &v[0]; let yneg = &v[1]; - let Y = &v[2]; + let mut Y = v[2].to_vec(); + Y.push(*y); // Check consistency of (Y, ypos, yneg) - if Y[ell] != *y { - return Err(NovaError::ProofVerifyError); - } - let two = E::Scalar::from(2u64); for i in 0..ell { if two * r * Y[i + 1] @@ -685,52 +696,58 @@ mod tests { type Fr = ::Scalar; #[test] - fn test_mlkzg_eval() { + fn test_hyperkzg_eval() { // Test with poly(X1, X2) = 1 + X1 + X2 + X1*X2 let n = 4; let ck: CommitmentKey = CommitmentEngine::setup(b"test", n); - let (pk, _vk): (ProverKey, VerifierKey) = EvaluationEngine::setup(&ck); + let (pk, vk): (ProverKey, VerifierKey) = EvaluationEngine::setup(&ck); // poly is in eval. representation; evaluated at [(0,0), (0,1), (1,0), (1,1)] let poly = vec![Fr::from(1), Fr::from(2), Fr::from(2), Fr::from(4)]; let C = CommitmentEngine::commit(&ck, &poly); - let mut tr = Keccak256Transcript::new(b"TestEval"); - // Call the prover with a (point, eval) pair. The prover recomputes - // poly(point) = eval', and fails if eval' != eval + let test_inner = |point: Vec, eval: Fr| -> Result<(), NovaError> { + let mut tr = Keccak256Transcript::new(b"TestEval"); + let proof = EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).unwrap(); + let mut tr = Keccak256Transcript::new(b"TestEval"); + EvaluationEngine::verify(&vk, &mut tr, &C, &point, &eval, &proof) + }; + + // Call the prover with a (point, eval) pair. + // The prover does not recompute so it may produce a proof, but it should not verify let point = vec![Fr::from(0), Fr::from(0)]; let eval = Fr::ONE; - assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_ok()); + assert!(test_inner(point, eval).is_ok()); let point = vec![Fr::from(0), Fr::from(1)]; let eval = Fr::from(2); - assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_ok()); + assert!(test_inner(point, eval).is_ok()); let point = vec![Fr::from(1), Fr::from(1)]; let eval = Fr::from(4); - assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_ok()); + assert!(test_inner(point, eval).is_ok()); let point = vec![Fr::from(0), Fr::from(2)]; let eval = Fr::from(3); - assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_ok()); + assert!(test_inner(point, eval).is_ok()); let point = vec![Fr::from(2), Fr::from(2)]; let eval = Fr::from(9); - assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_ok()); + assert!(test_inner(point, eval).is_ok()); // Try a couple incorrect evaluations and expect failure let point = vec![Fr::from(2), Fr::from(2)]; let eval = Fr::from(50); - assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_err()); + assert!(test_inner(point, eval).is_err()); let point = vec![Fr::from(0), Fr::from(2)]; let eval = Fr::from(4); - assert!(EvaluationEngine::prove(&ck, &pk, &mut tr, &C, &poly, &point, &eval).is_err()); + assert!(test_inner(point, eval).is_err()); } #[test] - fn test_mlkzg() { + fn test_hyperkzg() { let n = 4; // poly = [1, 2, 1, 4] @@ -778,7 +795,7 @@ mod tests { // Change the proof and expect verification to fail let mut bad_proof = proof.clone(); - bad_proof.com[0] = (bad_proof.com[0] + bad_proof.com[1]).to_affine(); + bad_proof.com[0] = (bad_proof.com[0] + bad_proof.com[0]).to_affine(); let mut verifier_transcript2 = Keccak256Transcript::new(b"TestEval"); assert!(EvaluationEngine::verify( &vk, @@ -792,8 +809,8 @@ mod tests { } #[test] - fn test_mlkzg_more() { - // test the mlkzg prover and verifier with random instances (derived from a seed) + fn test_hyperkzg_more() { + // test the hyperkzg prover and verifier with random instances (derived from a seed) for ell in [4, 5, 6] { let mut rng = rand::rngs::StdRng::seed_from_u64(ell as u64);