diff --git a/benches/bigint.rs b/benches/bigint.rs index 80ec191c..3bfea576 100644 --- a/benches/bigint.rs +++ b/benches/bigint.rs @@ -87,6 +87,21 @@ fn multiply_3(b: &mut Bencher) { multiply_bench(b, 1 << 16, 1 << 17); } +#[bench] +fn multiply_4(b: &mut Bencher) { + multiply_bench(b, 100_000, 1_003_741); +} + +#[bench] +fn multiply_5(b: &mut Bencher) { + multiply_bench(b, 2_718_328, 2_738_633); +} + +#[bench] +fn multiply_6(b: &mut Bencher) { + multiply_bench(b, 27_183_279, 27_386_321); +} + #[bench] fn divide_0(b: &mut Bencher) { divide_bench(b, 1 << 8, 1 << 6); diff --git a/benches/factorial.rs b/benches/factorial.rs index a1e7b3cf..61e91ab5 100644 --- a/benches/factorial.rs +++ b/benches/factorial.rs @@ -16,6 +16,36 @@ fn factorial_mul_biguint(b: &mut Bencher) { }); } +fn factorial_product(l: usize, r: usize) -> BigUint { + if l >= r { + BigUint::from(l) + } else { + let m = (l+r)/2; + factorial_product(l, m) * factorial_product(m+1, r) + } +} + +#[bench] +fn factorial_mul_biguint_dnc_10k(b: &mut Bencher) { + b.iter(|| { + factorial_product(1, 10_000); + }); +} + +#[bench] +fn factorial_mul_biguint_dnc_100k(b: &mut Bencher) { + b.iter(|| { + factorial_product(1, 100_000); + }); +} + +#[bench] +fn factorial_mul_biguint_dnc_300k(b: &mut Bencher) { + b.iter(|| { + factorial_product(1, 300_000); + }); +} + #[bench] fn factorial_mul_u32(b: &mut Bencher) { b.iter(|| (1u32..1000).fold(BigUint::one(), Mul::mul)); diff --git a/src/biguint.rs b/src/biguint.rs index 1554eb0f..373205af 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -22,6 +22,7 @@ mod bits; mod convert; mod iter; mod monty; +mod ntt; mod power; mod shift; diff --git a/src/biguint/multiplication.rs b/src/biguint/multiplication.rs index 4d7f1f21..bec130ee 100644 --- a/src/biguint/multiplication.rs +++ b/src/biguint/multiplication.rs @@ -13,6 +13,8 @@ use core::iter::Product; use core::ops::{Mul, MulAssign}; use num_traits::{CheckedMul, FromPrimitive, One, Zero}; +use super::ntt; + #[inline] pub(super) fn mac_with_carry( a: BigDigit, @@ -97,7 +99,7 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { // number of operations, but uses more temporary allocations. // // The thresholds are somewhat arbitrary, chosen by evaluating the results - // of `cargo bench --bench bigint multiply`. + // of `cargo bench --bench bigint multiply --features rand`. if x.len() <= 32 { // Long multiplication: @@ -217,7 +219,7 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { } NoSign => (), } - } else { + } else if x.len() <= if cfg!(u64_digit) { 512 } else { 2048 } { // Toom-3 multiplication: // // Toom-3 is like Karatsuba above, but dividing the inputs into three parts. @@ -346,6 +348,14 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) { NoSign => {} } } + } else { + // Number-theoretic transform (NTT) multiplication: + // + // NTT multiplies two integers by computing the convolution of the arrays + // modulo a prime. Since the result may exceed the prime, we use two or three + // distinct primes and combine the results using the Chinese Remainder + // Theroem (CRT). + ntt::mac3(acc, b, c); } } diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs new file mode 100644 index 00000000..3ec36351 --- /dev/null +++ b/src/biguint/ntt.rs @@ -0,0 +1,801 @@ +#![allow(clippy::cast_sign_loss)] +#![allow(clippy::cast_lossless)] +#![allow(clippy::cast_possible_truncation)] +#![allow(clippy::many_single_char_names)] +#![allow(clippy::needless_range_loop)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::similar_names)] + +use crate::biguint::Vec; + +mod arith { + // Extended Euclid algorithm: + // (g, x, y) is a solution to ax + by = g, where g = gcd(a, b) + const fn egcd(mut a: i128, mut b: i128) -> (i128, i128, i128) { + assert!(a > 0 && b > 0); + let mut c = if a > b { (a, b) = (b, a); [0, 1, 1, 0] } else { [1, 0, 0, 1] }; // treat as a row-major 2x2 matrix + loop { + if a == 0 { break (b, c[1], c[3]) } + let (q, r) = (b/a, b%a); + (a, b) = (r, a); + c = [c[1] - q*c[0], c[0], c[3] - q*c[2], c[2]]; + } + } + // Modular inverse: a^-1 mod modulus + // (m == 0 means m == 2^64) + pub const fn invmod(a: u64, modulus: u64) -> u64 { + let m = if modulus == 0 { 1i128 << 64 } else { modulus as i128 }; + let (g, mut x, _y) = egcd(a as i128, m); + assert!(g == 1); + if x < 0 { x += m; } + assert!(x > 0 && x < 1i128 << 64); + x as u64 + } +} + +struct Arith {} +impl Arith

{ + const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P + const R2: u64 = (Self::R as u128 * Self::R as u128 % P as u128) as u64; // R^2 mod P + const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 + const MAX_NTT_LEN: u64 = 2u64.pow(Self::factors(2)) * 3u64.pow(Self::factors(3)) * 5u64.pow(Self::factors(5)); + const ROOTR: u64 = { + // ROOT * R mod P (ROOT: MAX_NTT_LEN divides MultiplicativeOrder[ROOT, P]) + assert!(Self::MAX_NTT_LEN % 4050 == 0); + let mut p = Self::R; + loop { + if Self::mpowmod(p, P/2) != Self::R && + Self::mpowmod(p, P/3) != Self::R && + Self::mpowmod(p, P/5) != Self::R { + break Self::mpowmod(p, P/Self::MAX_NTT_LEN); + } + p = Self::addmod(p, Self::R); + } + }; + // Counts the number of `divisor` factors in P-1. + const fn factors(divisor: u64) -> u32 { + let (mut tmp, mut ans) = (P-1, 0); + while tmp % divisor == 0 { tmp /= divisor; ans += 1; } + ans + } + // Montgomery reduction: + // x * R^-1 mod P + const fn mreduce(x: u128) -> u64 { + let m = (x as u64).wrapping_mul(Self::PINV); + let y = (m as u128 * P as u128 >> 64) as u64; + let (out, overflow) = ((x >> 64) as u64).overflowing_sub(y); + if overflow { out.wrapping_add(P) } else { out } + } + // Multiplication with Montgomery reduction: + // a * b * R^-1 mod P + const fn mmulmod(a: u64, b: u64) -> u64 { + Self::mreduce(a as u128 * b as u128) + } + // Multiplication with Montgomery reduction: + // a * b * R^-1 mod P + // This function only applies the multiplication when INV && TWIDDLE, + // otherwise it just returns b. + const fn mmulmod_invtw(a: u64, b: u64) -> u64 { + if INV && TWIDDLE { Self::mmulmod(a, b) } else { b } + } + // Fused-multiply-sub with Montgomery reduction: + // a * b * R^-1 - c mod P + const fn mmulsubmod(a: u64, b: u64, c: u64) -> u64 { + let x = a as u128 * b as u128; + let lo = x as u64; + let hi = Self::submod((x >> 64) as u64, c); + Self::mreduce(lo as u128 | ((hi as u128) << 64)) + } + // Computes base^exponent mod P with Montgomery reduction + const fn mpowmod(mut base: u64, mut exponent: u64) -> u64 { + let mut cur = Self::R; + while exponent > 0 { + if exponent % 2 > 0 { + cur = Self::mmulmod(cur, base); + } + exponent /= 2; + base = Self::mmulmod(base, base); + } + cur + } + // Computes c as u128 * mreduce(v) as u128, + // using d: u64 = mmulmod(P-1, c). + // It is caller's responsibility to ensure that d is correct. + // Note that d can be computed by calling mreducelo(c). + const fn mmulmod_noreduce(v: u128, c: u64, d: u64) -> u128 { + let a: u128 = c as u128 * (v >> 64); + let b: u128 = d as u128 * (v as u64 as u128); + let (w, overflow) = a.overflowing_sub(b); + if overflow { w.wrapping_add((P as u128) << 64) } else { w } + } + // Computes submod(0, mreduce(x as u128)) fast. + const fn mreducelo(x: u64) -> u64 { + let m = x.wrapping_mul(Self::PINV); + (m as u128 * P as u128 >> 64) as u64 + } + // Computes a + b mod P, output range [0, P) + const fn addmod(a: u64, b: u64) -> u64 { + Self::submod(a, P.wrapping_sub(b)) + } + // Computes a + b mod P, output range [0, 2^64) + const fn addmod64(a: u64, b: u64) -> u64 { + let (out, overflow) = a.overflowing_add(b); + if overflow { out.wrapping_sub(P) } else { out } + } + // Computes a + b mod P, selects addmod64 or addmod depending on INV && TWIDDLE + const fn addmodopt_invtw(a: u64, b: u64) -> u64 { + if INV && TWIDDLE { Self::addmod64(a, b) } else { Self::addmod(a, b) } + } + // Computes a - b mod P, output range [0, P) + const fn submod(a: u64, b: u64) -> u64 { + let (out, overflow) = a.overflowing_sub(b); + if overflow { out.wrapping_add(P) } else { out } + } +} + +struct NttPlan { + pub n: usize, // n == g*m + pub g: usize, // g: size of the base case + pub m: usize, // m divides Arith::

::MAX_NTT_LEN + pub cost: usize, + pub last_radix: usize, + pub s_list: Vec<(usize, usize)>, +} +impl NttPlan { + fn build(min_len: usize) -> Self { + assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); + let (mut len_max, mut len_max_cost, mut g) = (usize::MAX, usize::MAX, 1); + for m7 in 0..=1 { + for m5 in 0..=Arith::

::factors(5) { + for m3 in 0..=Arith::

::factors(3) { + let len = 7u64.pow(m7) * 5u64.pow(m5) * 3u64.pow(m3); + if len >= 2 * min_len as u64 { break; } + let (mut len, mut m2) = (len as usize, 0); + while len < min_len && m2 < Arith::

::factors(2) { len *= 2; m2 += 1; } + if len >= min_len && len < len_max_cost { + let (mut tmp, mut cost) = (len, 0); + let mut g_new = 1; + if len % 7 == 0 { + (g_new, tmp, cost) = (7, tmp/7, cost + len*115/100); + } else if len % 5 == 0 { + (g_new, tmp, cost) = (5, tmp/5, cost + len*89/100); + } else if m3 >= m2 + 2 { + (g_new, tmp, cost) = (9, tmp/9, cost + len*180/100); + } else if m2 >= m3 + 3 && (m2 - m3) % 2 == 1 { + (g_new, tmp, cost) = (8, tmp/8, cost + len*130/100); + } else if m2 >= m3 + 2 && m3 == 0 { + (g_new, tmp, cost) = (4, tmp/4, cost + len*87/100); + } else if m2 == 0 && m3 >= 1 { + (g_new, tmp, cost) = (3, tmp/3, cost + len*86/100); + } else if m3 == 0 && m2 >= 1 { + (g_new, tmp, cost) = (2, tmp/2, cost + len*86/100); + } else if len % 6 == 0 { + (g_new, tmp, cost) = (6, tmp/6, cost + len*91/100); + } + let (mut b6, mut b2) = (false, false); + while tmp % 6 == 0 { (tmp, cost) = (tmp/6, cost + len*106/100); b6 = true; } + while tmp % 5 == 0 { (tmp, cost) = (tmp/5, cost + len*131/100); } + while tmp % 4 == 0 { (tmp, cost) = (tmp/4, cost + len); } + while tmp % 3 == 0 { (tmp, cost) = (tmp/3, cost + len); } + while tmp % 2 == 0 { (tmp, cost) = (tmp/2, cost + len); b2 = true; } + if b6 && b2 { cost -= len*6/100; } + if cost < len_max_cost { (len_max, len_max_cost, g) = (len, cost, g_new); } + } + } + } + } + let (mut cnt6, mut cnt5, mut cnt4, mut cnt3, mut cnt2) = (0, 0, 0, 0, 0); + let mut tmp = len_max / g; + while tmp % 6 == 0 { tmp /= 6; cnt6 += 1; } + while tmp % 5 == 0 { tmp /= 5; cnt5 += 1; } + while tmp % 4 == 0 { tmp /= 4; cnt4 += 1; } + while tmp % 3 == 0 { tmp /= 3; cnt3 += 1; } + while tmp % 2 == 0 { tmp /= 2; cnt2 += 1; } + while cnt6 > 0 && cnt2 > 0 { cnt6 -= 1; cnt2 -= 1; cnt4 += 1; cnt3 += 1; } + let s_list = { + let mut out = vec![]; + let mut tmp = len_max; + for _ in 0..cnt2 { out.push((tmp, 2)); tmp /= 2; } + for _ in 0..cnt3 { out.push((tmp, 3)); tmp /= 3; } + for _ in 0..cnt4 { out.push((tmp, 4)); tmp /= 4; } + for _ in 0..cnt5 { out.push((tmp, 5)); tmp /= 5; } + for _ in 0..cnt6 { out.push((tmp, 6)); tmp /= 6; } + out + }; + Self { + n: len_max, + g, + m: len_max / g, + cost: len_max_cost, + last_radix: s_list.last().unwrap_or(&(1, 1)).1, + s_list, + } + } +} +fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64) { + unsafe { + let c2 = Arith::

::mreducelo(c); + let out = x.sub(n); + for i in 0..n { + let mut v: u128 = 0; + for j in i+1..n { + let (w, overflow) = v.overflowing_sub(*x.add(j) as u128 * *y.add(i+n-j) as u128); + v = if overflow { w.wrapping_add((P as u128) << 64) } else { w }; + } + v = Arith::

::mmulmod_noreduce(v, c, c2); + for j in 0..=i { + let (w, overflow) = v.overflowing_sub(*x.add(j) as u128 * *y.add(i-j) as u128); + v = if overflow { w.wrapping_add((P as u128) << 64) } else { w }; + } + *out.add(i) = Arith::

::mreduce(v); + } + } +} + +struct NttKernelImpl; +impl NttKernelImpl { + const ROOTR: u64 = Arith::

::mpowmod(Arith::

::ROOTR, if INV { Arith::

::MAX_NTT_LEN - 1 } else { 1 }); + const U3: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/3); + const U4: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/4); + const U5: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/5); + const U6: u64 = Arith::

::mpowmod(Self::ROOTR, Arith::

::MAX_NTT_LEN/6); + const C5: (u64, u64, u64, u64, u64, u64) = { + let w = Self::U5; + let w2 = Arith::

::mpowmod(w, 2); + let w4 = Arith::

::mpowmod(w, 4); + let inv2 = Arith::

::mmulmod(Arith::

::R2, arith::invmod(2, P)); + let inv4 = Arith::

::mmulmod(Arith::

::R2, arith::invmod(4, P)); + let c51 = Arith::

::addmod(Arith::

::R, inv4); // 1 + 4^-1 mod P + let c52 = Arith::

::addmod(Arith::

::mmulmod(inv2, Arith::

::addmod(w, w4)), inv4); // 4^-1 * (2*w + 2*w^4 + 1) mod P + let c53 = Arith::

::mmulmod(inv2, Arith::

::submod(w, w4)); // 2^-1 * (w - w^4) mod P + let c54 = Arith::

::addmod(Arith::

::addmod(w, w2), inv2); // 2^-1 * (2*w + 2*w^2 + 1) mod P + let c55 = Arith::

::addmod(Arith::

::addmod(w2, w4), inv2); // 2^-1 * (2*w^2 + 2*w^4 + 1) mod P + (0, c51, c52, c53, c54, c55) + }; +} +const fn ntt2_kernel( + w1: u64, + a: u64, mut b: u64) -> (u64, u64) { + if !INV && TWIDDLE { + b = Arith::

::mmulmod(w1, b); + } + let out0 = Arith::

::addmod(a, b); + let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::submod(a, b)); + (out0, out1) +} +unsafe fn ntt2_single_block( + s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { + let w1 = if TWIDDLE { *ptf } else { 0 }; + for _ in 0..s1 { + (*px, *px.add(s1)) = ntt2_kernel::(w1, *px, *px.add(s1)); + px = px.add(1); + } + (px.add(s1), ptf.add(1)) +} +const fn ntt3_kernel( + w1: u64, w2: u64, + a: u64, mut b: u64, mut c: u64) -> (u64, u64, u64) { + if !INV && TWIDDLE { + b = Arith::

::mmulmod(w1, b); + c = Arith::

::mmulmod(w2, c); + } + let kbmc = Arith::

::mmulmod(NttKernelImpl::::U3, Arith::

::submod(b, c)); + let out0 = Arith::

::addmod(a, Arith::

::addmod(b, c)); + let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::submod(a, Arith::

::submod(c, kbmc))); + let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(Arith::

::submod(a, b), kbmc)); + (out0, out1, out2) +} +unsafe fn ntt3_single_block( + s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { + let w1 = if TWIDDLE { *ptf } else { 0 }; + let w2 = Arith::

::mmulmod(w1, w1); + for _ in 0..s1 { + (*px, *px.add(s1), *px.add(2*s1)) = + ntt3_kernel::(w1, w2, *px, *px.add(s1), *px.add(2*s1)); + px = px.add(1); + } + (px.add(2*s1), ptf.add(1)) +} +const fn ntt4_kernel( + w1: u64, w2: u64, w3: u64, + a: u64, mut b: u64, mut c: u64, mut d: u64) -> (u64, u64, u64, u64) { + if !INV && TWIDDLE { + b = Arith::

::mmulmod(w1, b); + c = Arith::

::mmulmod(w2, c); + d = Arith::

::mmulmod(w3, d); + } + let apc = Arith::

::addmod(a, c); + let amc = Arith::

::submod(a, c); + let bpd = Arith::

::addmod(b, d); + let bmd = Arith::

::submod(b, d); + let jbmd = Arith::

::mmulmod(NttKernelImpl::::U4, bmd); + let out0 = Arith::

::addmod(apc, bpd); + let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::addmodopt_invtw::(amc, jbmd)); + let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(apc, bpd)); + let out3 = Arith::

::mmulmod_invtw::(w3, Arith::

::submod(amc, jbmd)); + (out0, out1, out2, out3) +} +unsafe fn ntt4_single_block( + s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { + let w1 = if TWIDDLE { *ptf } else { 0 }; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + for _ in 0..s1 { + (*px, *px.add(s1), *px.add(2*s1), *px.add(3*s1)) = + ntt4_kernel::(w1, w2, w3, + *px, *px.add(s1), *px.add(2*s1), *px.add(3*s1)); + px = px.add(1); + } + (px.add(3*s1), ptf.add(1)) +} +const fn ntt5_kernel( + w1: u64, w2: u64, w3: u64, w4: u64, + a: u64, mut b: u64, mut c: u64, mut d: u64, mut e: u64) -> (u64, u64, u64, u64, u64) { + if !INV && TWIDDLE { + b = Arith::

::mmulmod(w1, b); + c = Arith::

::mmulmod(w2, c); + d = Arith::

::mmulmod(w3, d); + e = Arith::

::mmulmod(w4, e); + } + let t1 = Arith::

::addmod(b, e); + let t2 = Arith::

::addmod(c, d); + let t3 = Arith::

::submod(b, e); + let t4 = Arith::

::submod(d, c); + let t5 = Arith::

::addmod(t1, t2); + let t6 = Arith::

::submod(t1, t2); + let t7 = Arith::

::addmod64(t3, t4); + let m1 = Arith::

::addmod(a, t5); + let m2 = Arith::

::mmulsubmod(NttKernelImpl::::C5.1, t5, m1); + let m3 = Arith::

::mmulmod(NttKernelImpl::::C5.2, t6); + let m4 = Arith::

::mmulmod(NttKernelImpl::::C5.3, t7); + let m5 = Arith::

::mmulsubmod(NttKernelImpl::::C5.4, t4, m4); + let m6 = Arith::

::mmulsubmod(P.wrapping_sub(NttKernelImpl::::C5.5), t3, m4); + let s1 = Arith::

::submod(m3, m2); + let s2 = Arith::

::addmod(m2, m3); + let out0 = m1; + let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::submod(s1, m5)); + let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(Arith::

::submod(0, s2), m6)); + let out3 = Arith::

::mmulmod_invtw::(w3, Arith::

::submod(m6, s2)); + let out4 = Arith::

::mmulmod_invtw::(w4, Arith::

::addmodopt_invtw::(s1, m5)); + (out0, out1, out2, out3, out4) +} +unsafe fn ntt5_single_block( + s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { + let w1 = if TWIDDLE { *ptf } else { 0 }; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + let w4 = Arith::

::mmulmod(w2, w2); + for _ in 0..s1 { + (*px, *px.add(s1), *px.add(2*s1), + *px.add(3*s1), *px.add(4*s1)) = + ntt5_kernel::(w1, w2, w3, w4, + *px, *px.add(s1), *px.add(2*s1), + *px.add(3*s1), *px.add(4*s1)); + px = px.add(1); + } + (px.add(4*s1), ptf.add(1)) +} +const fn ntt6_kernel( + w1: u64, w2: u64, w3: u64, w4: u64, w5: u64, + mut a: u64, mut b: u64, mut c: u64, mut d: u64, mut e: u64, mut f: u64) -> (u64, u64, u64, u64, u64, u64) { + if !INV && TWIDDLE { + b = Arith::

::mmulmod(w1, b); + c = Arith::

::mmulmod(w2, c); + d = Arith::

::mmulmod(w3, d); + e = Arith::

::mmulmod(w4, e); + f = Arith::

::mmulmod(w5, f); + } + (a, d) = (Arith::

::addmod(a, d), Arith::

::submod(a, d)); + (b, e) = (Arith::

::addmod(b, e), Arith::

::submod(b, e)); + (c, f) = (Arith::

::addmod(c, f), Arith::

::submod(c, f)); + let lbmc = Arith::

::mmulmod(NttKernelImpl::::U6, Arith::

::submod(b, c)); + let out0 = Arith::

::addmod(a, Arith::

::addmod(b, c)); + let out2 = Arith::

::mmulmod_invtw::(w2, Arith::

::submod(a, Arith::

::submod(b, lbmc))); + let out4 = Arith::

::mmulmod_invtw::(w4, Arith::

::submod(Arith::

::submod(a, c), lbmc)); + let lepf = Arith::

::mmulmod(NttKernelImpl::::U6, Arith::

::addmod64(e, f)); + let out1 = Arith::

::mmulmod_invtw::(w1, Arith::

::submod(d, Arith::

::submod(f, lepf))); + let out3 = Arith::

::mmulmod_invtw::(w3, Arith::

::submod(d, Arith::

::submod(e, f))); + let out5 = Arith::

::mmulmod_invtw::(w5, Arith::

::submod(d, Arith::

::submod(lepf, e))); + (out0, out1, out2, out3, out4, out5) +} +unsafe fn ntt6_single_block( + s1: usize, mut px: *mut u64, ptf: *const u64) -> (*mut u64, *const u64) { + let w1 = if TWIDDLE { *ptf } else { 0 }; + let w2 = Arith::

::mmulmod(w1, w1); + let w3 = Arith::

::mmulmod(w1, w2); + let w4 = Arith::

::mmulmod(w2, w2); + let w5 = Arith::

::mmulmod(w2, w3); + for _ in 0..s1 { + (*px, *px.add(s1), *px.add(2*s1), + *px.add(3*s1), *px.add(4*s1), *px.add(5*s1)) = + ntt6_kernel::(w1, w2, w3, w4, w5, + *px, *px.add(s1), *px.add(2*s1), + *px.add(3*s1), *px.add(4*s1), *px.add(5*s1)); + px = px.add(1); + } + (px.add(5*s1), ptf.add(1)) +} + +fn ntt_dif_dit(plan: &NttPlan, x: &mut [u64], tf_list: &[u64]) { + let mut i_list: Vec<_> = (0..plan.s_list.len()).collect(); + if INV { i_list.reverse(); } + let mut ptf = tf_list.as_ptr(); + for i in i_list { + let (s, radix) = plan.s_list[i]; + let s1 = s / radix; + unsafe { + let mut px = x.as_mut_ptr(); + let px_end = px.add(plan.n); + match radix { + 2 => { + (px, ptf) = ntt2_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt2_single_block::(s1, px, ptf); + } + }, + 3 => { + (px, ptf) = ntt3_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt3_single_block::(s1, px, ptf); + } + }, + 4 => { + (px, ptf) = ntt4_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt4_single_block::(s1, px, ptf); + } + }, + 5 => { + (px, ptf) = ntt5_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt5_single_block::(s1, px, ptf); + } + }, + 6 => { + (px, ptf) = ntt6_single_block::(s1, px, ptf); + while px < px_end { + (px, ptf) = ntt6_single_block::(s1, px, ptf); + } + }, + _ => { unreachable!() } + } + } + } +} + +fn calc_twiddle_factors(s_list: &[(usize, usize)], out: &mut [u64]) -> usize { + let mut p = 1; + out[0] = Arith::

::R; + for i in (1..s_list.len()).rev() { + let radix = s_list[i-1].1; + let w = Arith::

::mpowmod(NttKernelImpl::::ROOTR, Arith::

::MAX_NTT_LEN/(p * radix * s_list.last().unwrap().1) as u64); + for j in p..radix*p { + out[j] = Arith::

::mmulmod(w, out[j - p]); + } + p *= radix; + } + p +} + +// Performs (cyclic) integer convolution modulo P using NTT. +// Modifies the input buffers in-place. +// The output is saved in the slice `x`. +// The input slices must have the same length. +fn conv(plan: &NttPlan, x: &mut [u64], xlen: usize, y: &mut [u64], ylen: usize, mut mult: u64) { + assert!(!x.is_empty() && x.len() == y.len()); + let (_n, g, m, last_radix) = (plan.n, plan.g, plan.m, plan.last_radix as u64); + + /* multiply by a constant in advance */ + mult = Arith::

::mmulmod(Arith::

::mpowmod(Arith::

::R2, 3), Arith::

::mmulmod(mult, (P-1)/m as u64)); + for v in if xlen < ylen { &mut x[g..g+xlen] } else { &mut y[g..g+ylen] } { + *v = Arith::

::mmulmod(*v, mult); + } + + /* compute the total space needed for twiddle factors */ + let (mut radix_cumul, mut tf_all_count) = (1, 2); // 2 extra slots + for &(_, radix) in &plan.s_list { + tf_all_count += radix_cumul; + radix_cumul *= radix; + } + + /* build twiddle factors */ + let mut tf_list = vec![0u64; tf_all_count]; + let mut tf_last_start = 0; + for i in 0..plan.s_list.len() { + let x = calc_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); + if i + 1 < plan.s_list.len() { tf_last_start += x; } + } + + /* dif fft */ + ntt_dif_dit::(plan, &mut x[g..], &tf_list); + ntt_dif_dit::(plan, &mut y[g..], &tf_list); + + /* naive multiplication */ + let (mut i, mut ii, mut ii_mod_last_radix) = (g, tf_last_start, 0); + let mut tf_current = Arith::

::R; + let tf_mult = Arith::

::mpowmod(NttKernelImpl::::ROOTR, Arith::

::MAX_NTT_LEN/last_radix); + while i < g + plan.n { + conv_base::

(g, (&mut x[i..]).as_mut_ptr(), (&mut y[i..]).as_mut_ptr(), tf_current); + i += g; + ii_mod_last_radix += 1; + if ii_mod_last_radix == last_radix { + ii += 1; + ii_mod_last_radix = 0; + tf_current = tf_list[ii]; + } else { + tf_current = Arith::

::mmulmod(tf_current, tf_mult); + } + } + + /* dit fft */ + let mut tf_last_start = 0; + for i in (0..plan.s_list.len()).rev() { + tf_last_start += calc_twiddle_factors::(&plan.s_list[0..=i], &mut tf_list[tf_last_start..]); + } + ntt_dif_dit::(plan, x, &tf_list); +} + +//////////////////////////////////////////////////////////////////////////////// + +use core::cmp::{min, max}; +use crate::big_digit::BigDigit; + +const P1: u64 = 14_259_017_916_245_606_401; // Max NTT length = 2^22 * 3^21 * 5^2 = 1_096_847_532_018_892_800 +const P2: u64 = 17_984_575_660_032_000_001; // Max NTT length = 2^19 * 3^17 * 5^6 = 1_057_916_215_296_000_000 +const P3: u64 = 17_995_154_822_184_960_001; // Max NTT length = 2^17 * 3^22 * 5^4 = 2_570_736_403_169_280_000 + +const P1INV_R_MOD_P2: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod(P1, P2)); +const P1P2INV_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod((P1 as u128 * P2 as u128 % P3 as u128) as u64, P3)); +const P1_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, P1); +const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; +const P1P2_HI: u64 = (P1 as u128 * P2 as u128 >> 64) as u64; + +// Propagates carry from the beginning to the end of acc, +// and returns the resulting carry if it is nonzero. +fn propagate_carry(acc: &mut [u64], mut carry: u64) -> u64 { + for x in acc { + let (v, overflow) = x.overflowing_add(carry); + (*x, carry) = (v, u64::from(overflow)); + if !overflow { break; } + } + carry +} + +fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { + fn pack_into(src: &[u64], dst1: &mut [u64], dst2: &mut [u64], bits: u64) { + let mut p = 0u64; + let mut pdst1 = dst1.as_mut_ptr(); + let mut pdst2 = dst2.as_mut_ptr(); + let mut x = 0u64; + let mask = (1u64 << bits) - 1; + for v in src { + let mut k = 0; + while k < 64 { + x |= (v >> k) << p; + let q = 64 - k; + if p + q >= bits { + unsafe { let out = x & mask; *pdst1 = out; *pdst2 = out; } + x = 0; + unsafe { (pdst1, pdst2, k, p) = (pdst1.add(1), pdst2.add(1), k + bits - p, 0); } + } else { + p += q; + break; + } + } + } + unsafe { + if p > 0 { let out = x & mask; *pdst1 = out; *pdst2 = out; } + } + } + + assert!(bits < 63); + let b_len = ((64 * b.len() as u64 + bits - 1) / bits) as usize; + let c_len = ((64 * c.len() as u64 + bits - 1) / bits) as usize; + let min_len = b_len + c_len; + let plan_x = NttPlan::build::(min_len); + let plan_y = NttPlan::build::(min_len); + + let mut x = vec![0u64; plan_x.g + plan_x.n]; + let mut y = vec![0u64; plan_y.g + plan_y.n]; + let mut r = vec![0u64; plan_x.g + plan_x.n]; + let mut s = vec![0u64; plan_y.g + plan_y.n]; + pack_into(b, &mut x[plan_x.g..], &mut y[plan_y.g..], bits); + pack_into(c, &mut r[plan_x.g..], &mut s[plan_y.g..], bits); + conv::(&plan_x, &mut x, b_len, &mut r[..plan_x.g+plan_x.n], c_len, arith::invmod(P3, P2)); + conv::(&plan_y, &mut y, b_len, &mut s[..plan_y.g+plan_y.n], c_len, Arith::::submod(0, arith::invmod(P2, P3))); + + /* merge the results in {x, y} into r (process carry along the way) */ + let mask = (1u64 << bits) - 1; + let mut carry: u128 = 0; + let (mut j, mut p) = (0usize, 0u64); + let mut s: u64 = 0; + let mut carry_acc: u64 = 0; + for i in 0..min_len { + /* extract the convolution result */ + let (a, b) = (x[i], y[i]); + let (mut v, overflow) = (a as u128 * P3 as u128 + carry).overflowing_sub(b as u128 * P2 as u128); + if overflow { v = v.wrapping_add(P2 as u128 * P3 as u128); } + carry = v >> bits; + + /* write to s */ + let out = (v as u64) & mask; + s |= out << p; + p += bits; + if p >= 64 { + /* flush s to the output buffer */ + let (w, overflow1) = s.overflowing_add(carry_acc); + let (w, overflow2) = acc[j].overflowing_add(w); + acc[j] = w; + carry_acc = u64::from(overflow1 || overflow2); + + /* roll-over */ + (j, p) = (j + 1, p - 64); + s = out >> (bits - p); + } + } + // Process remaining carries. The addition carry_acc + s should not overflow + // since s is underfilled and carry_acc is always 0 or 1. + propagate_carry(&mut acc[j..], carry_acc + s); +} + +fn mac3_three_primes(acc: &mut [u64], b: &[u64], c: &[u64]) { + let min_len = b.len() + c.len(); + let plan_x = NttPlan::build::(min_len); + let plan_y = NttPlan::build::(min_len); + let plan_z = NttPlan::build::(min_len); + let mut x = vec![0u64; plan_x.g + plan_x.n]; + let mut y = vec![0u64; plan_y.g + plan_y.n]; + let mut z = vec![0u64; plan_z.g + plan_z.n]; + let mut r = vec![0u64; max(x.len(), max(y.len(), z.len()))]; + + /* convolution with modulo P1 */ + for i in 0..b.len() { x[plan_x.g + i] = if b[i] >= P1 { b[i] - P1 } else { b[i] }; } + for i in 0..c.len() { r[plan_x.g + i] = if c[i] >= P1 { c[i] - P1 } else { c[i] }; } + conv::(&plan_x, &mut x, b.len(), &mut r[..plan_x.g+plan_x.n], c.len(), 1); + + /* convolution with modulo P2 */ + for i in 0..b.len() { y[plan_y.g + i] = if b[i] >= P2 { b[i] - P2 } else { b[i] }; } + for i in 0..c.len() { r[plan_y.g + i] = if c[i] >= P2 { c[i] - P2 } else { c[i] }; } + (&mut r[plan_y.g..])[c.len()..plan_y.n].fill(0u64); + conv::(&plan_y, &mut y, b.len(), &mut r[..plan_y.g+plan_y.n], c.len(), 1); + + /* convolution with modulo P3 */ + for i in 0..b.len() { z[plan_z.g + i] = if b[i] >= P3 { b[i] - P3 } else { b[i] }; } + for i in 0..c.len() { r[plan_z.g + i] = if c[i] >= P3 { c[i] - P3 } else { c[i] }; } + (&mut r[plan_z.g..])[c.len()..plan_z.n].fill(0u64); + conv::(&plan_z, &mut z, b.len(), &mut r[..plan_z.g+plan_z.n], c.len(), 1); + + /* merge the results in {x, y, z} into acc (process carry along the way) */ + let mut carry: u128 = 0; + for i in 0..min_len { + let (a, b, c) = (x[i], y[i], z[i]); + // We need to solve the following system of linear congruences: + // x === a mod P1, + // x === b mod P2, + // x === c mod P3. + // The first two equations are equivalent to + // x === a + P1 * (U * (b-a) mod P2) mod P1P2, + // where U is the solution to + // P1 * U === 1 mod P2. + let bma = Arith::::submod(b, a); + let u = Arith::::mmulmod(bma, P1INV_R_MOD_P2); + let v = a as u128 + P1 as u128 * u as u128; + let v_mod_p3 = Arith::::addmod(a, Arith::::mmulmod(P1_R_MOD_P3, u)); + // Now we have reduced the congruences into two: + // x === v mod P1P2, + // x === c mod P3. + // The solution is + // x === v + P1P2 * (V * (c-v) mod P3) mod P1P2P3, + // where V is the solution to + // P1P2 * V === 1 mod P3. + let cmv = Arith::::submod(c, v_mod_p3); + let vcmv = Arith::::mmulmod(cmv, P1P2INV_R_MOD_P3); + let (out_01, overflow) = carry.overflowing_add(v + P1P2_LO as u128 * vcmv as u128); + let out_0 = out_01 as u64; + let out_12 = P1P2_HI as u128 * vcmv as u128 + (out_01 >> 64) + + if overflow { 1u128 << 64 } else { 0 }; + + let (v, overflow) = acc[i].overflowing_add(out_0); + acc[i] = v; + carry = out_12 + u128::from(overflow); + } + propagate_carry(&mut acc[min_len..], carry as u64); +} + +fn mac3_u64(acc: &mut [u64], b: &[u64], c: &[u64]) { + const fn compute_bits(l: u64) -> u64 { + let total_bits = l * 64; + let (mut lo, mut hi) = (42, 62); + while lo < hi { + let mid = (lo + hi + 1) / 2; + let single_digit_max_val = (1u64 << mid) - 1; + let l_corrected = (total_bits + mid - 1) / mid; + let (lhs, overflow) = (single_digit_max_val as u128).pow(2).overflowing_mul(l_corrected as u128); + if !overflow && lhs < P2 as u128 * P3 as u128 { lo = mid; } + else { hi = mid - 1; } + } + lo + } + + let (b, c) = if b.len() < c.len() { (b, c) } else { (c, b) }; + let naive_cost = NttPlan::build::(b.len() + c.len()).cost; + let split_cost = NttPlan::build::(b.len() + b.len()).cost * (c.len() / b.len()) + + if c.len() % b.len() > 0 { NttPlan::build::(b.len() + (c.len() % b.len())).cost } else { 0 }; + if b.len() >= 128 && split_cost < naive_cost { + /* special handling for unbalanced multiplication: + we reduce it to about `c.len()/b.len()` balanced multiplications */ + let mut i = 0usize; + let mut carry = 0u64; + while i < c.len() { + let j = min(i + b.len(), c.len()); + let k = j + b.len(); + let tmp = acc[k]; + acc[k] = 0; + mac3_u64(&mut acc[i..=k], b, &c[i..j]); + (acc[k], carry) = (tmp, acc[k] + propagate_carry(&mut acc[j..k], carry)); + i = j; + } + propagate_carry(&mut acc[i + b.len()..], carry); + return; + } + + // We have two choices: + // 1. NTT with two primes. + // 2. NTT with three primes. + // Obviously we want to do only two passes for efficiency, not three. + // However, the number of bits per u64 we can pack for NTT + // depends on the length of the arrays being multiplied (convolved). + // If the arrays are too long, the resulting values may exceed the + // modulus range P2 * P3, which leads to incorrect results. + // Hence, we compute the number of bits required by the length of NTT, + // and use it to determine whether to use two-prime or three-prime. + // Since we can pack 64 bits per u64 in three-prime NTT, the effective + // number of bits in three-prime NTT is 64/3 = 21.3333..., which means + // two-prime NTT can only do better when at least 43 bits per u64 can + // be packed into each u64. + let max_cnt = max(b.len(), c.len()) as u64; + let bits = compute_bits(max_cnt); + if bits >= 43 { + /* can pack more effective bits per u64 with two primes than with three primes */ + mac3_two_primes(acc, b, c, bits); + } else { + /* can pack at most 21 effective bits per u64, which is worse than + 64/3 = 21.3333.. effective bits per u64 achieved with three primes */ + mac3_three_primes(acc, b, c); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +#[cfg(u64_digit)] +pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { + mac3_u64(acc, b, c); +} + +#[cfg(not(u64_digit))] +pub fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) { + fn bigdigit_to_u64(src: &[BigDigit], is_acc: bool) -> Vec { + let mut out = vec![0u64; (src.len() + 1) / 2 + is_acc as usize]; + for i in 0..src.len()/2 { + out[i] = (src[2*i] as u64) | ((src[2*i+1] as u64) << 32); + } + if src.len() % 2 == 1 { + out[src.len()/2] = src[src.len()-1] as u64; + } + out + } + fn u64_to_bigdigit(src: &[u64], dst: &mut [BigDigit]) { + for i in 0..dst.len()/2 { + dst[2*i] = src[i] as BigDigit; + dst[2*i+1] = (src[i] >> 32) as BigDigit; + } + if dst.len() % 2 == 1 { + dst[dst.len()-1] = src[src.len()-1] as BigDigit; + } + } + + /* convert to u64 => process => convert back to BigDigit (u32) */ + let mut acc_u64 = bigdigit_to_u64(acc, true); + mac3_u64(&mut acc_u64, &bigdigit_to_u64(b, false), &bigdigit_to_u64(c, false)); + u64_to_bigdigit(&acc_u64, acc); +} \ No newline at end of file