Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-DSA: Key generation across all parameter sets. #292

Merged
merged 10 commits into from
Jun 4, 2024
82 changes: 73 additions & 9 deletions libcrux-ml-dsa/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use crate::constants::{BITS_IN_LOWER_PART_OF_T, COEFFICIENTS_IN_RING_ELEMENT, FIELD_MODULUS};

/// Values having this type hold a representative 'x' of the ML-DSA field.
pub(crate) type FieldElement = i32;

#[derive(Clone, Copy)]
#[derive(Clone, Copy, Debug)]
pub struct PolynomialRingElement {
pub(crate) coefficients: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT],
}
Expand All @@ -15,16 +12,74 @@ impl PolynomialRingElement {
};
}

// Splits 0 ≤ t < Q into t0 and t1 with a = t1*2ᴰ + t0
pub(crate) fn add_to_ring_element(
mut lhs: PolynomialRingElement,
rhs: &PolynomialRingElement,
) -> PolynomialRingElement {
for i in 0..lhs.coefficients.len() {
lhs.coefficients[i] += rhs.coefficients[i];
}

lhs
}

pub(crate) fn get_n_least_significant_bits(n: u8, value: u64) -> u64 {
value & ((1 << n) - 1)
}

/// Values having this type hold a representative 'x' of the ML-DSA field.
pub(crate) type FieldElement = i32;

/// If 'x' denotes a value of type `fe`, values having this type hold a
/// representative y ≡ x·MONTGOMERY_R^(-1) (mod FIELD_MODULUS).
/// We use 'mfe' as a shorthand for this type
pub(crate) type MontgomeryFieldElement = i32;

/// If 'x' denotes a value of type `fe`, values having this type hold a
/// representative y ≡ x·MONTGOMERY_R (mod FIELD_MODULUS).
/// We use 'fer' as a shorthand for this type.
pub(crate) type FieldElementTimesMontgomeryR = i32;

const MONTGOMERY_SHIFT: u8 = 32;
const INVERSE_OF_MODULUS_MOD_MONTGOMERY_R: u64 = 58_728_449; // FIELD_MODULUS^{-1} mod 2^32

pub(crate) fn montgomery_reduce(value: i64) -> MontgomeryFieldElement {
let t = get_n_least_significant_bits(MONTGOMERY_SHIFT, value as u64)
* INVERSE_OF_MODULUS_MOD_MONTGOMERY_R;
let k = get_n_least_significant_bits(MONTGOMERY_SHIFT, t) as i32;

let k_times_modulus = (k as i64) * (FIELD_MODULUS as i64);

let c = (k_times_modulus >> MONTGOMERY_SHIFT) as i32;
let value_high = (value >> MONTGOMERY_SHIFT) as i32;

value_high - c
}

#[inline(always)]
pub(crate) fn montgomery_multiply_fe_by_fer(
fe: FieldElement,
fer: FieldElementTimesMontgomeryR,
) -> FieldElement {
montgomery_reduce((fe as i64) * (fer as i64))
}

// Splits t ∈ {0, ..., q-1} into t0 and t1 with a = t1*2ᴰ + t0
// and -2ᴰ⁻¹ < t0 < 2ᴰ⁻¹. Returns t0 and t1 computed as.
//
// - t0 = t mod± 2ᵈ
// - t1 = (t - t0) / 2ᵈ.
//
// We assume the input t is in the signed representative range and convert it
// to the standard unsigned range.
//
// This approach has been taken from:
// https://github.com/cloudflare/circl/blob/main/sign/dilithium/internal/common/field.go#L35
pub(crate) fn power2round(t: i32) -> (i32, i32) {
debug_assert!(t >= 0 && t < FIELD_MODULUS);
debug_assert!(t > -FIELD_MODULUS && t < FIELD_MODULUS, "t is {}", t);

// Convert the signed representative to the standard unsigned one.
let t = t + ((t >> 31) & FIELD_MODULUS);

// Compute t mod 2ᵈ
// t0 is now one of 0, 1, ..., 2ᵈ⁻¹-1, 2ᵈ⁻¹, 2ᵈ⁻¹+1, ..., 2ᵈ-1
Expand Down Expand Up @@ -54,10 +109,19 @@ pub(crate) fn t0_to_unsigned_representative(t0: i32) -> i32 {
mod tests {
use super::*;

#[test]
fn test_montgomery_reduce() {
assert_eq!(montgomery_reduce(10933346042510), -1553279);
assert_eq!(montgomery_reduce(-20392060523118), 1331779);
assert_eq!(montgomery_reduce(13704140696092), -1231016);
assert_eq!(montgomery_reduce(-631922212176), -2580954);
}

#[test]
fn test_power2round() {
assert_eq!(power2round(2898283), (-1685, 354));
assert_eq!(power2round(3821421), (3949, 466));
assert_eq!(power2round(2577417), (-3063, 315));
assert_eq!(power2round(669975), (-1769, 82));
assert_eq!(power2round(1843331), (131, 225));
assert_eq!(power2round(-1568816), (4049, 831));
assert_eq!(power2round(-4022142), (131, 532));
}
}
7 changes: 6 additions & 1 deletion libcrux-ml-dsa/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@ pub(crate) const COEFFICIENTS_IN_RING_ELEMENT: usize = 256;
pub(crate) const FIELD_MODULUS_MINUS_ONE_BIT_LENGTH: usize = 23;

pub(crate) const BITS_IN_LOWER_PART_OF_T: usize = 13;
pub(crate) const BYTES_FOR_RING_ELEMENT_OF_T0S: usize =
(BITS_IN_LOWER_PART_OF_T * COEFFICIENTS_IN_RING_ELEMENT) / 8;

pub(crate) const BITS_IN_UPPER_PART_OF_T: usize =
FIELD_MODULUS_MINUS_ONE_BIT_LENGTH - BITS_IN_LOWER_PART_OF_T;
pub(crate) const BYTES_FOR_RING_ELEMENT_OF_T1S: usize =
(BITS_IN_UPPER_PART_OF_T * COEFFICIENTS_IN_RING_ELEMENT) / 8;

pub(crate) const SEED_FOR_A_SIZE: usize = 32;
pub(crate) const HASH_OF_PUBLIC_KEY_SIZE: usize = 64;
pub(crate) const SEED_FOR_ERROR_VECTORS_SIZE: usize = 64;
pub(crate) const BYTES_FOR_VERIFICATION_KEY_HASH: usize = 64;
pub(crate) const SEED_FOR_SIGNING_SIZE: usize = 32;
6 changes: 4 additions & 2 deletions libcrux-ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ mod arithmetic;
mod constants;
mod hash_functions;
mod matrix;
mod ml_dsa_generic;
mod ntt;
mod sample;
mod serialize;
mod utils;

mod ml_dsa_generic;

pub mod ml_dsa_44;
pub mod ml_dsa_65;
pub mod ml_dsa_87;
80 changes: 69 additions & 11 deletions libcrux-ml-dsa/src/matrix.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,85 @@
use crate::{arithmetic::PolynomialRingElement, sample::sample_ring_element_uniform};
use crate::{
arithmetic::{add_to_ring_element, power2round, PolynomialRingElement},
ntt::{invert_ntt_montgomery, ntt, ntt_multiply_montgomery},
sample::{sample_error_ring_element_uniform, sample_ring_element_uniform},
};

pub(crate) fn power2round_vector<const ROWS_IN_A: usize>(
t: [PolynomialRingElement; ROWS_IN_A],
) -> (
[PolynomialRingElement; ROWS_IN_A],
[PolynomialRingElement; ROWS_IN_A],
) {
let mut vector_t0 = [PolynomialRingElement::ZERO; ROWS_IN_A];
let mut vector_t1 = [PolynomialRingElement::ZERO; ROWS_IN_A];

for i in 0..ROWS_IN_A {
for (j, coefficient) in t[i].coefficients.into_iter().enumerate() {
let (c0, c1) = power2round(coefficient);

vector_t0[i].coefficients[j] = c0;
vector_t1[i].coefficients[j] = c1;
}
}

(vector_t0, vector_t1)
}

#[inline(always)]
pub(crate) fn sample_error_vector<const DIMENSION: usize, const ETA: usize>(
mut seed: [u8; 66],
domain_separator: &mut u16,
) -> [PolynomialRingElement; DIMENSION] {
let mut error = [PolynomialRingElement::ZERO; DIMENSION];
for i in 0..DIMENSION {
seed[64] = *domain_separator as u8;
seed[65] = (*domain_separator >> 8) as u8;
*domain_separator += 1;

error[i] = sample_error_ring_element_uniform::<ETA>(seed);
}

error
}

#[allow(non_snake_case)]
#[inline(always)]
pub(crate) fn expand_to_A<const ROWS_IN_A: usize, const COLUMNS_IN_A: usize>(
mut seed: [u8; 34],
transposed: bool,
) -> [[PolynomialRingElement; COLUMNS_IN_A]; ROWS_IN_A] {
let mut A = [[PolynomialRingElement::ZERO; COLUMNS_IN_A]; ROWS_IN_A];

for i in 0..ROWS_IN_A {
for j in 0..COLUMNS_IN_A {
seed[32] = i as u8;
seed[33] = j as u8;

let sampled = sample_ring_element_uniform(seed);
seed[32] = j as u8;
seed[33] = i as u8;

if transposed {
A[j][i] = sampled;
} else {
A[i][j] = sampled;
}
A[i][j] = sample_ring_element_uniform(seed);
}
}

A
}

/// Compute InvertNTT(Â ◦ ŝ₁) + s₂
#[inline(always)]
#[allow(non_snake_case)]
pub(crate) fn compute_As1_plus_s2<const ROWS_IN_A: usize, const COLUMNS_IN_A: usize>(
A: &[[PolynomialRingElement; COLUMNS_IN_A]; ROWS_IN_A],
s1: &[PolynomialRingElement; COLUMNS_IN_A],
s2: &[PolynomialRingElement; ROWS_IN_A],
) -> [PolynomialRingElement; ROWS_IN_A] {
let mut result = [PolynomialRingElement::ZERO; ROWS_IN_A];

for (i, row) in A.iter().enumerate() {
for (j, ring_element) in row.iter().enumerate() {
let product = ntt_multiply_montgomery(ring_element, &ntt(s1[j]));
result[i] = add_to_ring_element(result[i], &product);
}

result[i] = invert_ntt_montgomery(result[i]);
result[i] = add_to_ring_element(result[i], &s2[i]);
}

result
}
46 changes: 46 additions & 0 deletions libcrux-ml-dsa/src/ml_dsa_44.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use crate::constants::*;

// ML-DSA-44 parameters

const ROWS_IN_A: usize = 4;
const COLUMNS_IN_A: usize = 4;

const ETA: usize = 2;
const TWO_TIMES_ETA_BIT_SIZE: usize = 3; // ⌊log_2(2 * 2)⌋ + 1

const BYTES_FOR_ERROR_RING_ELEMENT: usize =
(TWO_TIMES_ETA_BIT_SIZE * COEFFICIENTS_IN_RING_ELEMENT) / 8;

const VERIFICATION_KEY_SIZE: usize = SEED_FOR_A_SIZE
+ (COEFFICIENTS_IN_RING_ELEMENT
* ROWS_IN_A
* (FIELD_MODULUS_MINUS_ONE_BIT_LENGTH - BITS_IN_LOWER_PART_OF_T))
/ 8;

const SIGNING_KEY_SIZE: usize = SEED_FOR_A_SIZE
+ SEED_FOR_SIGNING_SIZE
+ BYTES_FOR_VERIFICATION_KEY_HASH
+ (ROWS_IN_A + COLUMNS_IN_A) * BYTES_FOR_ERROR_RING_ELEMENT
+ ROWS_IN_A * BYTES_FOR_RING_ELEMENT_OF_T0S;

pub struct MLDSA65KeyPair {
pub signing_key: [u8; SIGNING_KEY_SIZE],
pub verification_key: [u8; VERIFICATION_KEY_SIZE],
}

/// Generate an ML-DSA-65 Key Pair
pub fn generate_key_pair(randomness: [u8; 32]) -> MLDSA65KeyPair {
let (signing_key, verification_key) = crate::ml_dsa_generic::generate_key_pair::<
ROWS_IN_A,
COLUMNS_IN_A,
ETA,
BYTES_FOR_ERROR_RING_ELEMENT,
SIGNING_KEY_SIZE,
VERIFICATION_KEY_SIZE,
>(randomness);

MLDSA65KeyPair {
signing_key,
verification_key,
}
}
17 changes: 11 additions & 6 deletions libcrux-ml-dsa/src/ml_dsa_65.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@ const ROWS_IN_A: usize = 6;
const COLUMNS_IN_A: usize = 5;

const ETA: usize = 4;
const TWO_TIMES_ETA_BIT_SIZE: usize = 4; // ⌊log_2(4)⌋ + 1
const TWO_TIMES_ETA_BIT_SIZE: usize = 4; // ⌊log_2(2 * 4)⌋ + 1

const BYTES_FOR_ERROR_RING_ELEMENT: usize =
(TWO_TIMES_ETA_BIT_SIZE * COEFFICIENTS_IN_RING_ELEMENT) / 8;

const VERIFICATION_KEY_SIZE: usize = SEED_FOR_A_SIZE
+ (COEFFICIENTS_IN_RING_ELEMENT
* ROWS_IN_A
* (FIELD_MODULUS_MINUS_ONE_BIT_LENGTH - BITS_IN_LOWER_PART_OF_T))
/ 8;

const SIGNING_KEY_SIZE: usize = (SEED_FOR_A_SIZE + SEED_FOR_SIGNING_SIZE + HASH_OF_PUBLIC_KEY_SIZE)
+ (COEFFICIENTS_IN_RING_ELEMENT
* (((ROWS_IN_A + COLUMNS_IN_A) * TWO_TIMES_ETA_BIT_SIZE)
+ (BITS_IN_LOWER_PART_OF_T * ROWS_IN_A)))
/ 8;
const SIGNING_KEY_SIZE: usize = SEED_FOR_A_SIZE
+ SEED_FOR_SIGNING_SIZE
+ BYTES_FOR_VERIFICATION_KEY_HASH
+ (ROWS_IN_A + COLUMNS_IN_A) * BYTES_FOR_ERROR_RING_ELEMENT
+ ROWS_IN_A * BYTES_FOR_RING_ELEMENT_OF_T0S;

pub struct MLDSA65KeyPair {
pub signing_key: [u8; SIGNING_KEY_SIZE],
Expand All @@ -30,6 +33,8 @@ pub fn generate_key_pair(randomness: [u8; 32]) -> MLDSA65KeyPair {
let (signing_key, verification_key) = crate::ml_dsa_generic::generate_key_pair::<
ROWS_IN_A,
COLUMNS_IN_A,
ETA,
BYTES_FOR_ERROR_RING_ELEMENT,
SIGNING_KEY_SIZE,
VERIFICATION_KEY_SIZE,
>(randomness);
Expand Down
46 changes: 46 additions & 0 deletions libcrux-ml-dsa/src/ml_dsa_87.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use crate::constants::*;

// ML-DSA-65 parameters

const ROWS_IN_A: usize = 8;
const COLUMNS_IN_A: usize = 7;

const ETA: usize = 2;
const TWO_TIMES_ETA_BIT_SIZE: usize = 3; // ⌊log_2(2 * 2)⌋ + 1

const BYTES_FOR_ERROR_RING_ELEMENT: usize =
(TWO_TIMES_ETA_BIT_SIZE * COEFFICIENTS_IN_RING_ELEMENT) / 8;

const VERIFICATION_KEY_SIZE: usize = SEED_FOR_A_SIZE
+ (COEFFICIENTS_IN_RING_ELEMENT
* ROWS_IN_A
* (FIELD_MODULUS_MINUS_ONE_BIT_LENGTH - BITS_IN_LOWER_PART_OF_T))
/ 8;

const SIGNING_KEY_SIZE: usize = SEED_FOR_A_SIZE
+ SEED_FOR_SIGNING_SIZE
+ BYTES_FOR_VERIFICATION_KEY_HASH
+ (ROWS_IN_A + COLUMNS_IN_A) * BYTES_FOR_ERROR_RING_ELEMENT
+ ROWS_IN_A * BYTES_FOR_RING_ELEMENT_OF_T0S;

pub struct MLDSA65KeyPair {
pub signing_key: [u8; SIGNING_KEY_SIZE],
pub verification_key: [u8; VERIFICATION_KEY_SIZE],
}

/// Generate an ML-DSA-65 Key Pair
pub fn generate_key_pair(randomness: [u8; 32]) -> MLDSA65KeyPair {
let (signing_key, verification_key) = crate::ml_dsa_generic::generate_key_pair::<
ROWS_IN_A,
COLUMNS_IN_A,
ETA,
BYTES_FOR_ERROR_RING_ELEMENT,
SIGNING_KEY_SIZE,
VERIFICATION_KEY_SIZE,
>(randomness);

MLDSA65KeyPair {
signing_key,
verification_key,
}
}
Loading
Loading