diff --git a/benches/bigint.rs b/benches/bigint.rs index 6893f137..bba683bc 100644 --- a/benches/bigint.rs +++ b/benches/bigint.rs @@ -10,7 +10,7 @@ extern crate rand; use std::mem::replace; use test::Bencher; use num_bigint::{BigInt, BigUint, RandBigInt}; -use num_traits::{Zero, One, FromPrimitive, Num}; +use num_traits::{Zero, One, FromPrimitive, Num, Pow}; use rand::{SeedableRng, StdRng}; fn get_rng() -> StdRng { @@ -301,7 +301,7 @@ fn pow_bench(b: &mut Bencher) { for i in 2..upper + 1 { for j in 2..upper + 1 { let i_big = BigUint::from_usize(i).unwrap(); - num_traits::pow(i_big, j); + i_big.pow(j); } } }); diff --git a/src/bigint.rs b/src/bigint.rs index 93bb6b26..520e37dc 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -18,7 +18,7 @@ use serde; use integer::{Integer, Roots}; use traits::{ToPrimitive, FromPrimitive, Num, CheckedAdd, CheckedSub, - CheckedMul, CheckedDiv, Signed, Zero, One}; + CheckedMul, CheckedDiv, Signed, Zero, One, Pow}; use self::Sign::{Minus, NoSign, Plus}; @@ -811,6 +811,54 @@ impl Signed for BigInt { } } + +/// Help function for pow +/// +/// Computes the effect of the exponent on the sign. +#[inline] +fn powsign(sign: Sign, other: &T) -> Sign { + if other.is_zero() { + Plus + } else if sign != Minus { + sign + } else if other.is_odd() { + sign + } else { + -sign + } +} + +macro_rules! pow_impl { + ($T:ty) => { + impl<'a> Pow<$T> for &'a BigInt { + type Output = BigInt; + + #[inline] + fn pow(self, rhs: $T) -> BigInt { + BigInt::from_biguint(powsign(self.sign, &rhs), (&self.data).pow(rhs)) + } + } + + impl<'a, 'b> Pow<&'b $T> for &'a BigInt { + type Output = BigInt; + + #[inline] + fn pow(self, rhs: &$T) -> BigInt { + BigInt::from_biguint(powsign(self.sign, rhs), (&self.data).pow(rhs)) + } + } + } +} + +pow_impl!(u8); +pow_impl!(u16); +pow_impl!(u32); +pow_impl!(u64); +pow_impl!(usize); +#[cfg(has_i128)] +pow_impl!(u128); + + // A convenience method for getting the absolute value of an i32 in a u32. #[inline] fn i32_abs_as_u32(a: i32) -> u32 { diff --git a/src/biguint.rs b/src/biguint.rs index e7a3ce19..795ae323 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -19,7 +19,7 @@ use serde; use integer::{Integer, Roots}; use traits::{ToPrimitive, FromPrimitive, Float, Num, Unsigned, CheckedAdd, CheckedSub, CheckedMul, - CheckedDiv, Zero, One, pow}; + CheckedDiv, Zero, One, Pow}; use big_digit::{self, BigDigit, DoubleBigDigit}; @@ -434,6 +434,55 @@ impl One for BigUint { impl Unsigned for BigUint {} +macro_rules! pow_impl { + ($T:ty) => { + impl<'a> Pow<$T> for &'a BigUint { + type Output = BigUint; + + #[inline] + fn pow(self, mut exp: $T) -> Self::Output { + if exp == 0 { return BigUint::one(); } + let mut base = self.clone(); + + + while exp & 1 == 0 { + base = &base * &base; + exp >>= 1; + } + + if exp == 1 { return base; } + + let mut acc = base.clone(); + while exp > 1 { + exp >>= 1; + base = &base * &base; + if exp & 1 == 1 { + acc = &acc * &base; + } + } + acc + } + } + + impl<'a, 'b> Pow<&'b $T> for &'a BigUint { + type Output = BigUint; + + #[inline] + fn pow(self, exp: &$T) -> Self::Output { + self.pow(*exp) + } + } + } +} + +pow_impl!(u8); +pow_impl!(u16); +pow_impl!(u32); +pow_impl!(u64); +pow_impl!(usize); +#[cfg(has_i128)] +pow_impl!(u128); + forward_all_binop_to_val_ref_commutative!(impl Add for BigUint, add); forward_val_assign!(impl AddAssign for BigUint, add_assign); @@ -1056,7 +1105,7 @@ impl Roots for BigUint { loop { s = u; - let q = self / pow(s.clone(), n_min_1); + let q = self / s.pow(n_min_1); let t: BigUint = n_min_1 * &s + q; u = t / n; diff --git a/tests/bigint.rs b/tests/bigint.rs index be9c871d..67c14921 100644 --- a/tests/bigint.rs +++ b/tests/bigint.rs @@ -20,7 +20,7 @@ use std::hash::{BuildHasher, Hasher, Hash}; use std::collections::hash_map::RandomState; use num_integer::Integer; -use num_traits::{Zero, One, Signed, ToPrimitive, FromPrimitive, Num, Float}; +use num_traits::{Zero, One, Signed, ToPrimitive, FromPrimitive, Num, Float, Pow}; mod consts; use consts::*; @@ -1092,3 +1092,30 @@ fn test_iter_product_generic() { assert_eq!(result, data.iter().product()); assert_eq!(result, data.into_iter().product()); } + +#[test] +fn test_pow() { + let one = BigInt::from(1i32); + let two = BigInt::from(2i32); + let four = BigInt::from(4i32); + let eight = BigInt::from(8i32); + let minus_two = BigInt::from(-2i32); + macro_rules! check { + ($t:ty) => { + assert_eq!(two.pow(0 as $t), one); + assert_eq!(two.pow(1 as $t), two); + assert_eq!(two.pow(2 as $t), four); + assert_eq!(two.pow(3 as $t), eight); + assert_eq!(two.pow(&(3 as $t)), eight); + assert_eq!(minus_two.pow(0 as $t), one, "-2^0"); + assert_eq!(minus_two.pow(1 as $t), minus_two, "-2^1"); + assert_eq!(minus_two.pow(2 as $t), four, "-2^2"); + assert_eq!(minus_two.pow(3 as $t), -&eight, "-2^3"); + } + } + check!(u8); + check!(u16); + check!(u32); + check!(u64); + check!(usize); +} diff --git a/tests/biguint.rs b/tests/biguint.rs index 92c8ce9e..35c54c70 100644 --- a/tests/biguint.rs +++ b/tests/biguint.rs @@ -19,7 +19,7 @@ use std::hash::{BuildHasher, Hasher, Hash}; use std::collections::hash_map::RandomState; use num_traits::{Num, Zero, One, CheckedAdd, CheckedSub, CheckedMul, CheckedDiv, ToPrimitive, - FromPrimitive, Float}; + FromPrimitive, Float, Pow}; mod consts; use consts::*; @@ -766,7 +766,7 @@ fn test_sub() { #[should_panic] fn test_sub_fail_on_underflow() { let (a, b): (BigUint, BigUint) = (Zero::zero(), One::one()); - a - b; + let _ = a - b; } #[test] @@ -1530,3 +1530,31 @@ fn test_iter_product_generic() { assert_eq!(result, data.iter().product()); assert_eq!(result, data.into_iter().product()); } + +#[test] +fn test_pow() { + let one = BigUint::from(1u32); + let two = BigUint::from(2u32); + let four = BigUint::from(4u32); + let eight = BigUint::from(8u32); + let tentwentyfour = BigUint::from(1024u32); + let twentyfourtyeight = BigUint::from(2048u32); + macro_rules! check { + ($t:ty) => { + assert_eq!(two.pow(0 as $t), one); + assert_eq!(two.pow(1 as $t), two); + assert_eq!(two.pow(2 as $t), four); + assert_eq!(two.pow(3 as $t), eight); + assert_eq!(two.pow(10 as $t), tentwentyfour); + assert_eq!(two.pow(11 as $t), twentyfourtyeight); + assert_eq!(two.pow(&(11 as $t)), twentyfourtyeight); + } + } + check!(u8); + check!(u16); + check!(u32); + check!(u64); + check!(usize); + #[cfg(has_i128)] + check!(u128); +} diff --git a/tests/roots.rs b/tests/roots.rs index 58838b76..442f40a9 100644 --- a/tests/roots.rs +++ b/tests/roots.rs @@ -4,7 +4,7 @@ extern crate num_traits; mod biguint { use num_bigint::BigUint; - use num_traits::pow; + use num_traits::Pow; use std::str::FromStr; fn check(x: u64, n: u32) { @@ -17,8 +17,8 @@ mod biguint { assert_eq!(&res, &big_x.cbrt()) } - assert!(pow(res.clone(), n as usize) <= big_x); - assert!(pow(res.clone() + 1u32, n as usize) > big_x); + assert!(res.pow(n) <= big_x); + assert!((res + 1u32).pow(n) > big_x); } #[test] @@ -58,7 +58,7 @@ mod biguint { mod bigint { use num_bigint::BigInt; - use num_traits::{Signed, pow}; + use num_traits::{Signed, Pow}; fn check(x: i64, n: u32) { let big_x = BigInt::from(x); @@ -71,11 +71,11 @@ mod bigint { } if big_x.is_negative() { - assert!(pow(res.clone() - 1u32, n as usize) < big_x); - assert!(pow(res.clone(), n as usize) >= big_x); + assert!(res.pow(n) >= big_x); + assert!((res - 1u32).pow(n) < big_x); } else { - assert!(pow(res.clone(), n as usize) <= big_x); - assert!(pow(res.clone() + 1u32, n as usize) > big_x); + assert!(res.pow(n) <= big_x); + assert!((res + 1u32).pow(n) > big_x); } }