Skip to content

Commit

Permalink
lms: convert to hybrid-array
Browse files Browse the repository at this point in the history
  • Loading branch information
baloo committed Sep 29, 2024
1 parent e87f60d commit e8bd790
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 125 deletions.
12 changes: 5 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ members = [

[profile.dev]
opt-level = 2

[patch.crates-io]
# https://github.com/RustCrypto/hybrid-array/pull/92
hybrid-array = { git = "https://github.com/baloo/hybrid-array.git", branch = "baloo/lms-sizes" }
6 changes: 3 additions & 3 deletions lms/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ categories = ["cryptography"]
keywords = ["crypto", "signature"]

[dependencies]
digest = "0.10.7"
generic-array = { version = "0.14.4", features = ["zeroize"] }
digest = "=0.11.0-pre.9"
hybrid-array = { version = "0.2.0-rc.10", features = ["extra-sizes", "zeroize"] }
rand = "0.8.5"
sha2 = "0.10.8"
sha2 = "=0.11.0-pre.4"
static_assertions = "1.1.0"
rand_core = "0.6.4"
signature = { version = "2.3.0-pre.0", features = ["digest", "std", "rand_core"] }
Expand Down
19 changes: 9 additions & 10 deletions lms/src/lms/modes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
use crate::ots::modes::LmsOtsMode;
use crate::types::Typecode;
use digest::Digest;
use digest::Output;
use generic_array::ArrayLength;
use hybrid_array::ArraySize;
use std::ops::Add;
use std::{
marker::PhantomData,
Expand All @@ -18,9 +17,9 @@ pub trait LmsMode: Typecode + Clone {
/// The underlying LM-OTS mode
type OtsMode: LmsOtsMode;
/// Length of the internal Merkle tree, computed as `2^(h+1)-1`
type TreeLen: ArrayLength<Output<Self::Hasher>>;
type TreeLen: ArraySize;
/// `h` as a type
type HLen: ArrayLength<Output<Self::Hasher>>;
type HLen: ArraySize;
/// The length of the hash function output as a type
const M: usize;
/// `h` as a [usize]
Expand All @@ -35,7 +34,7 @@ pub trait LmsMode: Typecode + Clone {
pub struct LmsModeInternal<
OtsMode: LmsOtsMode,
Hasher: Digest,
HLen: ArrayLength<Output<Hasher>>,
HLen: ArraySize,
const M: usize,
const H: usize,
const TC: u32,
Expand All @@ -48,7 +47,7 @@ pub struct LmsModeInternal<
impl<
OtsMode: LmsOtsMode,
Hasher: Digest,
TreeLen: ArrayLength<Output<Hasher>>,
TreeLen: ArraySize,
const M: usize,
const H: usize,
const TC: u32,
Expand All @@ -62,7 +61,7 @@ impl<
impl<
OtsMode: LmsOtsMode,
Hasher: Digest,
TreeLen: ArrayLength<Output<Hasher>>,
TreeLen: ArraySize,
const M: usize,
const H: usize,
const TC: u32,
Expand All @@ -73,7 +72,7 @@ impl<
impl<
OtsMode: LmsOtsMode,
Hasher: Digest,
HLen: ArrayLength<Output<Hasher>>,
HLen: ArraySize,
const M: usize,
const H: usize,
const TC: u32,
Expand All @@ -82,7 +81,7 @@ where
HLen: Add<typenum::B1>,
U1: Shl<<HLen as Add<B1>>::Output>,
Shleft<U1, <HLen as Add<B1>>::Output>: Sub<B1>,
Sub1<Shleft<U1, <HLen as Add<B1>>::Output>>: ArrayLength<Output<Hasher>>,
Sub1<Shleft<U1, <HLen as Add<B1>>::Output>>: ArraySize,
{
type OtsMode = OtsMode;
type Hasher = Hasher;
Expand All @@ -97,7 +96,7 @@ where
impl<
Hasher: Digest,
OtsMode: LmsOtsMode,
TreeLen: ArrayLength<Output<Hasher>>,
TreeLen: ArraySize,
const M: usize,
const H: usize,
const TC: u32,
Expand Down
41 changes: 22 additions & 19 deletions lms/src/lms/private.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use crate::ots::SigningKey as OtsPrivateKey;
use crate::types::{Identifier, Typecode};

use digest::{Digest, Output, OutputSizeUser};
use generic_array::{ArrayLength, GenericArray};
use hybrid_array::{Array, ArraySize};
use rand::{CryptoRng, Rng};
use signature::{Error, RandomizedSignerMut};

use core::array::TryFromSliceError;
use std::cmp::Ordering;
use std::ops::Add;
use typenum::{Sum, U28};
Expand All @@ -23,7 +24,7 @@ use typenum::{Sum, U28};
pub struct SigningKey<Mode: LmsMode> {
id: Identifier,
seed: Output<Mode::Hasher>, // Re-generate the leaf privkeys as-needed from a seed
auth_tree: GenericArray<Output<Mode::Hasher>, Mode::TreeLen>, // TODO: Decide whether/when to precompute
auth_tree: Array<Output<Mode::Hasher>, Mode::TreeLen>, // TODO: Decide whether/when to precompute
q: u32,
}

Expand All @@ -36,27 +37,26 @@ impl<Mode: LmsMode> SigningKey<Mode> {

let mut seed = Output::<Mode::Hasher>::default();
rng.fill_bytes(seed.as_mut());
Self::new_from_seed(id, seed)
Self::new_from_seed(id, seed).expect("size invariant violation")
}

// Returns a new LMS private key generated pseudorandomly from an identifier
// and secret seed. The seed must be equal to the hash output length of the
// LMS mode ([Mode::M])
//
// TODO: Return error rather than panic? Or just make the input a
// GenericArray? This is the algorithm from Appendix A of
// <https://datatracker.ietf.org/doc/html/rfc8554#appendix-A>
pub fn new_from_seed(id: Identifier, seed: impl AsRef<[u8]>) -> Self {
pub fn new_from_seed(
id: Identifier,
seed: impl AsRef<[u8]>,
) -> Result<Self, TryFromSliceError> {
//let seed = seed.as_ref();
let seed = GenericArray::clone_from_slice(seed.as_ref());
let seed = Array::try_from(seed.as_ref())?;
let mut sk = Self {
id,
seed,
auth_tree: GenericArray::default(),
auth_tree: Array::default(),
q: 0, // we set q = 0 when generating keys; it will change
};
sk.gen_pk_tree(); // TODO: Use lazy generation / MTT
sk
Ok(sk)
}

/// Generates a Merkle tree of OTS public key hashes, using the indexing scheme of RFC 8554 offset by 1
Expand Down Expand Up @@ -135,14 +135,14 @@ impl<Mode: LmsMode> RandomizedSignerMut<Signature<Mode>> for SigningKey<Mode> {

/// Converts a [PrivateKey] into its byte representation
impl<Mode: LmsMode> From<SigningKey<Mode>>
for GenericArray<u8, Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U28>>
for Array<u8, Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U28>>
where
<Mode::Hasher as OutputSizeUser>::OutputSize: Add<U28>,
Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U28>: ArrayLength<u8>,
Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U28>: ArraySize,
{
fn from(pk: SigningKey<Mode>) -> Self {
// Return u32(type) || u32(otstype) || u32(q) || id || seed
GenericArray::from_exact_iter(
Array::try_from_iter(
std::iter::empty()
.chain(Mode::TYPECODE.to_be_bytes())
.chain(Mode::OtsMode::TYPECODE.to_be_bytes())
Expand Down Expand Up @@ -188,8 +188,8 @@ impl<'a, Mode: LmsMode> TryFrom<&'a [u8]> for SigningKey<Mode> {
let mut key = Self {
q: u32::from_be_bytes(q.try_into().expect("ok")),
id: id.try_into().expect("ok"),
seed: GenericArray::clone_from_slice(seed),
auth_tree: GenericArray::default(),
seed: Array::try_from(seed).expect("ok"),
auth_tree: Array::default(),
};
key.gen_pk_tree();
Ok(key)
Expand All @@ -215,7 +215,8 @@ mod tests {
let id = hex!("d08fabd4a2091ff0a8cb4ed834e74534");
let expected_k = hex!("32a58885cd9ba0431235466bff9651c6c92124404d45fa53cf161c28f1ad5a8e");

let lms_priv = SigningKey::<LmsSha256M32H10<LmsOtsSha256N32W4>>::new_from_seed(id, seed);
let lms_priv =
SigningKey::<LmsSha256M32H10<LmsOtsSha256N32W4>>::new_from_seed(id, seed).unwrap();
let lms_pub = lms_priv.public();
assert_eq!(lms_pub.k(), expected_k);
assert_eq!(lms_pub.id(), &id);
Expand All @@ -229,7 +230,8 @@ mod tests {
let id = hex!("215f83b7ccb9acbcd08db97b0d04dc2b");
let expected_k = hex!("a1cd035833e0e90059603f26e07ad2aad152338e7a5e5984bcd5f7bb4eba40b7");

let lms_priv = SigningKey::<LmsSha256M32H5<LmsOtsSha256N32W8>>::new_from_seed(id, seed);
let lms_priv =
SigningKey::<LmsSha256M32H5<LmsOtsSha256N32W8>>::new_from_seed(id, seed).unwrap();
let lms_pub = lms_priv.public();
assert_eq!(lms_pub.k(), expected_k);
assert_eq!(lms_pub.id(), &id);
Expand Down Expand Up @@ -339,7 +341,8 @@ mod tests {
let id = hex!("215f83b7ccb9acbcd08db97b0d04dc2b");
let _expected_k = hex!("a1cd035833e0e90059603f26e07ad2aad152338e7a5e5984bcd5f7bb4eba40b7");

let mut lms_priv = SigningKey::<LmsSha256M32H5<LmsOtsSha256N32W8>>::new_from_seed(id, seed);
let mut lms_priv =
SigningKey::<LmsSha256M32H5<LmsOtsSha256N32W8>>::new_from_seed(id, seed).unwrap();
lms_priv.q = 4;
let _lms_pub = lms_priv.public();

Expand Down
27 changes: 13 additions & 14 deletions lms/src/lms/public.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::lms::Signature;
use crate::types::Typecode;
use crate::{constants::D_INTR, lms::LmsMode};
use digest::{Digest, OutputSizeUser};
use generic_array::{ArrayLength, GenericArray};
use hybrid_array::{Array, ArraySize};
use signature::{Error, Verifier};
use typenum::{Sum, U24};

Expand Down Expand Up @@ -99,14 +99,14 @@ impl<Mode: LmsMode> Verifier<Signature<Mode>> for VerifyingKey<Mode> {

/// Converts a [`VerifyingKey`] into its byte representation
impl<Mode: LmsMode> From<VerifyingKey<Mode>>
for GenericArray<u8, Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U24>>
for Array<u8, Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U24>>
where
<Mode::Hasher as OutputSizeUser>::OutputSize: Add<U24>,
Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U24>: ArrayLength<u8>,
Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U24>: ArraySize,
{
fn from(pk: VerifyingKey<Mode>) -> Self {
// Return u32(type) || u32(otstype) || id || k
GenericArray::from_exact_iter(
Array::try_from_iter(
std::iter::empty()
.chain(Mode::TYPECODE.to_be_bytes())
.chain(Mode::OtsMode::TYPECODE.to_be_bytes())
Expand Down Expand Up @@ -149,7 +149,7 @@ impl<'a, Mode: LmsMode> TryFrom<&'a [u8]> for VerifyingKey<Mode> {

Ok(Self {
id: id.try_into().unwrap(),
k: GenericArray::clone_from_slice(k),
k: Array::try_from(k).expect("size invariant violation"),
})
}
}
Expand All @@ -165,8 +165,8 @@ mod tests {
ots::{LmsOtsSha256N32W4, LmsOtsSha256N32W8},
};
use digest::OutputSizeUser;
use generic_array::{ArrayLength, GenericArray};
use hex_literal::hex;
use hybrid_array::{Array, ArraySize};
use typenum::{Sum, U24};

// RFC 8554 Appendix F. Test Case 1
Expand Down Expand Up @@ -222,7 +222,7 @@ mod tests {
#[test]
fn test_kat1_round_trip() {
let pk = VerifyingKey::<LmsSha256M32H5<LmsOtsSha256N32W8>>::try_from(&KAT1[..]).unwrap();
let pk_serialized: GenericArray<u8, _> = pk.clone().into();
let pk_serialized: Array<u8, _> = pk.clone().into();
let bytes = pk_serialized.as_slice();
assert_eq!(bytes, &KAT1[..]);
}
Expand All @@ -244,9 +244,10 @@ mod tests {
c92124404d45fa53cf161c28f1ad5a8e
"
);
let lms_priv = SigningKey::<LmsSha256M32H10<LmsOtsSha256N32W4>>::new_from_seed(id, seed);
let lms_priv =
SigningKey::<LmsSha256M32H10<LmsOtsSha256N32W4>>::new_from_seed(id, seed).unwrap();
let lms_pub = lms_priv.public();
let lms_pub_serialized: GenericArray<u8, _> = lms_pub.into();
let lms_pub_serialized: Array<u8, _> = lms_pub.into();
let bytes = lms_pub_serialized.as_slice();
assert_eq!(bytes, &expected_pubkey[..]);
}
Expand All @@ -255,15 +256,13 @@ mod tests {
where
VerifyingKey<Mode>: std::fmt::Debug,
<Mode::Hasher as OutputSizeUser>::OutputSize: Add<U24>,
Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U24>: ArrayLength<u8>,
Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U24>: ArraySize,
{
let rng = rand::thread_rng();
let lms_priv = SigningKey::<Mode>::new(rng);
let lms_pub = lms_priv.public();
let lms_pub_serialized: GenericArray<
u8,
Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U24>,
> = lms_pub.clone().into();
let lms_pub_serialized: Array<u8, Sum<<Mode::Hasher as OutputSizeUser>::OutputSize, U24>> =
lms_pub.clone().into();
let bytes = lms_pub_serialized.as_slice();
let lms_pub_deserialized = VerifyingKey::<Mode>::try_from(bytes).unwrap();
assert_eq!(lms_pub, lms_pub_deserialized);
Expand Down
Loading

0 comments on commit e8bd790

Please sign in to comment.