From be631052be64f373818e4427fe42dd2426481618 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Mon, 11 Dec 2017 13:54:58 -0800 Subject: [PATCH] Make Shr for negative BigInt round down, like primitives do Primitive integers always round down when shifting right, but `BigInt` was effectively rounding toward zero, because it just kept its sign and used the `BigUint` magnitude rounded down (always toward zero). Now we adjust the result of shifting negative values, and explicitly test that it matches the result for primitive integers. --- src/bigint.rs | 17 +++++++++++++++-- src/biguint.rs | 34 ++++++++++++++++++---------------- src/tests/bigint.rs | 25 +++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/src/bigint.rs b/src/bigint.rs index 9a8583e3..c5f1d87d 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -228,12 +228,23 @@ impl<'a> Shl for &'a BigInt { } } +// Negative values need a rounding adjustment if there are any ones in the +// bits that are getting shifted out. +fn shr_round_down(i: &BigInt, rhs: usize) -> bool { + i.is_negative() && + biguint::trailing_zeros(&i.data) + .map(|n| n < rhs) + .unwrap_or(false) +} + impl Shr for BigInt { type Output = BigInt; #[inline] fn shr(self, rhs: usize) -> BigInt { - BigInt::from_biguint(self.sign, self.data >> rhs) + let round_down = shr_round_down(&self, rhs); + let data = self.data >> rhs; + BigInt::from_biguint(self.sign, if round_down { data + 1u8 } else { data }) } } @@ -242,7 +253,9 @@ impl<'a> Shr for &'a BigInt { #[inline] fn shr(self, rhs: usize) -> BigInt { - BigInt::from_biguint(self.sign, &self.data >> rhs) + let round_down = shr_round_down(&self, rhs); + let data = &self.data >> rhs; + BigInt::from_biguint(self.sign, if round_down { data + 1u8 } else { data }) } } diff --git a/src/biguint.rs b/src/biguint.rs index 3891fe32..cd21b10c 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -944,6 +944,11 @@ impl Integer for BigUint { /// The result is always positive. #[inline] fn gcd(&self, other: &Self) -> Self { + #[inline] + fn twos(x: &BigUint) -> usize { + trailing_zeros(x).unwrap_or(0) + } + // Stein's algorithm if self.is_zero() { return other.clone(); @@ -955,17 +960,14 @@ impl Integer for BigUint { let mut n = other.clone(); // find common factors of 2 - let shift = cmp::min( - n.trailing_zeros(), - m.trailing_zeros() - ); + let shift = cmp::min(twos(&n), twos(&m)); // divide m and n by 2 until odd // m inside loop - n >>= n.trailing_zeros(); + n >>= twos(&n); while !m.is_zero() { - m >>= m.trailing_zeros(); + m >>= twos(&m); if n > m { mem::swap(&mut n, &mut m) } m -= &n; } @@ -1628,16 +1630,6 @@ impl BigUint { return self.data.len() * big_digit::BITS - zeros as usize; } - // self is assumed to be normalized - fn trailing_zeros(&self) -> usize { - self.data - .iter() - .enumerate() - .find(|&(_, &digit)| digit != 0) - .map(|(i, digit)| i * big_digit::BITS + digit.trailing_zeros() as usize) - .unwrap_or(0) - } - /// Strips off trailing zero bigdigits - comparisons require the last element in the vector to /// be nonzero. #[inline] @@ -1689,6 +1681,16 @@ impl BigUint { } } +/// Returns the number of least-significant bits that are zero, +/// or `None` if the entire number is zero. +pub fn trailing_zeros(u: &BigUint) -> Option { + u.data + .iter() + .enumerate() + .find(|&(_, &digit)| digit != 0) + .map(|(i, digit)| i * big_digit::BITS + digit.trailing_zeros() as usize) +} + #[cfg(feature = "serde")] impl serde::Serialize for BigUint { fn serialize(&self, serializer: &mut S) -> Result<(), S::Error> diff --git a/src/tests/bigint.rs b/src/tests/bigint.rs index a3c8f6fe..1ff1d673 100644 --- a/src/tests/bigint.rs +++ b/src/tests/bigint.rs @@ -1192,3 +1192,28 @@ fn test_negative_rand_range() { // Switching u and l should fail: let _n: BigInt = rng.gen_bigint_range(&u, &l); } + +#[test] +fn test_negative_shr() { + assert_eq!(BigInt::from(-1) >> 1, BigInt::from(-1)); + assert_eq!(BigInt::from(-2) >> 1, BigInt::from(-1)); + assert_eq!(BigInt::from(-3) >> 1, BigInt::from(-2)); + assert_eq!(BigInt::from(-3) >> 2, BigInt::from(-1)); +} + +#[test] +fn test_random_shr() { + use rand::Rng; + let mut rng = thread_rng(); + + for p in rng.gen_iter::().take(1000) { + let big = BigInt::from(p); + let bigger = &big << 1000; + assert_eq!(&bigger >> 1000, big); + for i in 0..64 { + let answer = BigInt::from(p >> i); + assert_eq!(&big >> i, answer); + assert_eq!(&bigger >> (1000 + i), answer); + } + } +}