Skip to content

Commit

Permalink
Make Shr for negative BigInt round down, like primitives do
Browse files Browse the repository at this point in the history
Primitive integers always round down when shifting right, but `BigInt`
was effectively rounding toward zero, because it just kept its sign and
used the `BigUint` magnitude rounded down (always toward zero).

Now we adjust the result of shifting negative values, and explicitly
test that it matches the result for primitive integers.
  • Loading branch information
cuviper committed Feb 24, 2018
1 parent 5e389ca commit be63105
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
17 changes: 15 additions & 2 deletions src/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,23 @@ impl<'a> Shl<usize> for &'a BigInt {
}
}

// Negative values need a rounding adjustment if there are any ones in the
// bits that are getting shifted out.
fn shr_round_down(i: &BigInt, rhs: usize) -> bool {
i.is_negative() &&
biguint::trailing_zeros(&i.data)
.map(|n| n < rhs)
.unwrap_or(false)
}

impl Shr<usize> for BigInt {
type Output = BigInt;

#[inline]
fn shr(self, rhs: usize) -> BigInt {
BigInt::from_biguint(self.sign, self.data >> rhs)
let round_down = shr_round_down(&self, rhs);
let data = self.data >> rhs;
BigInt::from_biguint(self.sign, if round_down { data + 1u8 } else { data })
}
}

Expand All @@ -242,7 +253,9 @@ impl<'a> Shr<usize> for &'a BigInt {

#[inline]
fn shr(self, rhs: usize) -> BigInt {
BigInt::from_biguint(self.sign, &self.data >> rhs)
let round_down = shr_round_down(&self, rhs);
let data = &self.data >> rhs;
BigInt::from_biguint(self.sign, if round_down { data + 1u8 } else { data })
}
}

Expand Down
34 changes: 18 additions & 16 deletions src/biguint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,11 @@ impl Integer for BigUint {
/// The result is always positive.
#[inline]
fn gcd(&self, other: &Self) -> Self {
#[inline]
fn twos(x: &BigUint) -> usize {
trailing_zeros(x).unwrap_or(0)
}

// Stein's algorithm
if self.is_zero() {
return other.clone();
Expand All @@ -955,17 +960,14 @@ impl Integer for BigUint {
let mut n = other.clone();

// find common factors of 2
let shift = cmp::min(
n.trailing_zeros(),
m.trailing_zeros()
);
let shift = cmp::min(twos(&n), twos(&m));

// divide m and n by 2 until odd
// m inside loop
n >>= n.trailing_zeros();
n >>= twos(&n);

while !m.is_zero() {
m >>= m.trailing_zeros();
m >>= twos(&m);
if n > m { mem::swap(&mut n, &mut m) }
m -= &n;
}
Expand Down Expand Up @@ -1628,16 +1630,6 @@ impl BigUint {
return self.data.len() * big_digit::BITS - zeros as usize;
}

// self is assumed to be normalized
fn trailing_zeros(&self) -> usize {
self.data
.iter()
.enumerate()
.find(|&(_, &digit)| digit != 0)
.map(|(i, digit)| i * big_digit::BITS + digit.trailing_zeros() as usize)
.unwrap_or(0)
}

/// Strips off trailing zero bigdigits - comparisons require the last element in the vector to
/// be nonzero.
#[inline]
Expand Down Expand Up @@ -1689,6 +1681,16 @@ impl BigUint {
}
}

/// Returns the number of least-significant bits that are zero,
/// or `None` if the entire number is zero.
pub fn trailing_zeros(u: &BigUint) -> Option<usize> {
u.data
.iter()
.enumerate()
.find(|&(_, &digit)| digit != 0)
.map(|(i, digit)| i * big_digit::BITS + digit.trailing_zeros() as usize)
}

#[cfg(feature = "serde")]
impl serde::Serialize for BigUint {
fn serialize<S>(&self, serializer: &mut S) -> Result<(), S::Error>
Expand Down
25 changes: 25 additions & 0 deletions src/tests/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1192,3 +1192,28 @@ fn test_negative_rand_range() {
// Switching u and l should fail:
let _n: BigInt = rng.gen_bigint_range(&u, &l);
}

#[test]
fn test_negative_shr() {
assert_eq!(BigInt::from(-1) >> 1, BigInt::from(-1));
assert_eq!(BigInt::from(-2) >> 1, BigInt::from(-1));
assert_eq!(BigInt::from(-3) >> 1, BigInt::from(-2));
assert_eq!(BigInt::from(-3) >> 2, BigInt::from(-1));
}

#[test]
fn test_random_shr() {
use rand::Rng;
let mut rng = thread_rng();

for p in rng.gen_iter::<i64>().take(1000) {
let big = BigInt::from(p);
let bigger = &big << 1000;
assert_eq!(&bigger >> 1000, big);
for i in 0..64 {
let answer = BigInt::from(p >> i);
assert_eq!(&big >> i, answer);
assert_eq!(&bigger >> (1000 + i), answer);
}
}
}

0 comments on commit be63105

Please sign in to comment.