Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular inverse for BigUint and BigInt #288

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions ci/big_quickcheck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,63 @@ fn quickcheck_modpow() {
qc.quickcheck(test_modpow as fn(i128, u128, i128) -> TestResult);
}

#[test]
fn quickcheck_modinv() {
let gen = Gen::new(usize::max_value());
let mut qc = QuickCheck::new().gen(gen);

fn test_modinv(value: i128, modulus: i128) -> TestResult {
if modulus.is_zero() {
TestResult::discard()
} else {
let value = BigInt::from(value);
let modulus = BigInt::from(modulus);
match (value.modinv(&modulus), value.gcd(&modulus).is_one()) {
(None, false) => TestResult::passed(),
(None, true) => {
eprintln!("{}.modinv({}) -> None, expected Some(_)", value, modulus);
TestResult::failed()
}
(Some(inverse), false) => {
eprintln!(
"{}.modinv({}) -> Some({}), expected None",
value, modulus, inverse
);
TestResult::failed()
}
(Some(inverse), true) => {
// The inverse should either be in [0,m) or (m,0]
let zero = BigInt::zero();
if (modulus.is_positive() && !(zero <= inverse && inverse < modulus))
|| (modulus.is_negative() && !(modulus < inverse && inverse <= zero))
{
eprintln!(
"{}.modinv({}) -> Some({}) is out of range",
value, modulus, inverse
);
return TestResult::failed();
}

// We don't know the expected inverse, but we can verify the product ≡ 1
let product = (&value * &inverse).mod_floor(&modulus);
let mod_one = BigInt::one().mod_floor(&modulus);
if product != mod_one {
eprintln!("{}.modinv({}) -> Some({})", value, modulus, inverse);
eprintln!(
"{} * {} ≡ {}, expected {}",
value, inverse, product, mod_one
);
return TestResult::failed();
}
TestResult::passed()
}
}
}
}

qc.quickcheck(test_modinv as fn(i128, i128) -> TestResult);
}

#[test]
fn quickcheck_to_float_equals_i128_cast() {
let gen = Gen::new(usize::max_value());
Expand Down
58 changes: 58 additions & 0 deletions src/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,64 @@ impl BigInt {
power::modpow(self, exponent, modulus)
}

/// Returns the modular multiplicative inverse if it exists, otherwise `None`.
///
/// This solves for `x` such that `self * x ≡ 1 (mod modulus)`.
/// Note that this rounds like `mod_floor`, not like the `%` operator,
/// which makes a difference when given a negative `self` or `modulus`.
/// The solution will be in the interval `[0, modulus)` for `modulus > 0`,
/// or in the interval `(modulus, 0]` for `modulus < 0`,
/// and it exists if and only if `gcd(self, modulus) == 1`.
///
/// ```
/// use num_bigint::BigInt;
/// use num_integer::Integer;
/// use num_traits::{One, Zero};
///
/// let m = BigInt::from(383);
///
/// // Trivial cases
/// assert_eq!(BigInt::zero().modinv(&m), None);
/// assert_eq!(BigInt::one().modinv(&m), Some(BigInt::one()));
/// let neg1 = &m - 1u32;
/// assert_eq!(neg1.modinv(&m), Some(neg1));
///
/// // Positive self and modulus
/// let a = BigInt::from(271);
/// let x = a.modinv(&m).unwrap();
/// assert_eq!(x, BigInt::from(106));
/// assert_eq!(x.modinv(&m).unwrap(), a);
/// assert_eq!((&a * x).mod_floor(&m), BigInt::one());
///
/// // Negative self and positive modulus
/// let b = -&a;
/// let x = b.modinv(&m).unwrap();
/// assert_eq!(x, BigInt::from(277));
/// assert_eq!((&b * x).mod_floor(&m), BigInt::one());
///
/// // Positive self and negative modulus
/// let n = -&m;
/// let x = a.modinv(&n).unwrap();
/// assert_eq!(x, BigInt::from(-277));
/// assert_eq!((&a * x).mod_floor(&n), &n + 1);
///
/// // Negative self and modulus
/// let x = b.modinv(&n).unwrap();
/// assert_eq!(x, BigInt::from(-106));
/// assert_eq!((&b * x).mod_floor(&n), &n + 1);
/// ```
pub fn modinv(&self, modulus: &Self) -> Option<Self> {
let result = self.data.modinv(&modulus.data)?;
// The sign of the result follows the modulus, like `mod_floor`.
let (sign, mag) = match (self.is_negative(), modulus.is_negative()) {
(false, false) => (Plus, result),
(true, false) => (Plus, &modulus.data - result),
(false, true) => (Minus, &modulus.data - result),
(true, true) => (Minus, result),
};
Some(BigInt::from_biguint(sign, mag))
}

/// Returns the truncated principal square root of `self` --
/// see [`num_integer::Roots::sqrt()`].
pub fn sqrt(&self) -> Self {
Expand Down
80 changes: 80 additions & 0 deletions src/biguint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,86 @@ impl BigUint {
power::modpow(self, exponent, modulus)
}

/// Returns the modular multiplicative inverse if it exists, otherwise `None`.
///
/// This solves for `x` in the interval `[0, modulus)` such that `self * x ≡ 1 (mod modulus)`.
/// The solution exists if and only if `gcd(self, modulus) == 1`.
///
/// ```
/// use num_bigint::BigUint;
/// use num_traits::{One, Zero};
///
/// let m = BigUint::from(383_u32);
///
/// // Trivial cases
/// assert_eq!(BigUint::zero().modinv(&m), None);
/// assert_eq!(BigUint::one().modinv(&m), Some(BigUint::one()));
/// let neg1 = &m - 1u32;
/// assert_eq!(neg1.modinv(&m), Some(neg1));
///
/// let a = BigUint::from(271_u32);
/// let x = a.modinv(&m).unwrap();
/// assert_eq!(x, BigUint::from(106_u32));
/// assert_eq!(x.modinv(&m).unwrap(), a);
/// assert!((a * x % m).is_one());
/// ```
pub fn modinv(&self, modulus: &Self) -> Option<Self> {
// Based on the inverse pseudocode listed here:
// https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers
// TODO: consider Binary or Lehmer's GCD algorithms for optimization.

assert!(
!modulus.is_zero(),
"attempt to calculate with zero modulus!"
);
if modulus.is_one() {
return Some(Self::zero());
}

let mut r0; // = modulus.clone();
let mut r1 = self % modulus;
let mut t0; // = Self::zero();
let mut t1; // = Self::one();

// Lift and simplify the first iteration to avoid some initial allocations.
if r1.is_zero() {
return None;
} else if r1.is_one() {
return Some(r1);
} else {
let (q, r2) = modulus.div_rem(&r1);
if r2.is_zero() {
return None;
}
r0 = r1;
r1 = r2;
t0 = Self::one();
t1 = modulus - q;
}

while !r1.is_zero() {
let (q, r2) = r0.div_rem(&r1);
r0 = r1;
r1 = r2;

// let t2 = (t0 - q * t1) % modulus;
let qt1 = q * &t1 % modulus;
let t2 = if t0 < qt1 {
t0 + (modulus - qt1)
} else {
t0 - qt1
};
t0 = t1;
t1 = t2;
}

if r0.is_one() {
Some(t0)
} else {
None
}
}

/// Returns the truncated principal square root of `self` --
/// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt)
pub fn sqrt(&self) -> Self {
Expand Down