From b633aea9ac580f01d75215a31166d7984bc4a834 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sun, 27 Aug 2023 21:21:12 +0900 Subject: [PATCH 01/65] Use number theoretic transform for multiplication --- src/biguint.rs | 1 + src/biguint/multiplication.rs | 12 +- src/biguint/ntt.rs | 693 ++++++++++++++++++++++++++++++++++ 3 files changed, 705 insertions(+), 1 deletion(-) create mode 100644 src/biguint/ntt.rs diff --git a/src/biguint.rs b/src/biguint.rs index 1554eb0f..373205af 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -22,6 +22,7 @@ mod bits; mod convert; mod iter; mod monty; +mod ntt; mod power; mod shift; diff --git a/src/biguint/multiplication.rs b/src/biguint/multiplication.rs index 4d7f1f21..a1a93b2e 100644 --- a/src/biguint/multiplication.rs +++ b/src/biguint/multiplication.rs @@ -13,6 +13,8 @@ use core::iter::Product; use core::ops::{Mul, MulAssign}; use num_traits::{CheckedMul, FromPrimitive, One, Zero}; +use super::ntt; + #[inline] pub(super) fn mac_with_carry( a: BigDigit, @@ -217,7 +219,7 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { } NoSign => (), } - } else { + } else if x.len() <= 512 { // Toom-3 multiplication: // // Toom-3 is like Karatsuba above, but dividing the inputs into three parts. @@ -346,6 +348,14 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { NoSign => {} } } + } else { + // Number-theoretic transform (NTT) multiplication: + // + // NTT multiplies two integers by computing the convolution of the arrays + // modulo a prime. Since the result may exceed the prime, we use three + // distinct primes and combine the results using the Chinese Remainder + // Theroem (CRT). + ntt::mac3_ntt(acc, b, c); } } diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs new file mode 100644 index 00000000..3d005181 --- /dev/null +++ b/src/biguint/ntt.rs @@ -0,0 +1,693 @@ +mod arith { + // Extended Euclid algorithm: + // (g, x, y) is a solution to ax + by = g, where g = gcd(a, b) + pub const fn egcd(mut a: i128, mut b: i128) -> (i128, i128, i128) { + if a < 0 { a = -a; } + if b < 0 { b = -b; } + assert!(a > 0 || b > 0); + let mut c = [1, 0, 0, 1]; // treat as a row-major 2x2 matrix + let (g, x, y) = loop { + if a == 0 { break (b, 0, 1); } + if b == 0 { break (a, 1, 0); } + if a < b { + let (q, r) = (b/a, b%a); + b = r; + c = [c[0], c[1] - q*c[0], c[2], c[3] - q*c[2]]; + } else { + let (q, r) = (a/b, a%b); + a = r; + c = [c[0] - q*c[1], c[1], c[2] - q*c[3], c[3]]; + } + }; + (g, c[0]*x + c[1]*y, c[2]*x + c[3]*y) + } + // Modular inverse: a^-1 mod modulus + // (m == 0 means m == 2^64) + pub const fn invmod(a: u64, modulus: u64) -> u64 { + let m = if modulus == 0 { 1i128 << 64 } else { modulus as i128 }; + let (g, mut x, _y) = egcd(a as i128, m); + assert!(g == 1); + if x < 0 { x += m; } + assert!(x > 0 && x < 1i128 << 64); + x as u64 + } +} + +struct Arith {} +impl Arith

{ + pub const FACTOR_TWO: usize = (P-1).trailing_zeros() as usize; + pub const FACTOR_THREE: usize = Self::factor_three(); + pub const FACTOR_FIVE: usize = Self::factor_five(); + pub const MAX_NTT_LEN: u64 = Self::max_ntt_len(); + pub const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P + pub const R2: u64 = ((Self::R as u128 * Self::R as u128) % P as u128) as u64; // R^2 mod P + pub const R3: u64 = ((Self::R2 as u128 * Self::R as u128) % P as u128) as u64; // R^3 mod P + pub const RNEG: u64 = P.wrapping_sub(Self::R); // -2^64 mod P + pub const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 + pub const ROOT: u64 = Self::ntt_root(); // MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN + pub const ROOTR: u64 = Self::mulmod_naive(Self::ROOT, Self::R); // ROOT * R mod P + const fn factor_three() -> usize { + let mut tmp = P-1; + let mut ans = 0; + while tmp % 3 == 0 { tmp /= 3; ans += 1; } + ans + } + const fn factor_five() -> usize { + let mut tmp = P-1; + let mut ans = 0; + while tmp % 5 == 0 { tmp /= 5; ans += 1; } + ans + } + const fn max_ntt_len() -> u64 { + let mut ans = 1u64 << Self::FACTOR_TWO; + let mut i = 0; + while i < Self::FACTOR_THREE { ans *= 3; i += 1; } + let mut i = 0; + while i < Self::FACTOR_FIVE { ans *= 5; i += 1; } + assert!(ans % 4050 == 0); + ans + } + const fn ntt_root() -> u64 { + let mut p = 1; + 'outer: loop { + /* ensure p is prime */ + p += 1; + let mut i = 2; + while i * i <= p { + if p % i == 0 { continue 'outer; } + i += 1; + } + let root = Self::powmod_naive(p, P/Self::MAX_NTT_LEN); + let mut j = 0; + while j <= Self::FACTOR_TWO { + let mut k = 0; + while k <= Self::FACTOR_THREE { + let mut l = 0; + while l <= Self::FACTOR_FIVE { + let p2 = Self::powmod_naive(2, j as u64); + let p3 = Self::powmod_naive(3, k as u64); + let p5 = Self::powmod_naive(5, l as u64); + let exponent = p2 * p3 * p5; + if exponent < Self::MAX_NTT_LEN && Self::powmod_naive(root, exponent) == 1 { + continue 'outer; + } + l += 1; + } + k += 1; + } + j += 1; + } + break root + } + } + // Computes a * b mod P + const fn mulmod_naive(a: u64, b: u64) -> u64 { + ((a as u128 * b as u128) % P as u128) as u64 + } + // Computes base^exponent mod P + const fn powmod_naive(base: u64, exponent: u64) -> u64 { + let mut cur = 1; + let mut pow = base as u128; + let mut p = exponent; + while p > 0 { + if p % 2 > 0 { + cur = (cur * pow) % P as u128; + } + p /= 2; + pow = (pow * pow) % P as u128; + } + cur as u64 + } + // Multiplication with Montgomery reduction: + // a * b * R^-1 mod P + pub const fn mmulmod(a: u64, b: u64) -> u64 { + let x = a as u128 * b as u128; + let m = (x as u64).wrapping_mul(Self::PINV); + let y = ((m as u128 * P as u128) >> 64) as u64; + let (out, overflow) = ((x >> 64) as u64).overflowing_sub(y); + if overflow { out.wrapping_add(P) } else { out } + } + pub const fn mmulmod_cond(a: u64, b: u64) -> u64 { + if INV { Self::mmulmod(a, b) } else { b } + } + // Fused-multiply-add with Montgomery reduction: + // a * b * R^-1 + c mod P + pub const fn mmuladdmod(a: u64, b: u64, c: u64) -> u64 { + let x = a as u128 * b as u128; + let hi = Self::addmod((x >> 64) as u64, c); + let m = (x as u64).wrapping_mul(Self::PINV); + let y = ((m as u128 * P as u128) >> 64) as u64; + let (out, overflow) = hi.overflowing_sub(y); + if overflow { out.wrapping_add(P) } else { out } + } + // Fused-multiply-sub with Montgomery reduction: + // a * b * R^-1 - c mod P + pub const fn mmulsubmod(a: u64, b: u64, c: u64) -> u64 { + let x = a as u128 * b as u128; + let hi = Self::submod((x >> 64) as u64, c); + let m = (x as u64).wrapping_mul(Self::PINV); + let y = ((m as u128 * P as u128) >> 64) as u64; + let (out, overflow) = hi.overflowing_sub(y); + if overflow { out.wrapping_add(P) } else { out } + } + // Computes base^exponent mod P with Montgomery reduction + pub const fn mpowmod(base: u64, exponent: u64) -> u64 { + let mut cur = Self::R; + let mut pow = base; + let mut p = exponent; + while p > 0 { + if p % 2 > 0 { + cur = Self::mmulmod(cur, pow); + } + p /= 2; + pow = Self::mmulmod(pow, pow); + } + cur as u64 + } + // Computes a + b mod P, output range [0, P) + pub const fn addmod(a: u64, b: u64) -> u64 { + Self::submod(a, P.wrapping_sub(b)) + } + // Computes a + b mod P, output range [0, 2^64) + pub const fn addmod64(a: u64, b: u64) -> u64 { + let (out, overflow) = a.overflowing_add(b); + if overflow { out.wrapping_sub(P) } else { out } + } + // Computes a + b mod P, selects addmod64 or addmod depending on INV + pub const fn addmodopt(a: u64, b: u64) -> u64 { + if INV { Self::addmod64(a, b) } else { Self::addmod(a, b) } + } + // Computes a - b mod P, output range [0, P) + pub const fn submod(a: u64, b: u64) -> u64 { + let (out, overflow) = a.overflowing_sub(b); + if overflow { out.wrapping_add(P) } else { out } + } +} + +struct NttKernelImpl; +impl NttKernelImpl { + pub const ROOTR: u64 = Arith::

::mpowmod(Arith::

::ROOTR, if INV { Arith::

::MAX_NTT_LEN - 1 } else { 1 }); + pub const U3: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/3); + pub const U4: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/4); + pub const U5: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/5); + pub const U6: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/6); + pub const C51: u64 = Self::c5().0; + pub const C52: u64 = Self::c5().1; + pub const C53: u64 = Self::c5().2; + pub const C54: u64 = Self::c5().3; + pub const C55: u64 = Self::c5().4; + const fn c5() -> (u64, u64, u64, u64, u64) { + let w = Self::U5; + let w2 = Arith::

::mpowmod(w, 2); + let w4 = Arith::

::mpowmod(w, 4); + let inv2 = Arith::

::mmulmod(Arith::

::R2, arith::invmod(2, P)); + let inv4 = Arith::

::mmulmod(Arith::

::R2, arith::invmod(4, P)); + let c51 = Arith::

::submod(Arith::

::submod(0, Arith::

::R), inv4); // (-1) + (-1) * 4^-1 mod P + let c52 = Arith::

::addmod(Arith::

::mmulmod(inv2, Arith::

::addmod(w, w4)), inv4); // 4^-1 * (2*w + 2*w^4 + 1) mod P + let c53 = Arith::

::mmulmod(inv2, Arith::

::submod(w, w4)); // 2^-1 * (w - w^4) mod P + let c54 = Arith::

::addmod(Arith::

::addmod(w, w2), inv2); // 2^-1 * (2*w + 2*w^2 + 1) mod P + let c55 = Arith::

::addmod(Arith::

::addmod(w2, w4), inv2); // 2^-1 * (2*w^2 + 2*w^4 + 1) mod P + (c51, c52, c53, c54, c55) + } +} + +impl NttKernelImpl { + unsafe fn apply<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64]) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { + let mut src = x.as_ptr(); + let mut dst = y.as_mut_ptr(); + let omega1 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/n as u64); + let (n1, n1s) = (n/2, n/2*s); + let mut w1p = Arith::

::R; + for _ in 0..n1 { + for _ in 0..s { + let a = *src.wrapping_add(0); + let b = *src.wrapping_add(n1s); + *dst.wrapping_add(0) = Arith::

::addmod(a, b); + *dst.wrapping_add(s) = Arith::

::mmulmod(w1p, Arith::

::submod(a, b)); + src = src.wrapping_add(1); + dst = dst.wrapping_add(1); + } + dst = dst.wrapping_add(s); + w1p = Arith::

::mmulmod(w1p, omega1); + } + (n/2, s*2, !eo, y, x) + } + unsafe fn apply_last<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64], mult: u64) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { + assert_eq!(n, 2); + let mut src = x.as_ptr(); + let mut dst = if eo { y.as_mut_ptr() } else { x.as_mut_ptr() }; + for _ in 0..s { + let a = *src.wrapping_add(0); + let b = *src.wrapping_add(s); + *dst.wrapping_add(0) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(a, b)); + *dst.wrapping_add(s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(a, b)); + src = src.wrapping_add(1); + dst = dst.wrapping_add(1); + } + if eo { (n/2, s*2, !eo, y, x) } else { (n/2, s*2, eo, x, y) } + } +} + +impl NttKernelImpl { + unsafe fn apply<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64]) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { + let mut src = x.as_ptr(); + let mut dst = y.as_mut_ptr(); + let omega1 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/n as u64); + let (n1, n1s) = (n/3, n/3*s); + let (mut w1p, mut w2p) = (Arith::

::R, Arith::

::R); + for _ in 0..n1 { + for _ in 0..s { + let a = *src.wrapping_add(0); + let b = *src.wrapping_add(n1s); + let c = *src.wrapping_add(2*n1s); + let kbmc = Arith::

::mmulmod(Self::U3, Arith::

::submod(b, c)); + *dst.wrapping_add(0) = Arith::

::addmod(a, Arith::

::addmod(b, c)); + *dst.wrapping_add(s) = Arith::

::mmulmod(w1p, Arith::

::addmod64(Arith::

::submod(a, c), kbmc)); + *dst.wrapping_add(2*s) = Arith::

::mmulmod(w2p, Arith::

::submod(Arith::

::submod(a, b), kbmc)); + src = src.wrapping_add(1); + dst = dst.wrapping_add(1); + } + dst = dst.wrapping_add(2*s); + w1p = Arith::

::mmulmod(w1p, omega1); + w2p = Arith::

::mmulmod(w1p, w1p); + } + (n/3, s*3, !eo, y, x) + } + unsafe fn apply_last<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64], mult: u64) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { + assert_eq!(n, 3); + let mut src = x.as_ptr(); + let mut dst = if eo { y.as_mut_ptr() } else { x.as_mut_ptr() }; + for _ in 0..s { + let a = *src.wrapping_add(0); + let b = *src.wrapping_add(s); + let c = *src.wrapping_add(2*s); + let kbmc = Arith::

::mmulmod(Self::U3, Arith::

::submod(b, c)); + *dst.wrapping_add(0) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(a, Arith::

::addmodopt::(b, c))); + *dst.wrapping_add(s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(Arith::

::submod(a, c), kbmc)); + *dst.wrapping_add(2*s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(Arith::

::submod(a, b), kbmc)); + src = src.wrapping_add(1); + dst = dst.wrapping_add(1); + } + if eo { (n/3, s*3, !eo, y, x) } else { (n/3, s*3, eo, x, y) } + } +} + +impl NttKernelImpl { + unsafe fn apply<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64]) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { + let mut src = x.as_ptr(); + let mut dst = y.as_mut_ptr(); + let omega1 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/n as u64); + let (n1, n1s) = (n/4, n/4*s); + let (mut w1p, mut w2p, mut w3p) = (Arith::

::R, Arith::

::R, P.wrapping_sub(Self::U4)); + for _ in 0..n1 { + for _ in 0..s { + let a = *src.wrapping_add(0); + let b = *src.wrapping_add(n1s); + let c = *src.wrapping_add(2*n1s); + let d = *src.wrapping_add(3*n1s); + let apc = Arith::

::addmod(a, c); + let amc = Arith::

::mmulmod(w1p, Arith::

::submod(a, c)); + let bpd = Arith::

::addmod(b, d); + let bmd = Arith::

::submod(b, d); + let jbmd = Arith::

::mmulmod(w3p, bmd); + *dst.wrapping_add(0) = Arith::

::addmod(apc, bpd); + *dst.wrapping_add(s) = Arith::

::submod(amc, jbmd); + *dst.wrapping_add(2*s) = Arith::

::mmulmod(w2p, Arith::

::submod(apc, bpd)); + *dst.wrapping_add(3*s) = Arith::

::mmulmod(w2p, Arith::

::addmod64(amc, jbmd)); + src = src.wrapping_add(1); + dst = dst.wrapping_add(1); + } + dst = dst.wrapping_add(3*s); + w1p = Arith::

::mmulmod(w1p, omega1); + w2p = Arith::

::mmulmod(w1p, w1p); + w3p = Arith::

::mmulmod(w3p, omega1); + } + (n/4, s*4, !eo, y, x) + } + unsafe fn apply_last<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64], mult: u64) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { + assert_eq!(n, 4); + let mut src = x.as_ptr(); + let mut dst = if eo { y.as_mut_ptr() } else { x.as_mut_ptr() }; + for _ in 0..s { + let a = *src.wrapping_add(0); + let b = *src.wrapping_add(s); + let c = *src.wrapping_add(2*s); + let d = *src.wrapping_add(3*s); + let apc = Arith::

::addmod(a, c); + let amc = Arith::

::submod(a, c); + let bpd = Arith::

::addmod(b, d); + let bmd = Arith::

::submod(b, d); + let jbmd = Arith::

::mmulmod(bmd, P.wrapping_sub(Self::U4)); + *dst.wrapping_add(0) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(apc, bpd)); + *dst.wrapping_add(s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(amc, jbmd)); + *dst.wrapping_add(2*s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(apc, bpd)); + *dst.wrapping_add(3*s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(amc, jbmd)); + src = src.wrapping_add(1); + dst = dst.wrapping_add(1); + } + if eo { (n/4, s*4, !eo, y, x) } else { (n/4, s*4, eo, x, y) } + } +} + +impl NttKernelImpl { + unsafe fn apply<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64]) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { + let mut src = x.as_ptr(); + let mut dst = y.as_mut_ptr(); + let omega1 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/n as u64); + let (n1, n1s) = (n/5, n/5*s); + let (mut w1p, mut w2p, mut w3p, mut w4p) = (Arith::

::R, Arith::

::RNEG, Arith::

::RNEG, Arith::

::R); + for _ in 0..n1 { + for _ in 0..s { + let a = *src.wrapping_add(0); + let b = *src.wrapping_add(n1s); + let c = *src.wrapping_add(2*n1s); + let d = *src.wrapping_add(3*n1s); + let e = *src.wrapping_add(4*n1s); + let t1 = Arith::

::addmod(b, e); + let t2 = Arith::

::addmod(c, d); + let t3 = Arith::

::submod(b, e); + let t4= Arith::

::submod(d, c); + let t5 = Arith::

::addmod(t1, t2); + let t6 = Arith::

::submod(t1, t2); + let t7 = Arith::

::addmod64(t3, t4); + let m1 = Arith::

::addmod(a, t5); + let m2 = Arith::

::mmulsubmod(P.wrapping_sub(Self::C51), t5, m1); + let m3 = Arith::

::mmulmod(Self::C52, t6); + let m4 = Arith::

::mmulmod(Self::C53, t7); + let m5 = Arith::

::mmulsubmod(Self::C54, t4, m4); + let m6 = Arith::

::mmulsubmod(P.wrapping_sub(Self::C55), t3, m4); + let s2 = Arith::

::submod(m3, m2); + let s4 = Arith::

::addmod64(m2, m3); + *dst.wrapping_add(0) = m1; + *dst.wrapping_add(s) = Arith::

::mmulmod(w1p, Arith::

::submod(s2, m5)); + *dst.wrapping_add(2*s) = Arith::

::mmulmod(w2p, Arith::

::addmod64(s4, m6)); + *dst.wrapping_add(3*s) = Arith::

::mmulmod(w3p, Arith::

::submod(s4, m6)); + *dst.wrapping_add(4*s) = Arith::

::mmulmod(w4p, Arith::

::addmod64(s2, m5)); + src = src.wrapping_add(1); + dst = dst.wrapping_add(1); + } + dst = dst.wrapping_add(4*s); + w1p = Arith::

::mmulmod(w1p, omega1); + w2p = Arith::

::mmulmod(w1p, P.wrapping_sub(w1p)); + w3p = Arith::

::mmulmod(w1p, w2p); + w4p = Arith::

::mmulmod(w2p, w2p); + } + (n/5, s*5, !eo, y, x) + } + unsafe fn apply_last<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64], mult: u64) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { + assert_eq!(n, 5); + let mut src = x.as_ptr(); + let mut dst = if eo { y.as_mut_ptr() } else { x.as_mut_ptr() }; + for _ in 0..s { + let a = *src.wrapping_add(0); + let b = *src.wrapping_add(s); + let c = *src.wrapping_add(2*s); + let d = *src.wrapping_add(3*s); + let e = *src.wrapping_add(4*s); + let t1 = Arith::

::addmod(b, e); + let t2 = Arith::

::addmod(c, d); + let t3 = Arith::

::submod(b, e); + let t4= Arith::

::submod(d, c); + let t5 = Arith::

::addmod(t1, t2); + let t6 = Arith::

::submod(t1, t2); + let t7 = Arith::

::addmod64(t3, t4); + let m1 = Arith::

::addmod(a, t5); + let m2 = Arith::

::mmuladdmod(Self::C51, t5, m1); + let m3 = Arith::

::mmulmod(Self::C52, t6); + let m4 = Arith::

::mmulmod(Self::C53, t7); + let m5 = Arith::

::mmulsubmod(Self::C54, t4, m4); + let m6 = Arith::

::mmulsubmod(P.wrapping_sub(Self::C55), t3, m4); + let s2 = Arith::

::addmod(m2, m3); + let s4 = Arith::

::submod(m2, m3); + *dst.wrapping_add(0) = Arith::

::mmulmod_cond::(mult, m1); + *dst.wrapping_add(s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(s2, m5)); + *dst.wrapping_add(2*s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(s4, m6)); + *dst.wrapping_add(3*s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(s4, m6)); + *dst.wrapping_add(4*s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(s2, m5)); + src = src.wrapping_add(1); + dst = dst.wrapping_add(1); + } + if eo { (n/5, s*5, !eo, y, x) } else { (n/5, s*5, eo, x, y) } + } +} + +impl NttKernelImpl { + unsafe fn apply<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64]) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { + let mut src = x.as_ptr(); + let mut dst = y.as_mut_ptr(); + let omega1 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/n as u64); + let (n1, n1s) = (n/6, n/6*s); + let (mut w1p, mut w2p, mut w3p, mut w4p, mut w5p) = (Arith::

::R, Arith::

::R, Arith::

::R, Arith::

::R, Arith::

::R); + for _ in 0..n1 { + for _ in 0..s { + let mut a = *src.wrapping_add(0); + let mut b = *src.wrapping_add(n1s); + let mut c = *src.wrapping_add(2*n1s); + let mut d = *src.wrapping_add(3*n1s); + let mut e = *src.wrapping_add(4*n1s); + let mut f = *src.wrapping_add(5*n1s); + (a, d) = (Arith::

::addmod(a, d), Arith::

::submod(a, d)); + (b, e) = (Arith::

::addmod(b, e), Arith::

::submod(b, e)); + (c, f) = (Arith::

::addmod(c, f), Arith::

::submod(c, f)); + let lbmc = Arith::

::mmulmod(Self::U6, Arith::

::submod(b, c)); + *dst.wrapping_add(0) = Arith::

::addmod(a, Arith::

::addmod(b, c)); + *dst.wrapping_add(2*s) = Arith::

::mmulmod(w2p, Arith::

::addmod64(Arith::

::submod(a, b), lbmc)); + *dst.wrapping_add(4*s) = Arith::

::mmulmod(w4p, Arith::

::submod(Arith::

::submod(a, c), lbmc)); + let mlepf = Arith::

::mmulmod(P.wrapping_sub(Self::U6), Arith::

::addmod64(e, f)); + *dst.wrapping_add(1*s) = Arith::

::mmulmod(w1p, Arith::

::submod(Arith::

::submod(d, f), mlepf)); + *dst.wrapping_add(3*s) = Arith::

::mmulmod(w3p, Arith::

::submod(d, Arith::

::submod(e, f))); + *dst.wrapping_add(5*s) = Arith::

::mmulmod(w5p, Arith::

::addmod64(Arith::

::addmod64(d, mlepf), e)); + src = src.wrapping_add(1); + dst = dst.wrapping_add(1); + } + dst = dst.wrapping_add(5*s); + w1p = Arith::

::mmulmod(w1p, omega1); + w2p = Arith::

::mmulmod(w1p, w1p); + w3p = Arith::

::mmulmod(w1p, w2p); + w4p = Arith::

::mmulmod(w2p, w2p); + w5p = Arith::

::mmulmod(w2p, w3p); + } + (n/6, s*6, !eo, y, x) + } + unsafe fn apply_last<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64], mult: u64) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { + assert_eq!(n, 6); + let mut src = x.as_ptr(); + let mut dst = if eo { y.as_mut_ptr() } else { x.as_mut_ptr() }; + for _ in 0..s { + let mut a = *src.wrapping_add(0); + let mut b = *src.wrapping_add(s); + let mut c = *src.wrapping_add(2*s); + let mut d = *src.wrapping_add(3*s); + let mut e = *src.wrapping_add(4*s); + let mut f = *src.wrapping_add(5*s); + (a, d) = (Arith::

::addmod(a, d), Arith::

::submod(a, d)); + (b, e) = (Arith::

::addmod(b, e), Arith::

::submod(b, e)); + (c, f) = (Arith::

::addmod(c, f), Arith::

::submod(c, f)); + let lbmc = Arith::

::mmulmod(Self::U6, Arith::

::submod(b, c)); + *dst.wrapping_add(0) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(a, Arith::

::addmodopt::(b, c))); + *dst.wrapping_add(2*s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(Arith::

::submod(a, b), lbmc)); + *dst.wrapping_add(4*s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(Arith::

::submod(a, c), lbmc)); + let lepf = Arith::

::mmulmod(Self::U6, Arith::

::addmod64(e, f)); + *dst.wrapping_add(s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(Arith::

::submod(d, f), lepf)); + *dst.wrapping_add(3*s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(d, Arith::

::submod(e, f))); + *dst.wrapping_add(5*s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(Arith::

::submod(d, lepf), e)); + src = src.wrapping_add(1); + dst = dst.wrapping_add(1); + } + if eo { (n/6, s*6, !eo, y, x) } else { (n/6, s*6, eo, x, y) } + } +} + +fn ntt_stockham(input: &mut [u64], buf: &mut [u64]) { + let (mut n, mut s, mut eo, mut x, mut y) = (input.len(), 1, false, input, buf); + assert!(Arith::

::MAX_NTT_LEN % n as u64 == 0); + let inv_p2 = Arith::

::mmulmod(Arith::

::R3, Arith::

::submod(0, (P-1)/n as u64)); + if n == 1 { + x[0] = Arith::

::mmulmod_cond::(inv_p2, x[0]); + return; + } + let (mut cnt6, mut cnt5, mut cnt4, mut cnt3, mut cnt2) = (0, 0, 0, 0, 0); + let mut tmp = n; + while tmp % 6 == 0 { tmp /= 6; cnt6 += 1; } + while tmp % 5 == 0 { tmp /= 5; cnt5 += 1; } + while tmp % 4 == 0 { tmp /= 4; cnt4 += 1; } + while tmp % 3 == 0 { tmp /= 3; cnt3 += 1; } + while tmp % 2 == 0 { tmp /= 2; cnt2 += 1; } + while cnt6 > 0 && cnt2 > 0 { cnt6 -= 1; cnt2 -= 1; cnt4 += 1; cnt3 += 1; } + unsafe { + while cnt2 > 0 { + (n, s, eo, x, y) = if n > 2 { + NttKernelImpl::::apply(n, s, eo, x, y) + } else { + NttKernelImpl::::apply_last(n, s, eo, x, y, inv_p2) + }; + cnt2 -= 1; + } + while cnt3 > 0 { + (n, s, eo, x, y) = if n > 3 { + NttKernelImpl::::apply(n, s, eo, x, y) + } else { + NttKernelImpl::::apply_last(n, s, eo, x, y, inv_p2) + }; + cnt3 -= 1; + } + while cnt4 > 0 { + (n, s, eo, x, y) = if n > 4 { + NttKernelImpl::::apply(n, s, eo, x, y) + } else { + NttKernelImpl::::apply_last(n, s, eo, x, y, inv_p2) + }; + cnt4 -= 1; + } + while cnt5 > 0 { + (n, s, eo, x, y) = if n > 5 { + NttKernelImpl::::apply(n, s, eo, x, y) + } else { + NttKernelImpl::::apply_last(n, s, eo, x, y, inv_p2) + }; + cnt5 -= 1; + } + while cnt6 > 0 { + (n, s, eo, x, y) = if n > 6 { + NttKernelImpl::::apply(n, s, eo, x, y) + } else { + NttKernelImpl::::apply_last(n, s, eo, x, y, inv_p2) + }; + cnt6 -= 1; + } + } +} + +fn plan_ntt(min_len: usize) -> (usize, usize) { + assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); + let (mut len_max, mut len_max_cost) = (0usize, usize::MAX); + let mut len5 = 1; + for _ in 0..Arith::

::FACTOR_FIVE+1 { + let mut len35 = len5; + for _ in 0..Arith::

::FACTOR_THREE+1 { + let mut len = len35; + let mut i = 0; + while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } + if len >= min_len && len < len_max_cost { + let (mut tmp, mut cost) = (len, 0); + while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len); } + while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len + len/5); } + while tmp % 4 == 0 { (tmp, cost) = (tmp/4, cost + len); } + while tmp % 3 == 0 { (tmp, cost) = (tmp/3, cost + len); } + while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); } + if cost < len_max_cost { (len_max, len_max_cost) = (len, cost); } + } + len35 *= 3; + } + len5 *= 5; + } + (len_max, len_max_cost) +} + +// Performs (cyclic) integer convolution modulo P using NTT. +// Modifies the three buffers in-place. +// The output is saved in the slice `x`. +// The three slices must have the same length which divides `Arith::

::MAX_NTT_LEN`. +fn conv(x: &mut [u64], y: &mut [u64], buf: &mut [u64]) { + assert!(x.len() > 0 && x.len() == y.len() && y.len() == buf.len()); + ntt_stockham::(x, buf); + ntt_stockham::(y, buf); + for i in 0..x.len() { x[i] = Arith::

::mmulmod(x[i], y[i]); } + ntt_stockham::(x, buf); +} + +//////////////////////////////////////////////////////////////////////////////// + +use core::cmp::max; +use crate::big_digit::BigDigit; + +const P1: u64 = 10237243632176332801; // Max NTT length = 2^24 * 3^20 * 5^2 = 1462463376025190400 +const P2: u64 = 13649658176235110401; // Max NTT length = 2^26 * 3^19 * 5^2 = 1949951168033587200 +const P3: u64 = 14259017916245606401; // Max NTT length = 2^22 * 3^21 * 5^2 = 1096847532018892800 + +const P1INV_R_MOD_P2: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod(P1, P2)); +const P1P2INV_R_MOD_P3: u64 = Arith::::mmulmod( + Arith::::R3, + Arith::::mmulmod( + arith::invmod(P1, P3), + arith::invmod(P2, P3) + ) +); +const P1_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, P1); +const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; +const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; + +#[cfg(u64_digit)] +#[allow(clippy::many_single_char_names)] +pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { + let min_len = acc.len(); + let len_max_1 = plan_ntt::(min_len).0; + let len_max_2 = plan_ntt::(min_len).0; + let len_max_3 = plan_ntt::(min_len).0; + let len_max = max(len_max_1, max(len_max_2, len_max_3)); + let mut x = vec![0u64; len_max_1]; + let mut y = vec![0u64; len_max_2]; + let mut z = vec![0u64; len_max_3]; + let mut u = vec![0u64; len_max]; + let mut v = vec![0u64; len_max]; + + /* convolution with modulo P1 */ + for i in 0..b.len() { x[i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } + for i in 0..c.len() { u[i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } + u[c.len()..len_max_1].fill(0u64); + conv::(&mut x, &mut u[..len_max_1], &mut v[..len_max_1]); + + /* convolution with modulo P2 */ + for i in 0..b.len() { y[i] = if b[i] >= P2 { b[i] - P2 } else { b[i] }; } + for i in 0..c.len() { u[i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } + u[c.len()..len_max_2].fill(0u64); + conv::(&mut y, &mut u[..len_max_2], &mut v[..len_max_2]); + + /* convolution with modulo P3 */ + for i in 0..b.len() { z[i] = if b[i] >= P3 { b[i] - P3 } else { b[i] }; } + for i in 0..c.len() { u[i] = if c[i] >= P3 { c[i] - P3 } else { c[i] }; } + u[c.len()..len_max_3].fill(0u64); + conv::(&mut z, &mut u[..len_max_3], &mut v[..len_max_3]); + + /* merge the result in {x, y, z} into acc (process carry along the way) */ + let mut carry: u128 = 0; + for i in 0..min_len-1 { + let (a, b, c) = (x[i], y[i], z[i]); + // We need to solve the following system of linear congruences: + // x === a mod P1, + // x === b mod P2, + // x === c mod P3. + // The first two equations are equivalent to + // x === a + P1 * (U * (b-a) mod P2) mod P1P2, + // where U is the solution to + // P1 * U === 1 mod P2. + let bma = Arith::::submod(b, a); + let u = Arith::::mmulmod(bma, P1INV_R_MOD_P2); + let v = a as u128 + P1 as u128 * u as u128; + let v_mod_p3 = Arith::::addmod(a, Arith::::mmulmod(P1_R_MOD_P3, u)); + // Now we have reduced the congruences into two: + // x === v mod P1P2, + // x === c mod P3. + // The solution is + // x === v + P1P2 * (V * (c-v) mod P3) mod P1P2P3, + // where V is the solution to + // P1P2 * V === 1 mod P3. + let cmv = Arith::::submod(c, v_mod_p3); + let vcmv = Arith::::mmulmod(cmv, P1P2INV_R_MOD_P3); + let (out_01, overflow) = carry.overflowing_add(v + P1P2_LO as u128 * vcmv as u128); + let out_0 = out_01 as u64; + let out_12 = P1P2_HI as u128 * vcmv as u128 + (out_01 >> 64) + + if overflow { 1u128 << 64 } else { 0 }; + let out_1 = out_12 as u64; + let out_2 = (out_12 >> 64) as u64; + + let (v, overflow) = acc[i].overflowing_add(out_0); + acc[i] = v; + carry = out_1 as u128 + ((out_2 as u128) << 64) + if overflow { 1 } else { 0 }; + } + acc[min_len-1] += carry as u64; +} + +#[cfg(not(u64_digit))] +pub fn mac3_ntt(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { +} \ No newline at end of file From 98eba1c7e6d0c69fdd13a8944ce47d15f4e3e75b Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sun, 27 Aug 2023 23:37:11 +0900 Subject: [PATCH 02/65] Use 2 primes to multiply short arrays Previously 3 primes were used, which was suboptimal in terms of speed. Currently, the threshold for switching from 2 to 3 primes is 2^38. --- src/biguint/ntt.rs | 134 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 116 insertions(+), 18 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 3d005181..c79a1a85 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -598,7 +598,7 @@ fn conv(x: &mut [u64], y: &mut [u64], buf: &mut [u64]) { //////////////////////////////////////////////////////////////////////////////// -use core::cmp::max; +use core::cmp::{min, max}; use crate::big_digit::BigDigit; const P1: u64 = 10237243632176332801; // Max NTT length = 2^24 * 3^20 * 5^2 = 1462463376025190400 @@ -617,10 +617,65 @@ const P1_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, P1); const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; -#[cfg(u64_digit)] #[allow(clippy::many_single_char_names)] -pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { - let min_len = acc.len(); +fn mac3_ntt_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { + let min_len = b.len() + c.len(); + let len_max_1 = plan_ntt::(min_len).0; + let len_max_2 = plan_ntt::(min_len).0; + let len_max = max(len_max_1, len_max_2); + let mut x = vec![0u64; len_max_1]; + let mut y = vec![0u64; len_max_2]; + let mut r = vec![0u64; len_max]; + let mut s = vec![0u64; len_max]; + + /* convolution with modulo P1 */ + for i in 0..b.len() { x[i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } + for i in 0..c.len() { r[i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } + r[c.len()..len_max_1].fill(0u64); + conv::(&mut x, &mut r[..len_max_1], &mut s[..len_max_1]); + + /* convolution with modulo P2 */ + for i in 0..b.len() { y[i] = if b[i] >= P2 { b[i] - P2 } else { b[i] }; } + for i in 0..c.len() { r[i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } + r[c.len()..len_max_2].fill(0u64); + conv::(&mut y, &mut r[..len_max_2], &mut s[..len_max_2]); + + /* merge the results in {x, y} into r (process carry along the way) */ + let mask = (1u64 << bits) - 1; + let mut carry: u128 = 0; + let (mut j, mut p) = (0usize, 0u64); + for i in 0..min_len { + /* extract the convolution result */ + let (a, b) = (x[i], y[i]); + let bma = Arith::::submod(b, a); + let u = Arith::::mmulmod(bma, P1INV_R_MOD_P2); + let v = a as u128 + P1 as u128 * u as u128 + carry; + carry = v >> bits; + + /* write to r */ + let out = (v as u64) & mask; + r[j] = (r[j] & ((1u64 << p) - 1)) | (out << p); + p += bits; + if p >= 64 { + (j, p) = (j+1, p-64); + r[j] = out >> (bits - p); + } + } + + /* add r to acc */ + let mut carry: u64 = 0; + for i in 0..min(acc.len(), j+1) { + let w = r[i]; + let (v, overflow1) = acc[i].overflowing_add(w); + let (v, overflow2) = v.overflowing_add(carry); + acc[i] = v; + carry = if overflow1 || overflow2 { 1 } else { 0 }; + } +} + +#[allow(clippy::many_single_char_names)] +fn mac3_ntt_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { + let min_len = b.len() + c.len(); let len_max_1 = plan_ntt::(min_len).0; let len_max_2 = plan_ntt::(min_len).0; let len_max_3 = plan_ntt::(min_len).0; @@ -628,30 +683,30 @@ pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { let mut x = vec![0u64; len_max_1]; let mut y = vec![0u64; len_max_2]; let mut z = vec![0u64; len_max_3]; - let mut u = vec![0u64; len_max]; - let mut v = vec![0u64; len_max]; + let mut r = vec![0u64; len_max]; + let mut s = vec![0u64; len_max]; /* convolution with modulo P1 */ for i in 0..b.len() { x[i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } - for i in 0..c.len() { u[i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } - u[c.len()..len_max_1].fill(0u64); - conv::(&mut x, &mut u[..len_max_1], &mut v[..len_max_1]); + for i in 0..c.len() { r[i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } + r[c.len()..len_max_1].fill(0u64); + conv::(&mut x, &mut r[..len_max_1], &mut s[..len_max_1]); /* convolution with modulo P2 */ for i in 0..b.len() { y[i] = if b[i] >= P2 { b[i] - P2 } else { b[i] }; } - for i in 0..c.len() { u[i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } - u[c.len()..len_max_2].fill(0u64); - conv::(&mut y, &mut u[..len_max_2], &mut v[..len_max_2]); + for i in 0..c.len() { r[i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } + r[c.len()..len_max_2].fill(0u64); + conv::(&mut y, &mut r[..len_max_2], &mut s[..len_max_2]); /* convolution with modulo P3 */ for i in 0..b.len() { z[i] = if b[i] >= P3 { b[i] - P3 } else { b[i] }; } - for i in 0..c.len() { u[i] = if c[i] >= P3 { c[i] - P3 } else { c[i] }; } - u[c.len()..len_max_3].fill(0u64); - conv::(&mut z, &mut u[..len_max_3], &mut v[..len_max_3]); + for i in 0..c.len() { r[i] = if c[i] >= P3 { c[i] - P3 } else { c[i] }; } + r[c.len()..len_max_3].fill(0u64); + conv::(&mut z, &mut r[..len_max_3], &mut s[..len_max_3]); - /* merge the result in {x, y, z} into acc (process carry along the way) */ + /* merge the results in {x, y, z} into acc (process carry along the way) */ let mut carry: u128 = 0; - for i in 0..min_len-1 { + for i in 0..min_len { let (a, b, c) = (x[i], y[i], z[i]); // We need to solve the following system of linear congruences: // x === a mod P1, @@ -685,9 +740,52 @@ pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { acc[i] = v; carry = out_1 as u128 + ((out_2 as u128) << 64) + if overflow { 1 } else { 0 }; } - acc[min_len-1] += carry as u64; + let mut carry = carry as u64; + for i in min_len..acc.len() { + let (v, overflow) = acc[i].overflowing_add(carry); + acc[i] = v; + carry = if overflow { 1 } else { 0 }; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +#[cfg(u64_digit)] +#[allow(clippy::many_single_char_names)] +pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { + let max_cnt = max(b.len(), c.len()) as u64; + let mut bits = 0u64; + while 1u64 << (2*bits) < max_cnt { bits += 1; } + bits = 63 - bits; + if bits >= 44 { + /* can pack more effective bits per u64 with two primes than with three primes */ + fn pack_into(src: &[u64], dst: &mut [u64], bits: u64) -> usize { + let (mut j, mut p) = (0usize, 0u64); + for i in 0..src.len() { + let mut k = 0; + while k < 64 { + let bits_this_time = min(64 - k, bits - p); + dst[j] = (dst[j] & ((1u64 << p) - 1)) | (((src[i] >> k) & ((1u64 << bits_this_time) - 1)) << p); + k += bits_this_time; + p += bits_this_time; + if p == bits { (j, p) = (j+1, 0); } + } + } + if p == 0 { j } else { j+1 } + } + let mut b2 = vec![0u64; ((64 * b.len() as u64 + bits - 1) / bits) as usize]; + let mut c2 = vec![0u64; ((64 * c.len() as u64 + bits - 1) / bits) as usize]; + let b2_len = pack_into(b, &mut b2, bits); + let c2_len = pack_into(c, &mut c2, bits); + mac3_ntt_two_primes(acc, &b2[..b2_len], &c2[..c2_len], bits); + } else { + /* can pack at most 21 effective bits per u64, which is worse than + 64/3 = 21.3333.. effective bits per u64 achieved with three primes */ + mac3_ntt_three_primes(acc, b, c); + } } #[cfg(not(u64_digit))] pub fn mac3_ntt(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { + unimplemented!("Please enable u64_digit option until we implement u32 support"); } \ No newline at end of file From 5d8b725fada8b2050b0cfbbcfb11d694050065fb Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 28 Aug 2023 11:10:00 +0900 Subject: [PATCH 03/65] Speed up unbalanced multiplication (1) Despite the simple implementation with obvious inefficiencies (e.g., not reusing the NTT of the shorter array), this leads to speed gains in multiple benchmarks, although there is a small regression in others. --- src/biguint/multiplication.rs | 2 +- src/biguint/ntt.rs | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/biguint/multiplication.rs b/src/biguint/multiplication.rs index a1a93b2e..4bcbe2df 100644 --- a/src/biguint/multiplication.rs +++ b/src/biguint/multiplication.rs @@ -355,7 +355,7 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { // modulo a prime. Since the result may exceed the prime, we use three // distinct primes and combine the results using the Chinese Remainder // Theroem (CRT). - ntt::mac3_ntt(acc, b, c); + ntt::mac3_ntt(acc, b, c, true); } } diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index c79a1a85..b4cbfd5f 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -752,7 +752,39 @@ fn mac3_ntt_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { #[cfg(u64_digit)] #[allow(clippy::many_single_char_names)] -pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { +pub fn mac3_ntt(acc: &mut [BigDigit], bb: &[BigDigit], cc: &[BigDigit], split_unbalanced: bool) { + let (b, c) = if bb.len() < cc.len() { (bb, cc) } else { (cc, bb) }; + if split_unbalanced && b.len() * 2 <= c.len() { + /* special handling for unbalanced multiplication: + we reduce it to about `c.len()/b.len()` balanced multiplications */ + let mut i = 0usize; + let mut carry = 0u64; + while i < c.len() { + let j = min(i + b.len(), c.len()); + let k = j + b.len(); + let tmp = acc[k]; + acc[k] = 0; + mac3_ntt(&mut acc[i..k+1], b, &c[i..j], false); + let mut l = j; + while carry > 0 && l < k { + let (v, overflow) = acc[l].overflowing_add(carry); + acc[l] = v; + carry = if overflow { 1 } else { 0 }; + l += 1; + } + i = j; + carry += tmp; + } + i += b.len(); + while i < acc.len() { + let (v, overflow) = acc[i].overflowing_add(carry); + acc[i] = v; + carry = if overflow { 1 } else { 0 }; + i += 1; + } + return; + } + let max_cnt = max(b.len(), c.len()) as u64; let mut bits = 0u64; while 1u64 << (2*bits) < max_cnt { bits += 1; } From 5abc87940405800cd3ade14de01972e5a8bc1e31 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 28 Aug 2023 12:30:32 +0900 Subject: [PATCH 04/65] Support 32bit BigDigit --- src/biguint/multiplication.rs | 2 +- src/biguint/ntt.rs | 43 +++++++++++++++++++++++++++++------ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/src/biguint/multiplication.rs b/src/biguint/multiplication.rs index 4bcbe2df..a1a93b2e 100644 --- a/src/biguint/multiplication.rs +++ b/src/biguint/multiplication.rs @@ -355,7 +355,7 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { // modulo a prime. Since the result may exceed the prime, we use three // distinct primes and combine the results using the Chinese Remainder // Theroem (CRT). - ntt::mac3_ntt(acc, b, c, true); + ntt::mac3_ntt(acc, b, c); } } diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index b4cbfd5f..c24e2179 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -748,11 +748,8 @@ fn mac3_ntt_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { } } -//////////////////////////////////////////////////////////////////////////////// - -#[cfg(u64_digit)] #[allow(clippy::many_single_char_names)] -pub fn mac3_ntt(acc: &mut [BigDigit], bb: &[BigDigit], cc: &[BigDigit], split_unbalanced: bool) { +fn mac3_ntt_u64(acc: &mut [u64], bb: &[u64], cc: &[u64], split_unbalanced: bool) { let (b, c) = if bb.len() < cc.len() { (bb, cc) } else { (cc, bb) }; if split_unbalanced && b.len() * 2 <= c.len() { /* special handling for unbalanced multiplication: @@ -764,7 +761,7 @@ pub fn mac3_ntt(acc: &mut [BigDigit], bb: &[BigDigit], cc: &[BigDigit], split_un let k = j + b.len(); let tmp = acc[k]; acc[k] = 0; - mac3_ntt(&mut acc[i..k+1], b, &c[i..j], false); + mac3_ntt_u64(&mut acc[i..k+1], b, &c[i..j], false); let mut l = j; while carry > 0 && l < k { let (v, overflow) = acc[l].overflowing_add(carry); @@ -817,7 +814,39 @@ pub fn mac3_ntt(acc: &mut [BigDigit], bb: &[BigDigit], cc: &[BigDigit], split_un } } +//////////////////////////////////////////////////////////////////////////////// + +#[cfg(u64_digit)] +pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { + mac3_ntt_u64(acc, b, c, true); +} + #[cfg(not(u64_digit))] -pub fn mac3_ntt(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { - unimplemented!("Please enable u64_digit option until we implement u32 support"); +pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { + fn bigdigit_to_u64(src: &[BigDigit]) -> crate::biguint::Vec:: { + let mut out = vec![0u64; (src.len() + 1) / 2]; + for i in 0..src.len()/2 { + out[i] = (src[2*i] as u64) | ((src[2*i+1] as u64) << 32); + } + if src.len() % 2 == 1 { + out[src.len()/2] = src[src.len()-1] as u64; + } + out + } + fn u64_to_bigdigit(src: &[u64], dst: &mut [BigDigit]) { + for i in 0..dst.len()/2 { + dst[2*i] = src[i] as BigDigit; + dst[2*i+1] = (src[i] >> 32) as BigDigit; + } + if dst.len() % 2 == 1 { + dst[dst.len()-1] = src[src.len()-1] as BigDigit; + } + } + + /* convert to u64 => process => convert back to BigDigit (u32) */ + let mut acc_u64 = bigdigit_to_u64(acc); + let b_u64 = bigdigit_to_u64(b); + let c_u64 = bigdigit_to_u64(c); + mac3_ntt_u64(&mut acc_u64, &b_u64, &c_u64, true); + u64_to_bigdigit(&acc_u64, acc); } \ No newline at end of file From 5d2bdd518c5246e224350dbc63cff4a147b33368 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 28 Aug 2023 14:36:35 +0900 Subject: [PATCH 05/65] Fix clippy warnings --- src/biguint/multiplication.rs | 2 +- src/biguint/ntt.rs | 56 +++++++++++++++++++---------------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/src/biguint/multiplication.rs b/src/biguint/multiplication.rs index a1a93b2e..479d5cb7 100644 --- a/src/biguint/multiplication.rs +++ b/src/biguint/multiplication.rs @@ -355,7 +355,7 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { // modulo a prime. Since the result may exceed the prime, we use three // distinct primes and combine the results using the Chinese Remainder // Theroem (CRT). - ntt::mac3_ntt(acc, b, c); + ntt::mac3(acc, b, c); } } diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index c24e2179..4501118d 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -1,3 +1,10 @@ +#![allow(clippy::cast_sign_loss)] +#![allow(clippy::cast_lossless)] +#![allow(clippy::cast_possible_truncation)] +#![allow(clippy::many_single_char_names)] +#![allow(clippy::needless_range_loop)] +#![allow(clippy::similar_names)] + mod arith { // Extended Euclid algorithm: // (g, x, y) is a solution to ax + by = g, where g = gcd(a, b) @@ -162,7 +169,7 @@ impl Arith

{ p /= 2; pow = Self::mmulmod(pow, pow); } - cur as u64 + cur } // Computes a + b mod P, output range [0, P) pub const fn addmod(a: u64, b: u64) -> u64 { @@ -454,7 +461,7 @@ impl NttKernelImpl { *dst.wrapping_add(2*s) = Arith::

::mmulmod(w2p, Arith::

::addmod64(Arith::

::submod(a, b), lbmc)); *dst.wrapping_add(4*s) = Arith::

::mmulmod(w4p, Arith::

::submod(Arith::

::submod(a, c), lbmc)); let mlepf = Arith::

::mmulmod(P.wrapping_sub(Self::U6), Arith::

::addmod64(e, f)); - *dst.wrapping_add(1*s) = Arith::

::mmulmod(w1p, Arith::

::submod(Arith::

::submod(d, f), mlepf)); + *dst.wrapping_add(s) = Arith::

::mmulmod(w1p, Arith::

::submod(Arith::

::submod(d, f), mlepf)); *dst.wrapping_add(3*s) = Arith::

::mmulmod(w3p, Arith::

::submod(d, Arith::

::submod(e, f))); *dst.wrapping_add(5*s) = Arith::

::mmulmod(w5p, Arith::

::addmod64(Arith::

::addmod64(d, mlepf), e)); src = src.wrapping_add(1); @@ -562,9 +569,9 @@ fn plan_ntt(min_len: usize) -> (usize, usize) { assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); let (mut len_max, mut len_max_cost) = (0usize, usize::MAX); let mut len5 = 1; - for _ in 0..Arith::

::FACTOR_FIVE+1 { + for _ in 0..=Arith::

::FACTOR_FIVE { let mut len35 = len5; - for _ in 0..Arith::

::FACTOR_THREE+1 { + for _ in 0..=Arith::

::FACTOR_THREE { let mut len = len35; let mut i = 0; while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } @@ -589,7 +596,7 @@ fn plan_ntt(min_len: usize) -> (usize, usize) { // The output is saved in the slice `x`. // The three slices must have the same length which divides `Arith::

::MAX_NTT_LEN`. fn conv(x: &mut [u64], y: &mut [u64], buf: &mut [u64]) { - assert!(x.len() > 0 && x.len() == y.len() && y.len() == buf.len()); + assert!(!x.is_empty() && x.len() == y.len() && y.len() == buf.len()); ntt_stockham::(x, buf); ntt_stockham::(y, buf); for i in 0..x.len() { x[i] = Arith::

::mmulmod(x[i], y[i]); } @@ -601,9 +608,9 @@ fn conv(x: &mut [u64], y: &mut [u64], buf: &mut [u64]) { use core::cmp::{min, max}; use crate::big_digit::BigDigit; -const P1: u64 = 10237243632176332801; // Max NTT length = 2^24 * 3^20 * 5^2 = 1462463376025190400 -const P2: u64 = 13649658176235110401; // Max NTT length = 2^26 * 3^19 * 5^2 = 1949951168033587200 -const P3: u64 = 14259017916245606401; // Max NTT length = 2^22 * 3^21 * 5^2 = 1096847532018892800 +const P1: u64 = 10_237_243_632_176_332_801; // Max NTT length = 2^24 * 3^20 * 5^2 = 1_462_463_376_025_190_400 +const P2: u64 = 13_649_658_176_235_110_401; // Max NTT length = 2^26 * 3^19 * 5^2 = 1_949_951_168_033_587_200 +const P3: u64 = 14_259_017_916_245_606_401; // Max NTT length = 2^22 * 3^21 * 5^2 = 1_096_847_532_018_892_800 const P1INV_R_MOD_P2: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod(P1, P2)); const P1P2INV_R_MOD_P3: u64 = Arith::::mmulmod( @@ -617,8 +624,7 @@ const P1_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, P1); const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; -#[allow(clippy::many_single_char_names)] -fn mac3_ntt_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { +fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { let min_len = b.len() + c.len(); let len_max_1 = plan_ntt::(min_len).0; let len_max_2 = plan_ntt::(min_len).0; @@ -669,12 +675,11 @@ fn mac3_ntt_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { let (v, overflow1) = acc[i].overflowing_add(w); let (v, overflow2) = v.overflowing_add(carry); acc[i] = v; - carry = if overflow1 || overflow2 { 1 } else { 0 }; + carry = u64::from(overflow1 || overflow2); } } -#[allow(clippy::many_single_char_names)] -fn mac3_ntt_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { +fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { let min_len = b.len() + c.len(); let len_max_1 = plan_ntt::(min_len).0; let len_max_2 = plan_ntt::(min_len).0; @@ -738,18 +743,17 @@ fn mac3_ntt_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { let (v, overflow) = acc[i].overflowing_add(out_0); acc[i] = v; - carry = out_1 as u128 + ((out_2 as u128) << 64) + if overflow { 1 } else { 0 }; + carry = out_1 as u128 + ((out_2 as u128) << 64) + u128::from(overflow); } let mut carry = carry as u64; for i in min_len..acc.len() { let (v, overflow) = acc[i].overflowing_add(carry); acc[i] = v; - carry = if overflow { 1 } else { 0 }; + carry = u64::from(overflow); } } -#[allow(clippy::many_single_char_names)] -fn mac3_ntt_u64(acc: &mut [u64], bb: &[u64], cc: &[u64], split_unbalanced: bool) { +fn mac3_u64(acc: &mut [u64], bb: &[u64], cc: &[u64], split_unbalanced: bool) { let (b, c) = if bb.len() < cc.len() { (bb, cc) } else { (cc, bb) }; if split_unbalanced && b.len() * 2 <= c.len() { /* special handling for unbalanced multiplication: @@ -761,12 +765,12 @@ fn mac3_ntt_u64(acc: &mut [u64], bb: &[u64], cc: &[u64], split_unbalanced: bool) let k = j + b.len(); let tmp = acc[k]; acc[k] = 0; - mac3_ntt_u64(&mut acc[i..k+1], b, &c[i..j], false); + mac3_u64(&mut acc[i..=k], b, &c[i..j], false); let mut l = j; while carry > 0 && l < k { let (v, overflow) = acc[l].overflowing_add(carry); acc[l] = v; - carry = if overflow { 1 } else { 0 }; + carry = u64::from(overflow); l += 1; } i = j; @@ -776,7 +780,7 @@ fn mac3_ntt_u64(acc: &mut [u64], bb: &[u64], cc: &[u64], split_unbalanced: bool) while i < acc.len() { let (v, overflow) = acc[i].overflowing_add(carry); acc[i] = v; - carry = if overflow { 1 } else { 0 }; + carry = u64::from(overflow); i += 1; } return; @@ -806,23 +810,23 @@ fn mac3_ntt_u64(acc: &mut [u64], bb: &[u64], cc: &[u64], split_unbalanced: bool) let mut c2 = vec![0u64; ((64 * c.len() as u64 + bits - 1) / bits) as usize]; let b2_len = pack_into(b, &mut b2, bits); let c2_len = pack_into(c, &mut c2, bits); - mac3_ntt_two_primes(acc, &b2[..b2_len], &c2[..c2_len], bits); + mac3_two_primes(acc, &b2[..b2_len], &c2[..c2_len], bits); } else { /* can pack at most 21 effective bits per u64, which is worse than 64/3 = 21.3333.. effective bits per u64 achieved with three primes */ - mac3_ntt_three_primes(acc, b, c); + mac3_three_primes(acc, b, c); } } //////////////////////////////////////////////////////////////////////////////// #[cfg(u64_digit)] -pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { - mac3_ntt_u64(acc, b, c, true); +pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { + mac3_u64(acc, b, c, true); } #[cfg(not(u64_digit))] -pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { +pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { fn bigdigit_to_u64(src: &[BigDigit]) -> crate::biguint::Vec:: { let mut out = vec![0u64; (src.len() + 1) / 2]; for i in 0..src.len()/2 { @@ -847,6 +851,6 @@ pub fn mac3_ntt(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { let mut acc_u64 = bigdigit_to_u64(acc); let b_u64 = bigdigit_to_u64(b); let c_u64 = bigdigit_to_u64(c); - mac3_ntt_u64(&mut acc_u64, &b_u64, &c_u64, true); + mac3_u64(&mut acc_u64, &b_u64, &c_u64, true); u64_to_bigdigit(&acc_u64, acc); } \ No newline at end of file From 09edfac05089df16782e06bb56d9918eb746add4 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 28 Aug 2023 14:37:02 +0900 Subject: [PATCH 06/65] Fix multiplication overflow on 32bit --- src/biguint/ntt.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 4501118d..9e1d8800 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -585,8 +585,10 @@ fn plan_ntt(min_len: usize) -> (usize, usize) { if cost < len_max_cost { (len_max, len_max_cost) = (len, cost); } } len35 *= 3; + if len35 >= min_len { break; } } len5 *= 5; + if len5 >= min_len { break; } } (len_max, len_max_cost) } From 6700b64d34a82bd1187d470fa98e22753bc843b3 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 28 Aug 2023 21:54:27 +0900 Subject: [PATCH 07/65] Speed up unbalanced multiplication (2) --- src/biguint/ntt.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 9e1d8800..010c9f57 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -755,9 +755,12 @@ fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { } } -fn mac3_u64(acc: &mut [u64], bb: &[u64], cc: &[u64], split_unbalanced: bool) { - let (b, c) = if bb.len() < cc.len() { (bb, cc) } else { (cc, bb) }; - if split_unbalanced && b.len() * 2 <= c.len() { +fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64], split_unbalanced: bool) { + let (b, c) = if b.len() < c.len() { (b, c) } else { (c, b) }; + let naive_cost = plan_ntt::(b.len() + c.len()).1 * 3; + let split_cost = plan_ntt::(b.len() + b.len()).1 * (2 * c.len() / b.len() + 1) + + if c.len() % b.len() > 0 { plan_ntt::(b.len() + (c.len() % b.len())).1 * 3 } else { 0 }; + if split_unbalanced && split_cost < naive_cost { /* special handling for unbalanced multiplication: we reduce it to about `c.len()/b.len()` balanced multiplications */ let mut i = 0usize; From dc73b87661cb15f267d4ffb608e8b88b9283cdd9 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 28 Aug 2023 22:36:11 +0900 Subject: [PATCH 08/65] Adjust NTT threshold for u32 digits --- src/biguint/multiplication.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/biguint/multiplication.rs b/src/biguint/multiplication.rs index 479d5cb7..219cc593 100644 --- a/src/biguint/multiplication.rs +++ b/src/biguint/multiplication.rs @@ -219,7 +219,7 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { } NoSign => (), } - } else if x.len() <= 512 { + } else if x.len() <= if cfg!(u64_digit) { 512 } else { 2048 } { // Toom-3 multiplication: // // Toom-3 is like Karatsuba above, but dividing the inputs into three parts. From c803e43e2c49d267747b930a0d9ffb9ac93f917b Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 28 Aug 2023 22:36:59 +0900 Subject: [PATCH 09/65] Add more benchmarks for large integers --- benches/bigint.rs | 15 +++++++++++++++ benches/factorial.rs | 30 ++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/benches/bigint.rs b/benches/bigint.rs index 80ec191c..3bfea576 100644 --- a/benches/bigint.rs +++ b/benches/bigint.rs @@ -87,6 +87,21 @@ fn multiply_3(b: &mut Bencher) { multiply_bench(b, 1 << 16, 1 << 17); } +#[bench] +fn multiply_4(b: &mut Bencher) { + multiply_bench(b, 100_000, 1_003_741); +} + +#[bench] +fn multiply_5(b: &mut Bencher) { + multiply_bench(b, 2_718_328, 2_738_633); +} + +#[bench] +fn multiply_6(b: &mut Bencher) { + multiply_bench(b, 27_183_279, 27_386_321); +} + #[bench] fn divide_0(b: &mut Bencher) { divide_bench(b, 1 << 8, 1 << 6); diff --git a/benches/factorial.rs b/benches/factorial.rs index a1e7b3cf..61e91ab5 100644 --- a/benches/factorial.rs +++ b/benches/factorial.rs @@ -16,6 +16,36 @@ fn factorial_mul_biguint(b: &mut Bencher) { }); } +fn factorial_product(l: usize, r: usize) -> BigUint { + if l >= r { + BigUint::from(l) + } else { + let m = (l+r)/2; + factorial_product(l, m) * factorial_product(m+1, r) + } +} + +#[bench] +fn factorial_mul_biguint_dnc_10k(b: &mut Bencher) { + b.iter(|| { + factorial_product(1, 10_000); + }); +} + +#[bench] +fn factorial_mul_biguint_dnc_100k(b: &mut Bencher) { + b.iter(|| { + factorial_product(1, 100_000); + }); +} + +#[bench] +fn factorial_mul_biguint_dnc_300k(b: &mut Bencher) { + b.iter(|| { + factorial_product(1, 300_000); + }); +} + #[bench] fn factorial_mul_u32(b: &mut Bencher) { b.iter(|| (1u32..1000).fold(BigUint::one(), Mul::mul)); From 701bdbc0e0cd0a74e555608a637206d34a091ff0 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 28 Aug 2023 22:48:50 +0900 Subject: [PATCH 10/65] Update three-prime threshold (44 -> 43) --- src/biguint/ntt.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 010c9f57..bbf4d6fc 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -791,11 +791,27 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64], split_unbalanced: bool) { return; } + // We have two choices: + // 1. NTT with two primes. + // 2. NTT with three primes. + // Obviously we want to do only two passes for efficiency, not three. + // However, the number of bits per u64 we can pack for NTT + // depends on the length of the arrays being multiplied (convolved). + // If the arrays are too long, the resulting values may exceed the + // modulus range P1 * P2, which leads to incorrect results. + // Hence, we compute the number of bits required by the length of NTT, + // and use it to determine whether to use two-prime or three-prime. + // Since we can pack 64 bits per u64 in three-prime NTT, the effective + // number of bits in three-prime NTT is 64/3 = 21.3333..., which means + // two-prime NTT can only do better when at least 43 bits per u64 can + // be packed into each u64. + // Finally note that there should be no issues with overflow since + // 2^126 * 64 / 43 < 1.3 * 10^38 < P1 * P2. let max_cnt = max(b.len(), c.len()) as u64; let mut bits = 0u64; while 1u64 << (2*bits) < max_cnt { bits += 1; } bits = 63 - bits; - if bits >= 44 { + if bits >= 43 { /* can pack more effective bits per u64 with two primes than with three primes */ fn pack_into(src: &[u64], dst: &mut [u64], bits: u64) -> usize { let (mut j, mut p) = (0usize, 0u64); From a888b8eeca85efc8949a4ac8a5051d63afb543eb Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 28 Aug 2023 23:56:57 +0900 Subject: [PATCH 11/65] Update multiplication.rs --- src/biguint/multiplication.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/biguint/multiplication.rs b/src/biguint/multiplication.rs index 219cc593..ea4f8065 100644 --- a/src/biguint/multiplication.rs +++ b/src/biguint/multiplication.rs @@ -352,7 +352,7 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { // Number-theoretic transform (NTT) multiplication: // // NTT multiplies two integers by computing the convolution of the arrays - // modulo a prime. Since the result may exceed the prime, we use three + // modulo a prime. Since the result may exceed the prime, we use two or three // distinct primes and combine the results using the Chinese Remainder // Theroem (CRT). ntt::mac3(acc, b, c); From d9970e01e74581e6930042cca8db27ce2e45f02d Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 29 Aug 2023 11:32:39 +0900 Subject: [PATCH 12/65] Speed up unbalanced multiplication (3) --- src/biguint/ntt.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index bbf4d6fc..199daba5 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -755,12 +755,12 @@ fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { } } -fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64], split_unbalanced: bool) { +fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { let (b, c) = if b.len() < c.len() { (b, c) } else { (c, b) }; let naive_cost = plan_ntt::(b.len() + c.len()).1 * 3; - let split_cost = plan_ntt::(b.len() + b.len()).1 * (2 * c.len() / b.len() + 1) + let split_cost = plan_ntt::(b.len() + b.len()).1 * 3 * (c.len() / b.len()) + if c.len() % b.len() > 0 { plan_ntt::(b.len() + (c.len() % b.len())).1 * 3 } else { 0 }; - if split_unbalanced && split_cost < naive_cost { + if b.len() >= 128 && split_cost < naive_cost { /* special handling for unbalanced multiplication: we reduce it to about `c.len()/b.len()` balanced multiplications */ let mut i = 0usize; @@ -770,7 +770,7 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64], split_unbalanced: bool) { let k = j + b.len(); let tmp = acc[k]; acc[k] = 0; - mac3_u64(&mut acc[i..=k], b, &c[i..j], false); + mac3_u64(&mut acc[i..=k], b, &c[i..j]); let mut l = j; while carry > 0 && l < k { let (v, overflow) = acc[l].overflowing_add(carry); @@ -843,7 +843,7 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64], split_unbalanced: bool) { #[cfg(u64_digit)] pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { - mac3_u64(acc, b, c, true); + mac3_u64(acc, b, c); } #[cfg(not(u64_digit))] @@ -872,6 +872,6 @@ pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { let mut acc_u64 = bigdigit_to_u64(acc); let b_u64 = bigdigit_to_u64(b); let c_u64 = bigdigit_to_u64(c); - mac3_u64(&mut acc_u64, &b_u64, &c_u64, true); + mac3_u64(&mut acc_u64, &b_u64, &c_u64); u64_to_bigdigit(&acc_u64, acc); } \ No newline at end of file From e441c9250096b7cd065ae1255840605fe7daa4c0 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 7 Sep 2023 16:33:21 +0900 Subject: [PATCH 13/65] Add DIF-DIT, optimize CRT, etc. --- src/biguint/ntt.rs | 876 +++++++++++++++++++++++++-------------------- 1 file changed, 480 insertions(+), 396 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 199daba5..69b7ca6f 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -5,6 +5,8 @@ #![allow(clippy::needless_range_loop)] #![allow(clippy::similar_names)] +use crate::biguint::Vec; + mod arith { // Extended Euclid algorithm: // (g, x, y) is a solution to ax + by = g, where g = gcd(a, b) @@ -125,15 +127,19 @@ impl Arith

{ } cur as u64 } - // Multiplication with Montgomery reduction: - // a * b * R^-1 mod P - pub const fn mmulmod(a: u64, b: u64) -> u64 { - let x = a as u128 * b as u128; + // Montgomery reduction: + // x * R^-1 mod P + pub const fn mreduce(x: u128) -> u64 { let m = (x as u64).wrapping_mul(Self::PINV); let y = ((m as u128 * P as u128) >> 64) as u64; let (out, overflow) = ((x >> 64) as u64).overflowing_sub(y); if overflow { out.wrapping_add(P) } else { out } } + // Multiplication with Montgomery reduction: + // a * b * R^-1 mod P + pub const fn mmulmod(a: u64, b: u64) -> u64 { + Self::mreduce(a as u128 * b as u128) + } pub const fn mmulmod_cond(a: u64, b: u64) -> u64 { if INV { Self::mmulmod(a, b) } else { b } } @@ -141,21 +147,17 @@ impl Arith

{ // a * b * R^-1 + c mod P pub const fn mmuladdmod(a: u64, b: u64, c: u64) -> u64 { let x = a as u128 * b as u128; + let lo = x as u64; let hi = Self::addmod((x >> 64) as u64, c); - let m = (x as u64).wrapping_mul(Self::PINV); - let y = ((m as u128 * P as u128) >> 64) as u64; - let (out, overflow) = hi.overflowing_sub(y); - if overflow { out.wrapping_add(P) } else { out } + Self::mreduce(lo as u128 | ((hi as u128) << 64)) } // Fused-multiply-sub with Montgomery reduction: // a * b * R^-1 - c mod P pub const fn mmulsubmod(a: u64, b: u64, c: u64) -> u64 { let x = a as u128 * b as u128; + let lo = x as u64; let hi = Self::submod((x >> 64) as u64, c); - let m = (x as u64).wrapping_mul(Self::PINV); - let y = ((m as u128 * P as u128) >> 64) as u64; - let (out, overflow) = hi.overflowing_sub(y); - if overflow { out.wrapping_add(P) } else { out } + Self::mreduce(lo as u128 | ((hi as u128) << 64)) } // Computes base^exponent mod P with Montgomery reduction pub const fn mpowmod(base: u64, exponent: u64) -> u64 { @@ -191,9 +193,105 @@ impl Arith

{ } } -struct NttKernelImpl; -impl NttKernelImpl { +struct NttPlan { + pub n: usize, // n == g*m + pub g: usize, // g <= NttPlan::GMAX + pub m: usize, // m divides Arith::

::MAX_NTT_LEN + pub cost: usize, + pub last_radix: usize, + pub s_list: Vec<(usize, usize)>, +} +impl NttPlan { + pub const GMAX: usize = 6; + pub fn build(min_len: usize) -> NttPlan { + assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); + let (mut len_max, mut len_max_cost) = (0usize, usize::MAX); + let mut len5 = 10; + for _ in 0..Arith::

::FACTOR_FIVE+1 { + let mut len35 = len5; + for _ in 0..Arith::

::FACTOR_THREE+1 { + let mut len = len35; + let mut i = 0; + while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } + if len >= min_len && len < len_max_cost { + let (mut tmp, mut cost) = (len, 0); + while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len); } + while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len + len/5); } + while tmp % 4 == 0 { (tmp, cost) = (tmp/4, cost + len); } + while tmp % 3 == 0 { (tmp, cost) = (tmp/3, cost + len); } + while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); } + if cost < len_max_cost { (len_max, len_max_cost) = (len, cost); } + } + len35 *= 3; + } + len5 *= 5; + } + let (mut cnt6, mut cnt5, mut cnt4, mut cnt3, mut cnt2) = (0, 0, 0, 0, 0); + let mut tmp = len_max; + while tmp % 6 == 0 { tmp /= 6; cnt6 += 1; } + while tmp % 5 == 0 { tmp /= 5; cnt5 += 1; } + while tmp % 4 == 0 { tmp /= 4; cnt4 += 1; } + while tmp % 3 == 0 { tmp /= 3; cnt3 += 1; } + while tmp % 2 == 0 { tmp /= 2; cnt2 += 1; } + let mut g = 1; + while 5*g <= Self::GMAX && cnt5 > 0 { g *= 5; cnt5 -= 1; } + while 9*g <= Self::GMAX && cnt3 >= 2 { g *= 9; cnt3 -= 2; } + while 8*g <= Self::GMAX && cnt4 > 0 && cnt2 > 0 { g *= 8; cnt4 -= 1; cnt2 -= 1; } + while 6*g <= Self::GMAX && cnt6 > 0 { g *= 6; cnt6 -= 1; } + while 4*g <= Self::GMAX && cnt4 > 0 { g *= 4; cnt4 -= 1; } + while 3*g <= Self::GMAX && cnt3 > 0 { g *= 3; cnt3 -= 1; } + while 2*g <= Self::GMAX && cnt2 > 0 { g *= 2; cnt2 -= 1; } + while cnt6 > 0 && cnt2 > 0 { cnt6 -= 1; cnt2 -= 1; cnt4 += 1; cnt3 += 1; } + let s_list = { + let mut out = vec![]; + let mut tmp = len_max; + for _ in 0..cnt2 { out.push((tmp, 2)); tmp /= 2; } + for _ in 0..cnt3 { out.push((tmp, 3)); tmp /= 3; } + for _ in 0..cnt4 { out.push((tmp, 4)); tmp /= 4; } + for _ in 0..cnt5 { out.push((tmp, 5)); tmp /= 5; } + for _ in 0..cnt6 { out.push((tmp, 6)); tmp /= 6; } + out + }; + NttPlan { + n: len_max, + g: g, + m: len_max / g, + cost: len_max_cost, + last_radix: s_list.last().unwrap_or(&(1, 1)).1, + s_list: s_list, + } + } +} + +fn conv_base(n: usize, x: *mut u64, y: *mut u64, buf: *mut u64, c: u64, mult: u64) { + unsafe { + for i in 0..n { + *buf.wrapping_add(i) = Arith::

::mmulmod(*x.wrapping_add(i), mult); + } + for i in 0..n { + let mut v1: u128 = 0; + for j in 0..=i { + let (w, overflow) = v1.overflowing_add(*buf.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); + v1 = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; + } + let mut v2: u128 = 0; + for j in i+1..n { + let (w, overflow) = v2.overflowing_add(*buf.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); + v2 = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; + } + if v1 >= (P as u128) << 64 { v1 = v1.wrapping_sub((P as u128) << 64); } + if v2 >= (P as u128) << 64 { v2 = v2.wrapping_sub((P as u128) << 64); } + let u1 = Arith::

::mreduce(v1); + let u2 = Arith::

::mreduce(v2); + *x.wrapping_add(i) = Arith::

::mmuladdmod(c, u2, u1); + } + } +} + +struct NttKernelImpl; +impl NttKernelImpl { pub const ROOTR: u64 = Arith::

::mpowmod(Arith::

::ROOTR, if INV { Arith::

::MAX_NTT_LEN - 1 } else { 1 }); + pub const U2: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/2); // U2 == P - Arith::

::R pub const U3: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/3); pub const U4: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/4); pub const U5: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/5); @@ -217,392 +315,374 @@ impl NttKernelImpl NttKernelImpl { - unsafe fn apply<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64]) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { - let mut src = x.as_ptr(); - let mut dst = y.as_mut_ptr(); - let omega1 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/n as u64); - let (n1, n1s) = (n/2, n/2*s); - let mut w1p = Arith::

::R; - for _ in 0..n1 { - for _ in 0..s { - let a = *src.wrapping_add(0); - let b = *src.wrapping_add(n1s); - *dst.wrapping_add(0) = Arith::

::addmod(a, b); - *dst.wrapping_add(s) = Arith::

::mmulmod(w1p, Arith::

::submod(a, b)); - src = src.wrapping_add(1); - dst = dst.wrapping_add(1); - } - dst = dst.wrapping_add(s); - w1p = Arith::

::mmulmod(w1p, omega1); - } - (n/2, s*2, !eo, y, x) - } - unsafe fn apply_last<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64], mult: u64) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { - assert_eq!(n, 2); - let mut src = x.as_ptr(); - let mut dst = if eo { y.as_mut_ptr() } else { x.as_mut_ptr() }; - for _ in 0..s { - let a = *src.wrapping_add(0); - let b = *src.wrapping_add(s); - *dst.wrapping_add(0) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(a, b)); - *dst.wrapping_add(s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(a, b)); - src = src.wrapping_add(1); - dst = dst.wrapping_add(1); - } - if eo { (n/2, s*2, !eo, y, x) } else { (n/2, s*2, eo, x, y) } +const fn ntt2_kernel_core( + w1p: u64, + a: u64, mut b: u64) -> (u64, u64) { + if !INV && TWIDDLE { + b = Arith::

::mmulmod(w1p, b); + } + let out0 = Arith::

::addmod(a, b); + let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::submod(a, b)); + (out0, out1) +} +const fn ntt2_kernel( + w1p: u64, + a: u64, b: u64) -> (u64, u64) { + match (INV, TWIDDLE) { + (_, false) => ntt2_kernel_core::(w1p, a, b), + (false, true) => ntt2_kernel_core::(w1p, a, b), + (true, true) => ntt2_kernel_core::(w1p, a, b) } } - -impl NttKernelImpl { - unsafe fn apply<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64]) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { - let mut src = x.as_ptr(); - let mut dst = y.as_mut_ptr(); - let omega1 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/n as u64); - let (n1, n1s) = (n/3, n/3*s); - let (mut w1p, mut w2p) = (Arith::

::R, Arith::

::R); - for _ in 0..n1 { - for _ in 0..s { - let a = *src.wrapping_add(0); - let b = *src.wrapping_add(n1s); - let c = *src.wrapping_add(2*n1s); - let kbmc = Arith::

::mmulmod(Self::U3, Arith::

::submod(b, c)); - *dst.wrapping_add(0) = Arith::

::addmod(a, Arith::

::addmod(b, c)); - *dst.wrapping_add(s) = Arith::

::mmulmod(w1p, Arith::

::addmod64(Arith::

::submod(a, c), kbmc)); - *dst.wrapping_add(2*s) = Arith::

::mmulmod(w2p, Arith::

::submod(Arith::

::submod(a, b), kbmc)); - src = src.wrapping_add(1); - dst = dst.wrapping_add(1); - } - dst = dst.wrapping_add(2*s); - w1p = Arith::

::mmulmod(w1p, omega1); - w2p = Arith::

::mmulmod(w1p, w1p); - } - (n/3, s*3, !eo, y, x) - } - unsafe fn apply_last<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64], mult: u64) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { - assert_eq!(n, 3); - let mut src = x.as_ptr(); - let mut dst = if eo { y.as_mut_ptr() } else { x.as_mut_ptr() }; - for _ in 0..s { - let a = *src.wrapping_add(0); - let b = *src.wrapping_add(s); - let c = *src.wrapping_add(2*s); - let kbmc = Arith::

::mmulmod(Self::U3, Arith::

::submod(b, c)); - *dst.wrapping_add(0) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(a, Arith::

::addmodopt::(b, c))); - *dst.wrapping_add(s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(Arith::

::submod(a, c), kbmc)); - *dst.wrapping_add(2*s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(Arith::

::submod(a, b), kbmc)); - src = src.wrapping_add(1); - dst = dst.wrapping_add(1); +fn ntt2_single_block( + s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { + unsafe { + let w1p = if TWIDDLE { *ptf } else { 0 }; + for _ in 0..s1 { + (*px, *px.wrapping_add(s1)) = + ntt2_kernel::(w1p, + *px, *px.wrapping_add(s1)); + px = px.wrapping_add(1); } - if eo { (n/3, s*3, !eo, y, x) } else { (n/3, s*3, eo, x, y) } } + (px.wrapping_add(s1), ptf.wrapping_add(1)) } - -impl NttKernelImpl { - unsafe fn apply<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64]) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { - let mut src = x.as_ptr(); - let mut dst = y.as_mut_ptr(); - let omega1 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/n as u64); - let (n1, n1s) = (n/4, n/4*s); - let (mut w1p, mut w2p, mut w3p) = (Arith::

::R, Arith::

::R, P.wrapping_sub(Self::U4)); - for _ in 0..n1 { - for _ in 0..s { - let a = *src.wrapping_add(0); - let b = *src.wrapping_add(n1s); - let c = *src.wrapping_add(2*n1s); - let d = *src.wrapping_add(3*n1s); - let apc = Arith::

::addmod(a, c); - let amc = Arith::

::mmulmod(w1p, Arith::

::submod(a, c)); - let bpd = Arith::

::addmod(b, d); - let bmd = Arith::

::submod(b, d); - let jbmd = Arith::

::mmulmod(w3p, bmd); - *dst.wrapping_add(0) = Arith::

::addmod(apc, bpd); - *dst.wrapping_add(s) = Arith::

::submod(amc, jbmd); - *dst.wrapping_add(2*s) = Arith::

::mmulmod(w2p, Arith::

::submod(apc, bpd)); - *dst.wrapping_add(3*s) = Arith::

::mmulmod(w2p, Arith::

::addmod64(amc, jbmd)); - src = src.wrapping_add(1); - dst = dst.wrapping_add(1); - } - dst = dst.wrapping_add(3*s); - w1p = Arith::

::mmulmod(w1p, omega1); - w2p = Arith::

::mmulmod(w1p, w1p); - w3p = Arith::

::mmulmod(w3p, omega1); - } - (n/4, s*4, !eo, y, x) - } - unsafe fn apply_last<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64], mult: u64) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { - assert_eq!(n, 4); - let mut src = x.as_ptr(); - let mut dst = if eo { y.as_mut_ptr() } else { x.as_mut_ptr() }; - for _ in 0..s { - let a = *src.wrapping_add(0); - let b = *src.wrapping_add(s); - let c = *src.wrapping_add(2*s); - let d = *src.wrapping_add(3*s); - let apc = Arith::

::addmod(a, c); - let amc = Arith::

::submod(a, c); - let bpd = Arith::

::addmod(b, d); - let bmd = Arith::

::submod(b, d); - let jbmd = Arith::

::mmulmod(bmd, P.wrapping_sub(Self::U4)); - *dst.wrapping_add(0) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(apc, bpd)); - *dst.wrapping_add(s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(amc, jbmd)); - *dst.wrapping_add(2*s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(apc, bpd)); - *dst.wrapping_add(3*s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(amc, jbmd)); - src = src.wrapping_add(1); - dst = dst.wrapping_add(1); - } - if eo { (n/4, s*4, !eo, y, x) } else { (n/4, s*4, eo, x, y) } +const fn ntt3_kernel_core( + w1p: u64, w2p: u64, + a: u64, mut b: u64, mut c: u64) -> (u64, u64, u64) { + if !INV && TWIDDLE { + b = Arith::

::mmulmod(w1p, b); + c = Arith::

::mmulmod(w2p, c); + } + let kbmc = Arith::

::mmulmod(NttKernelImpl::::U3, Arith::

::submod(b, c)); + let out0 = Arith::

::addmod(a, Arith::

::addmod(b, c)); + let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::addmodopt::(Arith::

::submod(a, c), kbmc)); + let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::submod(Arith::

::submod(a, b), kbmc)); + (out0, out1, out2) +} +const fn ntt3_kernel( + w1p: u64, w2p: u64, + a: u64, b: u64, c: u64) -> (u64, u64, u64) { + match (INV, TWIDDLE) { + (_, false) => ntt3_kernel_core::(w1p, w2p, a, b, c), + (false, true) => ntt3_kernel_core::(w1p, w2p, a, b, c), + (true, true) => ntt3_kernel_core::(w1p, w2p, a, b, c) } } - -impl NttKernelImpl { - unsafe fn apply<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64]) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { - let mut src = x.as_ptr(); - let mut dst = y.as_mut_ptr(); - let omega1 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/n as u64); - let (n1, n1s) = (n/5, n/5*s); - let (mut w1p, mut w2p, mut w3p, mut w4p) = (Arith::

::R, Arith::

::RNEG, Arith::

::RNEG, Arith::

::R); - for _ in 0..n1 { - for _ in 0..s { - let a = *src.wrapping_add(0); - let b = *src.wrapping_add(n1s); - let c = *src.wrapping_add(2*n1s); - let d = *src.wrapping_add(3*n1s); - let e = *src.wrapping_add(4*n1s); - let t1 = Arith::

::addmod(b, e); - let t2 = Arith::

::addmod(c, d); - let t3 = Arith::

::submod(b, e); - let t4= Arith::

::submod(d, c); - let t5 = Arith::

::addmod(t1, t2); - let t6 = Arith::

::submod(t1, t2); - let t7 = Arith::

::addmod64(t3, t4); - let m1 = Arith::

::addmod(a, t5); - let m2 = Arith::

::mmulsubmod(P.wrapping_sub(Self::C51), t5, m1); - let m3 = Arith::

::mmulmod(Self::C52, t6); - let m4 = Arith::

::mmulmod(Self::C53, t7); - let m5 = Arith::

::mmulsubmod(Self::C54, t4, m4); - let m6 = Arith::

::mmulsubmod(P.wrapping_sub(Self::C55), t3, m4); - let s2 = Arith::

::submod(m3, m2); - let s4 = Arith::

::addmod64(m2, m3); - *dst.wrapping_add(0) = m1; - *dst.wrapping_add(s) = Arith::

::mmulmod(w1p, Arith::

::submod(s2, m5)); - *dst.wrapping_add(2*s) = Arith::

::mmulmod(w2p, Arith::

::addmod64(s4, m6)); - *dst.wrapping_add(3*s) = Arith::

::mmulmod(w3p, Arith::

::submod(s4, m6)); - *dst.wrapping_add(4*s) = Arith::

::mmulmod(w4p, Arith::

::addmod64(s2, m5)); - src = src.wrapping_add(1); - dst = dst.wrapping_add(1); - } - dst = dst.wrapping_add(4*s); - w1p = Arith::

::mmulmod(w1p, omega1); - w2p = Arith::

::mmulmod(w1p, P.wrapping_sub(w1p)); - w3p = Arith::

::mmulmod(w1p, w2p); - w4p = Arith::

::mmulmod(w2p, w2p); - } - (n/5, s*5, !eo, y, x) - } - unsafe fn apply_last<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64], mult: u64) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { - assert_eq!(n, 5); - let mut src = x.as_ptr(); - let mut dst = if eo { y.as_mut_ptr() } else { x.as_mut_ptr() }; - for _ in 0..s { - let a = *src.wrapping_add(0); - let b = *src.wrapping_add(s); - let c = *src.wrapping_add(2*s); - let d = *src.wrapping_add(3*s); - let e = *src.wrapping_add(4*s); - let t1 = Arith::

::addmod(b, e); - let t2 = Arith::

::addmod(c, d); - let t3 = Arith::

::submod(b, e); - let t4= Arith::

::submod(d, c); - let t5 = Arith::

::addmod(t1, t2); - let t6 = Arith::

::submod(t1, t2); - let t7 = Arith::

::addmod64(t3, t4); - let m1 = Arith::

::addmod(a, t5); - let m2 = Arith::

::mmuladdmod(Self::C51, t5, m1); - let m3 = Arith::

::mmulmod(Self::C52, t6); - let m4 = Arith::

::mmulmod(Self::C53, t7); - let m5 = Arith::

::mmulsubmod(Self::C54, t4, m4); - let m6 = Arith::

::mmulsubmod(P.wrapping_sub(Self::C55), t3, m4); - let s2 = Arith::

::addmod(m2, m3); - let s4 = Arith::

::submod(m2, m3); - *dst.wrapping_add(0) = Arith::

::mmulmod_cond::(mult, m1); - *dst.wrapping_add(s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(s2, m5)); - *dst.wrapping_add(2*s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(s4, m6)); - *dst.wrapping_add(3*s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(s4, m6)); - *dst.wrapping_add(4*s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(s2, m5)); - src = src.wrapping_add(1); - dst = dst.wrapping_add(1); +fn ntt3_single_block( + s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { + unsafe { + let (w1p, w2p) = if TWIDDLE { + let w1p = *ptf; + let w2p = Arith::

::mmulmod(w1p, w1p); + (w1p, w2p) + } else { + (0, 0) + }; + for _ in 0..s1 { + (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1)) = + ntt3_kernel::(w1p, w2p, + *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1)); + px = px.wrapping_add(1); } - if eo { (n/5, s*5, !eo, y, x) } else { (n/5, s*5, eo, x, y) } } + (px.wrapping_add(2*s1), ptf.wrapping_add(1)) } - -impl NttKernelImpl { - unsafe fn apply<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64]) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { - let mut src = x.as_ptr(); - let mut dst = y.as_mut_ptr(); - let omega1 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/n as u64); - let (n1, n1s) = (n/6, n/6*s); - let (mut w1p, mut w2p, mut w3p, mut w4p, mut w5p) = (Arith::

::R, Arith::

::R, Arith::

::R, Arith::

::R, Arith::

::R); - for _ in 0..n1 { - for _ in 0..s { - let mut a = *src.wrapping_add(0); - let mut b = *src.wrapping_add(n1s); - let mut c = *src.wrapping_add(2*n1s); - let mut d = *src.wrapping_add(3*n1s); - let mut e = *src.wrapping_add(4*n1s); - let mut f = *src.wrapping_add(5*n1s); - (a, d) = (Arith::

::addmod(a, d), Arith::

::submod(a, d)); - (b, e) = (Arith::

::addmod(b, e), Arith::

::submod(b, e)); - (c, f) = (Arith::

::addmod(c, f), Arith::

::submod(c, f)); - let lbmc = Arith::

::mmulmod(Self::U6, Arith::

::submod(b, c)); - *dst.wrapping_add(0) = Arith::

::addmod(a, Arith::

::addmod(b, c)); - *dst.wrapping_add(2*s) = Arith::

::mmulmod(w2p, Arith::

::addmod64(Arith::

::submod(a, b), lbmc)); - *dst.wrapping_add(4*s) = Arith::

::mmulmod(w4p, Arith::

::submod(Arith::

::submod(a, c), lbmc)); - let mlepf = Arith::

::mmulmod(P.wrapping_sub(Self::U6), Arith::

::addmod64(e, f)); - *dst.wrapping_add(s) = Arith::

::mmulmod(w1p, Arith::

::submod(Arith::

::submod(d, f), mlepf)); - *dst.wrapping_add(3*s) = Arith::

::mmulmod(w3p, Arith::

::submod(d, Arith::

::submod(e, f))); - *dst.wrapping_add(5*s) = Arith::

::mmulmod(w5p, Arith::

::addmod64(Arith::

::addmod64(d, mlepf), e)); - src = src.wrapping_add(1); - dst = dst.wrapping_add(1); - } - dst = dst.wrapping_add(5*s); - w1p = Arith::

::mmulmod(w1p, omega1); - w2p = Arith::

::mmulmod(w1p, w1p); - w3p = Arith::

::mmulmod(w1p, w2p); - w4p = Arith::

::mmulmod(w2p, w2p); - w5p = Arith::

::mmulmod(w2p, w3p); - } - (n/6, s*6, !eo, y, x) - } - unsafe fn apply_last<'a>(n: usize, s: usize, eo: bool, x: &'a mut [u64], y: &'a mut [u64], mult: u64) -> (usize, usize, bool, &'a mut [u64], &'a mut [u64]) { - assert_eq!(n, 6); - let mut src = x.as_ptr(); - let mut dst = if eo { y.as_mut_ptr() } else { x.as_mut_ptr() }; - for _ in 0..s { - let mut a = *src.wrapping_add(0); - let mut b = *src.wrapping_add(s); - let mut c = *src.wrapping_add(2*s); - let mut d = *src.wrapping_add(3*s); - let mut e = *src.wrapping_add(4*s); - let mut f = *src.wrapping_add(5*s); - (a, d) = (Arith::

::addmod(a, d), Arith::

::submod(a, d)); - (b, e) = (Arith::

::addmod(b, e), Arith::

::submod(b, e)); - (c, f) = (Arith::

::addmod(c, f), Arith::

::submod(c, f)); - let lbmc = Arith::

::mmulmod(Self::U6, Arith::

::submod(b, c)); - *dst.wrapping_add(0) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(a, Arith::

::addmodopt::(b, c))); - *dst.wrapping_add(2*s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(Arith::

::submod(a, b), lbmc)); - *dst.wrapping_add(4*s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(Arith::

::submod(a, c), lbmc)); - let lepf = Arith::

::mmulmod(Self::U6, Arith::

::addmod64(e, f)); - *dst.wrapping_add(s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(Arith::

::submod(d, f), lepf)); - *dst.wrapping_add(3*s) = Arith::

::mmulmod_cond::(mult, Arith::

::submod(d, Arith::

::submod(e, f))); - *dst.wrapping_add(5*s) = Arith::

::mmulmod_cond::(mult, Arith::

::addmodopt::(Arith::

::submod(d, lepf), e)); - src = src.wrapping_add(1); - dst = dst.wrapping_add(1); +const fn ntt4_kernel_core( + w1p: u64, w2p: u64, w3p: u64, + a: u64, mut b: u64, mut c: u64, mut d: u64) -> (u64, u64, u64, u64) { + if !INV && TWIDDLE { + b = Arith::

::mmulmod(w1p, b); + c = Arith::

::mmulmod(w2p, c); + d = Arith::

::mmulmod(w3p, d); + } + let apc = Arith::

::addmod(a, c); + let amc = Arith::

::submod(a, c); + let bpd = Arith::

::addmod(b, d); + let bmd = Arith::

::submod(b, d); + let jbmd = Arith::

::mmulmod(bmd, P.wrapping_sub(NttKernelImpl::::U4)); + let out0 = Arith::

::addmod(apc, bpd); + let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::submod(amc, jbmd)); + let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::submod(apc, bpd)); + let out3 = Arith::

::mmulmod_cond::(w3p, Arith::

::addmodopt::(amc, jbmd)); + (out0, out1, out2, out3) +} +const fn ntt4_kernel( + w1p: u64, w2p: u64, w3p: u64, + a: u64, b: u64, c: u64, d: u64) -> (u64, u64, u64, u64) { + match (INV, TWIDDLE) { + (_, false) => ntt4_kernel_core::(w1p, w2p, w3p, a, b, c, d), + (false, true) => ntt4_kernel_core::(w1p, w2p, w3p, a, b, c, d), + (true, true) => ntt4_kernel_core::(w1p, w2p, w3p, a, b, c, d) + } +} +fn ntt4_single_block( + s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { + unsafe { + let (w1p, w2p, w3p) = if TWIDDLE { + let w1p = *ptf; + let w2p = Arith::

::mmulmod(w1p, w1p); + let w3p = Arith::

::mmulmod(w1p, w2p); + (w1p, w2p, w3p) + } else { + (0, 0, 0) + }; + for _ in 0..s1 { + (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), + *px.wrapping_add(3*s1)) = + ntt4_kernel::(w1p, w2p, w3p, + *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), + *px.wrapping_add(3*s1)); + px = px.wrapping_add(1); } - if eo { (n/6, s*6, !eo, y, x) } else { (n/6, s*6, eo, x, y) } } + (px.wrapping_add(3*s1), ptf.wrapping_add(1)) } - -fn ntt_stockham(input: &mut [u64], buf: &mut [u64]) { - let (mut n, mut s, mut eo, mut x, mut y) = (input.len(), 1, false, input, buf); - assert!(Arith::

::MAX_NTT_LEN % n as u64 == 0); - let inv_p2 = Arith::

::mmulmod(Arith::

::R3, Arith::

::submod(0, (P-1)/n as u64)); - if n == 1 { - x[0] = Arith::

::mmulmod_cond::(inv_p2, x[0]); - return; +const fn ntt5_kernel_core( + w1p: u64, w2p: u64, w3p: u64, w4p: u64, + a: u64, mut b: u64, mut c: u64, mut d: u64, mut e: u64) -> (u64, u64, u64, u64, u64) { + if !INV && TWIDDLE { + b = Arith::

::mmulmod(w1p, b); + c = Arith::

::mmulmod(w2p, c); + d = Arith::

::mmulmod(w3p, d); + e = Arith::

::mmulmod(w4p, e); + } + let t1 = Arith::

::addmod(b, e); + let t2 = Arith::

::addmod(c, d); + let t3 = Arith::

::submod(b, e); + let t4= Arith::

::submod(d, c); + let t5 = Arith::

::addmod(t1, t2); + let t6 = Arith::

::submod(t1, t2); + let t7 = Arith::

::addmod64(t3, t4); + let m1 = Arith::

::addmod(a, t5); + let m2 = Arith::

::mmulsubmod(P.wrapping_sub(NttKernelImpl::::C51), t5, m1); + let m3 = Arith::

::mmulmod(NttKernelImpl::::C52, t6); + let m4 = Arith::

::mmulmod(NttKernelImpl::::C53, t7); + let m5 = Arith::

::mmulsubmod(NttKernelImpl::::C54, t4, m4); + let m6 = Arith::

::mmulsubmod(P.wrapping_sub(NttKernelImpl::::C55), t3, m4); + let s2 = Arith::

::submod(m3, m2); + let s4 = Arith::

::addmod(m2, m3); + let out0 = m1; + let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::submod(s2, m5)); + let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::submod(0, Arith::

::addmod(s4, m6))); + let out3 = Arith::

::mmulmod_cond::(w3p, Arith::

::submod(m6, s4)); + let out4 = Arith::

::mmulmod_cond::(w4p, Arith::

::addmodopt::(s2, m5)); + (out0, out1, out2, out3, out4) +} +const fn ntt5_kernel( + w1p: u64, w2p: u64, w3p: u64, w4p: u64, + a: u64, b: u64, c: u64, d: u64, e: u64) -> (u64, u64, u64, u64, u64) { + match (INV, TWIDDLE) { + (_, false) => ntt5_kernel_core::(w1p, w2p, w3p, w4p, a, b, c, d, e), + (false, true) => ntt5_kernel_core::(w1p, w2p, w3p, w4p, a, b, c, d, e), + (true, true) => ntt5_kernel_core::(w1p, w2p, w3p, w4p, a, b, c, d, e) } - let (mut cnt6, mut cnt5, mut cnt4, mut cnt3, mut cnt2) = (0, 0, 0, 0, 0); - let mut tmp = n; - while tmp % 6 == 0 { tmp /= 6; cnt6 += 1; } - while tmp % 5 == 0 { tmp /= 5; cnt5 += 1; } - while tmp % 4 == 0 { tmp /= 4; cnt4 += 1; } - while tmp % 3 == 0 { tmp /= 3; cnt3 += 1; } - while tmp % 2 == 0 { tmp /= 2; cnt2 += 1; } - while cnt6 > 0 && cnt2 > 0 { cnt6 -= 1; cnt2 -= 1; cnt4 += 1; cnt3 += 1; } +} +fn ntt5_single_block( + s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { - while cnt2 > 0 { - (n, s, eo, x, y) = if n > 2 { - NttKernelImpl::::apply(n, s, eo, x, y) - } else { - NttKernelImpl::::apply_last(n, s, eo, x, y, inv_p2) - }; - cnt2 -= 1; - } - while cnt3 > 0 { - (n, s, eo, x, y) = if n > 3 { - NttKernelImpl::::apply(n, s, eo, x, y) - } else { - NttKernelImpl::::apply_last(n, s, eo, x, y, inv_p2) - }; - cnt3 -= 1; - } - while cnt4 > 0 { - (n, s, eo, x, y) = if n > 4 { - NttKernelImpl::::apply(n, s, eo, x, y) - } else { - NttKernelImpl::::apply_last(n, s, eo, x, y, inv_p2) - }; - cnt4 -= 1; + let (w1p, w2p, w3p, w4p) = if TWIDDLE { + let w1p = *ptf; + let w2p = Arith::

::mmulmod(w1p, w1p); + let w3p = Arith::

::mmulmod(w1p, w2p); + let w4p = Arith::

::mmulmod(w2p, w2p); + (w1p, w2p, w3p, w4p) + } else { + (0, 0, 0, 0) + }; + for _ in 0..s1 { + (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), + *px.wrapping_add(3*s1), *px.wrapping_add(4*s1)) = + ntt5_kernel::(w1p, w2p, w3p, w4p, + *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), + *px.wrapping_add(3*s1), *px.wrapping_add(4*s1)); + px = px.wrapping_add(1); } - while cnt5 > 0 { - (n, s, eo, x, y) = if n > 5 { - NttKernelImpl::::apply(n, s, eo, x, y) - } else { - NttKernelImpl::::apply_last(n, s, eo, x, y, inv_p2) - }; - cnt5 -= 1; + } + (px.wrapping_add(4*s1), ptf.wrapping_add(1)) +} +const fn ntt6_kernel_core( + w1p: u64, w2p: u64, w3p: u64, w4p: u64, w5p: u64, + mut a: u64, mut b: u64, mut c: u64, mut d: u64, mut e: u64, mut f: u64) -> (u64, u64, u64, u64, u64, u64) { + if !INV && TWIDDLE { + b = Arith::

::mmulmod(w1p, b); + c = Arith::

::mmulmod(w2p, c); + d = Arith::

::mmulmod(w3p, d); + e = Arith::

::mmulmod(w4p, e); + f = Arith::

::mmulmod(w5p, f); + } + (a, d) = (Arith::

::addmod(a, d), Arith::

::submod(a, d)); + (b, e) = (Arith::

::addmod(b, e), Arith::

::submod(b, e)); + (c, f) = (Arith::

::addmod(c, f), Arith::

::submod(c, f)); + let lbmc = Arith::

::mmulmod(NttKernelImpl::::U6, Arith::

::submod(b, c)); + let out0 = Arith::

::addmod(a, Arith::

::addmod(b, c)); + let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::addmodopt::(Arith::

::submod(a, b), lbmc)); + let out4 = Arith::

::mmulmod_cond::(w4p, Arith::

::submod(Arith::

::submod(a, c), lbmc)); + let lepf = Arith::

::mmulmod(NttKernelImpl::::U6, Arith::

::addmod64(e, f)); + let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::addmodopt::(Arith::

::submod(d, f), lepf)); + let out3 = Arith::

::mmulmod_cond::(w3p, Arith::

::submod(d, Arith::

::submod(e, f))); + let out5 = Arith::

::mmulmod_cond::(w5p, Arith::

::addmodopt::(Arith::

::submod(d, lepf), e)); + (out0, out1, out2, out3, out4, out5) +} +const fn ntt6_kernel( + w1p: u64, w2p: u64, w3p: u64, w4p: u64, w5p: u64, + a: u64, b: u64, c: u64, d: u64, e: u64, f: u64) -> (u64, u64, u64, u64, u64, u64) { + match (INV, TWIDDLE) { + (_, false) => ntt6_kernel_core::(w1p, w2p, w3p, w4p, w5p, a, b, c, d, e, f), + (false, true) => ntt6_kernel_core::(w1p, w2p, w3p, w4p, w5p, a, b, c, d, e, f), + (true, true) => ntt6_kernel_core::(w1p, w2p, w3p, w4p, w5p, a, b, c, d, e, f) + } +} +fn ntt6_single_block( + s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { + unsafe { + let (w1p, w2p, w3p, w4p, w5p) = if TWIDDLE { + let w1p = *ptf; + let w2p = Arith::

::mmulmod(w1p, w1p); + let w3p = Arith::

::mmulmod(w1p, w2p); + let w4p = Arith::

::mmulmod(w2p, w2p); + let w5p = Arith::

::mmulmod(w2p, w3p); + (w1p, w2p, w3p, w4p, w5p) + } else { + (0, 0, 0, 0, 0) + }; + for _ in 0..s1 { + (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), + *px.wrapping_add(3*s1), *px.wrapping_add(4*s1), *px.wrapping_add(5*s1)) = + ntt6_kernel::(w1p, w2p, w3p, w4p, w5p, + *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), + *px.wrapping_add(3*s1), *px.wrapping_add(4*s1), *px.wrapping_add(5*s1)); + px = px.wrapping_add(1); } - while cnt6 > 0 { - (n, s, eo, x, y) = if n > 6 { - NttKernelImpl::::apply(n, s, eo, x, y) - } else { - NttKernelImpl::::apply_last(n, s, eo, x, y, inv_p2) - }; - cnt6 -= 1; + } + (px.wrapping_add(5*s1), ptf.wrapping_add(1)) +} + +fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_list: &[Vec]) { + let mut i_list = vec![]; + for i in 0..plan.s_list.len() { i_list.push(i); } + if INV { i_list.reverse(); } + for i in i_list { + let (s, radix) = plan.s_list[i]; + let s1 = s/radix; + let mut px = x.as_mut_ptr(); + let px_end = x.as_mut_ptr().wrapping_add(plan.n); + let mut ptf = tf_list[i].as_ptr(); + match radix { + 2 => { + (px, ptf) = ntt2_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt2_single_block::(s1, px, ptf); + } + }, + 3 => { + (px, ptf) = ntt3_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt3_single_block::(s1, px, ptf); + } + }, + 4 => { + (px, ptf) = ntt4_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt4_single_block::(s1, px, ptf); + } + }, + 5 => { + (px, ptf) = ntt5_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt5_single_block::(s1, px, ptf); + } + }, + 6 => { + (px, ptf) = ntt6_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt6_single_block::(s1, px, ptf); + } + }, + _ => { unreachable!() } } } } -fn plan_ntt(min_len: usize) -> (usize, usize) { - assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); - let (mut len_max, mut len_max_cost) = (0usize, usize::MAX); - let mut len5 = 1; - for _ in 0..=Arith::

::FACTOR_FIVE { - let mut len35 = len5; - for _ in 0..=Arith::

::FACTOR_THREE { - let mut len = len35; - let mut i = 0; - while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } - if len >= min_len && len < len_max_cost { - let (mut tmp, mut cost) = (len, 0); - while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len); } - while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len + len/5); } - while tmp % 4 == 0 { (tmp, cost) = (tmp/4, cost + len); } - while tmp % 3 == 0 { (tmp, cost) = (tmp/3, cost + len); } - while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); } - if cost < len_max_cost { (len_max, len_max_cost) = (len, cost); } - } - len35 *= 3; - if len35 >= min_len { break; } +fn compute_twiddle_factors(s_list: &[(usize, usize)]) -> Vec { + let mut len = 1; + for &(_, radix) in s_list { len *= radix; } + len /= s_list.last().unwrap().1; + let mut tf = vec![Arith::

::R; len]; + let r = s_list.last().unwrap_or(&(1, 1)).1; + let mut p = 1; + for i in (1..s_list.len()).rev() { + let radix = s_list[i-1].1; + let w = Arith::

::mpowmod(NttKernelImpl::::ROOTR, Arith::

::MAX_NTT_LEN/(p as u64 * radix as u64 * r as u64)); + for j in p..radix*p { + tf[j] = Arith::

::mmulmod(w, tf[j - p]); } - len5 *= 5; - if len5 >= min_len { break; } + p *= radix; } - (len_max, len_max_cost) + tf } // Performs (cyclic) integer convolution modulo P using NTT. // Modifies the three buffers in-place. // The output is saved in the slice `x`. -// The three slices must have the same length which divides `Arith::

::MAX_NTT_LEN`. -fn conv(x: &mut [u64], y: &mut [u64], buf: &mut [u64]) { - assert!(!x.is_empty() && x.len() == y.len() && y.len() == buf.len()); - ntt_stockham::(x, buf); - ntt_stockham::(y, buf); - for i in 0..x.len() { x[i] = Arith::

::mmulmod(x[i], y[i]); } - ntt_stockham::(x, buf); +// The three slices must have the same length. For maximum performance, +// the length should contain as many factors of 6 as possible. +fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u64) { + assert!(!x.is_empty() && x.len() == y.len()); + + let (_n, g, m) = (plan.n, plan.g, plan.m); + let last_radix = plan.last_radix; + + /* build twiddle factors */ + let mut tf_list = vec![vec![Arith::

::R; 1]; 1]; + for i in 1..plan.s_list.len() { + tf_list.push(compute_twiddle_factors::(&plan.s_list[0..=i])); + } + + /* dif fft */ + ntt_dif_dit::(&plan, x, &tf_list); + ntt_dif_dit::(&plan, y, &tf_list); + + /* naive or Karatsuba multiplication */ + let len_inv = Arith::

::mmulmod(Arith::

::R3, Arith::

::submod(0, (P-1)/m as u64)); + mult = Arith::

::mmulmod(Arith::

::mmulmod(Arith::

::R2, mult), len_inv); + let mut i = 0; + let (mut ii, mut ii_mod_last_radix) = (0, 0); + let mut buf = vec![0u64; g]; + let tf = tf_list.last().unwrap(); + let mut tf_current = tf[0]; + let tf_mult = match plan.last_radix { + 2 => NttKernelImpl::::U2, + 3 => NttKernelImpl::::U3, + 4 => NttKernelImpl::::U4, + 5 => NttKernelImpl::::U5, + 6 => NttKernelImpl::::U6, + _ => Arith::

::R + }; + while i < plan.n { + if ii_mod_last_radix == 0 { + tf_current = tf[ii]; + } else { + tf_current = Arith::

::mmulmod(tf_current, tf_mult); + } + + /* we multiply the inverse of the length here to save time */ + conv_base::

(g, x.as_mut_ptr().wrapping_add(i), y.as_mut_ptr().wrapping_add(i), + buf.as_mut_ptr(), tf_current, mult); + i += g; + ii_mod_last_radix += 1; + if ii_mod_last_radix == last_radix { + ii += 1; + ii_mod_last_radix = 0; + } + } + + /* dit fft */ + let mut tf_list = vec![vec![Arith::

::R; 1]; 1]; + for i in 1..plan.s_list.len() { + tf_list.push(compute_twiddle_factors::(&plan.s_list[0..=i])); + } + ntt_dif_dit::(&plan, x, &tf_list); } //////////////////////////////////////////////////////////////////////////////// @@ -614,6 +694,7 @@ const P1: u64 = 10_237_243_632_176_332_801; // Max NTT length = 2^24 * 3^20 * 5^ const P2: u64 = 13_649_658_176_235_110_401; // Max NTT length = 2^26 * 3^19 * 5^2 = 1_949_951_168_033_587_200 const P3: u64 = 14_259_017_916_245_606_401; // Max NTT length = 2^22 * 3^21 * 5^2 = 1_096_847_532_018_892_800 +const P1P2: u128 = P1 as u128 * P2 as u128; const P1INV_R_MOD_P2: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod(P1, P2)); const P1P2INV_R_MOD_P3: u64 = Arith::::mmulmod( Arith::::R3, @@ -627,26 +708,28 @@ const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { + assert!(bits < 64); + let min_len = b.len() + c.len(); - let len_max_1 = plan_ntt::(min_len).0; - let len_max_2 = plan_ntt::(min_len).0; + let plan_1 = NttPlan::build::(min_len); + let plan_2 = NttPlan::build::(min_len); + let len_max_1 = plan_1.n; + let len_max_2 = plan_2.n; let len_max = max(len_max_1, len_max_2); let mut x = vec![0u64; len_max_1]; let mut y = vec![0u64; len_max_2]; let mut r = vec![0u64; len_max]; - let mut s = vec![0u64; len_max]; /* convolution with modulo P1 */ - for i in 0..b.len() { x[i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } - for i in 0..c.len() { r[i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } - r[c.len()..len_max_1].fill(0u64); - conv::(&mut x, &mut r[..len_max_1], &mut s[..len_max_1]); + x[0..b.len()].clone_from_slice(b); + r[0..c.len()].clone_from_slice(c); + conv::(&plan_1, &mut x, &mut r[..len_max_1], arith::invmod(P2, P1)); /* convolution with modulo P2 */ - for i in 0..b.len() { y[i] = if b[i] >= P2 { b[i] - P2 } else { b[i] }; } - for i in 0..c.len() { r[i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } + y[0..b.len()].clone_from_slice(b); + r[0..c.len()].clone_from_slice(c); r[c.len()..len_max_2].fill(0u64); - conv::(&mut y, &mut r[..len_max_2], &mut s[..len_max_2]); + conv::(&plan_2, &mut y, &mut r[..len_max_2], arith::invmod(P1, P2)); /* merge the results in {x, y} into r (process carry along the way) */ let mask = (1u64 << bits) - 1; @@ -655,9 +738,8 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { for i in 0..min_len { /* extract the convolution result */ let (a, b) = (x[i], y[i]); - let bma = Arith::::submod(b, a); - let u = Arith::::mmulmod(bma, P1INV_R_MOD_P2); - let v = a as u128 + P1 as u128 * u as u128 + carry; + let mut v = a as u128 * P2 as u128 + b as u128 * P1 as u128 + carry; + if v >= P1P2 { v = v.wrapping_sub(P1P2); } carry = v >> bits; /* write to r */ @@ -683,33 +765,35 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { let min_len = b.len() + c.len(); - let len_max_1 = plan_ntt::(min_len).0; - let len_max_2 = plan_ntt::(min_len).0; - let len_max_3 = plan_ntt::(min_len).0; + let plan_1 = NttPlan::build::(min_len); + let plan_2 = NttPlan::build::(min_len); + let plan_3 = NttPlan::build::(min_len); + let len_max_1 = plan_1.n; + let len_max_2 = plan_2.n; + let len_max_3 = plan_3.n; let len_max = max(len_max_1, max(len_max_2, len_max_3)); let mut x = vec![0u64; len_max_1]; let mut y = vec![0u64; len_max_2]; let mut z = vec![0u64; len_max_3]; let mut r = vec![0u64; len_max]; - let mut s = vec![0u64; len_max]; /* convolution with modulo P1 */ for i in 0..b.len() { x[i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } for i in 0..c.len() { r[i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } r[c.len()..len_max_1].fill(0u64); - conv::(&mut x, &mut r[..len_max_1], &mut s[..len_max_1]); + conv::(&plan_1, &mut x, &mut r[..len_max_1], 1); /* convolution with modulo P2 */ for i in 0..b.len() { y[i] = if b[i] >= P2 { b[i] - P2 } else { b[i] }; } for i in 0..c.len() { r[i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } r[c.len()..len_max_2].fill(0u64); - conv::(&mut y, &mut r[..len_max_2], &mut s[..len_max_2]); + conv::(&plan_2, &mut y, &mut r[..len_max_2], 1); /* convolution with modulo P3 */ for i in 0..b.len() { z[i] = if b[i] >= P3 { b[i] - P3 } else { b[i] }; } for i in 0..c.len() { r[i] = if c[i] >= P3 { c[i] - P3 } else { c[i] }; } r[c.len()..len_max_3].fill(0u64); - conv::(&mut z, &mut r[..len_max_3], &mut s[..len_max_3]); + conv::(&plan_3, &mut z, &mut r[..len_max_3], 1); /* merge the results in {x, y, z} into acc (process carry along the way) */ let mut carry: u128 = 0; @@ -757,9 +841,9 @@ fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { let (b, c) = if b.len() < c.len() { (b, c) } else { (c, b) }; - let naive_cost = plan_ntt::(b.len() + c.len()).1 * 3; - let split_cost = plan_ntt::(b.len() + b.len()).1 * 3 * (c.len() / b.len()) - + if c.len() % b.len() > 0 { plan_ntt::(b.len() + (c.len() % b.len())).1 * 3 } else { 0 }; + let naive_cost = NttPlan::build::(b.len() + c.len()).cost * 3; + let split_cost = NttPlan::build::(b.len() + b.len()).cost * 3 * (c.len() / b.len()) + + if c.len() % b.len() > 0 { NttPlan::build::(b.len() + (c.len() % b.len())).cost * 3 } else { 0 }; if b.len() >= 128 && split_cost < naive_cost { /* special handling for unbalanced multiplication: we reduce it to about `c.len()/b.len()` balanced multiplications */ @@ -848,7 +932,7 @@ pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { #[cfg(not(u64_digit))] pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { - fn bigdigit_to_u64(src: &[BigDigit]) -> crate::biguint::Vec:: { + fn bigdigit_to_u64(src: &[BigDigit]) -> Vec { let mut out = vec![0u64; (src.len() + 1) / 2]; for i in 0..src.len()/2 { out[i] = (src[2*i] as u64) | ((src[2*i+1] as u64) << 32); From ebc5f0d4b1ac144c2cece5b00f6bda96fe30c659 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 7 Sep 2023 21:22:29 +0900 Subject: [PATCH 14/65] Reduce memory access --- src/biguint/ntt.rs | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 69b7ca6f..578feb85 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -735,6 +735,8 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { let mask = (1u64 << bits) - 1; let mut carry: u128 = 0; let (mut j, mut p) = (0usize, 0u64); + let mut s: u64 = 0; + let mut carry_acc: u64 = 0; for i in 0..min_len { /* extract the convolution result */ let (a, b) = (x[i], y[i]); @@ -742,24 +744,30 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { if v >= P1P2 { v = v.wrapping_sub(P1P2); } carry = v >> bits; - /* write to r */ + /* write to s */ let out = (v as u64) & mask; - r[j] = (r[j] & ((1u64 << p) - 1)) | (out << p); + s = (s & ((1u64 << p) - 1)) | (out << p); p += bits; if p >= 64 { + /* flush s to the output buffer */ + let (w, overflow1) = acc[j].overflowing_add(s); + let (w, overflow2) = w.overflowing_add(carry_acc); + acc[j] = w; + carry_acc = u64::from(overflow1 || overflow2); + + /* roll-over */ (j, p) = (j+1, p-64); - r[j] = out >> (bits - p); + s = out >> (bits - p); } } - /* add r to acc */ - let mut carry: u64 = 0; - for i in 0..min(acc.len(), j+1) { - let w = r[i]; - let (v, overflow1) = acc[i].overflowing_add(w); - let (v, overflow2) = v.overflowing_add(carry); - acc[i] = v; - carry = u64::from(overflow1 || overflow2); + /* process remaining carries */ + carry_acc += s; + while j < acc.len() { + let (w, overflow) = acc[j].overflowing_add(carry_acc); + acc[j] = w; + carry_acc = u64::from(overflow); + j += 1; } } @@ -866,7 +874,7 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { carry += tmp; } i += b.len(); - while i < acc.len() { + while carry > 0 && i < acc.len() { let (v, overflow) = acc[i].overflowing_add(carry); acc[i] = v; carry = u64::from(overflow); From 7e6f558a9cf5b5f09db2f9272150db7555217948 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 7 Sep 2023 22:46:48 +0900 Subject: [PATCH 15/65] Optimize add-with-carry --- src/biguint/ntt.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 578feb85..7a5c1459 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -708,7 +708,7 @@ const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { - assert!(bits < 64); + assert!(bits < 63); let min_len = b.len() + c.len(); let plan_1 = NttPlan::build::(min_len); @@ -750,10 +750,10 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { p += bits; if p >= 64 { /* flush s to the output buffer */ - let (w, overflow1) = acc[j].overflowing_add(s); - let (w, overflow2) = w.overflowing_add(carry_acc); + s += carry_acc; + let (w, overflow) = acc[j].overflowing_add(s); acc[j] = w; - carry_acc = u64::from(overflow1 || overflow2); + carry_acc = u64::from(overflow); /* roll-over */ (j, p) = (j+1, p-64); From 4d4c0dc858df183d9021595c23287879ef601e00 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Fri, 8 Sep 2023 09:36:04 +0900 Subject: [PATCH 16/65] Optimize base case multiplication --- src/biguint/ntt.rs | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 7a5c1459..443de402 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -269,21 +269,19 @@ fn conv_base(n: usize, x: *mut u64, y: *mut u64, buf: *mut u64, c: *buf.wrapping_add(i) = Arith::

::mmulmod(*x.wrapping_add(i), mult); } for i in 0..n { - let mut v1: u128 = 0; - for j in 0..=i { - let (w, overflow) = v1.overflowing_add(*buf.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); - v1 = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; - } - let mut v2: u128 = 0; + let mut v: u128 = 0; for j in i+1..n { - let (w, overflow) = v2.overflowing_add(*buf.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); - v2 = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; + let (w, overflow) = v.overflowing_add(*buf.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); + v = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; + } + if v >= (P as u128) << 64 { v = v.wrapping_sub((P as u128) << 64); } + v = c as u128 * Arith::

::mreduce(v) as u128; + for j in 0..=i { + let (w, overflow) = v.overflowing_add(*buf.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); + v = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; } - if v1 >= (P as u128) << 64 { v1 = v1.wrapping_sub((P as u128) << 64); } - if v2 >= (P as u128) << 64 { v2 = v2.wrapping_sub((P as u128) << 64); } - let u1 = Arith::

::mreduce(v1); - let u2 = Arith::

::mreduce(v2); - *x.wrapping_add(i) = Arith::

::mmuladdmod(c, u2, u1); + if v >= (P as u128) << 64 { v = v.wrapping_sub((P as u128) << 64); } + *x.wrapping_add(i) = Arith::

::mreduce(v); } } } From 253031f2e85be478d383b09f30eae540f1b9720a Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Fri, 8 Sep 2023 10:04:04 +0900 Subject: [PATCH 17/65] Update ntt.rs --- src/biguint/ntt.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 443de402..96a0bfa0 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -274,7 +274,6 @@ fn conv_base(n: usize, x: *mut u64, y: *mut u64, buf: *mut u64, c: let (w, overflow) = v.overflowing_add(*buf.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); v = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; } - if v >= (P as u128) << 64 { v = v.wrapping_sub((P as u128) << 64); } v = c as u128 * Arith::

::mreduce(v) as u128; for j in 0..=i { let (w, overflow) = v.overflowing_add(*buf.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); From a6ca654bffa3016c058cbcab93a20e05ceee2b8e Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Fri, 8 Sep 2023 13:35:27 +0900 Subject: [PATCH 18/65] Speed up bit repacking --- src/biguint/ntt.rs | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 96a0bfa0..58757492 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -447,7 +447,7 @@ const fn ntt5_kernel_core::addmod(b, e); let t2 = Arith::

::addmod(c, d); let t3 = Arith::

::submod(b, e); - let t4= Arith::

::submod(d, c); + let t4 = Arith::

::submod(d, c); let t5 = Arith::

::addmod(t1, t2); let t6 = Arith::

::submod(t1, t2); let t7 = Arith::

::addmod64(t3, t4); @@ -903,18 +903,29 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { if bits >= 43 { /* can pack more effective bits per u64 with two primes than with three primes */ fn pack_into(src: &[u64], dst: &mut [u64], bits: u64) -> usize { - let (mut j, mut p) = (0usize, 0u64); - for i in 0..src.len() { + let mut p = 0u64; + let mut pdst = dst.as_mut_ptr(); + let mut x = 0u64; + let mask = (1u64 << bits).wrapping_sub(1); + for v in src { let mut k = 0; while k < 64 { - let bits_this_time = min(64 - k, bits - p); - dst[j] = (dst[j] & ((1u64 << p) - 1)) | (((src[i] >> k) & ((1u64 << bits_this_time) - 1)) << p); - k += bits_this_time; - p += bits_this_time; - if p == bits { (j, p) = (j+1, 0); } + x |= (v >> k) << p; + let q = 64 - k; + if p + q >= bits { + unsafe { *pdst = x & mask; } + x = 0; + (pdst, k, p) = (pdst.wrapping_add(1), k + bits - p, 0); + } else { + p += q; + break; + } } } - if p == 0 { j } else { j+1 } + unsafe { + if p > 0 { *pdst = x & mask; pdst = pdst.wrapping_add(1); } + pdst.offset_from(dst.as_mut_ptr()) as usize + } } let mut b2 = vec![0u64; ((64 * b.len() as u64 + bits - 1) / bits) as usize]; let mut c2 = vec![0u64; ((64 * c.len() as u64 + bits - 1) / bits) as usize]; From 3b32e9330bde34170db722bb3a27dda29570b18b Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Fri, 8 Sep 2023 18:21:34 +0900 Subject: [PATCH 19/65] Share the same Vec for all twiddle factors --- src/biguint/ntt.rs | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 58757492..84f359f4 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -555,16 +555,16 @@ fn ntt6_single_block( (px.wrapping_add(5*s1), ptf.wrapping_add(1)) } -fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_list: &[Vec]) { +fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_list: &[u64]) { let mut i_list = vec![]; for i in 0..plan.s_list.len() { i_list.push(i); } if INV { i_list.reverse(); } + let mut ptf = tf_list.as_ptr(); for i in i_list { let (s, radix) = plan.s_list[i]; let s1 = s/radix; let mut px = x.as_mut_ptr(); let px_end = x.as_mut_ptr().wrapping_add(plan.n); - let mut ptf = tf_list[i].as_ptr(); match radix { 2 => { (px, ptf) = ntt2_single_block::(s1, px, ptf); @@ -601,22 +601,22 @@ fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_ } } -fn compute_twiddle_factors(s_list: &[(usize, usize)]) -> Vec { +fn compute_twiddle_factors(s_list: &[(usize, usize)], out: &mut [u64]) -> usize { let mut len = 1; for &(_, radix) in s_list { len *= radix; } len /= s_list.last().unwrap().1; - let mut tf = vec![Arith::

::R; len]; let r = s_list.last().unwrap_or(&(1, 1)).1; let mut p = 1; + out[0] = Arith::

::R; for i in (1..s_list.len()).rev() { let radix = s_list[i-1].1; let w = Arith::

::mpowmod(NttKernelImpl::::ROOTR, Arith::

::MAX_NTT_LEN/(p as u64 * radix as u64 * r as u64)); for j in p..radix*p { - tf[j] = Arith::

::mmulmod(w, tf[j - p]); + out[j] = Arith::

::mmulmod(w, out[j - p]); } p *= radix; } - tf + len } // Performs (cyclic) integer convolution modulo P using NTT. @@ -630,10 +630,23 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 let (_n, g, m) = (plan.n, plan.g, plan.m); let last_radix = plan.last_radix; + /* compute the total space needed for twiddle factors */ + let tf_all_count = (|| -> usize { + let (mut radix_cumul, mut out) = (1, 0); + for &(_, radix) in plan.s_list.iter() { + out += radix_cumul; + radix_cumul *= radix; + } + core::cmp::max(out, 1) + })(); + /* build twiddle factors */ - let mut tf_list = vec![vec![Arith::

::R; 1]; 1]; + let mut tf_list = vec![0u64; tf_all_count]; + tf_list[0] = Arith::

::R; + let mut tf_last_start = core::cmp::min(tf_all_count - 1, 1); for i in 1..plan.s_list.len() { - tf_list.push(compute_twiddle_factors::(&plan.s_list[0..=i])); + let x = compute_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); + if i + 1 < plan.s_list.len() { tf_last_start += x; } } /* dif fft */ @@ -646,7 +659,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 let mut i = 0; let (mut ii, mut ii_mod_last_radix) = (0, 0); let mut buf = vec![0u64; g]; - let tf = tf_list.last().unwrap(); + let tf = &tf_list[tf_last_start..]; let mut tf_current = tf[0]; let tf_mult = match plan.last_radix { 2 => NttKernelImpl::::U2, @@ -675,10 +688,11 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 } /* dit fft */ - let mut tf_list = vec![vec![Arith::

::R; 1]; 1]; - for i in 1..plan.s_list.len() { - tf_list.push(compute_twiddle_factors::(&plan.s_list[0..=i])); + let mut tf_last_start = 0; + for i in (1..plan.s_list.len()).rev() { + tf_last_start += compute_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); } + tf_list[tf_last_start] = Arith::

::R; ntt_dif_dit::(&plan, x, &tf_list); } From d3b478fd78ea1b49ef970d9e859e2b3ce3059e04 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sat, 9 Sep 2023 13:53:13 +0900 Subject: [PATCH 20/65] Pack more bits per one u64 digit if possible The prime numbers were replaced by larger ones to allow for tighter packing. Also, we compute the maximum number of bits that can be packed into one digit more precisely. --- src/biguint/ntt.rs | 84 ++++++++++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 40 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 84f359f4..4699740b 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -701,11 +701,11 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 use core::cmp::{min, max}; use crate::big_digit::BigDigit; -const P1: u64 = 10_237_243_632_176_332_801; // Max NTT length = 2^24 * 3^20 * 5^2 = 1_462_463_376_025_190_400 -const P2: u64 = 13_649_658_176_235_110_401; // Max NTT length = 2^26 * 3^19 * 5^2 = 1_949_951_168_033_587_200 -const P3: u64 = 14_259_017_916_245_606_401; // Max NTT length = 2^22 * 3^21 * 5^2 = 1_096_847_532_018_892_800 +const P1: u64 = 14_259_017_916_245_606_401; // Max NTT length = 2^22 * 3^21 * 5^2 = 1_096_847_532_018_892_800 +const P2: u64 = 17_984_575_660_032_000_001; // Max NTT length = 2^19 * 3^17 * 5^6 = 1_057_916_215_296_000_000 +const P3: u64 = 17_995_154_822_184_960_001; // Max NTT length = 2^17 * 3^22 * 5^4 = 2_570_736_403_169_280_000 -const P1P2: u128 = P1 as u128 * P2 as u128; +const P2P3: u128 = P2 as u128 * P3 as u128; const P1INV_R_MOD_P2: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod(P1, P2)); const P1P2INV_R_MOD_P3: u64 = Arith::::mmulmod( Arith::::R3, @@ -722,25 +722,23 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { assert!(bits < 63); let min_len = b.len() + c.len(); - let plan_1 = NttPlan::build::(min_len); - let plan_2 = NttPlan::build::(min_len); - let len_max_1 = plan_1.n; - let len_max_2 = plan_2.n; - let len_max = max(len_max_1, len_max_2); - let mut x = vec![0u64; len_max_1]; - let mut y = vec![0u64; len_max_2]; + let plan_x = NttPlan::build::(min_len); + let plan_y = NttPlan::build::(min_len); + let len_max = max(plan_x.n, plan_y.n); + let mut x = vec![0u64; plan_x.n]; + let mut y = vec![0u64; plan_y.n]; let mut r = vec![0u64; len_max]; - /* convolution with modulo P1 */ + /* convolution with modulo P2 */ x[0..b.len()].clone_from_slice(b); r[0..c.len()].clone_from_slice(c); - conv::(&plan_1, &mut x, &mut r[..len_max_1], arith::invmod(P2, P1)); + conv::(&plan_x, &mut x, &mut r[..plan_x.n], arith::invmod(P3, P2)); - /* convolution with modulo P2 */ + /* convolution with modulo P3 */ y[0..b.len()].clone_from_slice(b); r[0..c.len()].clone_from_slice(c); - r[c.len()..len_max_2].fill(0u64); - conv::(&plan_2, &mut y, &mut r[..len_max_2], arith::invmod(P1, P2)); + r[c.len()..plan_y.n].fill(0u64); + conv::(&plan_y, &mut y, &mut r[..plan_y.n], Arith::::submod(0, arith::invmod(P2, P3))); /* merge the results in {x, y} into r (process carry along the way) */ let mask = (1u64 << bits) - 1; @@ -751,8 +749,8 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { for i in 0..min_len { /* extract the convolution result */ let (a, b) = (x[i], y[i]); - let mut v = a as u128 * P2 as u128 + b as u128 * P1 as u128 + carry; - if v >= P1P2 { v = v.wrapping_sub(P1P2); } + let (mut v, overflow) = (a as u128 * P3 as u128 + carry).overflowing_sub(b as u128 * P2 as u128); + if overflow { v = v.wrapping_add(P2P3); } carry = v >> bits; /* write to s */ @@ -784,35 +782,32 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { let min_len = b.len() + c.len(); - let plan_1 = NttPlan::build::(min_len); - let plan_2 = NttPlan::build::(min_len); - let plan_3 = NttPlan::build::(min_len); - let len_max_1 = plan_1.n; - let len_max_2 = plan_2.n; - let len_max_3 = plan_3.n; - let len_max = max(len_max_1, max(len_max_2, len_max_3)); - let mut x = vec![0u64; len_max_1]; - let mut y = vec![0u64; len_max_2]; - let mut z = vec![0u64; len_max_3]; + let plan_x = NttPlan::build::(min_len); + let plan_y = NttPlan::build::(min_len); + let plan_z = NttPlan::build::(min_len); + let len_max = max(plan_x.n, max(plan_y.n, plan_z.n)); + let mut x = vec![0u64; plan_x.n]; + let mut y = vec![0u64; plan_y.n]; + let mut z = vec![0u64; plan_z.n]; let mut r = vec![0u64; len_max]; /* convolution with modulo P1 */ for i in 0..b.len() { x[i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } for i in 0..c.len() { r[i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } - r[c.len()..len_max_1].fill(0u64); - conv::(&plan_1, &mut x, &mut r[..len_max_1], 1); + r[c.len()..plan_x.n].fill(0u64); + conv::(&plan_x, &mut x, &mut r[..plan_x.n], 1); /* convolution with modulo P2 */ for i in 0..b.len() { y[i] = if b[i] >= P2 { b[i] - P2 } else { b[i] }; } for i in 0..c.len() { r[i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } - r[c.len()..len_max_2].fill(0u64); - conv::(&plan_2, &mut y, &mut r[..len_max_2], 1); + r[c.len()..plan_y.n].fill(0u64); + conv::(&plan_y, &mut y, &mut r[..plan_y.n], 1); /* convolution with modulo P3 */ for i in 0..b.len() { z[i] = if b[i] >= P3 { b[i] - P3 } else { b[i] }; } for i in 0..c.len() { r[i] = if c[i] >= P3 { c[i] - P3 } else { c[i] }; } - r[c.len()..len_max_3].fill(0u64); - conv::(&plan_3, &mut z, &mut r[..len_max_3], 1); + r[c.len()..plan_z.n].fill(0u64); + conv::(&plan_z, &mut z, &mut r[..plan_z.n], 1); /* merge the results in {x, y, z} into acc (process carry along the way) */ let mut carry: u128 = 0; @@ -901,19 +896,28 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { // However, the number of bits per u64 we can pack for NTT // depends on the length of the arrays being multiplied (convolved). // If the arrays are too long, the resulting values may exceed the - // modulus range P1 * P2, which leads to incorrect results. + // modulus range P2 * P3, which leads to incorrect results. // Hence, we compute the number of bits required by the length of NTT, // and use it to determine whether to use two-prime or three-prime. // Since we can pack 64 bits per u64 in three-prime NTT, the effective // number of bits in three-prime NTT is 64/3 = 21.3333..., which means // two-prime NTT can only do better when at least 43 bits per u64 can // be packed into each u64. - // Finally note that there should be no issues with overflow since - // 2^126 * 64 / 43 < 1.3 * 10^38 < P1 * P2. + const fn compute_bits(l: u64) -> u64 { + let total_bits = l * 64; + let (mut lo, mut hi) = (42, 62); + while lo < hi { + let mid = (lo + hi + 1) / 2; + let single_digit_max_val = (1u64 << mid) - 1; + let l_corrected = (total_bits + mid - 1) / mid; + let (lhs, overflow) = (single_digit_max_val as u128 * single_digit_max_val as u128).overflowing_mul(l_corrected as u128); + if !overflow && lhs < P2 as u128 * P3 as u128 { lo = mid; } + else { hi = mid - 1; } + } + lo + } let max_cnt = max(b.len(), c.len()) as u64; - let mut bits = 0u64; - while 1u64 << (2*bits) < max_cnt { bits += 1; } - bits = 63 - bits; + let bits = compute_bits(max_cnt); if bits >= 43 { /* can pack more effective bits per u64 with two primes than with three primes */ fn pack_into(src: &[u64], dst: &mut [u64], bits: u64) -> usize { From 6d4acd075e794529a67dd424eb0a66e1eb0634f4 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sat, 9 Sep 2023 17:49:26 +0900 Subject: [PATCH 21/65] Don't use intermediate buffer for conv_base --- src/biguint/ntt.rs | 74 ++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 4699740b..24fd1324 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -263,24 +263,22 @@ impl NttPlan { } } -fn conv_base(n: usize, x: *mut u64, y: *mut u64, buf: *mut u64, c: u64, mult: u64) { +fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64, mult: u64) { unsafe { - for i in 0..n { - *buf.wrapping_add(i) = Arith::

::mmulmod(*x.wrapping_add(i), mult); - } + let out = x.wrapping_sub(n); for i in 0..n { let mut v: u128 = 0; for j in i+1..n { - let (w, overflow) = v.overflowing_add(*buf.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); + let (w, overflow) = v.overflowing_add(*x.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); v = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; } v = c as u128 * Arith::

::mreduce(v) as u128; for j in 0..=i { - let (w, overflow) = v.overflowing_add(*buf.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); + let (w, overflow) = v.overflowing_add(*x.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); v = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; } if v >= (P as u128) << 64 { v = v.wrapping_sub((P as u128) << 64); } - *x.wrapping_add(i) = Arith::

::mreduce(v); + *out.wrapping_add(i) = Arith::

::mmulmod(Arith::

::mreduce(v), mult); } } } @@ -650,15 +648,14 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 } /* dif fft */ - ntt_dif_dit::(&plan, x, &tf_list); - ntt_dif_dit::(&plan, y, &tf_list); + ntt_dif_dit::(&plan, &mut x[g..], &tf_list); + ntt_dif_dit::(&plan, &mut y[g..], &tf_list); /* naive or Karatsuba multiplication */ let len_inv = Arith::

::mmulmod(Arith::

::R3, Arith::

::submod(0, (P-1)/m as u64)); mult = Arith::

::mmulmod(Arith::

::mmulmod(Arith::

::R2, mult), len_inv); - let mut i = 0; + let mut i = g; let (mut ii, mut ii_mod_last_radix) = (0, 0); - let mut buf = vec![0u64; g]; let tf = &tf_list[tf_last_start..]; let mut tf_current = tf[0]; let tf_mult = match plan.last_radix { @@ -669,7 +666,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 6 => NttKernelImpl::::U6, _ => Arith::

::R }; - while i < plan.n { + while i < g + plan.n { if ii_mod_last_radix == 0 { tf_current = tf[ii]; } else { @@ -678,7 +675,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 /* we multiply the inverse of the length here to save time */ conv_base::

(g, x.as_mut_ptr().wrapping_add(i), y.as_mut_ptr().wrapping_add(i), - buf.as_mut_ptr(), tf_current, mult); + tf_current, mult); i += g; ii_mod_last_radix += 1; if ii_mod_last_radix == last_radix { @@ -724,21 +721,21 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { let min_len = b.len() + c.len(); let plan_x = NttPlan::build::(min_len); let plan_y = NttPlan::build::(min_len); - let len_max = max(plan_x.n, plan_y.n); - let mut x = vec![0u64; plan_x.n]; - let mut y = vec![0u64; plan_y.n]; + let len_max = max(plan_x.g + plan_x.n, plan_y.g + plan_y.n); + let mut x = vec![0u64; plan_x.g + plan_x.n]; + let mut y = vec![0u64; plan_y.g + plan_y.n]; let mut r = vec![0u64; len_max]; /* convolution with modulo P2 */ - x[0..b.len()].clone_from_slice(b); - r[0..c.len()].clone_from_slice(c); - conv::(&plan_x, &mut x, &mut r[..plan_x.n], arith::invmod(P3, P2)); + x[plan_x.g..plan_x.g+b.len()].clone_from_slice(b); + r[plan_x.g..plan_x.g+c.len()].clone_from_slice(c); + conv::(&plan_x, &mut x, &mut r[..plan_x.g+plan_x.n], arith::invmod(P3, P2)); /* convolution with modulo P3 */ - y[0..b.len()].clone_from_slice(b); - r[0..c.len()].clone_from_slice(c); - r[c.len()..plan_y.n].fill(0u64); - conv::(&plan_y, &mut y, &mut r[..plan_y.n], Arith::::submod(0, arith::invmod(P2, P3))); + y[plan_y.g..plan_y.g+b.len()].clone_from_slice(b); + r[plan_y.g..plan_y.g+c.len()].clone_from_slice(c); + (&mut r[plan_y.g..])[c.len()..plan_y.n].fill(0u64); + conv::(&plan_y, &mut y, &mut r[..plan_y.g+plan_y.n], Arith::::submod(0, arith::invmod(P2, P3))); /* merge the results in {x, y} into r (process carry along the way) */ let mask = (1u64 << bits) - 1; @@ -785,29 +782,28 @@ fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { let plan_x = NttPlan::build::(min_len); let plan_y = NttPlan::build::(min_len); let plan_z = NttPlan::build::(min_len); - let len_max = max(plan_x.n, max(plan_y.n, plan_z.n)); - let mut x = vec![0u64; plan_x.n]; - let mut y = vec![0u64; plan_y.n]; - let mut z = vec![0u64; plan_z.n]; + let len_max = max(plan_x.g + plan_x.n, max(plan_y.g + plan_y.n, plan_z.g + plan_z.n)); + let mut x = vec![0u64; plan_x.g + plan_x.n]; + let mut y = vec![0u64; plan_y.g + plan_y.n]; + let mut z = vec![0u64; plan_z.g + plan_z.n]; let mut r = vec![0u64; len_max]; /* convolution with modulo P1 */ - for i in 0..b.len() { x[i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } - for i in 0..c.len() { r[i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } - r[c.len()..plan_x.n].fill(0u64); - conv::(&plan_x, &mut x, &mut r[..plan_x.n], 1); + for i in 0..b.len() { x[plan_x.g + i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } + for i in 0..c.len() { r[plan_x.g + i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } + conv::(&plan_x, &mut x, &mut r[..plan_x.g+plan_x.n], 1); /* convolution with modulo P2 */ - for i in 0..b.len() { y[i] = if b[i] >= P2 { b[i] - P2 } else { b[i] }; } - for i in 0..c.len() { r[i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } - r[c.len()..plan_y.n].fill(0u64); - conv::(&plan_y, &mut y, &mut r[..plan_y.n], 1); + for i in 0..b.len() { y[plan_y.g + i] = if b[i] >= P2 { b[i] - P2 } else { b[i] }; } + for i in 0..c.len() { r[plan_y.g + i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } + (&mut r[plan_y.g..])[c.len()..plan_y.n].fill(0u64); + conv::(&plan_y, &mut y, &mut r[..plan_y.g+plan_y.n], 1); /* convolution with modulo P3 */ - for i in 0..b.len() { z[i] = if b[i] >= P3 { b[i] - P3 } else { b[i] }; } - for i in 0..c.len() { r[i] = if c[i] >= P3 { c[i] - P3 } else { c[i] }; } - r[c.len()..plan_z.n].fill(0u64); - conv::(&plan_z, &mut z, &mut r[..plan_z.n], 1); + for i in 0..b.len() { z[plan_z.g + i] = if b[i] >= P3 { b[i] - P3 } else { b[i] }; } + for i in 0..c.len() { r[plan_z.g + i] = if c[i] >= P3 { c[i] - P3 } else { c[i] }; } + (&mut r[plan_z.g..])[c.len()..plan_z.n].fill(0u64); + conv::(&plan_z, &mut z, &mut r[..plan_z.g+plan_z.n], 1); /* merge the results in {x, y, z} into acc (process carry along the way) */ let mut carry: u128 = 0; From de73df6211d3692e3945d8c17728793512b0098d Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 12 Sep 2023 13:26:29 +0900 Subject: [PATCH 22/65] Remove unnecessary operation --- src/biguint/ntt.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 24fd1324..017a6059 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -277,7 +277,6 @@ fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64, mult: u64 let (w, overflow) = v.overflowing_add(*x.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); v = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; } - if v >= (P as u128) << 64 { v = v.wrapping_sub((P as u128) << 64); } *out.wrapping_add(i) = Arith::

::mmulmod(Arith::

::mreduce(v), mult); } } From 5350ae6782ce004d6fa3c8eb8fd75f1accf647f5 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 12 Sep 2023 17:17:13 +0900 Subject: [PATCH 23/65] Fix NTT planning bug --- src/biguint/ntt.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 017a6059..e2a8e1bf 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -206,7 +206,7 @@ impl NttPlan { pub fn build(min_len: usize) -> NttPlan { assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); let (mut len_max, mut len_max_cost) = (0usize, usize::MAX); - let mut len5 = 10; + let mut len5 = 1; for _ in 0..Arith::

::FACTOR_FIVE+1 { let mut len35 = len5; for _ in 0..Arith::

::FACTOR_THREE+1 { From e60b67863866743b41e91923f3eed0145ff28359 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 12 Sep 2023 20:58:17 +0900 Subject: [PATCH 24/65] Optimize base case multiplication --- src/biguint/ntt.rs | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index e2a8e1bf..eee15355 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -173,6 +173,22 @@ impl Arith

{ } cur } + // Computes c as u128 * mreduce(v) as u128, + // using d: u64 = mmulmod(P-1, c). + // It is caller's responsibility to ensure that d is correct. + // Note that d can be computed by calling mreducelo(c). + pub const fn mmulmod_noreduce(v: u128, c: u64, d: u64) -> u128 { + let a: u128 = c as u128 * (v >> 64); + let b: u128 = d as u128 * (v as u64 as u128); + let (w, overflow) = a.overflowing_sub(b); + if overflow { w.wrapping_add((P as u128) << 64) } else { w } + } + // Computes submod(0, mreduce(x as u128)) fast. + pub const fn mreducelo(x: u64) -> u64 { + let m = x.wrapping_mul(Self::PINV); + let y = ((m as u128 * P as u128) >> 64) as u64; + y + } // Computes a + b mod P, output range [0, P) pub const fn addmod(a: u64, b: u64) -> u64 { Self::submod(a, P.wrapping_sub(b)) @@ -262,9 +278,9 @@ impl NttPlan { } } } - -fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64, mult: u64) { +fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64, mult: u64, mult2: u64) { unsafe { + let c2 = Arith::

::mreducelo(c); let out = x.wrapping_sub(n); for i in 0..n { let mut v: u128 = 0; @@ -272,12 +288,12 @@ fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64, mult: u64 let (w, overflow) = v.overflowing_add(*x.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); v = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; } - v = c as u128 * Arith::

::mreduce(v) as u128; + v = Arith::

::mmulmod_noreduce(v, c, c2); for j in 0..=i { let (w, overflow) = v.overflowing_add(*x.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); v = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; } - *out.wrapping_add(i) = Arith::

::mmulmod(Arith::

::mreduce(v), mult); + *out.wrapping_add(i) = Arith::

::mreduce(Arith::

::mmulmod_noreduce(v, mult, mult2)); } } } @@ -653,6 +669,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 /* naive or Karatsuba multiplication */ let len_inv = Arith::

::mmulmod(Arith::

::R3, Arith::

::submod(0, (P-1)/m as u64)); mult = Arith::

::mmulmod(Arith::

::mmulmod(Arith::

::R2, mult), len_inv); + let mult2 = Arith::

::mreducelo(mult); let mut i = g; let (mut ii, mut ii_mod_last_radix) = (0, 0); let tf = &tf_list[tf_last_start..]; @@ -674,7 +691,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 /* we multiply the inverse of the length here to save time */ conv_base::

(g, x.as_mut_ptr().wrapping_add(i), y.as_mut_ptr().wrapping_add(i), - tf_current, mult); + tf_current, mult, mult2); i += g; ii_mod_last_radix += 1; if ii_mod_last_radix == last_radix { From 4475f49d3548a0a5799d59f495dfb754b858677e Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 12 Sep 2023 21:41:48 +0900 Subject: [PATCH 25/65] Replace some addmodopt calls with submod --- src/biguint/ntt.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index eee15355..00e54b0b 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -366,7 +366,7 @@ const fn ntt3_kernel_core::mmulmod(NttKernelImpl::::U3, Arith::

::submod(b, c)); let out0 = Arith::

::addmod(a, Arith::

::addmod(b, c)); - let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::addmodopt::(Arith::

::submod(a, c), kbmc)); + let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::submod(a, Arith::

::submod(c, kbmc))); let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::submod(Arith::

::submod(a, b), kbmc)); (out0, out1, out2) } @@ -526,12 +526,12 @@ const fn ntt6_kernel_core::addmod(c, f), Arith::

::submod(c, f)); let lbmc = Arith::

::mmulmod(NttKernelImpl::::U6, Arith::

::submod(b, c)); let out0 = Arith::

::addmod(a, Arith::

::addmod(b, c)); - let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::addmodopt::(Arith::

::submod(a, b), lbmc)); + let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::submod(a, Arith::

::submod(b, lbmc))); let out4 = Arith::

::mmulmod_cond::(w4p, Arith::

::submod(Arith::

::submod(a, c), lbmc)); let lepf = Arith::

::mmulmod(NttKernelImpl::::U6, Arith::

::addmod64(e, f)); - let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::addmodopt::(Arith::

::submod(d, f), lepf)); + let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::submod(d, Arith::

::submod(f, lepf))); let out3 = Arith::

::mmulmod_cond::(w3p, Arith::

::submod(d, Arith::

::submod(e, f))); - let out5 = Arith::

::mmulmod_cond::(w5p, Arith::

::addmodopt::(Arith::

::submod(d, lepf), e)); + let out5 = Arith::

::mmulmod_cond::(w5p, Arith::

::submod(d, Arith::

::submod(lepf, e))); (out0, out1, out2, out3, out4, out5) } const fn ntt6_kernel( From f782e935abda4a82ecf86658c597013ca16c1c86 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sat, 16 Sep 2023 08:50:40 +0900 Subject: [PATCH 26/65] Improve NTT planning --- src/biguint/ntt.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 00e54b0b..608a003f 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -231,8 +231,10 @@ impl NttPlan { while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } if len >= min_len && len < len_max_cost { let (mut tmp, mut cost) = (len, 0); - while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len); } - while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len + len/5); } + if tmp % 6 == 0 && tmp % 5 != 0 { (tmp, cost) = (tmp/6, cost + len*93/100); } + while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len + len*5/100); } + if tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len*95/100); } + while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len + len*22/100); } while tmp % 4 == 0 { (tmp, cost) = (tmp/4, cost + len); } while tmp % 3 == 0 { (tmp, cost) = (tmp/3, cost + len); } while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); } From 6370aa0969e5561688e51740acae23b6be12e80d Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sat, 16 Sep 2023 08:51:22 +0900 Subject: [PATCH 27/65] Reduce constant multiplication operations --- src/biguint/ntt.rs | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 608a003f..f49dbb1b 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -280,22 +280,22 @@ impl NttPlan { } } } -fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64, mult: u64, mult2: u64) { +fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64) { unsafe { let c2 = Arith::

::mreducelo(c); let out = x.wrapping_sub(n); for i in 0..n { let mut v: u128 = 0; for j in i+1..n { - let (w, overflow) = v.overflowing_add(*x.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); - v = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; + let (w, overflow) = v.overflowing_sub(*x.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); + v = if overflow { w.wrapping_add((P as u128) << 64) } else { w }; } v = Arith::

::mmulmod_noreduce(v, c, c2); for j in 0..=i { - let (w, overflow) = v.overflowing_add(*x.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); - v = if overflow { w.wrapping_sub((P as u128) << 64) } else { w }; + let (w, overflow) = v.overflowing_sub(*x.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); + v = if overflow { w.wrapping_add((P as u128) << 64) } else { w }; } - *out.wrapping_add(i) = Arith::

::mreduce(Arith::

::mmulmod_noreduce(v, mult, mult2)); + *out.wrapping_add(i) = Arith::

::mreduce(v); } } } @@ -639,12 +639,20 @@ fn compute_twiddle_factors(s_list: &[(usize, usiz // The output is saved in the slice `x`. // The three slices must have the same length. For maximum performance, // the length should contain as many factors of 6 as possible. -fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u64) { +fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], ylen: usize, mut mult: u64) { assert!(!x.is_empty() && x.len() == y.len()); let (_n, g, m) = (plan.n, plan.g, plan.m); let last_radix = plan.last_radix; + /* multiply by a constant in advance */ + let len_inv = Arith::

::mmulmod(Arith::

::R3, Arith::

::submod(0, (P-1)/m as u64)); + mult = Arith::

::mmulmod(Arith::

::mmulmod(Arith::

::R2, mult), len_inv); + mult = Arith::

::submod(0, mult); + for v in if xlen < ylen { &mut x[g..g+xlen] } else { &mut y[g..g+ylen] } { + *v = Arith::

::mmulmod(*v, mult); + } + /* compute the total space needed for twiddle factors */ let tf_all_count = (|| -> usize { let (mut radix_cumul, mut out) = (1, 0); @@ -669,9 +677,6 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 ntt_dif_dit::(&plan, &mut y[g..], &tf_list); /* naive or Karatsuba multiplication */ - let len_inv = Arith::

::mmulmod(Arith::

::R3, Arith::

::submod(0, (P-1)/m as u64)); - mult = Arith::

::mmulmod(Arith::

::mmulmod(Arith::

::R2, mult), len_inv); - let mult2 = Arith::

::mreducelo(mult); let mut i = g; let (mut ii, mut ii_mod_last_radix) = (0, 0); let tf = &tf_list[tf_last_start..]; @@ -693,7 +698,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], y: &mut [u64], mut mult: u6 /* we multiply the inverse of the length here to save time */ conv_base::

(g, x.as_mut_ptr().wrapping_add(i), y.as_mut_ptr().wrapping_add(i), - tf_current, mult, mult2); + tf_current); i += g; ii_mod_last_radix += 1; if ii_mod_last_radix == last_radix { @@ -747,13 +752,13 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { /* convolution with modulo P2 */ x[plan_x.g..plan_x.g+b.len()].clone_from_slice(b); r[plan_x.g..plan_x.g+c.len()].clone_from_slice(c); - conv::(&plan_x, &mut x, &mut r[..plan_x.g+plan_x.n], arith::invmod(P3, P2)); + conv::(&plan_x, &mut x, b.len(), &mut r[..plan_x.g+plan_x.n], c.len(), arith::invmod(P3, P2)); /* convolution with modulo P3 */ y[plan_y.g..plan_y.g+b.len()].clone_from_slice(b); r[plan_y.g..plan_y.g+c.len()].clone_from_slice(c); (&mut r[plan_y.g..])[c.len()..plan_y.n].fill(0u64); - conv::(&plan_y, &mut y, &mut r[..plan_y.g+plan_y.n], Arith::::submod(0, arith::invmod(P2, P3))); + conv::(&plan_y, &mut y, b.len(), &mut r[..plan_y.g+plan_y.n], c.len(), Arith::::submod(0, arith::invmod(P2, P3))); /* merge the results in {x, y} into r (process carry along the way) */ let mask = (1u64 << bits) - 1; @@ -809,19 +814,19 @@ fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { /* convolution with modulo P1 */ for i in 0..b.len() { x[plan_x.g + i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } for i in 0..c.len() { r[plan_x.g + i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } - conv::(&plan_x, &mut x, &mut r[..plan_x.g+plan_x.n], 1); + conv::(&plan_x, &mut x, b.len(), &mut r[..plan_x.g+plan_x.n], c.len(), 1); /* convolution with modulo P2 */ for i in 0..b.len() { y[plan_y.g + i] = if b[i] >= P2 { b[i] - P2 } else { b[i] }; } for i in 0..c.len() { r[plan_y.g + i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } (&mut r[plan_y.g..])[c.len()..plan_y.n].fill(0u64); - conv::(&plan_y, &mut y, &mut r[..plan_y.g+plan_y.n], 1); + conv::(&plan_y, &mut y, b.len(), &mut r[..plan_y.g+plan_y.n], c.len(), 1); /* convolution with modulo P3 */ for i in 0..b.len() { z[plan_z.g + i] = if b[i] >= P3 { b[i] - P3 } else { b[i] }; } for i in 0..c.len() { r[plan_z.g + i] = if c[i] >= P3 { c[i] - P3 } else { c[i] }; } (&mut r[plan_z.g..])[c.len()..plan_z.n].fill(0u64); - conv::(&plan_z, &mut z, &mut r[..plan_z.g+plan_z.n], 1); + conv::(&plan_z, &mut z, b.len(), &mut r[..plan_z.g+plan_z.n], c.len(), 1); /* merge the results in {x, y, z} into acc (process carry along the way) */ let mut carry: u128 = 0; From 4343d5f11f3e1b67d534761d7118858392b3014f Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sat, 16 Sep 2023 09:21:33 +0900 Subject: [PATCH 28/65] Simplify code --- src/biguint/ntt.rs | 99 ++++++++++++++-------------------------------- 1 file changed, 29 insertions(+), 70 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index f49dbb1b..7e67e098 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -140,8 +140,12 @@ impl Arith

{ pub const fn mmulmod(a: u64, b: u64) -> u64 { Self::mreduce(a as u128 * b as u128) } - pub const fn mmulmod_cond(a: u64, b: u64) -> u64 { - if INV { Self::mmulmod(a, b) } else { b } + // Multiplication with Montgomery reduction: + // a * b * R^-1 mod P + // This function only applies the multiplication when INV && TWIDDLE, + // otherwise it just returns b. + pub const fn mmulmod_invtw(a: u64, b: u64) -> u64 { + if INV && TWIDDLE { Self::mmulmod(a, b) } else { b } } // Fused-multiply-add with Montgomery reduction: // a * b * R^-1 + c mod P @@ -198,9 +202,9 @@ impl Arith

{ let (out, overflow) = a.overflowing_add(b); if overflow { out.wrapping_sub(P) } else { out } } - // Computes a + b mod P, selects addmod64 or addmod depending on INV - pub const fn addmodopt(a: u64, b: u64) -> u64 { - if INV { Self::addmod64(a, b) } else { Self::addmod(a, b) } + // Computes a + b mod P, selects addmod64 or addmod depending on INV && TWIDDLE + pub const fn addmodopt_invtw(a: u64, b: u64) -> u64 { + if INV && TWIDDLE { Self::addmod64(a, b) } else { Self::addmod(a, b) } } // Computes a - b mod P, output range [0, P) pub const fn submod(a: u64, b: u64) -> u64 { @@ -327,25 +331,16 @@ impl NttKernelImpl { (c51, c52, c53, c54, c55) } } -const fn ntt2_kernel_core( +const fn ntt2_kernel( w1p: u64, a: u64, mut b: u64) -> (u64, u64) { if !INV && TWIDDLE { b = Arith::

::mmulmod(w1p, b); } let out0 = Arith::

::addmod(a, b); - let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::submod(a, b)); + let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(a, b)); (out0, out1) } -const fn ntt2_kernel( - w1p: u64, - a: u64, b: u64) -> (u64, u64) { - match (INV, TWIDDLE) { - (_, false) => ntt2_kernel_core::(w1p, a, b), - (false, true) => ntt2_kernel_core::(w1p, a, b), - (true, true) => ntt2_kernel_core::(w1p, a, b) - } -} fn ntt2_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { @@ -359,7 +354,7 @@ fn ntt2_single_block( } (px.wrapping_add(s1), ptf.wrapping_add(1)) } -const fn ntt3_kernel_core( +const fn ntt3_kernel( w1p: u64, w2p: u64, a: u64, mut b: u64, mut c: u64) -> (u64, u64, u64) { if !INV && TWIDDLE { @@ -368,19 +363,10 @@ const fn ntt3_kernel_core::mmulmod(NttKernelImpl::::U3, Arith::

::submod(b, c)); let out0 = Arith::

::addmod(a, Arith::

::addmod(b, c)); - let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::submod(a, Arith::

::submod(c, kbmc))); - let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::submod(Arith::

::submod(a, b), kbmc)); + let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(a, Arith::

::submod(c, kbmc))); + let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(Arith::

::submod(a, b), kbmc)); (out0, out1, out2) } -const fn ntt3_kernel( - w1p: u64, w2p: u64, - a: u64, b: u64, c: u64) -> (u64, u64, u64) { - match (INV, TWIDDLE) { - (_, false) => ntt3_kernel_core::(w1p, w2p, a, b, c), - (false, true) => ntt3_kernel_core::(w1p, w2p, a, b, c), - (true, true) => ntt3_kernel_core::(w1p, w2p, a, b, c) - } -} fn ntt3_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { @@ -400,7 +386,7 @@ fn ntt3_single_block( } (px.wrapping_add(2*s1), ptf.wrapping_add(1)) } -const fn ntt4_kernel_core( +const fn ntt4_kernel( w1p: u64, w2p: u64, w3p: u64, a: u64, mut b: u64, mut c: u64, mut d: u64) -> (u64, u64, u64, u64) { if !INV && TWIDDLE { @@ -414,20 +400,11 @@ const fn ntt4_kernel_core::submod(b, d); let jbmd = Arith::

::mmulmod(bmd, P.wrapping_sub(NttKernelImpl::::U4)); let out0 = Arith::

::addmod(apc, bpd); - let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::submod(amc, jbmd)); - let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::submod(apc, bpd)); - let out3 = Arith::

::mmulmod_cond::(w3p, Arith::

::addmodopt::(amc, jbmd)); + let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(amc, jbmd)); + let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(apc, bpd)); + let out3 = Arith::

::mmulmod_invtw::(w3p, Arith::

::addmodopt_invtw::(amc, jbmd)); (out0, out1, out2, out3) } -const fn ntt4_kernel( - w1p: u64, w2p: u64, w3p: u64, - a: u64, b: u64, c: u64, d: u64) -> (u64, u64, u64, u64) { - match (INV, TWIDDLE) { - (_, false) => ntt4_kernel_core::(w1p, w2p, w3p, a, b, c, d), - (false, true) => ntt4_kernel_core::(w1p, w2p, w3p, a, b, c, d), - (true, true) => ntt4_kernel_core::(w1p, w2p, w3p, a, b, c, d) - } -} fn ntt4_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { @@ -450,7 +427,7 @@ fn ntt4_single_block( } (px.wrapping_add(3*s1), ptf.wrapping_add(1)) } -const fn ntt5_kernel_core( +const fn ntt5_kernel( w1p: u64, w2p: u64, w3p: u64, w4p: u64, a: u64, mut b: u64, mut c: u64, mut d: u64, mut e: u64) -> (u64, u64, u64, u64, u64) { if !INV && TWIDDLE { @@ -475,21 +452,12 @@ const fn ntt5_kernel_core::submod(m3, m2); let s4 = Arith::

::addmod(m2, m3); let out0 = m1; - let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::submod(s2, m5)); - let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::submod(0, Arith::

::addmod(s4, m6))); - let out3 = Arith::

::mmulmod_cond::(w3p, Arith::

::submod(m6, s4)); - let out4 = Arith::

::mmulmod_cond::(w4p, Arith::

::addmodopt::(s2, m5)); + let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(s2, m5)); + let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(0, Arith::

::addmod(s4, m6))); + let out3 = Arith::

::mmulmod_invtw::(w3p, Arith::

::submod(m6, s4)); + let out4 = Arith::

::mmulmod_invtw::(w4p, Arith::

::addmodopt_invtw::(s2, m5)); (out0, out1, out2, out3, out4) } -const fn ntt5_kernel( - w1p: u64, w2p: u64, w3p: u64, w4p: u64, - a: u64, b: u64, c: u64, d: u64, e: u64) -> (u64, u64, u64, u64, u64) { - match (INV, TWIDDLE) { - (_, false) => ntt5_kernel_core::(w1p, w2p, w3p, w4p, a, b, c, d, e), - (false, true) => ntt5_kernel_core::(w1p, w2p, w3p, w4p, a, b, c, d, e), - (true, true) => ntt5_kernel_core::(w1p, w2p, w3p, w4p, a, b, c, d, e) - } -} fn ntt5_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { @@ -513,7 +481,7 @@ fn ntt5_single_block( } (px.wrapping_add(4*s1), ptf.wrapping_add(1)) } -const fn ntt6_kernel_core( +const fn ntt6_kernel( w1p: u64, w2p: u64, w3p: u64, w4p: u64, w5p: u64, mut a: u64, mut b: u64, mut c: u64, mut d: u64, mut e: u64, mut f: u64) -> (u64, u64, u64, u64, u64, u64) { if !INV && TWIDDLE { @@ -528,23 +496,14 @@ const fn ntt6_kernel_core::addmod(c, f), Arith::

::submod(c, f)); let lbmc = Arith::

::mmulmod(NttKernelImpl::::U6, Arith::

::submod(b, c)); let out0 = Arith::

::addmod(a, Arith::

::addmod(b, c)); - let out2 = Arith::

::mmulmod_cond::(w2p, Arith::

::submod(a, Arith::

::submod(b, lbmc))); - let out4 = Arith::

::mmulmod_cond::(w4p, Arith::

::submod(Arith::

::submod(a, c), lbmc)); + let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(a, Arith::

::submod(b, lbmc))); + let out4 = Arith::

::mmulmod_invtw::(w4p, Arith::

::submod(Arith::

::submod(a, c), lbmc)); let lepf = Arith::

::mmulmod(NttKernelImpl::::U6, Arith::

::addmod64(e, f)); - let out1 = Arith::

::mmulmod_cond::(w1p, Arith::

::submod(d, Arith::

::submod(f, lepf))); - let out3 = Arith::

::mmulmod_cond::(w3p, Arith::

::submod(d, Arith::

::submod(e, f))); - let out5 = Arith::

::mmulmod_cond::(w5p, Arith::

::submod(d, Arith::

::submod(lepf, e))); + let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(d, Arith::

::submod(f, lepf))); + let out3 = Arith::

::mmulmod_invtw::(w3p, Arith::

::submod(d, Arith::

::submod(e, f))); + let out5 = Arith::

::mmulmod_invtw::(w5p, Arith::

::submod(d, Arith::

::submod(lepf, e))); (out0, out1, out2, out3, out4, out5) } -const fn ntt6_kernel( - w1p: u64, w2p: u64, w3p: u64, w4p: u64, w5p: u64, - a: u64, b: u64, c: u64, d: u64, e: u64, f: u64) -> (u64, u64, u64, u64, u64, u64) { - match (INV, TWIDDLE) { - (_, false) => ntt6_kernel_core::(w1p, w2p, w3p, w4p, w5p, a, b, c, d, e, f), - (false, true) => ntt6_kernel_core::(w1p, w2p, w3p, w4p, w5p, a, b, c, d, e, f), - (true, true) => ntt6_kernel_core::(w1p, w2p, w3p, w4p, w5p, a, b, c, d, e, f) - } -} fn ntt6_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { From f0b7f9602d49f5295c991da6adb8bb1946508edc Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sat, 16 Sep 2023 11:12:44 +0900 Subject: [PATCH 29/65] Simplify code --- src/biguint/ntt.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 7e67e098..161c0cbd 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -398,11 +398,11 @@ const fn ntt4_kernel( let amc = Arith::

::submod(a, c); let bpd = Arith::

::addmod(b, d); let bmd = Arith::

::submod(b, d); - let jbmd = Arith::

::mmulmod(bmd, P.wrapping_sub(NttKernelImpl::::U4)); + let jbmd = Arith::

::mmulmod(NttKernelImpl::::U4, bmd); let out0 = Arith::

::addmod(apc, bpd); - let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(amc, jbmd)); + let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::addmodopt_invtw::(amc, jbmd)); let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(apc, bpd)); - let out3 = Arith::

::mmulmod_invtw::(w3p, Arith::

::addmodopt_invtw::(amc, jbmd)); + let out3 = Arith::

::mmulmod_invtw::(w3p, Arith::

::submod(amc, jbmd)); (out0, out1, out2, out3) } fn ntt4_single_block( @@ -449,13 +449,13 @@ const fn ntt5_kernel( let m4 = Arith::

::mmulmod(NttKernelImpl::::C53, t7); let m5 = Arith::

::mmulsubmod(NttKernelImpl::::C54, t4, m4); let m6 = Arith::

::mmulsubmod(P.wrapping_sub(NttKernelImpl::::C55), t3, m4); - let s2 = Arith::

::submod(m3, m2); - let s4 = Arith::

::addmod(m2, m3); + let s1 = Arith::

::submod(m3, m2); + let s2 = Arith::

::addmod(m2, m3); let out0 = m1; - let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(s2, m5)); - let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(0, Arith::

::addmod(s4, m6))); - let out3 = Arith::

::mmulmod_invtw::(w3p, Arith::

::submod(m6, s4)); - let out4 = Arith::

::mmulmod_invtw::(w4p, Arith::

::addmodopt_invtw::(s2, m5)); + let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(s1, m5)); + let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(0, Arith::

::addmod(s2, m6))); + let out3 = Arith::

::mmulmod_invtw::(w3p, Arith::

::submod(m6, s2)); + let out4 = Arith::

::mmulmod_invtw::(w4p, Arith::

::addmodopt_invtw::(s1, m5)); (out0, out1, out2, out3, out4) } fn ntt5_single_block( From aab024a570bd0d6df469d243c724dee53356f457 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sat, 16 Sep 2023 13:52:21 +0900 Subject: [PATCH 30/65] Update multiplication.rs comment --- src/biguint/multiplication.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/biguint/multiplication.rs b/src/biguint/multiplication.rs index ea4f8065..bec130ee 100644 --- a/src/biguint/multiplication.rs +++ b/src/biguint/multiplication.rs @@ -99,7 +99,7 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { // number of operations, but uses more temporary allocations. // // The thresholds are somewhat arbitrary, chosen by evaluating the results - // of `cargo bench --bench bigint multiply`. + // of `cargo bench --bench bigint multiply --features rand`. if x.len() <= 32 { // Long multiplication: From 48fcb57331e204ba49f7cb7d5f4314fbaff69722 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sat, 16 Sep 2023 14:04:18 +0900 Subject: [PATCH 31/65] Improve NTT planning --- src/biguint/ntt.rs | 57 +++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 161c0cbd..384b417a 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -222,31 +222,44 @@ struct NttPlan { pub s_list: Vec<(usize, usize)>, } impl NttPlan { - pub const GMAX: usize = 6; + pub const GMAX: usize = 8; pub fn build(min_len: usize) -> NttPlan { assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); - let (mut len_max, mut len_max_cost) = (0usize, usize::MAX); - let mut len5 = 1; - for _ in 0..Arith::

::FACTOR_FIVE+1 { - let mut len35 = len5; - for _ in 0..Arith::

::FACTOR_THREE+1 { - let mut len = len35; - let mut i = 0; - while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } - if len >= min_len && len < len_max_cost { - let (mut tmp, mut cost) = (len, 0); - if tmp % 6 == 0 && tmp % 5 != 0 { (tmp, cost) = (tmp/6, cost + len*93/100); } - while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len + len*5/100); } - if tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len*95/100); } - while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len + len*22/100); } - while tmp % 4 == 0 { (tmp, cost) = (tmp/4, cost + len); } - while tmp % 3 == 0 { (tmp, cost) = (tmp/3, cost + len); } - while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); } - if cost < len_max_cost { (len_max, len_max_cost) = (len, cost); } + let (mut len_max, mut len_max_cost) = (usize::MAX, usize::MAX); + let mut len7 = 1; + for _ in 0..2 { + let mut len5 = len7; + for _ in 0..Arith::

::FACTOR_FIVE+1 { + let mut len3 = len5; + for _ in 0..Arith::

::FACTOR_THREE+1 { + let mut len = len3; + let mut i = 0; + while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } + if len >= min_len && len < len_max_cost { + let (mut tmp, mut cost) = (len, 0); + if len % 7 == 0 { + (tmp, cost) = (tmp/7, cost + len*115/100); + } else if len % 5 == 0 { + (tmp, cost) = (tmp/5, cost + len*89/100); + } else if len % 8 == 0 { + (tmp, cost) = (tmp/8, cost + len*130/100); + } else if len % 6 == 0 { + (tmp, cost) = (tmp/6, cost + len*91/100); + } + while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len + len*6/100); } + while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len + len*28/100); } + while tmp % 4 == 0 { (tmp, cost) = (tmp/4, cost + len); } + while tmp % 3 == 0 { (tmp, cost) = (tmp/3, cost + len); } + while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); } + if cost < len_max_cost { (len_max, len_max_cost) = (len, cost); } + } + len3 *= 3; + if len3 >= 2*min_len { break; } } - len35 *= 3; + len5 *= 5; + if len5 >= 2*min_len { break; } } - len5 *= 5; + len7 *= 7; } let (mut cnt6, mut cnt5, mut cnt4, mut cnt3, mut cnt2) = (0, 0, 0, 0, 0); let mut tmp = len_max; @@ -256,8 +269,10 @@ impl NttPlan { while tmp % 3 == 0 { tmp /= 3; cnt3 += 1; } while tmp % 2 == 0 { tmp /= 2; cnt2 += 1; } let mut g = 1; + while 7*g <= Self::GMAX && len_max % 7 == 0 { g *= 7; } while 5*g <= Self::GMAX && cnt5 > 0 { g *= 5; cnt5 -= 1; } while 9*g <= Self::GMAX && cnt3 >= 2 { g *= 9; cnt3 -= 2; } + while 8*g <= Self::GMAX && cnt4 >= 2 { g *= 8; cnt4 -= 2; cnt2 += 1; if cnt2 >= 2 { cnt4 += 1; cnt2 -= 2; } } while 8*g <= Self::GMAX && cnt4 > 0 && cnt2 > 0 { g *= 8; cnt4 -= 1; cnt2 -= 1; } while 6*g <= Self::GMAX && cnt6 > 0 { g *= 6; cnt6 -= 1; } while 4*g <= Self::GMAX && cnt4 > 0 { g *= 4; cnt4 -= 1; } From 050ee54d595749ac2d4348c9a783789c8bfdcce8 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Sun, 17 Sep 2023 22:37:39 +0900 Subject: [PATCH 32/65] Reduce memory access when repacking --- src/biguint/ntt.rs | 82 +++++++++++++++++++++------------------------- 1 file changed, 37 insertions(+), 45 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 384b417a..d32fb0c1 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -715,24 +715,46 @@ const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { assert!(bits < 63); - let min_len = b.len() + c.len(); + let b_len = ((64 * b.len() as u64 + bits - 1) / bits) as usize; + let c_len = ((64 * c.len() as u64 + bits - 1) / bits) as usize; + let min_len = b_len + c_len; let plan_x = NttPlan::build::(min_len); let plan_y = NttPlan::build::(min_len); - let len_max = max(plan_x.g + plan_x.n, plan_y.g + plan_y.n); - let mut x = vec![0u64; plan_x.g + plan_x.n]; - let mut y = vec![0u64; plan_y.g + plan_y.n]; - let mut r = vec![0u64; len_max]; - /* convolution with modulo P2 */ - x[plan_x.g..plan_x.g+b.len()].clone_from_slice(b); - r[plan_x.g..plan_x.g+c.len()].clone_from_slice(c); - conv::(&plan_x, &mut x, b.len(), &mut r[..plan_x.g+plan_x.n], c.len(), arith::invmod(P3, P2)); + fn pack_into(src: &[u64], dst1: &mut [u64], dst2: &mut [u64], bits: u64) { + let mut p = 0u64; + let mut pdst1 = dst1.as_mut_ptr(); + let mut pdst2 = dst2.as_mut_ptr(); + let mut x = 0u64; + let mask = (1u64 << bits).wrapping_sub(1); + for v in src { + let mut k = 0; + while k < 64 { + x |= (v >> k) << p; + let q = 64 - k; + if p + q >= bits { + unsafe { let out = x & mask; (*pdst1, *pdst2) = (out, out); } + x = 0; + (pdst1, pdst2, k, p) = (pdst1.wrapping_add(1), pdst2.wrapping_add(1), k + bits - p, 0); + } else { + p += q; + break; + } + } + } + unsafe { + if p > 0 { let out = x & mask; (*pdst1, *pdst2) = (out, out); } + } + } - /* convolution with modulo P3 */ - y[plan_y.g..plan_y.g+b.len()].clone_from_slice(b); - r[plan_y.g..plan_y.g+c.len()].clone_from_slice(c); - (&mut r[plan_y.g..])[c.len()..plan_y.n].fill(0u64); - conv::(&plan_y, &mut y, b.len(), &mut r[..plan_y.g+plan_y.n], c.len(), Arith::::submod(0, arith::invmod(P2, P3))); + let mut x = vec![0u64; plan_x.g + plan_x.n]; + let mut y = vec![0u64; plan_y.g + plan_y.n]; + let mut r = vec![0u64; plan_x.g + plan_x.n]; + let mut s = vec![0u64; plan_y.g + plan_y.n]; + pack_into(b, &mut x[plan_x.g..], &mut y[plan_y.g..], bits); + pack_into(c, &mut r[plan_x.g..], &mut s[plan_y.g..], bits); + conv::(&plan_x, &mut x, b_len, &mut r[..plan_x.g+plan_x.n], c_len, arith::invmod(P3, P2)); + conv::(&plan_y, &mut y, b_len, &mut s[..plan_y.g+plan_y.n], c_len, Arith::::submod(0, arith::invmod(P2, P3))); /* merge the results in {x, y} into r (process carry along the way) */ let mask = (1u64 << bits) - 1; @@ -912,37 +934,7 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { let max_cnt = max(b.len(), c.len()) as u64; let bits = compute_bits(max_cnt); if bits >= 43 { - /* can pack more effective bits per u64 with two primes than with three primes */ - fn pack_into(src: &[u64], dst: &mut [u64], bits: u64) -> usize { - let mut p = 0u64; - let mut pdst = dst.as_mut_ptr(); - let mut x = 0u64; - let mask = (1u64 << bits).wrapping_sub(1); - for v in src { - let mut k = 0; - while k < 64 { - x |= (v >> k) << p; - let q = 64 - k; - if p + q >= bits { - unsafe { *pdst = x & mask; } - x = 0; - (pdst, k, p) = (pdst.wrapping_add(1), k + bits - p, 0); - } else { - p += q; - break; - } - } - } - unsafe { - if p > 0 { *pdst = x & mask; pdst = pdst.wrapping_add(1); } - pdst.offset_from(dst.as_mut_ptr()) as usize - } - } - let mut b2 = vec![0u64; ((64 * b.len() as u64 + bits - 1) / bits) as usize]; - let mut c2 = vec![0u64; ((64 * c.len() as u64 + bits - 1) / bits) as usize]; - let b2_len = pack_into(b, &mut b2, bits); - let c2_len = pack_into(c, &mut c2, bits); - mac3_two_primes(acc, &b2[..b2_len], &c2[..c2_len], bits); + mac3_two_primes(acc, b, c, bits); } else { /* can pack at most 21 effective bits per u64, which is worse than 64/3 = 21.3333.. effective bits per u64 achieved with three primes */ From 22e352b49b8b24ba0443f2b607846e05bdc0cb37 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 18 Sep 2023 06:45:31 +0900 Subject: [PATCH 33/65] Fix potential carry bug --- src/biguint/ntt.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index d32fb0c1..7a12ceb0 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -775,10 +775,10 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { p += bits; if p >= 64 { /* flush s to the output buffer */ - s += carry_acc; - let (w, overflow) = acc[j].overflowing_add(s); + let (w, overflow1) = s.overflowing_add(carry_acc); + let (w, overflow2) = acc[j].overflowing_add(w); acc[j] = w; - carry_acc = u64::from(overflow); + carry_acc = u64::from(overflow1 || overflow2); /* roll-over */ (j, p) = (j+1, p-64); From 1ea06b93af35db10b24e1de47dfba9e6b62e55c1 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 18 Sep 2023 06:45:35 +0900 Subject: [PATCH 34/65] Update ntt.rs --- src/biguint/ntt.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 7a12ceb0..a0db31d5 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -771,7 +771,7 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { /* write to s */ let out = (v as u64) & mask; - s = (s & ((1u64 << p) - 1)) | (out << p); + s |= out << p; p += bits; if p >= 64 { /* flush s to the output buffer */ From 8bd2ab1ac4584fc15198bca300f5072a61e5518a Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Mon, 18 Sep 2023 06:49:23 +0900 Subject: [PATCH 35/65] Update ntt.rs --- src/biguint/ntt.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index a0db31d5..55a4a860 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -934,6 +934,7 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { let max_cnt = max(b.len(), c.len()) as u64; let bits = compute_bits(max_cnt); if bits >= 43 { + /* can pack more effective bits per u64 with two primes than with three primes */ mac3_two_primes(acc, b, c, bits); } else { /* can pack at most 21 effective bits per u64, which is worse than From 5dbbd8cb049a562b354f486c969ad3bf2698f9df Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 00:27:16 +0900 Subject: [PATCH 36/65] Improve NTT planning: fix nonmonotonicity up to 1M --- src/biguint/ntt.rs | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 55a4a860..bfe37149 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -225,33 +225,49 @@ impl NttPlan { pub const GMAX: usize = 8; pub fn build(min_len: usize) -> NttPlan { assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); - let (mut len_max, mut len_max_cost) = (usize::MAX, usize::MAX); + let (mut len_max, mut len_max_cost, mut g) = (usize::MAX, usize::MAX, 1); let mut len7 = 1; for _ in 0..2 { let mut len5 = len7; for _ in 0..Arith::

::FACTOR_FIVE+1 { let mut len3 = len5; - for _ in 0..Arith::

::FACTOR_THREE+1 { + for j in 0..Arith::

::FACTOR_THREE+1 { let mut len = len3; let mut i = 0; while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } if len >= min_len && len < len_max_cost { let (mut tmp, mut cost) = (len, 0); + let mut g_base_new = 1; if len % 7 == 0 { (tmp, cost) = (tmp/7, cost + len*115/100); + g_base_new = 7; } else if len % 5 == 0 { (tmp, cost) = (tmp/5, cost + len*89/100); - } else if len % 8 == 0 { + g_base_new = 5; + } else if i >= j + 3 { (tmp, cost) = (tmp/8, cost + len*130/100); + g_base_new = 8; + } else if i >= j + 2 { + (tmp, cost) = (tmp/4, cost + len*87/100); + g_base_new = 4; + } else if i == 0 && j >= 1 { + (tmp, cost) = (tmp/3, cost + len*86/100); + g_base_new = 3; + } else if j == 0 && i >= 1 { + (tmp, cost) = (tmp/2, cost + len*86/100); + g_base_new = 2; } else if len % 6 == 0 { (tmp, cost) = (tmp/6, cost + len*91/100); + g_base_new = 6; } - while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len + len*6/100); } - while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len + len*28/100); } + let (mut b6, mut b2) = (false, false); + while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len + len*6/100); b6 = true; } + while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len + len*31/100); } while tmp % 4 == 0 { (tmp, cost) = (tmp/4, cost + len); } while tmp % 3 == 0 { (tmp, cost) = (tmp/3, cost + len); } - while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); } - if cost < len_max_cost { (len_max, len_max_cost) = (len, cost); } + while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); b2 = true; } + if b6 && b2 { cost -= len*6/100; } + if cost < len_max_cost { (len_max, len_max_cost, g) = (len, cost, g_base_new); } } len3 *= 3; if len3 >= 2*min_len { break; } @@ -262,22 +278,12 @@ impl NttPlan { len7 *= 7; } let (mut cnt6, mut cnt5, mut cnt4, mut cnt3, mut cnt2) = (0, 0, 0, 0, 0); - let mut tmp = len_max; + let mut tmp = len_max/g; while tmp % 6 == 0 { tmp /= 6; cnt6 += 1; } while tmp % 5 == 0 { tmp /= 5; cnt5 += 1; } while tmp % 4 == 0 { tmp /= 4; cnt4 += 1; } while tmp % 3 == 0 { tmp /= 3; cnt3 += 1; } while tmp % 2 == 0 { tmp /= 2; cnt2 += 1; } - let mut g = 1; - while 7*g <= Self::GMAX && len_max % 7 == 0 { g *= 7; } - while 5*g <= Self::GMAX && cnt5 > 0 { g *= 5; cnt5 -= 1; } - while 9*g <= Self::GMAX && cnt3 >= 2 { g *= 9; cnt3 -= 2; } - while 8*g <= Self::GMAX && cnt4 >= 2 { g *= 8; cnt4 -= 2; cnt2 += 1; if cnt2 >= 2 { cnt4 += 1; cnt2 -= 2; } } - while 8*g <= Self::GMAX && cnt4 > 0 && cnt2 > 0 { g *= 8; cnt4 -= 1; cnt2 -= 1; } - while 6*g <= Self::GMAX && cnt6 > 0 { g *= 6; cnt6 -= 1; } - while 4*g <= Self::GMAX && cnt4 > 0 { g *= 4; cnt4 -= 1; } - while 3*g <= Self::GMAX && cnt3 > 0 { g *= 3; cnt3 -= 1; } - while 2*g <= Self::GMAX && cnt2 > 0 { g *= 2; cnt2 -= 1; } while cnt6 > 0 && cnt2 > 0 { cnt6 -= 1; cnt2 -= 1; cnt4 += 1; cnt3 += 1; } let s_list = { let mut out = vec![]; From 9740fd8f15320428bfb4ed3a7957716a2799a85a Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 08:50:12 +0900 Subject: [PATCH 37/65] Remove unused definitions from ntt.rs --- src/biguint/ntt.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index bfe37149..7a5ae8f5 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -51,7 +51,6 @@ impl Arith

{ pub const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P pub const R2: u64 = ((Self::R as u128 * Self::R as u128) % P as u128) as u64; // R^2 mod P pub const R3: u64 = ((Self::R2 as u128 * Self::R as u128) % P as u128) as u64; // R^3 mod P - pub const RNEG: u64 = P.wrapping_sub(Self::R); // -2^64 mod P pub const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 pub const ROOT: u64 = Self::ntt_root(); // MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN pub const ROOTR: u64 = Self::mulmod_naive(Self::ROOT, Self::R); // ROOT * R mod P @@ -147,14 +146,6 @@ impl Arith

{ pub const fn mmulmod_invtw(a: u64, b: u64) -> u64 { if INV && TWIDDLE { Self::mmulmod(a, b) } else { b } } - // Fused-multiply-add with Montgomery reduction: - // a * b * R^-1 + c mod P - pub const fn mmuladdmod(a: u64, b: u64, c: u64) -> u64 { - let x = a as u128 * b as u128; - let lo = x as u64; - let hi = Self::addmod((x >> 64) as u64, c); - Self::mreduce(lo as u128 | ((hi as u128) << 64)) - } // Fused-multiply-sub with Montgomery reduction: // a * b * R^-1 - c mod P pub const fn mmulsubmod(a: u64, b: u64, c: u64) -> u64 { @@ -222,7 +213,6 @@ struct NttPlan { pub s_list: Vec<(usize, usize)>, } impl NttPlan { - pub const GMAX: usize = 8; pub fn build(min_len: usize) -> NttPlan { assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); let (mut len_max, mut len_max_cost, mut g) = (usize::MAX, usize::MAX, 1); From 37e626c8f6f615ac5c26619db719fdb18a0bbc3d Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 08:53:50 +0900 Subject: [PATCH 38/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 7a5ae8f5..1f7204f0 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -76,15 +76,9 @@ impl Arith

{ ans } const fn ntt_root() -> u64 { - let mut p = 1; + let mut p = 2; 'outer: loop { /* ensure p is prime */ - p += 1; - let mut i = 2; - while i * i <= p { - if p % i == 0 { continue 'outer; } - i += 1; - } let root = Self::powmod_naive(p, P/Self::MAX_NTT_LEN); let mut j = 0; while j <= Self::FACTOR_TWO { @@ -97,6 +91,7 @@ impl Arith

{ let p5 = Self::powmod_naive(5, l as u64); let exponent = p2 * p3 * p5; if exponent < Self::MAX_NTT_LEN && Self::powmod_naive(root, exponent) == 1 { + p += 1; continue 'outer; } l += 1; From 9e3e3ed3851cfc4ffdc5c2787fd766e59cfeb719 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 09:28:44 +0900 Subject: [PATCH 39/65] Make ntt.rs shorter (simplify egcd) --- src/biguint/ntt.rs | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 1f7204f0..87e32277 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -11,22 +11,14 @@ mod arith { // Extended Euclid algorithm: // (g, x, y) is a solution to ax + by = g, where g = gcd(a, b) pub const fn egcd(mut a: i128, mut b: i128) -> (i128, i128, i128) { - if a < 0 { a = -a; } - if b < 0 { b = -b; } - assert!(a > 0 || b > 0); + assert!(a > 0 && b > 0); let mut c = [1, 0, 0, 1]; // treat as a row-major 2x2 matrix + if a > b { (a, b) = (b, a); c = [0, 1, 1, 0]; } let (g, x, y) = loop { if a == 0 { break (b, 0, 1); } - if b == 0 { break (a, 1, 0); } - if a < b { - let (q, r) = (b/a, b%a); - b = r; - c = [c[0], c[1] - q*c[0], c[2], c[3] - q*c[2]]; - } else { - let (q, r) = (a/b, a%b); - a = r; - c = [c[0] - q*c[1], c[1], c[2] - q*c[3], c[3]]; - } + let (q, r) = (b/a, b%a); + (a, b) = (r, a); + c = [c[1] - q*c[0], c[0], c[3] - q*c[2], c[2]]; }; (g, c[0]*x + c[1]*y, c[2]*x + c[3]*y) } From 36222a8c34ad83c5c2edf174ae9add6744e27bdc Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 10:37:33 +0900 Subject: [PATCH 40/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 87e32277..8040ea45 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -59,11 +59,9 @@ impl Arith

{ ans } const fn max_ntt_len() -> u64 { - let mut ans = 1u64 << Self::FACTOR_TWO; - let mut i = 0; - while i < Self::FACTOR_THREE { ans *= 3; i += 1; } - let mut i = 0; - while i < Self::FACTOR_FIVE { ans *= 5; i += 1; } + let ans = Self::powmod_naive(2, Self::FACTOR_TWO as u64) * + Self::powmod_naive(3, Self::FACTOR_THREE as u64) * + Self::powmod_naive(5, Self::FACTOR_FIVE as u64); assert!(ans % 4050 == 0); ans } @@ -214,28 +212,28 @@ impl NttPlan { while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } if len >= min_len && len < len_max_cost { let (mut tmp, mut cost) = (len, 0); - let mut g_base_new = 1; + let mut g_new = 1; if len % 7 == 0 { (tmp, cost) = (tmp/7, cost + len*115/100); - g_base_new = 7; + g_new = 7; } else if len % 5 == 0 { (tmp, cost) = (tmp/5, cost + len*89/100); - g_base_new = 5; + g_new = 5; } else if i >= j + 3 { (tmp, cost) = (tmp/8, cost + len*130/100); - g_base_new = 8; + g_new = 8; } else if i >= j + 2 { (tmp, cost) = (tmp/4, cost + len*87/100); - g_base_new = 4; + g_new = 4; } else if i == 0 && j >= 1 { (tmp, cost) = (tmp/3, cost + len*86/100); - g_base_new = 3; + g_new = 3; } else if j == 0 && i >= 1 { (tmp, cost) = (tmp/2, cost + len*86/100); - g_base_new = 2; + g_new = 2; } else if len % 6 == 0 { (tmp, cost) = (tmp/6, cost + len*91/100); - g_base_new = 6; + g_new = 6; } let (mut b6, mut b2) = (false, false); while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len + len*6/100); b6 = true; } @@ -244,7 +242,7 @@ impl NttPlan { while tmp % 3 == 0 { (tmp, cost) = (tmp/3, cost + len); } while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); b2 = true; } if b6 && b2 { cost -= len*6/100; } - if cost < len_max_cost { (len_max, len_max_cost, g) = (len, cost, g_base_new); } + if cost < len_max_cost { (len_max, len_max_cost, g) = (len, cost, g_new); } } len3 *= 3; if len3 >= 2*min_len { break; } From 771418b101aaa36f0f15a4c3e588f0c0d15bdd02 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 11:33:37 +0900 Subject: [PATCH 41/65] Fix clippy warnings + Make ntt.rs shorter --- src/biguint/ntt.rs | 226 ++++++++++++++++++++++----------------------- 1 file changed, 109 insertions(+), 117 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 8040ea45..49443f57 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -3,6 +3,7 @@ #![allow(clippy::cast_possible_truncation)] #![allow(clippy::many_single_char_names)] #![allow(clippy::needless_range_loop)] +#![allow(clippy::too_many_arguments)] #![allow(clippy::similar_names)] use crate::biguint::Vec; @@ -12,8 +13,7 @@ mod arith { // (g, x, y) is a solution to ax + by = g, where g = gcd(a, b) pub const fn egcd(mut a: i128, mut b: i128) -> (i128, i128, i128) { assert!(a > 0 && b > 0); - let mut c = [1, 0, 0, 1]; // treat as a row-major 2x2 matrix - if a > b { (a, b) = (b, a); c = [0, 1, 1, 0]; } + let mut c = if a > b { (a, b) = (b, a); [0, 1, 1, 0] } else { [1, 0, 0, 1] }; // treat as a row-major 2x2 matrix let (g, x, y) = loop { if a == 0 { break (b, 0, 1); } let (q, r) = (b/a, b%a); @@ -166,8 +166,7 @@ impl Arith

{ // Computes submod(0, mreduce(x as u128)) fast. pub const fn mreducelo(x: u64) -> u64 { let m = x.wrapping_mul(Self::PINV); - let y = ((m as u128 * P as u128) >> 64) as u64; - y + ((m as u128 * P as u128) >> 64) as u64 } // Computes a + b mod P, output range [0, P) pub const fn addmod(a: u64, b: u64) -> u64 { @@ -198,15 +197,15 @@ struct NttPlan { pub s_list: Vec<(usize, usize)>, } impl NttPlan { - pub fn build(min_len: usize) -> NttPlan { + pub fn build(min_len: usize) -> Self { assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); let (mut len_max, mut len_max_cost, mut g) = (usize::MAX, usize::MAX, 1); let mut len7 = 1; for _ in 0..2 { let mut len5 = len7; - for _ in 0..Arith::

::FACTOR_FIVE+1 { + for _ in 0..=Arith::

::FACTOR_FIVE { let mut len3 = len5; - for j in 0..Arith::

::FACTOR_THREE+1 { + for j in 0..=Arith::

::FACTOR_THREE { let mut len = len3; let mut i = 0; while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } @@ -214,30 +213,23 @@ impl NttPlan { let (mut tmp, mut cost) = (len, 0); let mut g_new = 1; if len % 7 == 0 { - (tmp, cost) = (tmp/7, cost + len*115/100); - g_new = 7; + (g_new, tmp, cost) = (7, tmp/7, cost + len*115/100); } else if len % 5 == 0 { - (tmp, cost) = (tmp/5, cost + len*89/100); - g_new = 5; + (g_new, tmp, cost) = (5, tmp/5, cost + len*89/100); } else if i >= j + 3 { - (tmp, cost) = (tmp/8, cost + len*130/100); - g_new = 8; + (g_new, tmp, cost) = (8, tmp/8, cost + len*130/100); } else if i >= j + 2 { - (tmp, cost) = (tmp/4, cost + len*87/100); - g_new = 4; + (g_new, tmp, cost) = (4, tmp/4, cost + len*87/100); } else if i == 0 && j >= 1 { - (tmp, cost) = (tmp/3, cost + len*86/100); - g_new = 3; + (g_new, tmp, cost) = (3, tmp/3, cost + len*86/100); } else if j == 0 && i >= 1 { - (tmp, cost) = (tmp/2, cost + len*86/100); - g_new = 2; + (g_new, tmp, cost) = (2, tmp/2, cost + len*86/100); } else if len % 6 == 0 { - (tmp, cost) = (tmp/6, cost + len*91/100); - g_new = 6; + (g_new, tmp, cost) = (6, tmp/6, cost + len*91/100); } let (mut b6, mut b2) = (false, false); - while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len + len*6/100); b6 = true; } - while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len + len*31/100); } + while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len*106/100); b6 = true; } + while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len*131/100); } while tmp % 4 == 0 { (tmp, cost) = (tmp/4, cost + len); } while tmp % 3 == 0 { (tmp, cost) = (tmp/3, cost + len); } while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); b2 = true; } @@ -270,13 +262,13 @@ impl NttPlan { for _ in 0..cnt6 { out.push((tmp, 6)); tmp /= 6; } out }; - NttPlan { + Self { n: len_max, - g: g, + g, m: len_max / g, cost: len_max_cost, last_radix: s_list.last().unwrap_or(&(1, 1)).1, - s_list: s_list, + s_list, } } } @@ -328,22 +320,22 @@ impl NttKernelImpl { } } const fn ntt2_kernel( - w1p: u64, + w1: u64, a: u64, mut b: u64) -> (u64, u64) { if !INV && TWIDDLE { - b = Arith::

::mmulmod(w1p, b); + b = Arith::

::mmulmod(w1, b); } let out0 = Arith::

::addmod(a, b); - let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(a, b)); + let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::submod(a, b)); (out0, out1) } fn ntt2_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { - let w1p = if TWIDDLE { *ptf } else { 0 }; + let w1 = if TWIDDLE { *ptf } else { 0 }; for _ in 0..s1 { (*px, *px.wrapping_add(s1)) = - ntt2_kernel::(w1p, + ntt2_kernel::(w1, *px, *px.wrapping_add(s1)); px = px.wrapping_add(1); } @@ -351,31 +343,31 @@ fn ntt2_single_block( (px.wrapping_add(s1), ptf.wrapping_add(1)) } const fn ntt3_kernel( - w1p: u64, w2p: u64, + w1: u64, w2: u64, a: u64, mut b: u64, mut c: u64) -> (u64, u64, u64) { if !INV && TWIDDLE { - b = Arith::

::mmulmod(w1p, b); - c = Arith::

::mmulmod(w2p, c); + b = Arith::

::mmulmod(w1, b); + c = Arith::

::mmulmod(w2, c); } let kbmc = Arith::

::mmulmod(NttKernelImpl::::U3, Arith::

::submod(b, c)); let out0 = Arith::

::addmod(a, Arith::

::addmod(b, c)); - let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(a, Arith::

::submod(c, kbmc))); - let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(Arith::

::submod(a, b), kbmc)); + let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::submod(a, Arith::

::submod(c, kbmc))); + let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(Arith::

::submod(a, b), kbmc)); (out0, out1, out2) } fn ntt3_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { - let (w1p, w2p) = if TWIDDLE { - let w1p = *ptf; - let w2p = Arith::

::mmulmod(w1p, w1p); - (w1p, w2p) + let (w1, w2) = if TWIDDLE { + let w1 = *ptf; + let w2 = Arith::

::mmulmod(w1, w1); + (w1, w2) } else { (0, 0) }; for _ in 0..s1 { (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1)) = - ntt3_kernel::(w1p, w2p, + ntt3_kernel::(w1, w2, *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1)); px = px.wrapping_add(1); } @@ -383,12 +375,12 @@ fn ntt3_single_block( (px.wrapping_add(2*s1), ptf.wrapping_add(1)) } const fn ntt4_kernel( - w1p: u64, w2p: u64, w3p: u64, + w1: u64, w2: u64, w3: u64, a: u64, mut b: u64, mut c: u64, mut d: u64) -> (u64, u64, u64, u64) { if !INV && TWIDDLE { - b = Arith::

::mmulmod(w1p, b); - c = Arith::

::mmulmod(w2p, c); - d = Arith::

::mmulmod(w3p, d); + b = Arith::

::mmulmod(w1, b); + c = Arith::

::mmulmod(w2, c); + d = Arith::

::mmulmod(w3, d); } let apc = Arith::

::addmod(a, c); let amc = Arith::

::submod(a, c); @@ -396,26 +388,26 @@ const fn ntt4_kernel( let bmd = Arith::

::submod(b, d); let jbmd = Arith::

::mmulmod(NttKernelImpl::::U4, bmd); let out0 = Arith::

::addmod(apc, bpd); - let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::addmodopt_invtw::(amc, jbmd)); - let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(apc, bpd)); - let out3 = Arith::

::mmulmod_invtw::(w3p, Arith::

::submod(amc, jbmd)); + let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::addmodopt_invtw::(amc, jbmd)); + let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(apc, bpd)); + let out3 = Arith::

::mmulmod_invtw::(w3, Arith::

::submod(amc, jbmd)); (out0, out1, out2, out3) } fn ntt4_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { - let (w1p, w2p, w3p) = if TWIDDLE { - let w1p = *ptf; - let w2p = Arith::

::mmulmod(w1p, w1p); - let w3p = Arith::

::mmulmod(w1p, w2p); - (w1p, w2p, w3p) + let (w1, w2, w3) = if TWIDDLE { + let w1 = *ptf; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + (w1, w2, w3) } else { (0, 0, 0) }; for _ in 0..s1 { (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), *px.wrapping_add(3*s1)) = - ntt4_kernel::(w1p, w2p, w3p, + ntt4_kernel::(w1, w2, w3, *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), *px.wrapping_add(3*s1)); px = px.wrapping_add(1); @@ -424,13 +416,13 @@ fn ntt4_single_block( (px.wrapping_add(3*s1), ptf.wrapping_add(1)) } const fn ntt5_kernel( - w1p: u64, w2p: u64, w3p: u64, w4p: u64, + w1: u64, w2: u64, w3: u64, w4: u64, a: u64, mut b: u64, mut c: u64, mut d: u64, mut e: u64) -> (u64, u64, u64, u64, u64) { if !INV && TWIDDLE { - b = Arith::

::mmulmod(w1p, b); - c = Arith::

::mmulmod(w2p, c); - d = Arith::

::mmulmod(w3p, d); - e = Arith::

::mmulmod(w4p, e); + b = Arith::

::mmulmod(w1, b); + c = Arith::

::mmulmod(w2, c); + d = Arith::

::mmulmod(w3, d); + e = Arith::

::mmulmod(w4, e); } let t1 = Arith::

::addmod(b, e); let t2 = Arith::

::addmod(c, d); @@ -448,28 +440,28 @@ const fn ntt5_kernel( let s1 = Arith::

::submod(m3, m2); let s2 = Arith::

::addmod(m2, m3); let out0 = m1; - let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(s1, m5)); - let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(0, Arith::

::addmod(s2, m6))); - let out3 = Arith::

::mmulmod_invtw::(w3p, Arith::

::submod(m6, s2)); - let out4 = Arith::

::mmulmod_invtw::(w4p, Arith::

::addmodopt_invtw::(s1, m5)); + let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::submod(s1, m5)); + let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(0, Arith::

::addmod(s2, m6))); + let out3 = Arith::

::mmulmod_invtw::(w3, Arith::

::submod(m6, s2)); + let out4 = Arith::

::mmulmod_invtw::(w4, Arith::

::addmodopt_invtw::(s1, m5)); (out0, out1, out2, out3, out4) } fn ntt5_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { - let (w1p, w2p, w3p, w4p) = if TWIDDLE { - let w1p = *ptf; - let w2p = Arith::

::mmulmod(w1p, w1p); - let w3p = Arith::

::mmulmod(w1p, w2p); - let w4p = Arith::

::mmulmod(w2p, w2p); - (w1p, w2p, w3p, w4p) + let (w1, w2, w3, w4) = if TWIDDLE { + let w1 = *ptf; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + let w4 = Arith::

::mmulmod(w2, w2); + (w1, w2, w3, w4) } else { (0, 0, 0, 0) }; for _ in 0..s1 { (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), *px.wrapping_add(3*s1), *px.wrapping_add(4*s1)) = - ntt5_kernel::(w1p, w2p, w3p, w4p, + ntt5_kernel::(w1, w2, w3, w4, *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), *px.wrapping_add(3*s1), *px.wrapping_add(4*s1)); px = px.wrapping_add(1); @@ -478,45 +470,45 @@ fn ntt5_single_block( (px.wrapping_add(4*s1), ptf.wrapping_add(1)) } const fn ntt6_kernel( - w1p: u64, w2p: u64, w3p: u64, w4p: u64, w5p: u64, + w1: u64, w2: u64, w3: u64, w4: u64, w5: u64, mut a: u64, mut b: u64, mut c: u64, mut d: u64, mut e: u64, mut f: u64) -> (u64, u64, u64, u64, u64, u64) { if !INV && TWIDDLE { - b = Arith::

::mmulmod(w1p, b); - c = Arith::

::mmulmod(w2p, c); - d = Arith::

::mmulmod(w3p, d); - e = Arith::

::mmulmod(w4p, e); - f = Arith::

::mmulmod(w5p, f); + b = Arith::

::mmulmod(w1, b); + c = Arith::

::mmulmod(w2, c); + d = Arith::

::mmulmod(w3, d); + e = Arith::

::mmulmod(w4, e); + f = Arith::

::mmulmod(w5, f); } (a, d) = (Arith::

::addmod(a, d), Arith::

::submod(a, d)); (b, e) = (Arith::

::addmod(b, e), Arith::

::submod(b, e)); (c, f) = (Arith::

::addmod(c, f), Arith::

::submod(c, f)); let lbmc = Arith::

::mmulmod(NttKernelImpl::::U6, Arith::

::submod(b, c)); let out0 = Arith::

::addmod(a, Arith::

::addmod(b, c)); - let out2 = Arith::

::mmulmod_invtw::(w2p, Arith::

::submod(a, Arith::

::submod(b, lbmc))); - let out4 = Arith::

::mmulmod_invtw::(w4p, Arith::

::submod(Arith::

::submod(a, c), lbmc)); + let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(a, Arith::

::submod(b, lbmc))); + let out4 = Arith::

::mmulmod_invtw::(w4, Arith::

::submod(Arith::

::submod(a, c), lbmc)); let lepf = Arith::

::mmulmod(NttKernelImpl::::U6, Arith::

::addmod64(e, f)); - let out1 = Arith::

::mmulmod_invtw::(w1p, Arith::

::submod(d, Arith::

::submod(f, lepf))); - let out3 = Arith::

::mmulmod_invtw::(w3p, Arith::

::submod(d, Arith::

::submod(e, f))); - let out5 = Arith::

::mmulmod_invtw::(w5p, Arith::

::submod(d, Arith::

::submod(lepf, e))); + let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::submod(d, Arith::

::submod(f, lepf))); + let out3 = Arith::

::mmulmod_invtw::(w3, Arith::

::submod(d, Arith::

::submod(e, f))); + let out5 = Arith::

::mmulmod_invtw::(w5, Arith::

::submod(d, Arith::

::submod(lepf, e))); (out0, out1, out2, out3, out4, out5) } fn ntt6_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { unsafe { - let (w1p, w2p, w3p, w4p, w5p) = if TWIDDLE { - let w1p = *ptf; - let w2p = Arith::

::mmulmod(w1p, w1p); - let w3p = Arith::

::mmulmod(w1p, w2p); - let w4p = Arith::

::mmulmod(w2p, w2p); - let w5p = Arith::

::mmulmod(w2p, w3p); - (w1p, w2p, w3p, w4p, w5p) + let (w1, w2, w3, w4, w5) = if TWIDDLE { + let w1 = *ptf; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + let w4 = Arith::

::mmulmod(w2, w2); + let w5 = Arith::

::mmulmod(w2, w3); + (w1, w2, w3, w4, w5) } else { (0, 0, 0, 0, 0) }; for _ in 0..s1 { (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), *px.wrapping_add(3*s1), *px.wrapping_add(4*s1), *px.wrapping_add(5*s1)) = - ntt6_kernel::(w1p, w2p, w3p, w4p, w5p, + ntt6_kernel::(w1, w2, w3, w4, w5, *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), *px.wrapping_add(3*s1), *px.wrapping_add(4*s1), *px.wrapping_add(5*s1)); px = px.wrapping_add(1); @@ -609,14 +601,14 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], } /* compute the total space needed for twiddle factors */ - let tf_all_count = (|| -> usize { + let tf_all_count = { let (mut radix_cumul, mut out) = (1, 0); - for &(_, radix) in plan.s_list.iter() { + for &(_, radix) in &plan.s_list { out += radix_cumul; radix_cumul *= radix; } core::cmp::max(out, 1) - })(); + }; /* build twiddle factors */ let mut tf_list = vec![0u64; tf_all_count]; @@ -628,8 +620,8 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], } /* dif fft */ - ntt_dif_dit::(&plan, &mut x[g..], &tf_list); - ntt_dif_dit::(&plan, &mut y[g..], &tf_list); + ntt_dif_dit::(plan, &mut x[g..], &tf_list); + ntt_dif_dit::(plan, &mut y[g..], &tf_list); /* naive or Karatsuba multiplication */ let mut i = g; @@ -668,7 +660,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], tf_last_start += compute_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); } tf_list[tf_last_start] = Arith::

::R; - ntt_dif_dit::(&plan, x, &tf_list); + ntt_dif_dit::(plan, x, &tf_list); } //////////////////////////////////////////////////////////////////////////////// @@ -694,14 +686,6 @@ const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { - assert!(bits < 63); - - let b_len = ((64 * b.len() as u64 + bits - 1) / bits) as usize; - let c_len = ((64 * c.len() as u64 + bits - 1) / bits) as usize; - let min_len = b_len + c_len; - let plan_x = NttPlan::build::(min_len); - let plan_y = NttPlan::build::(min_len); - fn pack_into(src: &[u64], dst1: &mut [u64], dst2: &mut [u64], bits: u64) { let mut p = 0u64; let mut pdst1 = dst1.as_mut_ptr(); @@ -728,6 +712,13 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { } } + assert!(bits < 63); + let b_len = ((64 * b.len() as u64 + bits - 1) / bits) as usize; + let c_len = ((64 * c.len() as u64 + bits - 1) / bits) as usize; + let min_len = b_len + c_len; + let plan_x = NttPlan::build::(min_len); + let plan_y = NttPlan::build::(min_len); + let mut x = vec![0u64; plan_x.g + plan_x.n]; let mut y = vec![0u64; plan_y.g + plan_y.n]; let mut r = vec![0u64; plan_x.g + plan_x.n]; @@ -849,7 +840,21 @@ fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { } } -fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { +fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { + const fn compute_bits(l: u64) -> u64 { + let total_bits = l * 64; + let (mut lo, mut hi) = (42, 62); + while lo < hi { + let mid = (lo + hi + 1) / 2; + let single_digit_max_val = (1u64 << mid) - 1; + let l_corrected = (total_bits + mid - 1) / mid; + let (lhs, overflow) = (single_digit_max_val as u128 * single_digit_max_val as u128).overflowing_mul(l_corrected as u128); + if !overflow && lhs < P2 as u128 * P3 as u128 { lo = mid; } + else { hi = mid - 1; } + } + lo + } + let (b, c) = if b.len() < c.len() { (b, c) } else { (c, b) }; let naive_cost = NttPlan::build::(b.len() + c.len()).cost * 3; let split_cost = NttPlan::build::(b.len() + b.len()).cost * 3 * (c.len() / b.len()) @@ -899,19 +904,6 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { // number of bits in three-prime NTT is 64/3 = 21.3333..., which means // two-prime NTT can only do better when at least 43 bits per u64 can // be packed into each u64. - const fn compute_bits(l: u64) -> u64 { - let total_bits = l * 64; - let (mut lo, mut hi) = (42, 62); - while lo < hi { - let mid = (lo + hi + 1) / 2; - let single_digit_max_val = (1u64 << mid) - 1; - let l_corrected = (total_bits + mid - 1) / mid; - let (lhs, overflow) = (single_digit_max_val as u128 * single_digit_max_val as u128).overflowing_mul(l_corrected as u128); - if !overflow && lhs < P2 as u128 * P3 as u128 { lo = mid; } - else { hi = mid - 1; } - } - lo - } let max_cnt = max(b.len(), c.len()) as u64; let bits = compute_bits(max_cnt); if bits >= 43 { From 52eb321e5ab66525bd9f8be961ba6b2e36cc908a Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 11:47:18 +0900 Subject: [PATCH 42/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 72 +++++++++++++++++----------------------------- 1 file changed, 27 insertions(+), 45 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 49443f57..34b2475d 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -11,16 +11,15 @@ use crate::biguint::Vec; mod arith { // Extended Euclid algorithm: // (g, x, y) is a solution to ax + by = g, where g = gcd(a, b) - pub const fn egcd(mut a: i128, mut b: i128) -> (i128, i128, i128) { + const fn egcd(mut a: i128, mut b: i128) -> (i128, i128, i128) { assert!(a > 0 && b > 0); let mut c = if a > b { (a, b) = (b, a); [0, 1, 1, 0] } else { [1, 0, 0, 1] }; // treat as a row-major 2x2 matrix - let (g, x, y) = loop { - if a == 0 { break (b, 0, 1); } + loop { + if a == 0 { break (b, c[1], c[3]) } let (q, r) = (b/a, b%a); (a, b) = (r, a); c = [c[1] - q*c[0], c[0], c[3] - q*c[2], c[2]]; - }; - (g, c[0]*x + c[1]*y, c[2]*x + c[3]*y) + } } // Modular inverse: a^-1 mod modulus // (m == 0 means m == 2^64) @@ -36,32 +35,30 @@ mod arith { struct Arith {} impl Arith

{ - pub const FACTOR_TWO: usize = (P-1).trailing_zeros() as usize; - pub const FACTOR_THREE: usize = Self::factor_three(); - pub const FACTOR_FIVE: usize = Self::factor_five(); + pub const FACTOR_TWO: u32 = (P-1).trailing_zeros(); + pub const FACTOR_THREE: u32 = Self::factor_three(); + pub const FACTOR_FIVE: u32 = Self::factor_five(); pub const MAX_NTT_LEN: u64 = Self::max_ntt_len(); pub const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P pub const R2: u64 = ((Self::R as u128 * Self::R as u128) % P as u128) as u64; // R^2 mod P pub const R3: u64 = ((Self::R2 as u128 * Self::R as u128) % P as u128) as u64; // R^3 mod P pub const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 pub const ROOT: u64 = Self::ntt_root(); // MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN - pub const ROOTR: u64 = Self::mulmod_naive(Self::ROOT, Self::R); // ROOT * R mod P - const fn factor_three() -> usize { + pub const ROOTR: u64 = ((Self::ROOT as u128 * Self::R as u128) % P as u128) as u64; // ROOT * R mod P + const fn factor_three() -> u32 { let mut tmp = P-1; let mut ans = 0; while tmp % 3 == 0 { tmp /= 3; ans += 1; } ans } - const fn factor_five() -> usize { + const fn factor_five() -> u32 { let mut tmp = P-1; let mut ans = 0; while tmp % 5 == 0 { tmp /= 5; ans += 1; } ans } const fn max_ntt_len() -> u64 { - let ans = Self::powmod_naive(2, Self::FACTOR_TWO as u64) * - Self::powmod_naive(3, Self::FACTOR_THREE as u64) * - Self::powmod_naive(5, Self::FACTOR_FIVE as u64); + let ans = 2u64.pow(Self::FACTOR_TWO) * 3u64.pow(Self::FACTOR_THREE) * 5u64.pow(Self::FACTOR_FIVE); assert!(ans % 4050 == 0); ans } @@ -76,10 +73,7 @@ impl Arith

{ while k <= Self::FACTOR_THREE { let mut l = 0; while l <= Self::FACTOR_FIVE { - let p2 = Self::powmod_naive(2, j as u64); - let p3 = Self::powmod_naive(3, k as u64); - let p5 = Self::powmod_naive(5, l as u64); - let exponent = p2 * p3 * p5; + let exponent = 2u64.pow(j) * 3u64.pow(k) * 5u64.pow(l); if exponent < Self::MAX_NTT_LEN && Self::powmod_naive(root, exponent) == 1 { p += 1; continue 'outer; @@ -93,10 +87,6 @@ impl Arith

{ break root } } - // Computes a * b mod P - const fn mulmod_naive(a: u64, b: u64) -> u64 { - ((a as u128 * b as u128) % P as u128) as u64 - } // Computes base^exponent mod P const fn powmod_naive(base: u64, exponent: u64) -> u64 { let mut cur = 1; @@ -189,9 +179,9 @@ impl Arith

{ } struct NttPlan { - pub n: usize, // n == g*m - pub g: usize, // g <= NttPlan::GMAX - pub m: usize, // m divides Arith::

::MAX_NTT_LEN + pub n: usize, // n == g*m + pub g: usize, // g <= NttPlan::GMAX + pub m: usize, // m divides Arith::

::MAX_NTT_LEN pub cost: usize, pub last_radix: usize, pub s_list: Vec<(usize, usize)>, @@ -200,15 +190,13 @@ impl NttPlan { pub fn build(min_len: usize) -> Self { assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); let (mut len_max, mut len_max_cost, mut g) = (usize::MAX, usize::MAX, 1); - let mut len7 = 1; - for _ in 0..2 { - let mut len5 = len7; - for _ in 0..=Arith::

::FACTOR_FIVE { - let mut len3 = len5; - for j in 0..=Arith::

::FACTOR_THREE { - let mut len = len3; - let mut i = 0; - while len < min_len && i < Arith::

::FACTOR_TWO { len *= 2; i += 1; } + for m7 in 0..=1 { + for m5 in 0..=Arith::

::FACTOR_FIVE { + for m3 in 0..=Arith::

::FACTOR_THREE { + let len = 7u64.pow(m7) * 5u64.pow(m5) * 3u64.pow(m3); + if len >= 2 * min_len as u64 { continue; } + let (mut len, mut m2) = (len as usize, 0); + while len < min_len && m2 < Arith::

::FACTOR_TWO { len *= 2; m2 += 1; } if len >= min_len && len < len_max_cost { let (mut tmp, mut cost) = (len, 0); let mut g_new = 1; @@ -216,13 +204,13 @@ impl NttPlan { (g_new, tmp, cost) = (7, tmp/7, cost + len*115/100); } else if len % 5 == 0 { (g_new, tmp, cost) = (5, tmp/5, cost + len*89/100); - } else if i >= j + 3 { + } else if m2 >= m3 + 3 { (g_new, tmp, cost) = (8, tmp/8, cost + len*130/100); - } else if i >= j + 2 { + } else if m2 >= m3 + 2 { (g_new, tmp, cost) = (4, tmp/4, cost + len*87/100); - } else if i == 0 && j >= 1 { + } else if m2 == 0 && m3 >= 1 { (g_new, tmp, cost) = (3, tmp/3, cost + len*86/100); - } else if j == 0 && i >= 1 { + } else if m3 == 0 && m2 >= 1 { (g_new, tmp, cost) = (2, tmp/2, cost + len*86/100); } else if len % 6 == 0 { (g_new, tmp, cost) = (6, tmp/6, cost + len*91/100); @@ -236,13 +224,8 @@ impl NttPlan { if b6 && b2 { cost -= len*6/100; } if cost < len_max_cost { (len_max, len_max_cost, g) = (len, cost, g_new); } } - len3 *= 3; - if len3 >= 2*min_len { break; } } - len5 *= 5; - if len5 >= 2*min_len { break; } } - len7 *= 7; } let (mut cnt6, mut cnt5, mut cnt4, mut cnt3, mut cnt2) = (0, 0, 0, 0, 0); let mut tmp = len_max/g; @@ -593,9 +576,8 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], let last_radix = plan.last_radix; /* multiply by a constant in advance */ - let len_inv = Arith::

::mmulmod(Arith::

::R3, Arith::

::submod(0, (P-1)/m as u64)); + let len_inv = Arith::

::mmulmod(Arith::

::R3, (P-1)/m as u64); mult = Arith::

::mmulmod(Arith::

::mmulmod(Arith::

::R2, mult), len_inv); - mult = Arith::

::submod(0, mult); for v in if xlen < ylen { &mut x[g..g+xlen] } else { &mut y[g..g+ylen] } { *v = Arith::

::mmulmod(*v, mult); } From 0e4119264538754d8b1c61b5af6a8a07e18d45af Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 12:51:32 +0900 Subject: [PATCH 43/65] Fix stale comments --- src/biguint/ntt.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 34b2475d..1449b8e4 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -65,7 +65,6 @@ impl Arith

{ const fn ntt_root() -> u64 { let mut p = 2; 'outer: loop { - /* ensure p is prime */ let root = Self::powmod_naive(p, P/Self::MAX_NTT_LEN); let mut j = 0; while j <= Self::FACTOR_TWO { @@ -565,10 +564,9 @@ fn compute_twiddle_factors(s_list: &[(usize, usiz } // Performs (cyclic) integer convolution modulo P using NTT. -// Modifies the three buffers in-place. +// Modifies the input buffers in-place. // The output is saved in the slice `x`. -// The three slices must have the same length. For maximum performance, -// the length should contain as many factors of 6 as possible. +// The input slices must have the same length. fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], ylen: usize, mut mult: u64) { assert!(!x.is_empty() && x.len() == y.len()); @@ -605,7 +603,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], ntt_dif_dit::(plan, &mut x[g..], &tf_list); ntt_dif_dit::(plan, &mut y[g..], &tf_list); - /* naive or Karatsuba multiplication */ + /* naive multiplication */ let mut i = g; let (mut ii, mut ii_mod_last_radix) = (0, 0); let tf = &tf_list[tf_last_start..]; @@ -624,8 +622,6 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], } else { tf_current = Arith::

::mmulmod(tf_current, tf_mult); } - - /* we multiply the inverse of the length here to save time */ conv_base::

(g, x.as_mut_ptr().wrapping_add(i), y.as_mut_ptr().wrapping_add(i), tf_current); i += g; From bbb563a79c2e4f7fc1084cf1eeb0e4520abcc51f Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 14:18:11 +0900 Subject: [PATCH 44/65] Update ntt.rs --- src/biguint/ntt.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 1449b8e4..250010fe 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -500,8 +500,7 @@ fn ntt6_single_block( } fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_list: &[u64]) { - let mut i_list = vec![]; - for i in 0..plan.s_list.len() { i_list.push(i); } + let mut i_list: Vec<_> = (0..plan.s_list.len()).collect(); if INV { i_list.reverse(); } let mut ptf = tf_list.as_ptr(); for i in i_list { From ce8754523905381a36cb6945e79a515228089a4a Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 14:35:02 +0900 Subject: [PATCH 45/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 250010fe..bc04bdd5 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -40,11 +40,11 @@ impl Arith

{ pub const FACTOR_FIVE: u32 = Self::factor_five(); pub const MAX_NTT_LEN: u64 = Self::max_ntt_len(); pub const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P - pub const R2: u64 = ((Self::R as u128 * Self::R as u128) % P as u128) as u64; // R^2 mod P - pub const R3: u64 = ((Self::R2 as u128 * Self::R as u128) % P as u128) as u64; // R^3 mod P + pub const R2: u64 = (Self::R as u128 * Self::R as u128 % P as u128) as u64; // R^2 mod P + pub const R3: u64 = (Self::R2 as u128 * Self::R as u128 % P as u128) as u64; // R^3 mod P pub const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 pub const ROOT: u64 = Self::ntt_root(); // MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN - pub const ROOTR: u64 = ((Self::ROOT as u128 * Self::R as u128) % P as u128) as u64; // ROOT * R mod P + pub const ROOTR: u64 = (Self::ROOT as u128 * Self::R as u128 % P as u128) as u64; // ROOT * R mod P const fn factor_three() -> u32 { let mut tmp = P-1; let mut ans = 0; @@ -545,9 +545,6 @@ fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_ } fn compute_twiddle_factors(s_list: &[(usize, usize)], out: &mut [u64]) -> usize { - let mut len = 1; - for &(_, radix) in s_list { len *= radix; } - len /= s_list.last().unwrap().1; let r = s_list.last().unwrap_or(&(1, 1)).1; let mut p = 1; out[0] = Arith::

::R; @@ -559,7 +556,7 @@ fn compute_twiddle_factors(s_list: &[(usize, usiz } p *= radix; } - len + p } // Performs (cyclic) integer convolution modulo P using NTT. @@ -651,13 +648,7 @@ const P3: u64 = 17_995_154_822_184_960_001; // Max NTT length = 2^17 * 3^22 * 5^ const P2P3: u128 = P2 as u128 * P3 as u128; const P1INV_R_MOD_P2: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod(P1, P2)); -const P1P2INV_R_MOD_P3: u64 = Arith::::mmulmod( - Arith::::R3, - Arith::::mmulmod( - arith::invmod(P1, P3), - arith::invmod(P2, P3) - ) -); +const P1P2INV_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod((P1 as u128 * P2 as u128 % P3 as u128) as u64, P3)); const P1_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, P1); const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; @@ -750,11 +741,10 @@ fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { let plan_x = NttPlan::build::(min_len); let plan_y = NttPlan::build::(min_len); let plan_z = NttPlan::build::(min_len); - let len_max = max(plan_x.g + plan_x.n, max(plan_y.g + plan_y.n, plan_z.g + plan_z.n)); let mut x = vec![0u64; plan_x.g + plan_x.n]; let mut y = vec![0u64; plan_y.g + plan_y.n]; let mut z = vec![0u64; plan_z.g + plan_z.n]; - let mut r = vec![0u64; len_max]; + let mut r = vec![0u64; max(x.len(), max(y.len(), z.len()))]; /* convolution with modulo P1 */ for i in 0..b.len() { x[plan_x.g + i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } From 2f7f1ddcaa390beee0ddeac1958f75a4ac296924 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 15:59:05 +0900 Subject: [PATCH 46/65] Update ntt.rs --- src/biguint/ntt.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index bc04bdd5..8cde1224 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -807,7 +807,7 @@ fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { } } -fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { +fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { const fn compute_bits(l: u64) -> u64 { let total_bits = l * 64; let (mut lo, mut hi) = (42, 62); @@ -823,9 +823,9 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { } let (b, c) = if b.len() < c.len() { (b, c) } else { (c, b) }; - let naive_cost = NttPlan::build::(b.len() + c.len()).cost * 3; - let split_cost = NttPlan::build::(b.len() + b.len()).cost * 3 * (c.len() / b.len()) - + if c.len() % b.len() > 0 { NttPlan::build::(b.len() + (c.len() % b.len())).cost * 3 } else { 0 }; + let naive_cost = NttPlan::build::(b.len() + c.len()).cost; + let split_cost = NttPlan::build::(b.len() + b.len()).cost * (c.len() / b.len()) + + if c.len() % b.len() > 0 { NttPlan::build::(b.len() + (c.len() % b.len())).cost } else { 0 }; if b.len() >= 128 && split_cost < naive_cost { /* special handling for unbalanced multiplication: we reduce it to about `c.len()/b.len()` balanced multiplications */ From 6871c4de47455b5f1916b704f4442fd619d96589 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 16:43:14 +0900 Subject: [PATCH 47/65] Refactor & fix potential carry bug --- src/biguint/ntt.rs | 47 +++++++++++++++++----------------------------- 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 8cde1224..1c724501 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -653,6 +653,17 @@ const P1_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, P1); const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; +// Propagates carry from the beginning to the end of acc, +// and returns the resulting carry if it is nonzero. +fn propagate_carry(acc: &mut [u64], mut carry: u64) -> u64 { + for x in acc { + let (v, overflow) = x.overflowing_add(carry); + (*x, carry) = (v, u64::from(overflow)); + if !overflow { break; } + } + carry +} + fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { fn pack_into(src: &[u64], dst1: &mut [u64], dst2: &mut [u64], bits: u64) { let mut p = 0u64; @@ -725,15 +736,9 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { s = out >> (bits - p); } } - - /* process remaining carries */ - carry_acc += s; - while j < acc.len() { - let (w, overflow) = acc[j].overflowing_add(carry_acc); - acc[j] = w; - carry_acc = u64::from(overflow); - j += 1; - } + // Process remaining carries. The addition carry_acc + s should not overflow + // since s is underfilled and carry_acc is always 0 or 1. + propagate_carry(&mut acc[j..], carry_acc + s); } fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { @@ -799,12 +804,7 @@ fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { acc[i] = v; carry = out_1 as u128 + ((out_2 as u128) << 64) + u128::from(overflow); } - let mut carry = carry as u64; - for i in min_len..acc.len() { - let (v, overflow) = acc[i].overflowing_add(carry); - acc[i] = v; - carry = u64::from(overflow); - } + propagate_carry(&mut acc[min_len..], carry as u64); } fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { @@ -837,23 +837,10 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { let tmp = acc[k]; acc[k] = 0; mac3_u64(&mut acc[i..=k], b, &c[i..j]); - let mut l = j; - while carry > 0 && l < k { - let (v, overflow) = acc[l].overflowing_add(carry); - acc[l] = v; - carry = u64::from(overflow); - l += 1; - } + (acc[k], carry) = (tmp, acc[k] + propagate_carry(&mut acc[j..k], carry)); i = j; - carry += tmp; - } - i += b.len(); - while carry > 0 && i < acc.len() { - let (v, overflow) = acc[i].overflowing_add(carry); - acc[i] = v; - carry = u64::from(overflow); - i += 1; } + propagate_carry(&mut acc[i + b.len()..], carry); return; } From cac183027f4e41f248a09a02c6e4f313cb6f95aa Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 16:54:09 +0900 Subject: [PATCH 48/65] Update ntt.rs Breaking at this point is the right thing to do since future encounters will all `continue`. --- src/biguint/ntt.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 1c724501..d14aa11b 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -193,7 +193,7 @@ impl NttPlan { for m5 in 0..=Arith::

::FACTOR_FIVE { for m3 in 0..=Arith::

::FACTOR_THREE { let len = 7u64.pow(m7) * 5u64.pow(m5) * 3u64.pow(m3); - if len >= 2 * min_len as u64 { continue; } + if len >= 2 * min_len as u64 { break; } let (mut len, mut m2) = (len as usize, 0); while len < min_len && m2 < Arith::

::FACTOR_TWO { len *= 2; m2 += 1; } if len >= min_len && len < len_max_cost { From 1d480ff42e613c86736c1b1a5625cdf7ebb4792b Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 17:37:36 +0900 Subject: [PATCH 49/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index d14aa11b..26501a42 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -577,20 +577,16 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], } /* compute the total space needed for twiddle factors */ - let tf_all_count = { - let (mut radix_cumul, mut out) = (1, 0); - for &(_, radix) in &plan.s_list { - out += radix_cumul; - radix_cumul *= radix; - } - core::cmp::max(out, 1) - }; + let (mut radix_cumul, mut tf_all_count) = (1, 2); // 2 extra slots + for &(_, radix) in &plan.s_list { + tf_all_count += radix_cumul; + radix_cumul *= radix; + } /* build twiddle factors */ let mut tf_list = vec![0u64; tf_all_count]; - tf_list[0] = Arith::

::R; - let mut tf_last_start = core::cmp::min(tf_all_count - 1, 1); - for i in 1..plan.s_list.len() { + let mut tf_last_start = 0; + for i in 0..plan.s_list.len() { let x = compute_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); if i + 1 < plan.s_list.len() { tf_last_start += x; } } @@ -601,9 +597,8 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], /* naive multiplication */ let mut i = g; - let (mut ii, mut ii_mod_last_radix) = (0, 0); - let tf = &tf_list[tf_last_start..]; - let mut tf_current = tf[0]; + let (mut ii, mut ii_mod_last_radix) = (tf_last_start, 0); + let mut tf_current = Arith::

::R; let tf_mult = match plan.last_radix { 2 => NttKernelImpl::::U2, 3 => NttKernelImpl::::U3, @@ -613,9 +608,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], _ => Arith::

::R }; while i < g + plan.n { - if ii_mod_last_radix == 0 { - tf_current = tf[ii]; - } else { + if ii_mod_last_radix > 0 { tf_current = Arith::

::mmulmod(tf_current, tf_mult); } conv_base::

(g, x.as_mut_ptr().wrapping_add(i), y.as_mut_ptr().wrapping_add(i), @@ -625,15 +618,15 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], if ii_mod_last_radix == last_radix { ii += 1; ii_mod_last_radix = 0; + tf_current = tf_list[ii]; } } /* dit fft */ let mut tf_last_start = 0; - for i in (1..plan.s_list.len()).rev() { + for i in (0..plan.s_list.len()).rev() { tf_last_start += compute_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); } - tf_list[tf_last_start] = Arith::

::R; ntt_dif_dit::(plan, x, &tf_list); } From 69c4ec6cc807f7e1ea6f42235cc2ea52d8c9447e Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 17:44:32 +0900 Subject: [PATCH 50/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 26501a42..807e5e93 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -277,7 +277,6 @@ fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64) { struct NttKernelImpl; impl NttKernelImpl { pub const ROOTR: u64 = Arith::

::mpowmod(Arith::

::ROOTR, if INV { Arith::

::MAX_NTT_LEN - 1 } else { 1 }); - pub const U2: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/2); // U2 == P - Arith::

::R pub const U3: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/3); pub const U4: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/4); pub const U5: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/5); @@ -507,7 +506,7 @@ fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_ let (s, radix) = plan.s_list[i]; let s1 = s/radix; let mut px = x.as_mut_ptr(); - let px_end = x.as_mut_ptr().wrapping_add(plan.n); + let px_end = px.wrapping_add(plan.n); match radix { 2 => { (px, ptf) = ntt2_single_block::(s1, px, ptf); @@ -566,8 +565,7 @@ fn compute_twiddle_factors(s_list: &[(usize, usiz fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], ylen: usize, mut mult: u64) { assert!(!x.is_empty() && x.len() == y.len()); - let (_n, g, m) = (plan.n, plan.g, plan.m); - let last_radix = plan.last_radix; + let (_n, g, m, last_radix) = (plan.n, plan.g, plan.m, plan.last_radix); /* multiply by a constant in advance */ let len_inv = Arith::

::mmulmod(Arith::

::R3, (P-1)/m as u64); @@ -599,14 +597,8 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], let mut i = g; let (mut ii, mut ii_mod_last_radix) = (tf_last_start, 0); let mut tf_current = Arith::

::R; - let tf_mult = match plan.last_radix { - 2 => NttKernelImpl::::U2, - 3 => NttKernelImpl::::U3, - 4 => NttKernelImpl::::U4, - 5 => NttKernelImpl::::U5, - 6 => NttKernelImpl::::U6, - _ => Arith::

::R - }; + let tf_mult = Arith::

::mpowmod(NttKernelImpl::::ROOTR, + Arith::

::MAX_NTT_LEN/last_radix as u64); while i < g + plan.n { if ii_mod_last_radix > 0 { tf_current = Arith::

::mmulmod(tf_current, tf_mult); From bad433f0e156968cc8e987f3135e33f9dba90af5 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 17:58:58 +0900 Subject: [PATCH 51/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 282 +++++++++++++++++++++------------------------ 1 file changed, 130 insertions(+), 152 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 807e5e93..e4130ba0 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -41,19 +41,16 @@ impl Arith

{ pub const MAX_NTT_LEN: u64 = Self::max_ntt_len(); pub const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P pub const R2: u64 = (Self::R as u128 * Self::R as u128 % P as u128) as u64; // R^2 mod P - pub const R3: u64 = (Self::R2 as u128 * Self::R as u128 % P as u128) as u64; // R^3 mod P + pub const R4: u64 = Self::mpowmod(Self::R2, 3); // R^4 mod P pub const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 - pub const ROOT: u64 = Self::ntt_root(); // MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN - pub const ROOTR: u64 = (Self::ROOT as u128 * Self::R as u128 % P as u128) as u64; // ROOT * R mod P + pub const ROOTR: u64 = Self::ntt_root_r(); // ROOT * R mod P (ROOT: MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN) const fn factor_three() -> u32 { - let mut tmp = P-1; - let mut ans = 0; + let (mut tmp, mut ans) = (P-1, 0); while tmp % 3 == 0 { tmp /= 3; ans += 1; } ans } const fn factor_five() -> u32 { - let mut tmp = P-1; - let mut ans = 0; + let (mut tmp, mut ans) = (P-1, 0); while tmp % 5 == 0 { tmp /= 5; ans += 1; } ans } @@ -62,7 +59,7 @@ impl Arith

{ assert!(ans % 4050 == 0); ans } - const fn ntt_root() -> u64 { + const fn ntt_root_r() -> u64 { let mut p = 2; 'outer: loop { let root = Self::powmod_naive(p, P/Self::MAX_NTT_LEN); @@ -83,7 +80,7 @@ impl Arith

{ } j += 1; } - break root + break Self::mmulmod(Self::R2, root) } } // Computes base^exponent mod P @@ -227,7 +224,7 @@ impl NttPlan { } } let (mut cnt6, mut cnt5, mut cnt4, mut cnt3, mut cnt2) = (0, 0, 0, 0, 0); - let mut tmp = len_max/g; + let mut tmp = len_max / g; while tmp % 6 == 0 { tmp /= 6; cnt6 += 1; } while tmp % 5 == 0 { tmp /= 5; cnt5 += 1; } while tmp % 4 == 0 { tmp /= 4; cnt4 += 1; } @@ -310,18 +307,14 @@ const fn ntt2_kernel( let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::submod(a, b)); (out0, out1) } -fn ntt2_single_block( +unsafe fn ntt2_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { - unsafe { - let w1 = if TWIDDLE { *ptf } else { 0 }; - for _ in 0..s1 { - (*px, *px.wrapping_add(s1)) = - ntt2_kernel::(w1, - *px, *px.wrapping_add(s1)); - px = px.wrapping_add(1); - } + let w1 = if TWIDDLE { *ptf } else { 0 }; + for _ in 0..s1 { + (*px, *px.add(s1)) = ntt2_kernel::(w1, *px, *px.add(s1)); + px = px.add(1); } - (px.wrapping_add(s1), ptf.wrapping_add(1)) + (px.add(s1), ptf.add(1)) } const fn ntt3_kernel( w1: u64, w2: u64, @@ -336,24 +329,21 @@ const fn ntt3_kernel( let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(Arith::

::submod(a, b), kbmc)); (out0, out1, out2) } -fn ntt3_single_block( +unsafe fn ntt3_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { - unsafe { - let (w1, w2) = if TWIDDLE { - let w1 = *ptf; - let w2 = Arith::

::mmulmod(w1, w1); - (w1, w2) - } else { - (0, 0) - }; - for _ in 0..s1 { - (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1)) = - ntt3_kernel::(w1, w2, - *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1)); - px = px.wrapping_add(1); - } - } - (px.wrapping_add(2*s1), ptf.wrapping_add(1)) + let (w1, w2) = if TWIDDLE { + let w1 = *ptf; + let w2 = Arith::

::mmulmod(w1, w1); + (w1, w2) + } else { + (0, 0) + }; + for _ in 0..s1 { + (*px, *px.add(s1), *px.add(2*s1)) = + ntt3_kernel::(w1, w2, *px, *px.add(s1), *px.add(2*s1)); + px = px.add(1); + } + (px.add(2*s1), ptf.add(1)) } const fn ntt4_kernel( w1: u64, w2: u64, w3: u64, @@ -374,27 +364,23 @@ const fn ntt4_kernel( let out3 = Arith::

::mmulmod_invtw::(w3, Arith::

::submod(amc, jbmd)); (out0, out1, out2, out3) } -fn ntt4_single_block( +unsafe fn ntt4_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { - unsafe { - let (w1, w2, w3) = if TWIDDLE { - let w1 = *ptf; - let w2 = Arith::

::mmulmod(w1, w1); - let w3 = Arith::

::mmulmod(w1, w2); - (w1, w2, w3) - } else { - (0, 0, 0) - }; - for _ in 0..s1 { - (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), - *px.wrapping_add(3*s1)) = - ntt4_kernel::(w1, w2, w3, - *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), - *px.wrapping_add(3*s1)); - px = px.wrapping_add(1); - } - } - (px.wrapping_add(3*s1), ptf.wrapping_add(1)) + let (w1, w2, w3) = if TWIDDLE { + let w1 = *ptf; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + (w1, w2, w3) + } else { + (0, 0, 0) + }; + for _ in 0..s1 { + (*px, *px.add(s1), *px.add(2*s1), *px.add(3*s1)) = + ntt4_kernel::(w1, w2, w3, + *px, *px.add(s1), *px.add(2*s1), *px.add(3*s1)); + px = px.add(1); + } + (px.add(3*s1), ptf.add(1)) } const fn ntt5_kernel( w1: u64, w2: u64, w3: u64, w4: u64, @@ -427,28 +413,26 @@ const fn ntt5_kernel( let out4 = Arith::

::mmulmod_invtw::(w4, Arith::

::addmodopt_invtw::(s1, m5)); (out0, out1, out2, out3, out4) } -fn ntt5_single_block( +unsafe fn ntt5_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { - unsafe { - let (w1, w2, w3, w4) = if TWIDDLE { - let w1 = *ptf; - let w2 = Arith::

::mmulmod(w1, w1); - let w3 = Arith::

::mmulmod(w1, w2); - let w4 = Arith::

::mmulmod(w2, w2); - (w1, w2, w3, w4) - } else { - (0, 0, 0, 0) - }; - for _ in 0..s1 { - (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), - *px.wrapping_add(3*s1), *px.wrapping_add(4*s1)) = - ntt5_kernel::(w1, w2, w3, w4, - *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), - *px.wrapping_add(3*s1), *px.wrapping_add(4*s1)); - px = px.wrapping_add(1); - } - } - (px.wrapping_add(4*s1), ptf.wrapping_add(1)) + let (w1, w2, w3, w4) = if TWIDDLE { + let w1 = *ptf; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + let w4 = Arith::

::mmulmod(w2, w2); + (w1, w2, w3, w4) + } else { + (0, 0, 0, 0) + }; + for _ in 0..s1 { + (*px, *px.add(s1), *px.add(2*s1), + *px.add(3*s1), *px.add(4*s1)) = + ntt5_kernel::(w1, w2, w3, w4, + *px, *px.add(s1), *px.add(2*s1), + *px.add(3*s1), *px.add(4*s1)); + px = px.add(1); + } + (px.add(4*s1), ptf.add(1)) } const fn ntt6_kernel( w1: u64, w2: u64, w3: u64, w4: u64, w5: u64, @@ -473,29 +457,27 @@ const fn ntt6_kernel( let out5 = Arith::

::mmulmod_invtw::(w5, Arith::

::submod(d, Arith::

::submod(lepf, e))); (out0, out1, out2, out3, out4, out5) } -fn ntt6_single_block( +unsafe fn ntt6_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { - unsafe { - let (w1, w2, w3, w4, w5) = if TWIDDLE { - let w1 = *ptf; - let w2 = Arith::

::mmulmod(w1, w1); - let w3 = Arith::

::mmulmod(w1, w2); - let w4 = Arith::

::mmulmod(w2, w2); - let w5 = Arith::

::mmulmod(w2, w3); - (w1, w2, w3, w4, w5) - } else { - (0, 0, 0, 0, 0) - }; - for _ in 0..s1 { - (*px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), - *px.wrapping_add(3*s1), *px.wrapping_add(4*s1), *px.wrapping_add(5*s1)) = - ntt6_kernel::(w1, w2, w3, w4, w5, - *px, *px.wrapping_add(s1), *px.wrapping_add(2*s1), - *px.wrapping_add(3*s1), *px.wrapping_add(4*s1), *px.wrapping_add(5*s1)); - px = px.wrapping_add(1); - } - } - (px.wrapping_add(5*s1), ptf.wrapping_add(1)) + let (w1, w2, w3, w4, w5) = if TWIDDLE { + let w1 = *ptf; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + let w4 = Arith::

::mmulmod(w2, w2); + let w5 = Arith::

::mmulmod(w2, w3); + (w1, w2, w3, w4, w5) + } else { + (0, 0, 0, 0, 0) + }; + for _ in 0..s1 { + (*px, *px.add(s1), *px.add(2*s1), + *px.add(3*s1), *px.add(4*s1), *px.add(5*s1)) = + ntt6_kernel::(w1, w2, w3, w4, w5, + *px, *px.add(s1), *px.add(2*s1), + *px.add(3*s1), *px.add(4*s1), *px.add(5*s1)); + px = px.add(1); + } + (px.add(5*s1), ptf.add(1)) } fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_list: &[u64]) { @@ -505,45 +487,47 @@ fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_ for i in i_list { let (s, radix) = plan.s_list[i]; let s1 = s/radix; - let mut px = x.as_mut_ptr(); - let px_end = px.wrapping_add(plan.n); - match radix { - 2 => { - (px, ptf) = ntt2_single_block::(s1, px, ptf); - while px < px_end { - (px, ptf) = ntt2_single_block::(s1, px, ptf); - } - }, - 3 => { - (px, ptf) = ntt3_single_block::(s1, px, ptf); - while px < px_end { - (px, ptf) = ntt3_single_block::(s1, px, ptf); - } - }, - 4 => { - (px, ptf) = ntt4_single_block::(s1, px, ptf); - while px < px_end { - (px, ptf) = ntt4_single_block::(s1, px, ptf); - } - }, - 5 => { - (px, ptf) = ntt5_single_block::(s1, px, ptf); - while px < px_end { - (px, ptf) = ntt5_single_block::(s1, px, ptf); - } - }, - 6 => { - (px, ptf) = ntt6_single_block::(s1, px, ptf); - while px < px_end { - (px, ptf) = ntt6_single_block::(s1, px, ptf); - } - }, - _ => { unreachable!() } + unsafe { + let mut px = x.as_mut_ptr(); + let px_end = px.add(plan.n); + match radix { + 2 => { + (px, ptf) = ntt2_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt2_single_block::(s1, px, ptf); + } + }, + 3 => { + (px, ptf) = ntt3_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt3_single_block::(s1, px, ptf); + } + }, + 4 => { + (px, ptf) = ntt4_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt4_single_block::(s1, px, ptf); + } + }, + 5 => { + (px, ptf) = ntt5_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt5_single_block::(s1, px, ptf); + } + }, + 6 => { + (px, ptf) = ntt6_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt6_single_block::(s1, px, ptf); + } + }, + _ => { unreachable!() } + } } } } -fn compute_twiddle_factors(s_list: &[(usize, usize)], out: &mut [u64]) -> usize { +fn calc_twiddle_factors(s_list: &[(usize, usize)], out: &mut [u64]) -> usize { let r = s_list.last().unwrap_or(&(1, 1)).1; let mut p = 1; out[0] = Arith::

::R; @@ -564,12 +548,10 @@ fn compute_twiddle_factors(s_list: &[(usize, usiz // The input slices must have the same length. fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], ylen: usize, mut mult: u64) { assert!(!x.is_empty() && x.len() == y.len()); - - let (_n, g, m, last_radix) = (plan.n, plan.g, plan.m, plan.last_radix); + let (_n, g, m, last_radix) = (plan.n, plan.g, plan.m, plan.last_radix as u64); /* multiply by a constant in advance */ - let len_inv = Arith::

::mmulmod(Arith::

::R3, (P-1)/m as u64); - mult = Arith::

::mmulmod(Arith::

::mmulmod(Arith::

::R2, mult), len_inv); + mult = Arith::

::mmulmod(Arith::

::R4, Arith::

::mmulmod(mult, (P-1)/m as u64)); for v in if xlen < ylen { &mut x[g..g+xlen] } else { &mut y[g..g+ylen] } { *v = Arith::

::mmulmod(*v, mult); } @@ -585,7 +567,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], let mut tf_list = vec![0u64; tf_all_count]; let mut tf_last_start = 0; for i in 0..plan.s_list.len() { - let x = compute_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); + let x = calc_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); if i + 1 < plan.s_list.len() { tf_last_start += x; } } @@ -594,30 +576,26 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], ntt_dif_dit::(plan, &mut y[g..], &tf_list); /* naive multiplication */ - let mut i = g; - let (mut ii, mut ii_mod_last_radix) = (tf_last_start, 0); + let (mut i, mut ii, mut ii_mod_last_radix) = (g, tf_last_start, 0); let mut tf_current = Arith::

::R; - let tf_mult = Arith::

::mpowmod(NttKernelImpl::::ROOTR, - Arith::

::MAX_NTT_LEN/last_radix as u64); + let tf_mult = Arith::

::mpowmod(NttKernelImpl::::ROOTR, Arith::

::MAX_NTT_LEN/last_radix); while i < g + plan.n { - if ii_mod_last_radix > 0 { - tf_current = Arith::

::mmulmod(tf_current, tf_mult); - } - conv_base::

(g, x.as_mut_ptr().wrapping_add(i), y.as_mut_ptr().wrapping_add(i), - tf_current); + conv_base::

(g, (&mut x[i..]).as_mut_ptr(), (&mut y[i..]).as_mut_ptr(), tf_current); i += g; ii_mod_last_radix += 1; if ii_mod_last_radix == last_radix { ii += 1; ii_mod_last_radix = 0; tf_current = tf_list[ii]; + } else { + tf_current = Arith::

::mmulmod(tf_current, tf_mult); } } /* dit fft */ let mut tf_last_start = 0; for i in (0..plan.s_list.len()).rev() { - tf_last_start += compute_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); + tf_last_start += calc_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); } ntt_dif_dit::(plan, x, &tf_list); } @@ -717,7 +695,7 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { carry_acc = u64::from(overflow1 || overflow2); /* roll-over */ - (j, p) = (j+1, p-64); + (j, p) = (j + 1, p - 64); s = out >> (bits - p); } } @@ -800,7 +778,7 @@ fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { let mid = (lo + hi + 1) / 2; let single_digit_max_val = (1u64 << mid) - 1; let l_corrected = (total_bits + mid - 1) / mid; - let (lhs, overflow) = (single_digit_max_val as u128 * single_digit_max_val as u128).overflowing_mul(l_corrected as u128); + let (lhs, overflow) = (single_digit_max_val as u128).pow(2).overflowing_mul(l_corrected as u128); if !overflow && lhs < P2 as u128 * P3 as u128 { lo = mid; } else { hi = mid - 1; } } From f207402ec5cb2f626229c75e2f0b926b51fde18f Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 18:43:05 +0900 Subject: [PATCH 52/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index e4130ba0..12f06ea1 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -486,7 +486,7 @@ fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_ let mut ptf = tf_list.as_ptr(); for i in i_list { let (s, radix) = plan.s_list[i]; - let s1 = s/radix; + let s1 = s / radix; unsafe { let mut px = x.as_mut_ptr(); let px_end = px.add(plan.n); @@ -633,14 +633,14 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { let mut pdst1 = dst1.as_mut_ptr(); let mut pdst2 = dst2.as_mut_ptr(); let mut x = 0u64; - let mask = (1u64 << bits).wrapping_sub(1); + let mask = (1u64 << bits) - 1; for v in src { let mut k = 0; while k < 64 { x |= (v >> k) << p; let q = 64 - k; if p + q >= bits { - unsafe { let out = x & mask; (*pdst1, *pdst2) = (out, out); } + unsafe { let out = x & mask; *pdst1 = out; *pdst2 = out; } x = 0; (pdst1, pdst2, k, p) = (pdst1.wrapping_add(1), pdst2.wrapping_add(1), k + bits - p, 0); } else { @@ -650,7 +650,7 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { } } unsafe { - if p > 0 { let out = x & mask; (*pdst1, *pdst2) = (out, out); } + if p > 0 { let out = x & mask; *pdst1 = out; *pdst2 = out; } } } @@ -760,12 +760,10 @@ fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { let out_0 = out_01 as u64; let out_12 = P1P2_HI as u128 * vcmv as u128 + (out_01 >> 64) + if overflow { 1u128 << 64 } else { 0 }; - let out_1 = out_12 as u64; - let out_2 = (out_12 >> 64) as u64; let (v, overflow) = acc[i].overflowing_add(out_0); acc[i] = v; - carry = out_1 as u128 + ((out_2 as u128) << 64) + u128::from(overflow); + carry = out_12 + u128::from(overflow); } propagate_carry(&mut acc[min_len..], carry as u64); } @@ -864,8 +862,6 @@ pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { /* convert to u64 => process => convert back to BigDigit (u32) */ let mut acc_u64 = bigdigit_to_u64(acc); - let b_u64 = bigdigit_to_u64(b); - let c_u64 = bigdigit_to_u64(c); - mac3_u64(&mut acc_u64, &b_u64, &c_u64); + mac3_u64(&mut acc_u64, &bigdigit_to_u64(b), &bigdigit_to_u64(c)); u64_to_bigdigit(&acc_u64, acc); } \ No newline at end of file From d7d3de17ee72b17727ceb4819bbb4e6a93e2b712 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Tue, 19 Sep 2023 20:26:07 +0900 Subject: [PATCH 53/65] Update ntt.rs --- src/biguint/ntt.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 12f06ea1..060704da 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -176,7 +176,7 @@ impl Arith

{ struct NttPlan { pub n: usize, // n == g*m - pub g: usize, // g <= NttPlan::GMAX + pub g: usize, // g: size of the base case pub m: usize, // m divides Arith::

::MAX_NTT_LEN pub cost: usize, pub last_radix: usize, From eecdf92bececd9a1be13a31f738fc509c4676d9c Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 21 Sep 2023 06:54:28 +0900 Subject: [PATCH 54/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 060704da..c3326bd3 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -254,19 +254,19 @@ impl NttPlan { fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64) { unsafe { let c2 = Arith::

::mreducelo(c); - let out = x.wrapping_sub(n); + let out = x.sub(n); for i in 0..n { let mut v: u128 = 0; for j in i+1..n { - let (w, overflow) = v.overflowing_sub(*x.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); + let (w, overflow) = v.overflowing_sub(*x.add(j) as u128 * *y.add(i+n-j) as u128); v = if overflow { w.wrapping_add((P as u128) << 64) } else { w }; } v = Arith::

::mmulmod_noreduce(v, c, c2); for j in 0..=i { - let (w, overflow) = v.overflowing_sub(*x.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); + let (w, overflow) = v.overflowing_sub(*x.add(j) as u128 * *y.add(i-j) as u128); v = if overflow { w.wrapping_add((P as u128) << 64) } else { w }; } - *out.wrapping_add(i) = Arith::

::mreduce(v); + *out.add(i) = Arith::

::mreduce(v); } } } @@ -642,7 +642,7 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { if p + q >= bits { unsafe { let out = x & mask; *pdst1 = out; *pdst2 = out; } x = 0; - (pdst1, pdst2, k, p) = (pdst1.wrapping_add(1), pdst2.wrapping_add(1), k + bits - p, 0); + unsafe { (pdst1, pdst2, k, p) = (pdst1.add(1), pdst2.add(1), k + bits - p, 0); } } else { p += q; break; From a2426faff35d6caff7365899a25ed2bb80e7bf73 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 21 Sep 2023 13:32:11 +0900 Subject: [PATCH 55/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 97 ++++++++++++++++++++-------------------------- 1 file changed, 42 insertions(+), 55 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index c3326bd3..d96bee93 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -35,40 +35,23 @@ mod arith { struct Arith {} impl Arith

{ - pub const FACTOR_TWO: u32 = (P-1).trailing_zeros(); - pub const FACTOR_THREE: u32 = Self::factor_three(); - pub const FACTOR_FIVE: u32 = Self::factor_five(); - pub const MAX_NTT_LEN: u64 = Self::max_ntt_len(); - pub const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P - pub const R2: u64 = (Self::R as u128 * Self::R as u128 % P as u128) as u64; // R^2 mod P - pub const R4: u64 = Self::mpowmod(Self::R2, 3); // R^4 mod P - pub const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 - pub const ROOTR: u64 = Self::ntt_root_r(); // ROOT * R mod P (ROOT: MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN) - const fn factor_three() -> u32 { - let (mut tmp, mut ans) = (P-1, 0); - while tmp % 3 == 0 { tmp /= 3; ans += 1; } - ans - } - const fn factor_five() -> u32 { - let (mut tmp, mut ans) = (P-1, 0); - while tmp % 5 == 0 { tmp /= 5; ans += 1; } - ans - } - const fn max_ntt_len() -> u64 { - let ans = 2u64.pow(Self::FACTOR_TWO) * 3u64.pow(Self::FACTOR_THREE) * 5u64.pow(Self::FACTOR_FIVE); - assert!(ans % 4050 == 0); - ans - } - const fn ntt_root_r() -> u64 { + const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P + const R2: u64 = (Self::R as u128 * Self::R as u128 % P as u128) as u64; // R^2 mod P + const R4: u64 = Self::mpowmod(Self::R2, 3); // R^4 mod P + const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 + const MAX_NTT_LEN: u64 = 2u64.pow(Self::factors(2)) * 3u64.pow(Self::factors(3)) * 5u64.pow(Self::factors(5)); + const ROOTR: u64 = { + // ROOT * R mod P (ROOT: MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN) + assert!(Self::MAX_NTT_LEN % 4050 == 0); let mut p = 2; 'outer: loop { let root = Self::powmod_naive(p, P/Self::MAX_NTT_LEN); let mut j = 0; - while j <= Self::FACTOR_TWO { + while j <= Self::factors(2) { let mut k = 0; - while k <= Self::FACTOR_THREE { + while k <= Self::factors(3) { let mut l = 0; - while l <= Self::FACTOR_FIVE { + while l <= Self::factors(5) { let exponent = 2u64.pow(j) * 3u64.pow(k) * 5u64.pow(l); if exponent < Self::MAX_NTT_LEN && Self::powmod_naive(root, exponent) == 1 { p += 1; @@ -82,59 +65,63 @@ impl Arith

{ } break Self::mmulmod(Self::R2, root) } + }; + // Counts the number of `divisor` factors in P-1. + const fn factors(divisor: u64) -> u32 { + let (mut tmp, mut ans) = (P-1, 0); + while tmp % divisor == 0 { tmp /= divisor; ans += 1; } + ans } // Computes base^exponent mod P - const fn powmod_naive(base: u64, exponent: u64) -> u64 { + const fn powmod_naive(base: u64, mut exponent: u64) -> u64 { let mut cur = 1; let mut pow = base as u128; - let mut p = exponent; - while p > 0 { - if p % 2 > 0 { + while exponent > 0 { + if exponent % 2 > 0 { cur = (cur * pow) % P as u128; } - p /= 2; + exponent /= 2; pow = (pow * pow) % P as u128; } cur as u64 } // Montgomery reduction: // x * R^-1 mod P - pub const fn mreduce(x: u128) -> u64 { + const fn mreduce(x: u128) -> u64 { let m = (x as u64).wrapping_mul(Self::PINV); - let y = ((m as u128 * P as u128) >> 64) as u64; + let y = (m as u128 * P as u128 >> 64) as u64; let (out, overflow) = ((x >> 64) as u64).overflowing_sub(y); if overflow { out.wrapping_add(P) } else { out } } // Multiplication with Montgomery reduction: // a * b * R^-1 mod P - pub const fn mmulmod(a: u64, b: u64) -> u64 { + const fn mmulmod(a: u64, b: u64) -> u64 { Self::mreduce(a as u128 * b as u128) } // Multiplication with Montgomery reduction: // a * b * R^-1 mod P // This function only applies the multiplication when INV && TWIDDLE, // otherwise it just returns b. - pub const fn mmulmod_invtw(a: u64, b: u64) -> u64 { + const fn mmulmod_invtw(a: u64, b: u64) -> u64 { if INV && TWIDDLE { Self::mmulmod(a, b) } else { b } } // Fused-multiply-sub with Montgomery reduction: // a * b * R^-1 - c mod P - pub const fn mmulsubmod(a: u64, b: u64, c: u64) -> u64 { + const fn mmulsubmod(a: u64, b: u64, c: u64) -> u64 { let x = a as u128 * b as u128; let lo = x as u64; let hi = Self::submod((x >> 64) as u64, c); Self::mreduce(lo as u128 | ((hi as u128) << 64)) } // Computes base^exponent mod P with Montgomery reduction - pub const fn mpowmod(base: u64, exponent: u64) -> u64 { + const fn mpowmod(base: u64, mut exponent: u64) -> u64 { let mut cur = Self::R; let mut pow = base; - let mut p = exponent; - while p > 0 { - if p % 2 > 0 { + while exponent > 0 { + if exponent % 2 > 0 { cur = Self::mmulmod(cur, pow); } - p /= 2; + exponent /= 2; pow = Self::mmulmod(pow, pow); } cur @@ -143,32 +130,32 @@ impl Arith

{ // using d: u64 = mmulmod(P-1, c). // It is caller's responsibility to ensure that d is correct. // Note that d can be computed by calling mreducelo(c). - pub const fn mmulmod_noreduce(v: u128, c: u64, d: u64) -> u128 { + const fn mmulmod_noreduce(v: u128, c: u64, d: u64) -> u128 { let a: u128 = c as u128 * (v >> 64); let b: u128 = d as u128 * (v as u64 as u128); let (w, overflow) = a.overflowing_sub(b); if overflow { w.wrapping_add((P as u128) << 64) } else { w } } // Computes submod(0, mreduce(x as u128)) fast. - pub const fn mreducelo(x: u64) -> u64 { + const fn mreducelo(x: u64) -> u64 { let m = x.wrapping_mul(Self::PINV); - ((m as u128 * P as u128) >> 64) as u64 + (m as u128 * P as u128 >> 64) as u64 } // Computes a + b mod P, output range [0, P) - pub const fn addmod(a: u64, b: u64) -> u64 { + const fn addmod(a: u64, b: u64) -> u64 { Self::submod(a, P.wrapping_sub(b)) } // Computes a + b mod P, output range [0, 2^64) - pub const fn addmod64(a: u64, b: u64) -> u64 { + const fn addmod64(a: u64, b: u64) -> u64 { let (out, overflow) = a.overflowing_add(b); if overflow { out.wrapping_sub(P) } else { out } } // Computes a + b mod P, selects addmod64 or addmod depending on INV && TWIDDLE - pub const fn addmodopt_invtw(a: u64, b: u64) -> u64 { + const fn addmodopt_invtw(a: u64, b: u64) -> u64 { if INV && TWIDDLE { Self::addmod64(a, b) } else { Self::addmod(a, b) } } // Computes a - b mod P, output range [0, P) - pub const fn submod(a: u64, b: u64) -> u64 { + const fn submod(a: u64, b: u64) -> u64 { let (out, overflow) = a.overflowing_sub(b); if overflow { out.wrapping_add(P) } else { out } } @@ -183,16 +170,16 @@ struct NttPlan { pub s_list: Vec<(usize, usize)>, } impl NttPlan { - pub fn build(min_len: usize) -> Self { + fn build(min_len: usize) -> Self { assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); let (mut len_max, mut len_max_cost, mut g) = (usize::MAX, usize::MAX, 1); for m7 in 0..=1 { - for m5 in 0..=Arith::

::FACTOR_FIVE { - for m3 in 0..=Arith::

::FACTOR_THREE { + for m5 in 0..=Arith::

::factors(5) { + for m3 in 0..=Arith::

::factors(3) { let len = 7u64.pow(m7) * 5u64.pow(m5) * 3u64.pow(m3); if len >= 2 * min_len as u64 { break; } let (mut len, mut m2) = (len as usize, 0); - while len < min_len && m2 < Arith::

::FACTOR_TWO { len *= 2; m2 += 1; } + while len < min_len && m2 < Arith::

::factors(2) { len *= 2; m2 += 1; } if len >= min_len && len < len_max_cost { let (mut tmp, mut cost) = (len, 0); let mut g_new = 1; @@ -614,7 +601,7 @@ const P1INV_R_MOD_P2: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod( const P1P2INV_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod((P1 as u128 * P2 as u128 % P3 as u128) as u64, P3)); const P1_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, P1); const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; -const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; +const P1P2_HI: u64 = (P1 as u128 * P2 as u128 >> 64) as u64; // Propagates carry from the beginning to the end of acc, // and returns the resulting carry if it is nonzero. From 76b5adde34368863016048660c8f65c3c15ffa33 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 21 Sep 2023 13:55:41 +0900 Subject: [PATCH 56/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index d96bee93..337f84c5 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -515,12 +515,11 @@ fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_ } fn calc_twiddle_factors(s_list: &[(usize, usize)], out: &mut [u64]) -> usize { - let r = s_list.last().unwrap_or(&(1, 1)).1; let mut p = 1; out[0] = Arith::

::R; for i in (1..s_list.len()).rev() { let radix = s_list[i-1].1; - let w = Arith::

::mpowmod(NttKernelImpl::::ROOTR, Arith::

::MAX_NTT_LEN/(p as u64 * radix as u64 * r as u64)); + let w = Arith::

::mpowmod(NttKernelImpl::::ROOTR, Arith::

::MAX_NTT_LEN/(p * radix * s_list.last().unwrap().1) as u64); for j in p..radix*p { out[j] = Arith::

::mmulmod(w, out[j - p]); } From 1e41e1675d8f5c83ed92283a3b8cd786cf22ecdb Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 21 Sep 2023 18:55:06 +0900 Subject: [PATCH 57/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 337f84c5..5250d134 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -114,15 +114,14 @@ impl Arith

{ Self::mreduce(lo as u128 | ((hi as u128) << 64)) } // Computes base^exponent mod P with Montgomery reduction - const fn mpowmod(base: u64, mut exponent: u64) -> u64 { + const fn mpowmod(mut base: u64, mut exponent: u64) -> u64 { let mut cur = Self::R; - let mut pow = base; while exponent > 0 { if exponent % 2 > 0 { - cur = Self::mmulmod(cur, pow); + cur = Self::mmulmod(cur, base); } exponent /= 2; - pow = Self::mmulmod(pow, pow); + base = Self::mmulmod(base, base); } cur } From a8ee9ffd80751f2513f1c73221de907bfa03519e Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 21 Sep 2023 18:55:14 +0900 Subject: [PATCH 58/65] Improve compile time --- src/biguint/ntt.rs | 40 +++++++--------------------------------- 1 file changed, 7 insertions(+), 33 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 5250d134..ce9d308a 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -43,27 +43,14 @@ impl Arith

{ const ROOTR: u64 = { // ROOT * R mod P (ROOT: MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN) assert!(Self::MAX_NTT_LEN % 4050 == 0); - let mut p = 2; - 'outer: loop { - let root = Self::powmod_naive(p, P/Self::MAX_NTT_LEN); - let mut j = 0; - while j <= Self::factors(2) { - let mut k = 0; - while k <= Self::factors(3) { - let mut l = 0; - while l <= Self::factors(5) { - let exponent = 2u64.pow(j) * 3u64.pow(k) * 5u64.pow(l); - if exponent < Self::MAX_NTT_LEN && Self::powmod_naive(root, exponent) == 1 { - p += 1; - continue 'outer; - } - l += 1; - } - k += 1; - } - j += 1; + let mut p = Self::R; + loop { + if Self::mpowmod(p, P/2) != Self::R && + Self::mpowmod(p, P/3) != Self::R && + Self::mpowmod(p, P/5) != Self::R { + break Self::mpowmod(p, P/Self::MAX_NTT_LEN); } - break Self::mmulmod(Self::R2, root) + p = Self::addmod(p, Self::R); } }; // Counts the number of `divisor` factors in P-1. @@ -72,19 +59,6 @@ impl Arith

{ while tmp % divisor == 0 { tmp /= divisor; ans += 1; } ans } - // Computes base^exponent mod P - const fn powmod_naive(base: u64, mut exponent: u64) -> u64 { - let mut cur = 1; - let mut pow = base as u128; - while exponent > 0 { - if exponent % 2 > 0 { - cur = (cur * pow) % P as u128; - } - exponent /= 2; - pow = (pow * pow) % P as u128; - } - cur as u64 - } // Montgomery reduction: // x * R^-1 mod P const fn mreduce(x: u128) -> u64 { From 5f564de45983739fbb74d53e4de8ebef47589f96 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 21 Sep 2023 19:10:44 +0900 Subject: [PATCH 59/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index ce9d308a..64336a69 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -37,11 +37,10 @@ struct Arith {} impl Arith

{ const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P const R2: u64 = (Self::R as u128 * Self::R as u128 % P as u128) as u64; // R^2 mod P - const R4: u64 = Self::mpowmod(Self::R2, 3); // R^4 mod P const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 const MAX_NTT_LEN: u64 = 2u64.pow(Self::factors(2)) * 3u64.pow(Self::factors(3)) * 5u64.pow(Self::factors(5)); const ROOTR: u64 = { - // ROOT * R mod P (ROOT: MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN) + // ROOT * R mod P (ROOT: MAX_NTT_LEN divides MultiplicativeOrder[ROOT, P]) assert!(Self::MAX_NTT_LEN % 4050 == 0); let mut p = Self::R; loop { @@ -233,23 +232,23 @@ fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64) { struct NttKernelImpl; impl NttKernelImpl { - pub const ROOTR: u64 = Arith::

::mpowmod(Arith::

::ROOTR, if INV { Arith::

::MAX_NTT_LEN - 1 } else { 1 }); - pub const U3: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/3); - pub const U4: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/4); - pub const U5: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/5); - pub const U6: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/6); - pub const C51: u64 = Self::c5().0; - pub const C52: u64 = Self::c5().1; - pub const C53: u64 = Self::c5().2; - pub const C54: u64 = Self::c5().3; - pub const C55: u64 = Self::c5().4; + const ROOTR: u64 = Arith::

::mpowmod(Arith::

::ROOTR, if INV { Arith::

::MAX_NTT_LEN - 1 } else { 1 }); + const U3: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/3); + const U4: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/4); + const U5: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/5); + const U6: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/6); + const C51: u64 = Self::c5().0; + const C52: u64 = Self::c5().1; + const C53: u64 = Self::c5().2; + const C54: u64 = Self::c5().3; + const C55: u64 = Self::c5().4; const fn c5() -> (u64, u64, u64, u64, u64) { let w = Self::U5; let w2 = Arith::

::mpowmod(w, 2); let w4 = Arith::

::mpowmod(w, 4); let inv2 = Arith::

::mmulmod(Arith::

::R2, arith::invmod(2, P)); let inv4 = Arith::

::mmulmod(Arith::

::R2, arith::invmod(4, P)); - let c51 = Arith::

::submod(Arith::

::submod(0, Arith::

::R), inv4); // (-1) + (-1) * 4^-1 mod P + let c51 = Arith::

::addmod(Arith::

::R, inv4); // 1 + 4^-1 mod P let c52 = Arith::

::addmod(Arith::

::mmulmod(inv2, Arith::

::addmod(w, w4)), inv4); // 4^-1 * (2*w + 2*w^4 + 1) mod P let c53 = Arith::

::mmulmod(inv2, Arith::

::submod(w, w4)); // 2^-1 * (w - w^4) mod P let c54 = Arith::

::addmod(Arith::

::addmod(w, w2), inv2); // 2^-1 * (2*w + 2*w^2 + 1) mod P @@ -359,7 +358,7 @@ const fn ntt5_kernel( let t6 = Arith::

::submod(t1, t2); let t7 = Arith::

::addmod64(t3, t4); let m1 = Arith::

::addmod(a, t5); - let m2 = Arith::

::mmulsubmod(P.wrapping_sub(NttKernelImpl::::C51), t5, m1); + let m2 = Arith::

::mmulsubmod(NttKernelImpl::::C51, t5, m1); let m3 = Arith::

::mmulmod(NttKernelImpl::::C52, t6); let m4 = Arith::

::mmulmod(NttKernelImpl::::C53, t7); let m5 = Arith::

::mmulsubmod(NttKernelImpl::::C54, t4, m4); @@ -510,7 +509,7 @@ fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], let (_n, g, m, last_radix) = (plan.n, plan.g, plan.m, plan.last_radix as u64); /* multiply by a constant in advance */ - mult = Arith::

::mmulmod(Arith::

::R4, Arith::

::mmulmod(mult, (P-1)/m as u64)); + mult = Arith::

::mmulmod(Arith::

::mpowmod(Arith::

::R2, 3), Arith::

::mmulmod(mult, (P-1)/m as u64)); for v in if xlen < ylen { &mut x[g..g+xlen] } else { &mut y[g..g+ylen] } { *v = Arith::

::mmulmod(*v, mult); } From c85db2b6e9e0eca36965f613262a3c56b9a89d3a Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 21 Sep 2023 19:50:23 +0900 Subject: [PATCH 60/65] A very slight optimization --- src/biguint/ntt.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 64336a69..9ff836ba 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -367,7 +367,7 @@ const fn ntt5_kernel( let s2 = Arith::

::addmod(m2, m3); let out0 = m1; let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::submod(s1, m5)); - let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(0, Arith::

::addmod(s2, m6))); + let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(Arith::

::submod(0, s2), m6)); let out3 = Arith::

::mmulmod_invtw::(w3, Arith::

::submod(m6, s2)); let out4 = Arith::

::mmulmod_invtw::(w4, Arith::

::addmodopt_invtw::(s1, m5)); (out0, out1, out2, out3, out4) From deacbedb5a0d71f19d4f0021ef7144fbaa07eba0 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 21 Sep 2023 20:25:03 +0900 Subject: [PATCH 61/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 50 ++++++++++++++-------------------------------- 1 file changed, 15 insertions(+), 35 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 9ff836ba..ba3ce36d 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -290,13 +290,8 @@ const fn ntt3_kernel( } unsafe fn ntt3_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { - let (w1, w2) = if TWIDDLE { - let w1 = *ptf; - let w2 = Arith::

::mmulmod(w1, w1); - (w1, w2) - } else { - (0, 0) - }; + let w1 = if TWIDDLE { *ptf } else { 0 }; + let w2 = Arith::

::mmulmod(w1, w1); for _ in 0..s1 { (*px, *px.add(s1), *px.add(2*s1)) = ntt3_kernel::(w1, w2, *px, *px.add(s1), *px.add(2*s1)); @@ -319,20 +314,15 @@ const fn ntt4_kernel( let jbmd = Arith::

::mmulmod(NttKernelImpl::::U4, bmd); let out0 = Arith::

::addmod(apc, bpd); let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::addmodopt_invtw::(amc, jbmd)); - let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(apc, bpd)); + let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(apc, bpd)); let out3 = Arith::

::mmulmod_invtw::(w3, Arith::

::submod(amc, jbmd)); (out0, out1, out2, out3) } unsafe fn ntt4_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { - let (w1, w2, w3) = if TWIDDLE { - let w1 = *ptf; - let w2 = Arith::

::mmulmod(w1, w1); - let w3 = Arith::

::mmulmod(w1, w2); - (w1, w2, w3) - } else { - (0, 0, 0) - }; + let w1 = if TWIDDLE { *ptf } else { 0 }; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); for _ in 0..s1 { (*px, *px.add(s1), *px.add(2*s1), *px.add(3*s1)) = ntt4_kernel::(w1, w2, w3, @@ -374,15 +364,10 @@ const fn ntt5_kernel( } unsafe fn ntt5_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { - let (w1, w2, w3, w4) = if TWIDDLE { - let w1 = *ptf; - let w2 = Arith::

::mmulmod(w1, w1); - let w3 = Arith::

::mmulmod(w1, w2); - let w4 = Arith::

::mmulmod(w2, w2); - (w1, w2, w3, w4) - } else { - (0, 0, 0, 0) - }; + let w1 = if TWIDDLE { *ptf } else { 0 }; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + let w4 = Arith::

::mmulmod(w2, w2); for _ in 0..s1 { (*px, *px.add(s1), *px.add(2*s1), *px.add(3*s1), *px.add(4*s1)) = @@ -418,16 +403,11 @@ const fn ntt6_kernel( } unsafe fn ntt6_single_block( s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { - let (w1, w2, w3, w4, w5) = if TWIDDLE { - let w1 = *ptf; - let w2 = Arith::

::mmulmod(w1, w1); - let w3 = Arith::

::mmulmod(w1, w2); - let w4 = Arith::

::mmulmod(w2, w2); - let w5 = Arith::

::mmulmod(w2, w3); - (w1, w2, w3, w4, w5) - } else { - (0, 0, 0, 0, 0) - }; + let w1 = if TWIDDLE { *ptf } else { 0 }; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + let w4 = Arith::

::mmulmod(w2, w2); + let w5 = Arith::

::mmulmod(w2, w3); for _ in 0..s1 { (*px, *px.add(s1), *px.add(2*s1), *px.add(3*s1), *px.add(4*s1), *px.add(5*s1)) = From 8ecb65dad7ac7648a8b1da4667bc98bc184fbd64 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 21 Sep 2023 20:50:40 +0900 Subject: [PATCH 62/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index ba3ce36d..6836baba 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -547,7 +547,6 @@ const P1: u64 = 14_259_017_916_245_606_401; // Max NTT length = 2^22 * 3^21 * 5^ const P2: u64 = 17_984_575_660_032_000_001; // Max NTT length = 2^19 * 3^17 * 5^6 = 1_057_916_215_296_000_000 const P3: u64 = 17_995_154_822_184_960_001; // Max NTT length = 2^17 * 3^22 * 5^4 = 2_570_736_403_169_280_000 -const P2P3: u128 = P2 as u128 * P3 as u128; const P1INV_R_MOD_P2: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod(P1, P2)); const P1P2INV_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod((P1 as u128 * P2 as u128 % P3 as u128) as u64, P3)); const P1_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, P1); @@ -618,7 +617,7 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { /* extract the convolution result */ let (a, b) = (x[i], y[i]); let (mut v, overflow) = (a as u128 * P3 as u128 + carry).overflowing_sub(b as u128 * P2 as u128); - if overflow { v = v.wrapping_add(P2P3); } + if overflow { v = v.wrapping_add(P2 as u128 * P3 as u128); } carry = v >> bits; /* write to s */ From 3f8c2c5a86e4b06379747a63d72fb1c1aeeadfd6 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Thu, 21 Sep 2023 20:53:29 +0900 Subject: [PATCH 63/65] Make ntt.rs shorter --- src/biguint/ntt.rs | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 6836baba..d4b114ed 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -237,12 +237,7 @@ impl NttKernelImpl { const U4: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/4); const U5: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/5); const U6: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/6); - const C51: u64 = Self::c5().0; - const C52: u64 = Self::c5().1; - const C53: u64 = Self::c5().2; - const C54: u64 = Self::c5().3; - const C55: u64 = Self::c5().4; - const fn c5() -> (u64, u64, u64, u64, u64) { + const C5: (u64, u64, u64, u64, u64, u64) = { let w = Self::U5; let w2 = Arith::

::mpowmod(w, 2); let w4 = Arith::

::mpowmod(w, 4); @@ -253,8 +248,8 @@ impl NttKernelImpl { let c53 = Arith::

::mmulmod(inv2, Arith::

::submod(w, w4)); // 2^-1 * (w - w^4) mod P let c54 = Arith::

::addmod(Arith::

::addmod(w, w2), inv2); // 2^-1 * (2*w + 2*w^2 + 1) mod P let c55 = Arith::

::addmod(Arith::

::addmod(w2, w4), inv2); // 2^-1 * (2*w^2 + 2*w^4 + 1) mod P - (c51, c52, c53, c54, c55) - } + (0, c51, c52, c53, c54, c55) + }; } const fn ntt2_kernel( w1: u64, @@ -348,11 +343,11 @@ const fn ntt5_kernel( let t6 = Arith::

::submod(t1, t2); let t7 = Arith::

::addmod64(t3, t4); let m1 = Arith::

::addmod(a, t5); - let m2 = Arith::

::mmulsubmod(NttKernelImpl::::C51, t5, m1); - let m3 = Arith::

::mmulmod(NttKernelImpl::::C52, t6); - let m4 = Arith::

::mmulmod(NttKernelImpl::::C53, t7); - let m5 = Arith::

::mmulsubmod(NttKernelImpl::::C54, t4, m4); - let m6 = Arith::

::mmulsubmod(P.wrapping_sub(NttKernelImpl::::C55), t3, m4); + let m2 = Arith::

::mmulsubmod(NttKernelImpl::::C5.1, t5, m1); + let m3 = Arith::

::mmulmod(NttKernelImpl::::C5.2, t6); + let m4 = Arith::

::mmulmod(NttKernelImpl::::C5.3, t7); + let m5 = Arith::

::mmulsubmod(NttKernelImpl::::C5.4, t4, m4); + let m6 = Arith::

::mmulsubmod(P.wrapping_sub(NttKernelImpl::::C5.5), t3, m4); let s1 = Arith::

::submod(m3, m2); let s2 = Arith::

::addmod(m2, m3); let out0 = m1; From 2f4460b3e54e57acc04789db3aa1bca5980acf70 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Fri, 22 Sep 2023 09:37:30 +0900 Subject: [PATCH 64/65] Improve NTT planning --- src/biguint/ntt.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index d4b114ed..1870a60a 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -159,9 +159,11 @@ impl NttPlan { (g_new, tmp, cost) = (7, tmp/7, cost + len*115/100); } else if len % 5 == 0 { (g_new, tmp, cost) = (5, tmp/5, cost + len*89/100); - } else if m2 >= m3 + 3 { + } else if m3 >= m2 + 2 { + (g_new, tmp, cost) = (9, tmp/9, cost + len*180/100); + } else if m2 >= m3 + 3 && (m2 - m3) % 2 == 1 { (g_new, tmp, cost) = (8, tmp/8, cost + len*130/100); - } else if m2 >= m3 + 2 { + } else if m2 >= m3 + 2 && m3 == 0 { (g_new, tmp, cost) = (4, tmp/4, cost + len*87/100); } else if m2 == 0 && m3 >= 1 { (g_new, tmp, cost) = (3, tmp/3, cost + len*86/100); From 22914667b01a72ec3f1b19baa1efa17d7a7ccb53 Mon Sep 17 00:00:00 2001 From: Byeongkeun Ahn <7p54ks3@naver.com> Date: Fri, 9 Feb 2024 21:32:01 +0900 Subject: [PATCH 65/65] Fix NTT pack/unpack bug with u32 digits --- src/biguint/ntt.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 1870a60a..3ec36351 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -774,8 +774,8 @@ pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { #[cfg(not(u64_digit))] pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { - fn bigdigit_to_u64(src: &[BigDigit]) -> Vec { - let mut out = vec![0u64; (src.len() + 1) / 2]; + fn bigdigit_to_u64(src: &[BigDigit], is_acc: bool) -> Vec { + let mut out = vec![0u64; (src.len() + 1) / 2 + is_acc as usize]; for i in 0..src.len()/2 { out[i] = (src[2*i] as u64) | ((src[2*i+1] as u64) << 32); } @@ -795,7 +795,7 @@ pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { } /* convert to u64 => process => convert back to BigDigit (u32) */ - let mut acc_u64 = bigdigit_to_u64(acc); - mac3_u64(&mut acc_u64, &bigdigit_to_u64(b), &bigdigit_to_u64(c)); + let mut acc_u64 = bigdigit_to_u64(acc, true); + mac3_u64(&mut acc_u64, &bigdigit_to_u64(b, false), &bigdigit_to_u64(c, false)); u64_to_bigdigit(&acc_u64, acc); } \ No newline at end of file