diff --git a/benches/bigint.rs b/benches/bigint.rs index 4c8ec239..3a82c77b 100644 --- a/benches/bigint.rs +++ b/benches/bigint.rs @@ -292,3 +292,11 @@ fn modpow_even(b: &mut Bencher) { b.iter(|| base.modpow(&e, &m)); } + +#[bench] +fn sqrt(b: &mut Bencher) { + let mut rng = get_rng(); + let n = rng.gen_biguint(2048); + + b.iter(|| n.sqrt()); +} diff --git a/src/bigint.rs b/src/bigint.rs index 3c8d2962..2f1feb88 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -2538,6 +2538,18 @@ impl BigInt { }; BigInt::from_biguint(sign, mag) } + + /// Finds square root of `self`. + /// + /// The result is the greatest integer less than or equal to the + /// square root of `self`. + /// + /// Panics if `self` is a negative number. + pub fn sqrt(&self) -> Self { + assert!(!self.is_negative(), "number is negative"); + + BigInt::from_biguint(self.sign, self.data.sqrt()) + } } impl_sum_iter_type!(BigInt); diff --git a/src/biguint.rs b/src/biguint.rs index 5d4aaf89..2a856355 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -1749,6 +1749,49 @@ impl BigUint { } acc } + + /// Finds square root of `self`. + /// + /// The result is the greatest integer less than or equal to the + /// square root of `self`. + pub fn sqrt(&self) -> Self { + let one = BigUint::one(); + + // Trivial cases + if self.is_zero() { + return BigUint::zero(); + } + + if self.is_one() { + return one; + } + + // Newton's method to compute the square root of an integer. + // + // Reference: + // Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 + // + // Set initial guess to something >= floor(sqrt(self)), but as low + // as possible to speed up convergence. + let bit_len = self.len() * big_digit::BITS; + let guess = one << (bit_len/2 + 1); + + let mut u = guess; + let mut s: BigUint; + + loop { + s = u; + let q = self / &s; + let t: BigUint = &s + q; + + // Compute the candidate value for next iteration + u = t >> 1; + + if u >= s { break; } + } + + s + } } /// Returns the number of least-significant bits that are zero, diff --git a/tests/biguint.rs b/tests/biguint.rs index 92c8ce9e..9565a8f2 100644 --- a/tests/biguint.rs +++ b/tests/biguint.rs @@ -956,6 +956,27 @@ fn test_lcm() { check(99, 17, 1683); } +#[test] +fn test_sqrt() { + fn check(n: usize, expected: usize) { + let big_n: BigUint = FromPrimitive::from_usize(n).unwrap(); + let big_expected: BigUint = FromPrimitive::from_usize(expected).unwrap(); + + assert_eq!(big_n.sqrt(), big_expected); + } + + check(0, 0); + check(1, 1); + check(99, 9); + check(100, 10); + check(102, 10); + check(120, 10); + + let big_n: BigUint = FromStr::from_str("123_456_789").unwrap(); + let expected : BigUint = FromStr::from_str("11_111").unwrap(); + assert_eq!(big_n.sqrt(), expected); +} + #[test] fn test_is_even() { let one: BigUint = FromStr::from_str("1").unwrap();