Skip to content

Commit

Permalink
Modular inverse for BigUint and BigInt
Browse files Browse the repository at this point in the history
  • Loading branch information
cuviper committed Nov 9, 2023
1 parent 3674819 commit b655e8d
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 0 deletions.
57 changes: 57 additions & 0 deletions ci/big_quickcheck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,63 @@ fn quickcheck_modpow() {
qc.quickcheck(test_modpow as fn(i128, u128, i128) -> TestResult);
}

#[test]
fn quickcheck_modinv() {
let gen = Gen::new(usize::max_value());
let mut qc = QuickCheck::new().gen(gen);

fn test_modinv(value: i128, modulus: i128) -> TestResult {
if modulus.is_zero() {
TestResult::discard()
} else {
let value = BigInt::from(value);
let modulus = BigInt::from(modulus);
match (value.modinv(&modulus), value.gcd(&modulus).is_one()) {
(None, false) => TestResult::passed(),
(None, true) => {
eprintln!("{}.modinv({}) -> None, expected Some(_)", value, modulus);
TestResult::failed()
}
(Some(inverse), false) => {
eprintln!(
"{}.modinv({}) -> Some({}), expected None",
value, modulus, inverse
);
TestResult::failed()
}
(Some(inverse), true) => {
// The inverse should either be in [0,m) or (m,0]
let zero = BigInt::zero();
if (modulus.is_positive() && !(zero <= inverse && inverse < modulus))
|| (modulus.is_negative() && !(modulus < inverse && inverse <= zero))
{
eprintln!(
"{}.modinv({}) -> Some({}) is out of range",
value, modulus, inverse
);
return TestResult::failed();
}

// We don't know the expected inverse, but we can verify the product ≡ 1
let product = (&value * &inverse).mod_floor(&modulus);
let mod_one = BigInt::one().mod_floor(&modulus);
if product != mod_one {
eprintln!("{}.modinv({}) -> Some({})", value, modulus, inverse);
eprintln!(
"{} * {} ≡ {}, expected {}",
value, inverse, product, mod_one
);
return TestResult::failed();
}
TestResult::passed()
}
}
}
}

qc.quickcheck(test_modinv as fn(i128, i128) -> TestResult);
}

#[test]
fn quickcheck_to_float_equals_i128_cast() {
let gen = Gen::new(usize::max_value());
Expand Down
58 changes: 58 additions & 0 deletions src/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,64 @@ impl BigInt {
power::modpow(self, exponent, modulus)
}

/// Returns the modular multiplicative inverse if it exists, otherwise `None`.
///
/// This solves for `x` such that `self * x ≡ 1 (mod modulus)`.
/// Note that this rounds like `mod_floor`, not like the `%` operator,
/// which makes a difference when given a negative `self` or `modulus`.
/// The solution will be in the interval `[0, modulus)` for `modulus > 0`,
/// or in the interval `(modulus, 0]` for `modulus < 0`,
/// and it exists if and only if `gcd(self, modulus) == 1`.
///
/// ```
/// use num_bigint::BigInt;
/// use num_integer::Integer;
/// use num_traits::{One, Zero};
///
/// let m = BigInt::from(383);
///
/// // Trivial cases
/// assert_eq!(BigInt::zero().modinv(&m), None);
/// assert_eq!(BigInt::one().modinv(&m), Some(BigInt::one()));
/// let neg1 = &m - 1u32;
/// assert_eq!(neg1.modinv(&m), Some(neg1));
///
/// // Positive self and modulus
/// let a = BigInt::from(271);
/// let x = a.modinv(&m).unwrap();
/// assert_eq!(x, BigInt::from(106));
/// assert_eq!(x.modinv(&m).unwrap(), a);
/// assert_eq!((&a * x).mod_floor(&m), BigInt::one());
///
/// // Negative self and positive modulus
/// let b = -&a;
/// let x = b.modinv(&m).unwrap();
/// assert_eq!(x, BigInt::from(277));
/// assert_eq!((&b * x).mod_floor(&m), BigInt::one());
///
/// // Positive self and negative modulus
/// let n = -&m;
/// let x = a.modinv(&n).unwrap();
/// assert_eq!(x, BigInt::from(-277));
/// assert_eq!((&a * x).mod_floor(&n), &n + 1);
///
/// // Negative self and modulus
/// let x = b.modinv(&n).unwrap();
/// assert_eq!(x, BigInt::from(-106));
/// assert_eq!((&b * x).mod_floor(&n), &n + 1);
/// ```
pub fn modinv(&self, modulus: &Self) -> Option<Self> {
let result = self.data.modinv(&modulus.data)?;
// The sign of the result follows the modulus, like `mod_floor`.
let (sign, mag) = match (self.is_negative(), modulus.is_negative()) {
(false, false) => (Plus, result),
(true, false) => (Plus, &modulus.data - result),
(false, true) => (Minus, &modulus.data - result),
(true, true) => (Minus, result),
};
Some(BigInt::from_biguint(sign, mag))
}

/// Returns the truncated principal square root of `self` --
/// see [`num_integer::Roots::sqrt()`].
pub fn sqrt(&self) -> Self {
Expand Down
80 changes: 80 additions & 0 deletions src/biguint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,86 @@ impl BigUint {
power::modpow(self, exponent, modulus)
}

/// Returns the modular multiplicative inverse if it exists, otherwise `None`.
///
/// This solves for `x` in the interval `[0, modulus)` such that `self * x ≡ 1 (mod modulus)`.
/// The solution exists if and only if `gcd(self, modulus) == 1`.
///
/// ```
/// use num_bigint::BigUint;
/// use num_traits::{One, Zero};
///
/// let m = BigUint::from(383_u32);
///
/// // Trivial cases
/// assert_eq!(BigUint::zero().modinv(&m), None);
/// assert_eq!(BigUint::one().modinv(&m), Some(BigUint::one()));
/// let neg1 = &m - 1u32;
/// assert_eq!(neg1.modinv(&m), Some(neg1));
///
/// let a = BigUint::from(271_u32);
/// let x = a.modinv(&m).unwrap();
/// assert_eq!(x, BigUint::from(106_u32));
/// assert_eq!(x.modinv(&m).unwrap(), a);
/// assert!((a * x % m).is_one());
/// ```
pub fn modinv(&self, modulus: &Self) -> Option<Self> {
// Based on the inverse pseudocode listed here:
// https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers
// TODO: convert to extended *binary* GCD so we're shifting instead of dividing.

assert!(
!modulus.is_zero(),
"attempt to calculate with zero modulus!"
);
if modulus.is_one() {
return Some(Self::zero());
}

let mut r0; // = modulus.clone();
let mut r1 = self % modulus;
let mut t0; // = Self::zero();
let mut t1; // = Self::one();

// Lift and simplify the first iteration to avoid some initial allocations.
if r1.is_zero() {
return None;
} else if r1.is_one() {
return Some(r1);
} else {
let (q, r2) = modulus.div_rem(&r1);
if r2.is_zero() {
return None;
}
r0 = r1;
r1 = r2;
t0 = Self::one();
t1 = modulus - q;
}

while !r1.is_zero() {
let (q, r2) = r0.div_rem(&r1);
r0 = r1;
r1 = r2;

// let t2 = (t0 - q * t1) % modulus;
let qt1 = q * &t1 % modulus;
let t2 = if t0 < qt1 {
t0 + (modulus - qt1)
} else {
t0 - qt1
};
t0 = t1;
t1 = t2;
}

if r0.is_one() {
Some(t0)
} else {
None
}
}

/// Returns the truncated principal square root of `self` --
/// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt)
pub fn sqrt(&self) -> Self {
Expand Down

0 comments on commit b655e8d

Please sign in to comment.