diff --git a/benches/bigint.rs b/benches/bigint.rs index e731ce1a..6893f137 100644 --- a/benches/bigint.rs +++ b/benches/bigint.rs @@ -11,7 +11,7 @@ use std::mem::replace; use test::Bencher; use num_bigint::{BigInt, BigUint, RandBigInt}; use num_traits::{Zero, One, FromPrimitive, Num}; -use rand::{SeedableRng, StdRng, Rng}; +use rand::{SeedableRng, StdRng}; fn get_rng() -> StdRng { let mut seed = [0; 32]; @@ -361,14 +361,9 @@ fn roots_cbrt(b: &mut Bencher) { } #[bench] -fn roots_nth(b: &mut Bencher) { +fn roots_nth_100(b: &mut Bencher) { let mut rng = get_rng(); let x = rng.gen_biguint(2048); - // Although n is u32, here we limit it to the set of u8 values since it - // hugely impacts the performance of nth_root due to exponentiation to - // the power of n-1. Using very large values for n is also not very realistic, - // and any n > x's bit size produces 1 as a result anyway. - let n: u8 = rng.gen(); - b.iter(|| { x.nth_root(n as u32) }); + b.iter(|| x.nth_root(100)); } diff --git a/src/bigint.rs b/src/bigint.rs index beeae2c0..c5195fe7 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -1805,10 +1805,20 @@ impl Integer for BigInt { impl Roots for BigInt { fn nth_root(&self, n: u32) -> Self { assert!(!(self.is_negative() && n.is_even()), - "n-th root is undefined for number (n={})", n); + "root of degree {} is imaginary", n); BigInt::from_biguint(self.sign, self.data.nth_root(n)) } + + fn sqrt(&self) -> Self { + assert!(!self.is_negative(), "square root is imaginary"); + + BigInt::from_biguint(self.sign, self.data.sqrt()) + } + + fn cbrt(&self) -> Self { + BigInt::from_biguint(self.sign, self.data.cbrt()) + } } impl ToPrimitive for BigInt { diff --git a/src/biguint.rs b/src/biguint.rs index 6e7c991c..87ed7531 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -1027,32 +1027,30 @@ impl Integer for BigUint { } impl Roots for BigUint { - fn nth_root(&self, n: u32) -> Self { - assert!(n > 0, "n must be at least 1"); + // nth_root, sqrt and cbrt use Newton's method to compute + // principal root of a given degree for a given integer. - let one = BigUint::one(); + // Reference: + // Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.14 + fn nth_root(&self, n: u32) -> Self { + assert!(n > 0, "root degree n must be at least 1"); - // Trivial cases - if self.is_zero() { - return BigUint::zero(); + if self.is_zero() || self.is_one() { + return self.clone() } - if self.is_one() { - return one; + match n { // Optimize for small n + 1 => return self.clone(), + 2 => return self.sqrt(), + 3 => return self.cbrt(), + _ => (), } let n = n as usize; - let n_min_1 = (n as usize) - 1; - - // Newton's method to compute the nth root of an integer. - // - // Reference: - // Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.14 - // - // Set initial guess to something definitely >= floor(nth_root of self) - // but as low as possible to speed up convergence. + let n_min_1 = n - 1; + let bit_len = self.len() * big_digit::BITS; - let guess = one << (bit_len/n + 1); + let guess = BigUint::one() << (bit_len/n + 1); let mut u = guess; let mut s: BigUint; @@ -1062,7 +1060,6 @@ impl Roots for BigUint { let q = self / pow(s.clone(), n_min_1); let t: BigUint = n_min_1 * &s + q; - // Compute the candidate value for next iteration u = t / n; if u >= s { break; } @@ -1070,6 +1067,54 @@ impl Roots for BigUint { s } + + // Reference: + // Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 + fn sqrt(&self) -> Self { + if self.is_zero() || self.is_one() { + return self.clone() + } + + let bit_len = self.len() * big_digit::BITS; + let guess = BigUint::one() << (bit_len/2 + 1); + + let mut u = guess; + let mut s: BigUint; + + loop { + s = u; + let q = self / &s; + let t: BigUint = &s + q; + u = t >> 1; + + if u >= s { break; } + } + + s + } + + fn cbrt(&self) -> Self { + if self.is_zero() || self.is_one() { + return self.clone() + } + + let bit_len = self.len() * big_digit::BITS; + let guess = BigUint::one() << (bit_len/3 + 1); + + let mut u = guess; + let mut s: BigUint; + + loop { + s = u; + let q = self / (&s * &s); + let t: BigUint = (&s << 1) + q; + u = t / 3u32; + + if u >= s { break; } + } + + s + } } fn high_bits_to_u64(v: &BigUint) -> u64 { @@ -1797,8 +1842,7 @@ impl BigUint { } /// Returns the truncated principal square root of `self` -- - /// see [Roots::sqrt](Roots::sqrt). - // struct.BigInt.html#trait.Roots + /// see [Roots::sqrt](Roots::sqrt) pub fn sqrt(&self) -> Self { Roots::sqrt(self) } @@ -1810,7 +1854,7 @@ impl BigUint { } /// Returns the truncated principal `n`th root of `self` -- - /// See [Roots::nth_root](Roots::nth_root). + /// see [Roots::nth_root](Roots::nth_root). pub fn nth_root(&self, n: u32) -> Self { Roots::nth_root(self, n) } diff --git a/tests/roots.rs b/tests/roots.rs index 4ceb5491..24f7ea70 100644 --- a/tests/roots.rs +++ b/tests/roots.rs @@ -4,46 +4,53 @@ extern crate num_traits; mod biguint { use num_bigint::BigUint; - use num_traits::FromPrimitive; + use num_traits::pow; use std::str::FromStr; - fn check(x: i32, n: u32, expected: i32) { - let big_x: BigUint = FromPrimitive::from_i32(x).unwrap(); - let big_expected: BigUint = FromPrimitive::from_i32(expected).unwrap(); + fn check(x: u64, n: u32) { + let big_x = BigUint::from(x); + let res = big_x.nth_root(n); - assert_eq!(big_x.nth_root(n), big_expected); + if n == 2 { + assert_eq!(&res, &big_x.sqrt()) + } else if n == 3 { + assert_eq!(&res, &big_x.cbrt()) + } + + assert!(pow(res.clone(), n as usize) <= big_x); + assert!(pow(res.clone() + 1u32, n as usize) > big_x); } #[test] fn test_sqrt() { - check(99, 2, 9); - check(100, 2, 10); - check(120, 2, 10); + check(99, 2); + check(100, 2); + check(120, 2); } #[test] fn test_cbrt() { - check(8, 3, 2); - check(26, 3, 2); + check(8, 3); + check(26, 3); } #[test] fn test_nth_root() { - check(0, 1, 0); - check(10, 1, 10); - check(100, 4, 3); + check(0, 1); + check(10, 1); + check(100, 4); } #[test] #[should_panic] fn test_nth_root_n_is_zero() { - check(4, 0, 0); + check(4, 0); } #[test] fn test_nth_root_big() { - let x: BigUint = FromStr::from_str("123_456_789").unwrap(); - let expected : BigUint = FromPrimitive::from_i32(6).unwrap(); + let x = BigUint::from_str("123_456_789").unwrap(); + let expected = BigUint::from(6u32); assert_eq!(x.nth_root(10), expected); } @@ -51,34 +58,47 @@ mod biguint { mod bigint { use num_bigint::BigInt; - use num_traits::FromPrimitive; - - fn check(x: i32, n: u32, expected: i32) { - let big_x: BigInt = FromPrimitive::from_i32(x).unwrap(); - let big_expected: BigInt = FromPrimitive::from_i32(expected).unwrap(); - - assert_eq!(big_x.nth_root(n), big_expected); + use num_traits::{Signed, pow}; + + fn check(x: i64, n: u32) { + let big_x = BigInt::from(x); + let res = big_x.nth_root(n); + + if n == 2 { + assert_eq!(&res, &big_x.sqrt()) + } else if n == 3 { + assert_eq!(&res, &big_x.cbrt()) + } + + if big_x.is_negative() { + assert!(pow(res.clone() - 1u32, n as usize) < big_x); + assert!(pow(res.clone(), n as usize) >= big_x); + } else { + assert!(pow(res.clone(), n as usize) <= big_x); + assert!(pow(res.clone() + 1u32, n as usize) > big_x); + } } #[test] fn test_nth_root() { - check(-100, 3, -4); + check(100, 4); } #[test] #[should_panic] fn test_nth_root_x_neg_n_even() { - check(-100, 4, 0); + check(-100, 4); } #[test] #[should_panic] fn test_sqrt_x_neg() { - check(-4, 2, -2); + check(-4, 2); } #[test] fn test_cbrt() { - check(-8, 3, -2); + check(8, 3); + check(-8, 3); } }