diff --git a/src/algorithms.rs b/src/algorithms.rs index 8bdc1d9a..ab803fb0 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -1,6 +1,10 @@ //! Useful algorithms related to RSA. -use digest::{Digest, DynDigest, FixedOutputReset}; +pub(crate) mod mgf; +pub(crate) mod oaep; +pub(crate) mod pkcs1v15; +pub(crate) mod pss; + use num_bigint::traits::ModInverse; use num_bigint::{BigUint, RandPrime}; #[allow(unused_imports)] @@ -134,75 +138,3 @@ pub fn generate_multi_prime_key_with_exp( RsaPrivateKey::from_components(n_final, exp.clone(), d_final, primes) } - -/// Mask generation function. -/// -/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 -pub fn mgf1_xor(out: &mut [u8], digest: &mut dyn DynDigest, seed: &[u8]) { - let mut counter = [0u8; 4]; - let mut i = 0; - - const MAX_LEN: u64 = core::u32::MAX as u64 + 1; - assert!(out.len() as u64 <= MAX_LEN); - - while i < out.len() { - let mut digest_input = vec![0u8; seed.len() + 4]; - digest_input[0..seed.len()].copy_from_slice(seed); - digest_input[seed.len()..].copy_from_slice(&counter); - - digest.update(digest_input.as_slice()); - let digest_output = &*digest.finalize_reset(); - let mut j = 0; - loop { - if j >= digest_output.len() || i >= out.len() { - break; - } - - out[i] ^= digest_output[j]; - j += 1; - i += 1; - } - inc_counter(&mut counter); - } -} - -/// Mask generation function. -/// -/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 -pub fn mgf1_xor_digest(out: &mut [u8], digest: &mut D, seed: &[u8]) -where - D: Digest + FixedOutputReset, -{ - let mut counter = [0u8; 4]; - let mut i = 0; - - const MAX_LEN: u64 = core::u32::MAX as u64 + 1; - assert!(out.len() as u64 <= MAX_LEN); - - while i < out.len() { - Digest::update(digest, seed); - Digest::update(digest, counter); - - let digest_output = digest.finalize_reset(); - let mut j = 0; - loop { - if j >= digest_output.len() || i >= out.len() { - break; - } - - out[i] ^= digest_output[j]; - j += 1; - i += 1; - } - inc_counter(&mut counter); - } -} -fn inc_counter(counter: &mut [u8; 4]) { - for i in (0..4).rev() { - counter[i] = counter[i].wrapping_add(1); - if counter[i] != 0 { - // No overflow - return; - } - } -} diff --git a/src/algorithms/mgf.rs b/src/algorithms/mgf.rs new file mode 100644 index 00000000..aa8fb2a3 --- /dev/null +++ b/src/algorithms/mgf.rs @@ -0,0 +1,75 @@ +//! Mask generation function common to both PSS and OAEP padding + +use digest::{Digest, DynDigest, FixedOutputReset}; + +/// Mask generation function. +/// +/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 +pub fn mgf1_xor(out: &mut [u8], digest: &mut dyn DynDigest, seed: &[u8]) { + let mut counter = [0u8; 4]; + let mut i = 0; + + const MAX_LEN: u64 = core::u32::MAX as u64 + 1; + assert!(out.len() as u64 <= MAX_LEN); + + while i < out.len() { + let mut digest_input = vec![0u8; seed.len() + 4]; + digest_input[0..seed.len()].copy_from_slice(seed); + digest_input[seed.len()..].copy_from_slice(&counter); + + digest.update(digest_input.as_slice()); + let digest_output = &*digest.finalize_reset(); + let mut j = 0; + loop { + if j >= digest_output.len() || i >= out.len() { + break; + } + + out[i] ^= digest_output[j]; + j += 1; + i += 1; + } + inc_counter(&mut counter); + } +} + +/// Mask generation function. +/// +/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1 +pub fn mgf1_xor_digest(out: &mut [u8], digest: &mut D, seed: &[u8]) +where + D: Digest + FixedOutputReset, +{ + let mut counter = [0u8; 4]; + let mut i = 0; + + const MAX_LEN: u64 = core::u32::MAX as u64 + 1; + assert!(out.len() as u64 <= MAX_LEN); + + while i < out.len() { + Digest::update(digest, seed); + Digest::update(digest, counter); + + let digest_output = digest.finalize_reset(); + let mut j = 0; + loop { + if j >= digest_output.len() || i >= out.len() { + break; + } + + out[i] ^= digest_output[j]; + j += 1; + i += 1; + } + inc_counter(&mut counter); + } +} +fn inc_counter(counter: &mut [u8; 4]) { + for i in (0..4).rev() { + counter[i] = counter[i].wrapping_add(1); + if counter[i] != 0 { + // No overflow + return; + } + } +} diff --git a/src/algorithms/oaep.rs b/src/algorithms/oaep.rs new file mode 100644 index 00000000..989d29b3 --- /dev/null +++ b/src/algorithms/oaep.rs @@ -0,0 +1,246 @@ +//! Encryption and Decryption using [OAEP padding](https://datatracker.ietf.org/doc/html/rfc8017#section-7.1). +//! +use alloc::string::String; +use alloc::vec::Vec; + +use digest::{Digest, DynDigest, FixedOutputReset}; +use rand_core::CryptoRngCore; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; +use zeroize::Zeroizing; + +use super::mgf::{mgf1_xor, mgf1_xor_digest}; +use crate::errors::{Error, Result}; + +// 2**61 -1 (pow is not const yet) +// TODO: This is the maximum for SHA-1, unclear from the RFC what the values are for other hashing functions. +const MAX_LABEL_LEN: u64 = 2_305_843_009_213_693_951; + +#[inline] +fn encrypt_internal( + rng: &mut R, + msg: &[u8], + p_hash: &[u8], + h_size: usize, + k: usize, + mut mgf: MGF, +) -> Result>> { + if msg.len() + 2 * h_size + 2 > k { + return Err(Error::MessageTooLong); + } + + let mut em = Zeroizing::new(vec![0u8; k]); + + let (_, payload) = em.split_at_mut(1); + let (seed, db) = payload.split_at_mut(h_size); + rng.fill_bytes(seed); + + // Data block DB = pHash || PS || 01 || M + let db_len = k - h_size - 1; + + db[0..h_size].copy_from_slice(p_hash); + db[db_len - msg.len() - 1] = 1; + db[db_len - msg.len()..].copy_from_slice(msg); + + mgf(seed, db); + + Ok(Zeroizing::new(em.to_vec())) +} + +/// Encrypts the given message with RSA and the padding scheme from +/// [PKCS#1 OAEP]. +/// +/// The message must be no longer than the length of the public modulus minus +/// `2 + (2 * hash.size())`. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_encrypt( + rng: &mut R, + msg: &[u8], + digest: &mut dyn DynDigest, + mgf_digest: &mut dyn DynDigest, + label: Option, + k: usize, +) -> Result>> { + let h_size = digest.output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + digest.update(label.as_bytes()); + let p_hash = digest.finalize_reset(); + + encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| { + mgf1_xor(db, mgf_digest, seed); + mgf1_xor(seed, mgf_digest, db); + }) +} + +/// Encrypts the given message with RSA and the padding scheme from +/// [PKCS#1 OAEP]. +/// +/// The message must be no longer than the length of the public modulus minus +/// `2 + (2 * hash.size())`. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_encrypt_digest< + R: CryptoRngCore + ?Sized, + D: Digest, + MGD: Digest + FixedOutputReset, +>( + rng: &mut R, + msg: &[u8], + label: Option, + k: usize, +) -> Result>> { + let h_size = ::output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + let p_hash = D::digest(label.as_bytes()); + + encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| { + let mut mgf_digest = MGD::new(); + mgf1_xor_digest(db, &mut mgf_digest, seed); + mgf1_xor_digest(seed, &mut mgf_digest, db); + }) +} + +///Decrypts OAEP padding. +/// +/// Note that whether this function returns an error or not discloses secret +/// information. If an attacker can cause this function to run repeatedly and +/// learn whether each instance returned an error then they can decrypt and +/// forge signatures as if they had the private key. +/// +/// See `decrypt_session_key` for a way of solving this problem. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_decrypt( + em: &mut [u8], + digest: &mut dyn DynDigest, + mgf_digest: &mut dyn DynDigest, + label: Option, + k: usize, +) -> Result> { + let h_size = digest.output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::Decryption); + } + + digest.update(label.as_bytes()); + + let expected_p_hash = digest.finalize_reset(); + + let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| { + mgf1_xor(seed, mgf_digest, db); + mgf1_xor(db, mgf_digest, seed); + })?; + if res.is_none().into() { + return Err(Error::Decryption); + } + + let (out, index) = res.unwrap(); + + Ok(out[index as usize..].to_vec()) +} + +///Decrypts OAEP padding. +/// +/// Note that whether this function returns an error or not discloses secret +/// information. If an attacker can cause this function to run repeatedly and +/// learn whether each instance returned an error then they can decrypt and +/// forge signatures as if they had the private key. +/// +/// See `decrypt_session_key` for a way of solving this problem. +/// +/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 +#[inline] +pub(crate) fn oaep_decrypt_digest( + em: &mut [u8], + label: Option, + k: usize, +) -> Result> { + let h_size = ::output_size(); + + let label = label.unwrap_or_default(); + if label.len() as u64 > MAX_LABEL_LEN { + return Err(Error::LabelTooLong); + } + + let expected_p_hash = D::digest(label.as_bytes()); + + let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| { + let mut mgf_digest = MGD::new(); + mgf1_xor_digest(seed, &mut mgf_digest, db); + mgf1_xor_digest(db, &mut mgf_digest, seed); + })?; + if res.is_none().into() { + return Err(Error::Decryption); + } + + let (out, index) = res.unwrap(); + + Ok(out[index as usize..].to_vec()) +} + +/// Decrypts OAEP padding. It returns one or zero in valid that indicates whether the +/// plaintext was correctly structured. +#[inline] +fn decrypt_inner( + em: &mut [u8], + h_size: usize, + expected_p_hash: &[u8], + k: usize, + mut mgf: MGF, +) -> Result, u32)>> { + if k < 11 { + return Err(Error::Decryption); + } + + if k < h_size * 2 + 2 { + return Err(Error::Decryption); + } + + let first_byte_is_zero = em[0].ct_eq(&0u8); + + let (_, payload) = em.split_at_mut(1); + let (seed, db) = payload.split_at_mut(h_size); + + mgf(seed, db); + + let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash); + + // The remainder of the plaintext must be zero or more 0x00, followed + // by 0x01, followed by the message. + // looking_for_index: 1 if we are still looking for the 0x01 + // index: the offset of the first 0x01 byte + // zero_before_one: 1 if we saw a non-zero byte before the 1 + let mut looking_for_index = Choice::from(1u8); + let mut index = 0u32; + let mut nonzero_before_one = Choice::from(0u8); + + for (i, el) in db.iter().skip(h_size).enumerate() { + let equals0 = el.ct_eq(&0u8); + let equals1 = el.ct_eq(&1u8); + index.conditional_assign(&(i as u32), looking_for_index & equals1); + looking_for_index &= !equals1; + nonzero_before_one |= looking_for_index & !equals0; + } + + let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index; + + Ok(CtOption::new( + (em.to_vec(), index + 2 + (h_size * 2) as u32), + valid, + )) +} diff --git a/src/algorithms/pkcs1v15.rs b/src/algorithms/pkcs1v15.rs new file mode 100644 index 00000000..c1f0779a --- /dev/null +++ b/src/algorithms/pkcs1v15.rs @@ -0,0 +1,198 @@ +//! PKCS#1 v1.5 support as described in [RFC8017 § 8.2]. +//! +//! # Usage +//! +//! See [code example in the toplevel rustdoc](../index.html#pkcs1-v15-signatures). +//! +//! [RFC8017 § 8.2]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.2 + +use alloc::vec::Vec; +use digest::Digest; +use pkcs8::AssociatedOid; +use rand_core::CryptoRngCore; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; +use zeroize::Zeroizing; + +use crate::errors::{Error, Result}; + +/// Fills the provided slice with random values, which are guaranteed +/// to not be zero. +#[inline] +fn non_zero_random_bytes(rng: &mut R, data: &mut [u8]) { + rng.fill_bytes(data); + + for el in data { + if *el == 0u8 { + // TODO: break after a certain amount of time + while *el == 0u8 { + rng.fill_bytes(core::slice::from_mut(el)); + } + } + } +} + +/// Applied the padding scheme from PKCS#1 v1.5 for encryption. The message must be no longer than +/// the length of the public modulus minus 11 bytes. +pub(crate) fn pkcs1v15_encrypt_pad( + rng: &mut R, + msg: &[u8], + k: usize, +) -> Result>> +where + R: CryptoRngCore + ?Sized, +{ + if msg.len() > k - 11 { + return Err(Error::MessageTooLong); + } + + // EM = 0x00 || 0x02 || PS || 0x00 || M + let mut em = Zeroizing::new(vec![0u8; k]); + em[1] = 2; + non_zero_random_bytes(rng, &mut em[2..k - msg.len() - 1]); + em[k - msg.len() - 1] = 0; + em[k - msg.len()..].copy_from_slice(msg); + Ok(em) +} + +/// Removes the encryption padding scheme from PKCS#1 v1.5. +/// +/// Note that whether this function returns an error or not discloses secret +/// information. If an attacker can cause this function to run repeatedly and +/// learn whether each instance returned an error then they can decrypt and +/// forge signatures as if they had the private key. See +/// `decrypt_session_key` for a way of solving this problem. +#[inline] +pub(crate) fn pkcs1v15_encrypt_unpad(em: Vec, k: usize) -> Result> { + let (valid, out, index) = decrypt_inner(em, k)?; + if valid == 0 { + return Err(Error::Decryption); + } + + Ok(out[index as usize..].to_vec()) +} + +/// Removes the PKCS1v15 padding It returns one or zero in valid that indicates whether the +/// plaintext was correctly structured. In either case, the plaintext is +/// returned in em so that it may be read independently of whether it was valid +/// in order to maintain constant memory access patterns. If the plaintext was +/// valid then index contains the index of the original message in em. +#[inline] +fn decrypt_inner(em: Vec, k: usize) -> Result<(u8, Vec, u32)> { + if k < 11 { + return Err(Error::Decryption); + } + + let first_byte_is_zero = em[0].ct_eq(&0u8); + let second_byte_is_two = em[1].ct_eq(&2u8); + + // The remainder of the plaintext must be a string of non-zero random + // octets, followed by a 0, followed by the message. + // looking_for_index: 1 iff we are still looking for the zero. + // index: the offset of the first zero byte. + let mut looking_for_index = 1u8; + let mut index = 0u32; + + for (i, el) in em.iter().enumerate().skip(2) { + let equals0 = el.ct_eq(&0u8); + index.conditional_assign(&(i as u32), Choice::from(looking_for_index) & equals0); + looking_for_index.conditional_assign(&0u8, equals0); + } + + // The PS padding must be at least 8 bytes long, and it starts two + // bytes into em. + // TODO: WARNING: THIS MUST BE CONSTANT TIME CHECK: + // Ref: https://github.com/dalek-cryptography/subtle/issues/20 + // This is currently copy & paste from the constant time impl in + // go, but very likely not sufficient. + let valid_ps = Choice::from((((2i32 + 8i32 - index as i32 - 1i32) >> 31) & 1) as u8); + let valid = + first_byte_is_zero & second_byte_is_two & Choice::from(!looking_for_index & 1) & valid_ps; + index = u32::conditional_select(&0, &(index + 1), valid); + + Ok((valid.unwrap_u8(), em, index)) +} + +#[inline] +pub(crate) fn pkcs1v15_sign_pad(prefix: &[u8], hashed: &[u8], k: usize) -> Result> { + let hash_len = hashed.len(); + let t_len = prefix.len() + hashed.len(); + if k < t_len + 11 { + return Err(Error::MessageTooLong); + } + + // EM = 0x00 || 0x01 || PS || 0x00 || T + let mut em = vec![0xff; k]; + em[0] = 0; + em[1] = 1; + em[k - t_len - 1] = 0; + em[k - t_len..k - hash_len].copy_from_slice(prefix); + em[k - hash_len..k].copy_from_slice(hashed); + + Ok(em) +} + +#[inline] +pub(crate) fn pkcs1v15_sign_unpad(prefix: &[u8], hashed: &[u8], em: &[u8], k: usize) -> Result<()> { + let hash_len = hashed.len(); + let t_len = prefix.len() + hashed.len(); + if k < t_len + 11 { + return Err(Error::Verification); + } + + // EM = 0x00 || 0x01 || PS || 0x00 || T + let mut ok = em[0].ct_eq(&0u8); + ok &= em[1].ct_eq(&1u8); + ok &= em[k - hash_len..k].ct_eq(hashed); + ok &= em[k - t_len..k - hash_len].ct_eq(prefix); + ok &= em[k - t_len - 1].ct_eq(&0u8); + + for el in em.iter().skip(2).take(k - t_len - 3) { + ok &= el.ct_eq(&0xff) + } + + if ok.unwrap_u8() != 1 { + return Err(Error::Verification); + } + + Ok(()) +} + +/// prefix = 0x30 0x30 0x06 oid 0x05 0x00 0x04 +#[inline] +pub(crate) fn pkcs1v15_generate_prefix() -> Vec +where + D: Digest + AssociatedOid, +{ + let oid = D::OID.as_bytes(); + let oid_len = oid.len() as u8; + let digest_len = ::output_size() as u8; + let mut v = vec![ + 0x30, + oid_len + 8 + digest_len, + 0x30, + oid_len + 4, + 0x6, + oid_len, + ]; + v.extend_from_slice(oid); + v.extend_from_slice(&[0x05, 0x00, 0x04, digest_len]); + v +} + +#[cfg(test)] +mod tests { + use super::*; + use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; + + #[test] + fn test_non_zero_bytes() { + for _ in 0..10 { + let mut rng = ChaCha8Rng::from_seed([42; 32]); + let mut b = vec![0u8; 512]; + non_zero_random_bytes(&mut rng, &mut b); + for el in &b { + assert_ne!(*el, 0u8); + } + } + } +} diff --git a/src/algorithms/pss.rs b/src/algorithms/pss.rs new file mode 100644 index 00000000..5f59d19c --- /dev/null +++ b/src/algorithms/pss.rs @@ -0,0 +1,334 @@ +//! Support for the [Probabilistic Signature Scheme] (PSS) a.k.a. RSASSA-PSS. +//! +//! Designed by Mihir Bellare and Phillip Rogaway. Specified in [RFC8017 § 8.1]. +//! +//! # Usage +//! +//! See [code example in the toplevel rustdoc](../index.html#pss-signatures). +//! +//! [Probabilistic Signature Scheme]: https://en.wikipedia.org/wiki/Probabilistic_signature_scheme +//! [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1 + +use alloc::vec::Vec; +use digest::{Digest, DynDigest, FixedOutputReset}; +use subtle::{Choice, ConstantTimeEq}; + +use super::mgf::{mgf1_xor, mgf1_xor_digest}; +use crate::errors::{Error, Result}; + +pub(crate) fn emsa_pss_encode( + m_hash: &[u8], + em_bits: usize, + salt: &[u8], + hash: &mut dyn DynDigest, +) -> Result> { + // See [1], section 9.1.1 + let h_len = hash.output_size(); + let s_len = salt.len(); + let em_len = (em_bits + 7) / 8; + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "message too + // long" and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + if m_hash.len() != h_len { + return Err(Error::InputNotHashed); + } + + // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. + if em_len < h_len + s_len + 2 { + // TODO: Key size too small + return Err(Error::Internal); + } + + let mut em = vec![0; em_len]; + + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..(em_len - 1) - db.len()]; + + // 4. Generate a random octet string salt of length s_len; if s_len = 0, + // then salt is the empty string. + // + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; + // + // M' is an octet string of length 8 + h_len + s_len with eight + // initial zero octets. + // + // 6. Let H = Hash(M'), an octet string of length h_len. + let prefix = [0u8; 8]; + + hash.update(&prefix); + hash.update(m_hash); + hash.update(salt); + + let hashed = hash.finalize_reset(); + h.copy_from_slice(&hashed); + + // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 + // zero octets. The length of PS may be 0. + // + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + db[em_len - s_len - h_len - 2] = 0x01; + db[em_len - s_len - h_len - 1..].copy_from_slice(salt); + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 10. Let maskedDB = DB \xor dbMask. + mgf1_xor(db, hash, h); + + // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB to zero. + db[0] &= 0xFF >> (8 * em_len - em_bits); + + // 12. Let EM = maskedDB || H || 0xbc. + em[em_len - 1] = 0xBC; + + Ok(em) +} + +pub(crate) fn emsa_pss_encode_digest( + m_hash: &[u8], + em_bits: usize, + salt: &[u8], +) -> Result> +where + D: Digest + FixedOutputReset, +{ + // See [1], section 9.1.1 + let h_len = ::output_size(); + let s_len = salt.len(); + let em_len = (em_bits + 7) / 8; + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "message too + // long" and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + if m_hash.len() != h_len { + return Err(Error::InputNotHashed); + } + + // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. + if em_len < h_len + s_len + 2 { + // TODO: Key size too small + return Err(Error::Internal); + } + + let mut em = vec![0; em_len]; + + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..(em_len - 1) - db.len()]; + + // 4. Generate a random octet string salt of length s_len; if s_len = 0, + // then salt is the empty string. + // + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; + // + // M' is an octet string of length 8 + h_len + s_len with eight + // initial zero octets. + // + // 6. Let H = Hash(M'), an octet string of length h_len. + let prefix = [0u8; 8]; + + let mut hash = D::new(); + + Digest::update(&mut hash, prefix); + Digest::update(&mut hash, m_hash); + Digest::update(&mut hash, salt); + + let hashed = hash.finalize_reset(); + h.copy_from_slice(&hashed); + + // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 + // zero octets. The length of PS may be 0. + // + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + db[em_len - s_len - h_len - 2] = 0x01; + db[em_len - s_len - h_len - 1..].copy_from_slice(salt); + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 10. Let maskedDB = DB \xor dbMask. + mgf1_xor_digest(db, &mut hash, h); + + // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB to zero. + db[0] &= 0xFF >> (8 * em_len - em_bits); + + // 12. Let EM = maskedDB || H || 0xbc. + em[em_len - 1] = 0xBC; + + Ok(em) +} + +fn emsa_pss_verify_pre<'a>( + m_hash: &[u8], + em: &'a mut [u8], + em_bits: usize, + s_len: usize, + h_len: usize, +) -> Result<(&'a mut [u8], &'a mut [u8])> { + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" + // and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen + if m_hash.len() != h_len { + return Err(Error::Verification); + } + + // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. + let em_len = em.len(); //(em_bits + 7) / 8; + if em_len < h_len + s_len + 2 { + return Err(Error::Verification); + } + + // 4. If the rightmost octet of EM does not have hexadecimal value + // 0xbc, output "inconsistent" and stop. + if em[em.len() - 1] != 0xBC { + return Err(Error::Verification); + } + + // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and + // let H be the next hLen octets. + let (db, h) = em.split_at_mut(em_len - h_len - 1); + let h = &mut h[..h_len]; + + // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in + // maskedDB are not all equal to zero, output "inconsistent" and + // stop. + if db[0] + & (0xFF_u8 + .checked_shl(8 - (8 * em_len - em_bits) as u32) + .unwrap_or(0)) + != 0 + { + return Err(Error::Verification); + } + + Ok((db, h)) +} + +fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice { + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero + // or if the octet at position emLen - hLen - sLen - 1 (the leftmost + // position is "position 1") does not have hexadecimal value 0x01, + // output "inconsistent" and stop. + let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); + let valid: Choice = zeroes + .iter() + .fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00)); + + valid & rest[0].ct_eq(&0x01) +} + +pub(crate) fn emsa_pss_verify( + m_hash: &[u8], + em: &mut Vec, + s_len: usize, + hash: &mut dyn DynDigest, + key_bits: usize, +) -> Result<()> { + let em_bits = key_bits - 1; + let em_len = (em_bits + 7) / 8; + let key_len = (key_bits + 7) / 8; + let h_len = hash.output_size(); + + let em = &mut em[key_len - em_len..]; + + let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; + + // 7. Let dbMask = MGF(H, em_len - h_len - 1) + // + // 8. Let DB = maskedDB \xor dbMask + mgf1_xor(db, hash, &*h); + + // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB + // to zero. + db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); + + let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 13. Let H' = Hash(M'), an octet string of length hLen. + let prefix = [0u8; 8]; + + hash.update(&prefix[..]); + hash.update(m_hash); + hash.update(salt); + let h0 = hash.finalize_reset(); + + // 14. If H = H', output "consistent." Otherwise, output "inconsistent." + if (salt_valid & h0.ct_eq(h)).into() { + Ok(()) + } else { + Err(Error::Verification) + } +} + +pub(crate) fn emsa_pss_verify_digest( + m_hash: &[u8], + em: &mut Vec, + s_len: usize, + key_bits: usize, +) -> Result<()> +where + D: Digest + FixedOutputReset, +{ + let em_bits = key_bits - 1; + let em_len = (em_bits + 7) / 8; + let key_len = (key_bits + 7) / 8; + let h_len = ::output_size(); + + let em = &mut em[key_len - em_len..]; + + let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; + + let mut hash = D::new(); + + // 7. Let dbMask = MGF(H, em_len - h_len - 1) + // + // 8. Let DB = maskedDB \xor dbMask + mgf1_xor_digest::(db, &mut hash, &*h); + + // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB + // to zero. + db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); + + let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + + // 11. Let salt be the last s_len octets of DB. + let salt = &db[db.len() - s_len..]; + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 13. Let H' = Hash(M'), an octet string of length hLen. + let prefix = [0u8; 8]; + + Digest::update(&mut hash, &prefix[..]); + Digest::update(&mut hash, m_hash); + Digest::update(&mut hash, salt); + let h0 = hash.finalize_reset(); + + // 14. If H = H', output "consistent." Otherwise, output "inconsistent." + if (salt_valid & h0.ct_eq(h)).into() { + Ok(()) + } else { + Err(Error::Verification) + } +} diff --git a/src/oaep.rs b/src/oaep.rs index cc0c488d..07755cd9 100644 --- a/src/oaep.rs +++ b/src/oaep.rs @@ -5,7 +5,6 @@ //! See [code example in the toplevel rustdoc](../index.html#oaep-encryption). use alloc::boxed::Box; use alloc::string::{String, ToString}; -use alloc::vec; use alloc::vec::Vec; use core::fmt; use core::marker::PhantomData; @@ -13,10 +12,9 @@ use rand_core::CryptoRngCore; use digest::{Digest, DynDigest, FixedOutputReset}; use num_bigint::BigUint; -use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use zeroize::Zeroizing; -use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; +use crate::algorithms::oaep::*; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; @@ -24,10 +22,6 @@ use crate::key::{self, PublicKeyParts, RsaPrivateKey, RsaPublicKey}; use crate::padding::PaddingScheme; use crate::traits::{Decryptor, RandomizedDecryptor, RandomizedEncryptor}; -// 2**61 -1 (pow is not const yet) -// TODO: This is the maximum for SHA-1, unclear from the RFC what the values are for other hashing functions. -const MAX_LABEL_LEN: u64 = 2_305_843_009_213_693_951; - /// Encryption and Decryption using [OAEP padding](https://datatracker.ietf.org/doc/html/rfc8017#section-7.1). /// /// - `digest` is used to hash the label. The maximum possible plaintext length is `m = k - 2 * h_len - 2`, @@ -176,42 +170,6 @@ impl fmt::Debug for Oaep { } } -#[inline] -fn encrypt_internal( - rng: &mut R, - pub_key: &RsaPublicKey, - msg: &[u8], - p_hash: &[u8], - h_size: usize, - mut mgf: MGF, -) -> Result> { - key::check_public(pub_key)?; - - let k = pub_key.size(); - - if msg.len() + 2 * h_size + 2 > k { - return Err(Error::MessageTooLong); - } - - let mut em = Zeroizing::new(vec![0u8; k]); - - let (_, payload) = em.split_at_mut(1); - let (seed, db) = payload.split_at_mut(h_size); - rng.fill_bytes(seed); - - // Data block DB = pHash || PS || 01 || M - let db_len = k - h_size - 1; - - db[0..h_size].copy_from_slice(p_hash); - db[db_len - msg.len() - 1] = 1; - db[db_len - msg.len()..].copy_from_slice(msg); - - mgf(seed, db); - - let int = Zeroizing::new(BigUint::from_bytes_be(&em)); - uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) -} - /// Encrypts the given message with RSA and the padding scheme from /// [PKCS#1 OAEP]. /// @@ -228,20 +186,12 @@ fn encrypt( mgf_digest: &mut dyn DynDigest, label: Option, ) -> Result> { - let h_size = digest.output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::LabelTooLong); - } + key::check_public(pub_key)?; - digest.update(label.as_bytes()); - let p_hash = digest.finalize_reset(); + let em = oaep_encrypt(rng, msg, digest, mgf_digest, label, pub_key.size())?; - encrypt_internal(rng, pub_key, msg, &p_hash, h_size, |seed, db| { - mgf1_xor(db, mgf_digest, seed); - mgf1_xor(seed, mgf_digest, db); - }) + let int = Zeroizing::new(BigUint::from_bytes_be(&em)); + uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } /// Encrypts the given message with RSA and the padding scheme from @@ -251,27 +201,18 @@ fn encrypt( /// `2 + (2 * hash.size())`. /// /// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1 -#[inline] fn encrypt_digest( rng: &mut R, pub_key: &RsaPublicKey, msg: &[u8], label: Option, ) -> Result> { - let h_size = ::output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::LabelTooLong); - } + key::check_public(pub_key)?; - let p_hash = D::digest(label.as_bytes()); + let em = oaep_encrypt_digest::<_, D, MGD>(rng, msg, label, pub_key.size())?; - encrypt_internal(rng, pub_key, msg, &p_hash, h_size, |seed, db| { - let mut mgf_digest = MGD::new(); - mgf1_xor_digest(db, &mut mgf_digest, seed); - mgf1_xor_digest(seed, &mut mgf_digest, db); - }) + let int = Zeroizing::new(BigUint::from_bytes_be(&em)); + uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } /// Decrypts a plaintext using RSA and the padding scheme from [PKCS#1 OAEP]. @@ -297,35 +238,14 @@ fn decrypt( ) -> Result> { key::check_public(priv_key)?; - let h_size = digest.output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::Decryption); - } - - digest.update(label.as_bytes()); - - let expected_p_hash = digest.finalize_reset(); - - let res = decrypt_inner( - rng, - priv_key, - ciphertext, - h_size, - &expected_p_hash, - |seed, db| { - mgf1_xor(seed, mgf_digest, db); - mgf1_xor(db, mgf_digest, seed); - }, - )?; - if res.is_none().into() { + if ciphertext.len() != priv_key.size() { return Err(Error::Decryption); } - let (out, index) = res.unwrap(); + let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; + let mut em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - Ok(out[index as usize..].to_vec()) + oaep_decrypt(&mut em, digest, mgf_digest, label, priv_key.size()) } /// Decrypts a plaintext using RSA and the padding scheme from [PKCS#1 OAEP]. @@ -349,89 +269,14 @@ fn decrypt_digest Result> { key::check_public(priv_key)?; - let h_size = ::output_size(); - - let label = label.unwrap_or_default(); - if label.len() as u64 > MAX_LABEL_LEN { - return Err(Error::LabelTooLong); - } - - let expected_p_hash = D::digest(label.as_bytes()); - - let res = decrypt_inner( - rng, - priv_key, - ciphertext, - h_size, - &expected_p_hash, - |seed, db| { - let mut mgf_digest = MGD::new(); - mgf1_xor_digest(seed, &mut mgf_digest, db); - mgf1_xor_digest(db, &mut mgf_digest, seed); - }, - )?; - if res.is_none().into() { - return Err(Error::Decryption); - } - - let (out, index) = res.unwrap(); - - Ok(out[index as usize..].to_vec()) -} - -/// Decrypts ciphertext using `priv_key` and blinds the operation if -/// `rng` is given. It returns one or zero in valid that indicates whether the -/// plaintext was correctly structured. -#[inline] -fn decrypt_inner( - rng: Option<&mut R>, - priv_key: &RsaPrivateKey, - ciphertext: &[u8], - h_size: usize, - expected_p_hash: &[u8], - mut mgf: MGF, -) -> Result, u32)>> { - let k = priv_key.size(); - if k < 11 { - return Err(Error::Decryption); - } - - if ciphertext.len() != k || k < h_size * 2 + 2 { + if ciphertext.len() != priv_key.size() { return Err(Error::Decryption); } let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; let mut em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - let first_byte_is_zero = em[0].ct_eq(&0u8); - - let (_, payload) = em.split_at_mut(1); - let (seed, db) = payload.split_at_mut(h_size); - - mgf(seed, db); - - let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash); - - // The remainder of the plaintext must be zero or more 0x00, followed - // by 0x01, followed by the message. - // looking_for_index: 1 if we are still looking for the 0x01 - // index: the offset of the first 0x01 byte - // zero_before_one: 1 if we saw a non-zero byte before the 1 - let mut looking_for_index = Choice::from(1u8); - let mut index = 0u32; - let mut nonzero_before_one = Choice::from(0u8); - - for (i, el) in db.iter().skip(h_size).enumerate() { - let equals0 = el.ct_eq(&0u8); - let equals1 = el.ct_eq(&1u8); - index.conditional_assign(&(i as u32), looking_for_index & equals1); - looking_for_index &= !equals1; - nonzero_before_one |= looking_for_index & !equals0; - } - - let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index; - - Ok(CtOption::new((em, index + 2 + (h_size * 2) as u32), valid)) + oaep_decrypt_digest::(&mut em, label, priv_key.size()) } /// Encryption key for PKCS#1 v1.5 encryption as described in [RFC8017 § 7.1]. diff --git a/src/pkcs1v15.rs b/src/pkcs1v15.rs index 9973761d..d82a7303 100644 --- a/src/pkcs1v15.rs +++ b/src/pkcs1v15.rs @@ -24,9 +24,9 @@ use signature::{ DigestSigner, DigestVerifier, Keypair, RandomizedDigestSigner, RandomizedSigner, SignatureEncoding, Signer, Verifier, }; -use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use zeroize::Zeroizing; +use crate::algorithms::pkcs1v15::*; use crate::dummy_rng::DummyRng; use crate::errors::{Error, Result}; use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; @@ -80,7 +80,7 @@ impl Pkcs1v15Sign { { Self { hash_len: Some(::output_size()), - prefix: generate_prefix::().into_boxed_slice(), + prefix: pkcs1v15_generate_prefix::().into_boxed_slice(), } } @@ -193,25 +193,14 @@ impl Display for Signature { /// scheme from PKCS#1 v1.5. The message must be no longer than the /// length of the public modulus minus 11 bytes. #[inline] -pub(crate) fn encrypt( +fn encrypt( rng: &mut R, pub_key: &RsaPublicKey, msg: &[u8], ) -> Result> { key::check_public(pub_key)?; - let k = pub_key.size(); - if msg.len() > k - 11 { - return Err(Error::MessageTooLong); - } - - // EM = 0x00 || 0x02 || PS || 0x00 || M - let mut em = Zeroizing::new(vec![0u8; k]); - em[1] = 2; - non_zero_random_bytes(rng, &mut em[2..k - msg.len() - 1]); - em[k - msg.len() - 1] = 0; - em[k - msg.len()..].copy_from_slice(msg); - + let em = pkcs1v15_encrypt_pad(rng, msg, pub_key.size())?; let int = Zeroizing::new(BigUint::from_bytes_be(&em)); uint_to_be_pad(pub_key.raw_int_encryption_primitive(&int)?, pub_key.size()) } @@ -226,19 +215,17 @@ pub(crate) fn encrypt( /// forge signatures as if they had the private key. See /// `decrypt_session_key` for a way of solving this problem. #[inline] -pub(crate) fn decrypt( +fn decrypt( rng: Option<&mut R>, priv_key: &RsaPrivateKey, ciphertext: &[u8], ) -> Result> { key::check_public(priv_key)?; - let (valid, out, index) = decrypt_inner(rng, priv_key, ciphertext)?; - if valid == 0 { - return Err(Error::Decryption); - } + let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; + let em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - Ok(out[index as usize..].to_vec()) + pkcs1v15_encrypt_unpad(em, priv_key.size()) } /// Calculates the signature of hashed using @@ -255,26 +242,13 @@ pub(crate) fn decrypt( /// messages to signatures and identify the signed messages. As ever, /// signatures provide authenticity, not confidentiality. #[inline] -pub(crate) fn sign( +fn sign( rng: Option<&mut R>, priv_key: &RsaPrivateKey, prefix: &[u8], hashed: &[u8], ) -> Result> { - let hash_len = hashed.len(); - let t_len = prefix.len() + hashed.len(); - let k = priv_key.size(); - if k < t_len + 11 { - return Err(Error::MessageTooLong); - } - - // EM = 0x00 || 0x01 || PS || 0x00 || T - let mut em = vec![0xff; k]; - em[0] = 0; - em[1] = 1; - em[k - t_len - 1] = 0; - em[k - t_len..k - hash_len].copy_from_slice(prefix); - em[k - hash_len..k].copy_from_slice(hashed); + let em = pkcs1v15_sign_pad(prefix, hashed, priv_key.size())?; uint_to_zeroizing_be_pad( priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(&em))?, @@ -284,125 +258,10 @@ pub(crate) fn sign( /// Verifies an RSA PKCS#1 v1.5 signature. #[inline] -pub(crate) fn verify( - pub_key: &RsaPublicKey, - prefix: &[u8], - hashed: &[u8], - sig: &BigUint, -) -> Result<()> { - let hash_len = hashed.len(); - let t_len = prefix.len() + hashed.len(); - let k = pub_key.size(); - if k < t_len + 11 { - return Err(Error::Verification); - } - +fn verify(pub_key: &RsaPublicKey, prefix: &[u8], hashed: &[u8], sig: &BigUint) -> Result<()> { let em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, pub_key.size())?; - // EM = 0x00 || 0x01 || PS || 0x00 || T - let mut ok = em[0].ct_eq(&0u8); - ok &= em[1].ct_eq(&1u8); - ok &= em[k - hash_len..k].ct_eq(hashed); - ok &= em[k - t_len..k - hash_len].ct_eq(prefix); - ok &= em[k - t_len - 1].ct_eq(&0u8); - - for el in em.iter().skip(2).take(k - t_len - 3) { - ok &= el.ct_eq(&0xff) - } - - if ok.unwrap_u8() != 1 { - return Err(Error::Verification); - } - - Ok(()) -} - -/// prefix = 0x30 0x30 0x06 oid 0x05 0x00 0x04 -#[inline] -pub(crate) fn generate_prefix() -> Vec -where - D: Digest + AssociatedOid, -{ - let oid = D::OID.as_bytes(); - let oid_len = oid.len() as u8; - let digest_len = ::output_size() as u8; - let mut v = vec![ - 0x30, - oid_len + 8 + digest_len, - 0x30, - oid_len + 4, - 0x6, - oid_len, - ]; - v.extend_from_slice(oid); - v.extend_from_slice(&[0x05, 0x00, 0x04, digest_len]); - v -} - -/// Decrypts ciphertext using `priv_key` and blinds the operation if -/// `rng` is given. It returns one or zero in valid that indicates whether the -/// plaintext was correctly structured. In either case, the plaintext is -/// returned in em so that it may be read independently of whether it was valid -/// in order to maintain constant memory access patterns. If the plaintext was -/// valid then index contains the index of the original message in em. -#[inline] -fn decrypt_inner( - rng: Option<&mut R>, - priv_key: &RsaPrivateKey, - ciphertext: &[u8], -) -> Result<(u8, Vec, u32)> { - let k = priv_key.size(); - if k < 11 { - return Err(Error::Decryption); - } - - let em = priv_key.raw_int_decryption_primitive(rng, &BigUint::from_bytes_be(ciphertext))?; - let em = uint_to_zeroizing_be_pad(em, priv_key.size())?; - - let first_byte_is_zero = em[0].ct_eq(&0u8); - let second_byte_is_two = em[1].ct_eq(&2u8); - - // The remainder of the plaintext must be a string of non-zero random - // octets, followed by a 0, followed by the message. - // looking_for_index: 1 iff we are still looking for the zero. - // index: the offset of the first zero byte. - let mut looking_for_index = 1u8; - let mut index = 0u32; - - for (i, el) in em.iter().enumerate().skip(2) { - let equals0 = el.ct_eq(&0u8); - index.conditional_assign(&(i as u32), Choice::from(looking_for_index) & equals0); - looking_for_index.conditional_assign(&0u8, equals0); - } - - // The PS padding must be at least 8 bytes long, and it starts two - // bytes into em. - // TODO: WARNING: THIS MUST BE CONSTANT TIME CHECK: - // Ref: https://github.com/dalek-cryptography/subtle/issues/20 - // This is currently copy & paste from the constant time impl in - // go, but very likely not sufficient. - let valid_ps = Choice::from((((2i32 + 8i32 - index as i32 - 1i32) >> 31) & 1) as u8); - let valid = - first_byte_is_zero & second_byte_is_two & Choice::from(!looking_for_index & 1) & valid_ps; - index = u32::conditional_select(&0, &(index + 1), valid); - - Ok((valid.unwrap_u8(), em, index)) -} - -/// Fills the provided slice with random values, which are guaranteed -/// to not be zero. -#[inline] -fn non_zero_random_bytes(rng: &mut R, data: &mut [u8]) { - rng.fill_bytes(data); - - for el in data { - if *el == 0u8 { - // TODO: break after a certain amount of time - while *el == 0u8 { - rng.fill_bytes(core::slice::from_mut(el)); - } - } - } + pkcs1v15_sign_unpad(prefix, hashed, &em, pub_key.size()) } /// Signing key for PKCS#1 v1.5 signatures as described in [RFC8017 § 8.2]. @@ -496,7 +355,7 @@ where pub fn new(key: RsaPrivateKey) -> Self { Self { inner: key, - prefix: generate_prefix::(), + prefix: pkcs1v15_generate_prefix::(), phantom: Default::default(), } } @@ -505,7 +364,7 @@ where pub fn random(rng: &mut R, bit_size: usize) -> Result { Ok(Self { inner: RsaPrivateKey::new(rng, bit_size)?, - prefix: generate_prefix::(), + prefix: pkcs1v15_generate_prefix::(), phantom: Default::default(), }) } @@ -700,7 +559,7 @@ where pub fn new(key: RsaPublicKey) -> Self { Self { inner: key, - prefix: generate_prefix::(), + prefix: pkcs1v15_generate_prefix::(), phantom: Default::default(), } } @@ -907,18 +766,6 @@ mod tests { use crate::{PublicKeyParts, RsaPrivateKey, RsaPublicKey}; - #[test] - fn test_non_zero_bytes() { - for _ in 0..10 { - let mut rng = ChaCha8Rng::from_seed([42; 32]); - let mut b = vec![0u8; 512]; - non_zero_random_bytes(&mut rng, &mut b); - for el in &b { - assert_ne!(*el, 0u8); - } - } - } - fn get_private_key() -> RsaPrivateKey { // In order to generate new test vectors you'll need the PEM form of this key: // -----BEGIN RSA PRIVATE KEY----- diff --git a/src/pss.rs b/src/pss.rs index a6ce9f93..8713a8dd 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -30,9 +30,8 @@ use signature::{ hazmat::{PrehashVerifier, RandomizedPrehashSigner}, DigestVerifier, Keypair, RandomizedDigestSigner, RandomizedSigner, SignatureEncoding, Verifier, }; -use subtle::{Choice, ConstantTimeEq}; -use crate::algorithms::{mgf1_xor, mgf1_xor_digest}; +use crate::algorithms::pss::*; use crate::errors::{Error, Result}; use crate::internals::{uint_to_be_pad, uint_to_zeroizing_be_pad}; use crate::key::PublicKeyParts; @@ -196,18 +195,9 @@ pub(crate) fn verify( return Err(Error::Verification); } - let em_bits = pub_key.n().bits() - 1; - let em_len = (em_bits + 7) / 8; - let key_len = pub_key.size(); - let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, key_len)?; + let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, pub_key.size())?; - emsa_pss_verify( - hashed, - &mut em[key_len - em_len..], - em_bits, - salt_len, - digest, - ) + emsa_pss_verify(hashed, &mut em, salt_len, digest, pub_key.n().bits()) } pub(crate) fn verify_digest( @@ -224,12 +214,9 @@ where return Err(Error::Verification); } - let em_bits = pub_key.n().bits() - 1; - let em_len = (em_bits + 7) / 8; - let key_len = pub_key.size(); - let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, key_len)?; + let mut em = uint_to_be_pad(pub_key.raw_int_encryption_primitive(sig)?, pub_key.size())?; - emsa_pss_verify_digest::(hashed, &mut em[key_len - em_len..], em_bits, salt_len) + emsa_pss_verify_digest::(hashed, &mut em, salt_len, pub_key.n().bits()) } /// SignPSS calculates the signature of hashed using RSASSA-PSS. @@ -300,311 +287,6 @@ fn sign_pss_with_salt_digest Result> { - // See [1], section 9.1.1 - let h_len = hash.output_size(); - let s_len = salt.len(); - let em_len = (em_bits + 7) / 8; - - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "message too - // long" and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen. - if m_hash.len() != h_len { - return Err(Error::InputNotHashed); - } - - // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. - if em_len < h_len + s_len + 2 { - // TODO: Key size too small - return Err(Error::Internal); - } - - let mut em = vec![0; em_len]; - - let (db, h) = em.split_at_mut(em_len - h_len - 1); - let h = &mut h[..(em_len - 1) - db.len()]; - - // 4. Generate a random octet string salt of length s_len; if s_len = 0, - // then salt is the empty string. - // - // 5. Let - // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; - // - // M' is an octet string of length 8 + h_len + s_len with eight - // initial zero octets. - // - // 6. Let H = Hash(M'), an octet string of length h_len. - let prefix = [0u8; 8]; - - hash.update(&prefix); - hash.update(m_hash); - hash.update(salt); - - let hashed = hash.finalize_reset(); - h.copy_from_slice(&hashed); - - // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 - // zero octets. The length of PS may be 0. - // - // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length - // emLen - hLen - 1. - db[em_len - s_len - h_len - 2] = 0x01; - db[em_len - s_len - h_len - 1..].copy_from_slice(salt); - - // 9. Let dbMask = MGF(H, emLen - hLen - 1). - // - // 10. Let maskedDB = DB \xor dbMask. - mgf1_xor(db, hash, h); - - // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in - // maskedDB to zero. - db[0] &= 0xFF >> (8 * em_len - em_bits); - - // 12. Let EM = maskedDB || H || 0xbc. - em[em_len - 1] = 0xBC; - - Ok(em) -} - -fn emsa_pss_encode_digest(m_hash: &[u8], em_bits: usize, salt: &[u8]) -> Result> -where - D: Digest + FixedOutputReset, -{ - // See [1], section 9.1.1 - let h_len = ::output_size(); - let s_len = salt.len(); - let em_len = (em_bits + 7) / 8; - - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "message too - // long" and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen. - if m_hash.len() != h_len { - return Err(Error::InputNotHashed); - } - - // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop. - if em_len < h_len + s_len + 2 { - // TODO: Key size too small - return Err(Error::Internal); - } - - let mut em = vec![0; em_len]; - - let (db, h) = em.split_at_mut(em_len - h_len - 1); - let h = &mut h[..(em_len - 1) - db.len()]; - - // 4. Generate a random octet string salt of length s_len; if s_len = 0, - // then salt is the empty string. - // - // 5. Let - // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt; - // - // M' is an octet string of length 8 + h_len + s_len with eight - // initial zero octets. - // - // 6. Let H = Hash(M'), an octet string of length h_len. - let prefix = [0u8; 8]; - - let mut hash = D::new(); - - Digest::update(&mut hash, prefix); - Digest::update(&mut hash, m_hash); - Digest::update(&mut hash, salt); - - let hashed = hash.finalize_reset(); - h.copy_from_slice(&hashed); - - // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2 - // zero octets. The length of PS may be 0. - // - // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length - // emLen - hLen - 1. - db[em_len - s_len - h_len - 2] = 0x01; - db[em_len - s_len - h_len - 1..].copy_from_slice(salt); - - // 9. Let dbMask = MGF(H, emLen - hLen - 1). - // - // 10. Let maskedDB = DB \xor dbMask. - mgf1_xor_digest(db, &mut hash, h); - - // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in - // maskedDB to zero. - db[0] &= 0xFF >> (8 * em_len - em_bits); - - // 12. Let EM = maskedDB || H || 0xbc. - em[em_len - 1] = 0xBC; - - Ok(em) -} - -fn emsa_pss_verify_pre<'a>( - m_hash: &[u8], - em: &'a mut [u8], - em_bits: usize, - s_len: usize, - h_len: usize, -) -> Result<(&'a mut [u8], &'a mut [u8])> { - // 1. If the length of M is greater than the input limitation for the - // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" - // and stop. - // - // 2. Let mHash = Hash(M), an octet string of length hLen - if m_hash.len() != h_len { - return Err(Error::Verification); - } - - // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. - let em_len = em.len(); //(em_bits + 7) / 8; - if em_len < h_len + s_len + 2 { - return Err(Error::Verification); - } - - // 4. If the rightmost octet of EM does not have hexadecimal value - // 0xbc, output "inconsistent" and stop. - if em[em.len() - 1] != 0xBC { - return Err(Error::Verification); - } - - // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and - // let H be the next hLen octets. - let (db, h) = em.split_at_mut(em_len - h_len - 1); - let h = &mut h[..h_len]; - - // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in - // maskedDB are not all equal to zero, output "inconsistent" and - // stop. - if db[0] - & (0xFF_u8 - .checked_shl(8 - (8 * em_len - em_bits) as u32) - .unwrap_or(0)) - != 0 - { - return Err(Error::Verification); - } - - Ok((db, h)) -} - -fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice { - // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero - // or if the octet at position emLen - hLen - sLen - 1 (the leftmost - // position is "position 1") does not have hexadecimal value 0x01, - // output "inconsistent" and stop. - let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2); - let valid: Choice = zeroes - .iter() - .fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00)); - - valid & rest[0].ct_eq(&0x01) -} - -fn emsa_pss_verify( - m_hash: &[u8], - em: &mut [u8], - em_bits: usize, - s_len: usize, - hash: &mut dyn DynDigest, -) -> Result<()> { - let em_len = em.len(); //(em_bits + 7) / 8; - let h_len = hash.output_size(); - - let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; - - // 7. Let dbMask = MGF(H, em_len - h_len - 1) - // - // 8. Let DB = maskedDB \xor dbMask - mgf1_xor(db, hash, &*h); - - // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB - // to zero. - db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - - let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); - - // 11. Let salt be the last s_len octets of DB. - let salt = &db[db.len() - s_len..]; - - // 12. Let - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; - // M' is an octet string of length 8 + hLen + sLen with eight - // initial zero octets. - // - // 13. Let H' = Hash(M'), an octet string of length hLen. - let prefix = [0u8; 8]; - - hash.update(&prefix[..]); - hash.update(m_hash); - hash.update(salt); - let h0 = hash.finalize_reset(); - - // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if (salt_valid & h0.ct_eq(h)).into() { - Ok(()) - } else { - Err(Error::Verification) - } -} - -fn emsa_pss_verify_digest( - m_hash: &[u8], - em: &mut [u8], - em_bits: usize, - s_len: usize, -) -> Result<()> -where - D: Digest + FixedOutputReset, -{ - let em_len = em.len(); //(em_bits + 7) / 8; - let h_len = ::output_size(); - - let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?; - - let mut hash = D::new(); - - // 7. Let dbMask = MGF(H, em_len - h_len - 1) - // - // 8. Let DB = maskedDB \xor dbMask - mgf1_xor_digest::(db, &mut hash, &*h); - - // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB - // to zero. - db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - - let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); - - // 11. Let salt be the last s_len octets of DB. - let salt = &db[db.len() - s_len..]; - - // 12. Let - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; - // M' is an octet string of length 8 + hLen + sLen with eight - // initial zero octets. - // - // 13. Let H' = Hash(M'), an octet string of length hLen. - let prefix = [0u8; 8]; - - Digest::update(&mut hash, &prefix[..]); - Digest::update(&mut hash, m_hash); - Digest::update(&mut hash, salt); - let h0 = hash.finalize_reset(); - - // 14. If H = H', output "consistent." Otherwise, output "inconsistent." - if (salt_valid & h0.ct_eq(h)).into() { - Ok(()) - } else { - Err(Error::Verification) - } -} - /// Signing key for producing RSASSA-PSS signatures as described in /// [RFC8017 § 8.1]. ///