From 5997113c3684d17805dd59233349c2444a8a7a96 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sun, 12 May 2019 23:18:15 +0900 Subject: [PATCH 1/3] Optimize BigUint::modpow for even exponents --- src/biguint.rs | 134 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 106 insertions(+), 28 deletions(-) diff --git a/src/biguint.rs b/src/biguint.rs index da953dac..084c7da7 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -2126,36 +2126,13 @@ impl BigUint { pub fn modpow(&self, exponent: &Self, modulus: &Self) -> Self { assert!(!modulus.is_zero(), "divide by zero!"); - // For an odd modulus, we can use Montgomery multiplication in base 2^32. if modulus.is_odd() { - return monty_modpow(self, exponent, modulus); - } - - // Otherwise do basically the same as `num::pow`, but with a modulus. - let one = BigUint::one(); - if exponent.is_zero() { - return one; - } - - let mut base = self % modulus; - let mut exp = exponent.clone(); - while exp.is_even() { - base = &base * &base % modulus; - exp >>= 1; - } - if exp == one { - return base; - } - - let mut acc = base.clone(); - while exp > one { - exp >>= 1; - base = &base * &base % modulus; - if exp.is_odd() { - acc = acc * &base % modulus; - } + // For an odd modulus, we can use Montgomery multiplication in base 2^32. + monty_modpow(self, exponent, modulus) + } else { + // Otherwise do basically the same as `num::pow`, but with a modulus. + plain_modpow(self, &exponent.data, modulus) } - acc } /// Returns the truncated principal square root of `self` -- @@ -2177,6 +2154,107 @@ impl BigUint { } } +fn plain_modpow<'a, T>(base: &BigUint, exp_data: &Vec, modulus: &BigUint) -> BigUint +where + T: Copy + PartialOrd + Unsigned + ShrAssign + BitAnd, +{ + assert!(!modulus.is_zero(), "divide by zero!"); + + if exp_data.len() == 0 { + return BigUint::one(); + } + + let digit_bits = mem::size_of::() * 8; + let one = One::one(); + let last_i = exp_data.len() - 1; + let mut base = base.clone(); + let mut i = 0usize; + while exp_data[i].is_zero() { + for _ in 0..digit_bits { + base = &base * &base % modulus; + } + i += 1; + } + + let mut b = 0usize; + let mut r = exp_data[i]; + while (r & one).is_zero() { + base = &base * &base % modulus; + r >>= one; + b += 1; + } + + let last = i == last_i; + if last && r.is_one() { + return base; + } + + let mut acc = base.clone(); + { + let mut unit = |bit| { + base = &base * &base % modulus; + if bit == one { + acc = &acc * &base % modulus; + } + }; + + r >>= one; + b += 1; + if !last { + // consume exp_data[i] + for _ in b..digit_bits { + unit(r & one); + r >>= one; + } + for i in (i + 1)..last_i { + r = exp_data[i]; + for _ in 0..digit_bits { + unit(r & one); + r >>= one; + } + } + r = exp_data[last_i]; + } + while !r.is_zero() { + unit(r & one); + r >>= one; + } + } + acc +} + +#[test] +fn test_plain_modpow() { + let two = BigUint::from(2u32); + let modulus = BigUint::from(0x1100u32); + + let exp: Vec = vec![0, 0b1]; + assert_eq!( + two.pow(0b1_00000000_u32) % &modulus, + plain_modpow(&two, &exp, &modulus) + ); + let exp: Vec = vec![0, 0b10]; + assert_eq!( + two.pow(0b10_00000000_u32) % &modulus, + plain_modpow(&two, &exp, &modulus) + ); + let exp: Vec = vec![0, 0b110010]; + assert_eq!( + two.pow(0b110010_00000000_u32) % &modulus, + plain_modpow(&two, &exp, &modulus) + ); + let exp: Vec = vec![0b1, 0b1]; + assert_eq!( + two.pow(0b1_00000001_u32) % &modulus, + plain_modpow(&two, &exp, &modulus) + ); + let exp: Vec = vec![0b1100, 0, 0b1]; + assert_eq!( + two.pow(0b1_00000000_00001100_u32) % &modulus, + plain_modpow(&two, &exp, &modulus) + ); +} + /// Returns the number of least-significant bits that are zero, /// or `None` if the entire number is zero. pub fn trailing_zeros(u: &BigUint) -> Option { From 2905608107194e887ff96b1b4569fdcd5c68d5a2 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Mon, 5 Aug 2019 18:24:59 -0700 Subject: [PATCH 2/3] Simplify plain_modpow for BigDigit, not generic --- src/biguint.rs | 45 ++++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/biguint.rs b/src/biguint.rs index 084c7da7..194699d2 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -2154,23 +2154,18 @@ impl BigUint { } } -fn plain_modpow<'a, T>(base: &BigUint, exp_data: &Vec, modulus: &BigUint) -> BigUint -where - T: Copy + PartialOrd + Unsigned + ShrAssign + BitAnd, -{ +fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint { assert!(!modulus.is_zero(), "divide by zero!"); if exp_data.len() == 0 { return BigUint::one(); } - let digit_bits = mem::size_of::() * 8; - let one = One::one(); let last_i = exp_data.len() - 1; let mut base = base.clone(); let mut i = 0usize; while exp_data[i].is_zero() { - for _ in 0..digit_bits { + for _ in 0..big_digit::BITS { base = &base * &base % modulus; } i += 1; @@ -2178,9 +2173,9 @@ where let mut b = 0usize; let mut r = exp_data[i]; - while (r & one).is_zero() { + while r.is_even() { base = &base * &base % modulus; - r >>= one; + r >>= 1; b += 1; } @@ -2191,33 +2186,33 @@ where let mut acc = base.clone(); { - let mut unit = |bit| { + let mut unit = |exp_is_odd| { base = &base * &base % modulus; - if bit == one { + if exp_is_odd { acc = &acc * &base % modulus; } }; - r >>= one; + r >>= 1; b += 1; if !last { // consume exp_data[i] - for _ in b..digit_bits { - unit(r & one); - r >>= one; + for _ in b..big_digit::BITS { + unit(r.is_odd()); + r >>= 1; } for i in (i + 1)..last_i { r = exp_data[i]; - for _ in 0..digit_bits { - unit(r & one); - r >>= one; + for _ in 0..big_digit::BITS { + unit(r.is_odd()); + r >>= 1; } } r = exp_data[last_i]; } while !r.is_zero() { - unit(r & one); - r >>= one; + unit(r.is_odd()); + r >>= 1; } } acc @@ -2228,27 +2223,27 @@ fn test_plain_modpow() { let two = BigUint::from(2u32); let modulus = BigUint::from(0x1100u32); - let exp: Vec = vec![0, 0b1]; + let exp = vec![0, 0b1]; assert_eq!( two.pow(0b1_00000000_u32) % &modulus, plain_modpow(&two, &exp, &modulus) ); - let exp: Vec = vec![0, 0b10]; + let exp = vec![0, 0b10]; assert_eq!( two.pow(0b10_00000000_u32) % &modulus, plain_modpow(&two, &exp, &modulus) ); - let exp: Vec = vec![0, 0b110010]; + let exp = vec![0, 0b110010]; assert_eq!( two.pow(0b110010_00000000_u32) % &modulus, plain_modpow(&two, &exp, &modulus) ); - let exp: Vec = vec![0b1, 0b1]; + let exp = vec![0b1, 0b1]; assert_eq!( two.pow(0b1_00000001_u32) % &modulus, plain_modpow(&two, &exp, &modulus) ); - let exp: Vec = vec![0b1100, 0, 0b1]; + let exp = vec![0b1100, 0, 0b1]; assert_eq!( two.pow(0b1_00000000_00001100_u32) % &modulus, plain_modpow(&two, &exp, &modulus) From 9616b0c7e6034a6045a75772c132241361d11546 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Tue, 6 Aug 2019 09:46:46 -0700 Subject: [PATCH 3/3] Use an iterator instead of open indexing in plain_modpow --- src/biguint.rs | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/biguint.rs b/src/biguint.rs index 194699d2..32273550 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -2157,34 +2157,35 @@ impl BigUint { fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint { assert!(!modulus.is_zero(), "divide by zero!"); - if exp_data.len() == 0 { - return BigUint::one(); - } + let i = match exp_data.iter().position(|&r| r != 0) { + None => return BigUint::one(), + Some(i) => i, + }; - let last_i = exp_data.len() - 1; let mut base = base.clone(); - let mut i = 0usize; - while exp_data[i].is_zero() { + for _ in 0..i { for _ in 0..big_digit::BITS { base = &base * &base % modulus; } - i += 1; } - let mut b = 0usize; let mut r = exp_data[i]; + let mut b = 0usize; while r.is_even() { base = &base * &base % modulus; r >>= 1; b += 1; } - let last = i == last_i; - if last && r.is_one() { + let mut exp_iter = exp_data[i + 1..].iter(); + if exp_iter.len() == 0 && r.is_one() { return base; } let mut acc = base.clone(); + r >>= 1; + b += 1; + { let mut unit = |exp_is_odd| { base = &base * &base % modulus; @@ -2193,23 +2194,25 @@ fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> Big } }; - r >>= 1; - b += 1; - if !last { + if let Some(&last) = exp_iter.next_back() { // consume exp_data[i] for _ in b..big_digit::BITS { unit(r.is_odd()); r >>= 1; } - for i in (i + 1)..last_i { - r = exp_data[i]; + + // consume all other digits before the last + for &r in exp_iter { + let mut r = r; for _ in 0..big_digit::BITS { unit(r.is_odd()); r >>= 1; } } - r = exp_data[last_i]; + r = last; } + + debug_assert_ne!(r, 0); while !r.is_zero() { unit(r.is_odd()); r >>= 1;