Skip to content

Commit

Permalink
Ensure that sapling::keys::TransmissionKey jubjub point is always in …
Browse files Browse the repository at this point in the history
…the prime order group
  • Loading branch information
dconnolly committed Dec 3, 2021
1 parent 224265d commit 04203f8
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 15 deletions.
7 changes: 5 additions & 2 deletions zebra-chain/src/sapling/address.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Shielded addresses.
use std::{
convert::TryFrom,
fmt,
io::{self, Read, Write},
};
Expand Down Expand Up @@ -112,7 +113,8 @@ impl std::str::FromStr for Address {
_ => Network::Testnet,
},
diversifier: keys::Diversifier::from(diversifier_bytes),
transmission_key: keys::TransmissionKey::from(transmission_key_bytes),
transmission_key: keys::TransmissionKey::try_from(transmission_key_bytes)
.unwrap(),
})
}
_ => Err(SerializationError::Parse("bech32 decoding error")),
Expand Down Expand Up @@ -178,7 +180,8 @@ mod tests {
keys::IncomingViewingKey::from((authorizing_key, nullifier_deriving_key));

let diversifier = keys::Diversifier::new(&mut OsRng);
let transmission_key = keys::TransmissionKey::from((incoming_viewing_key, diversifier));
let transmission_key = keys::TransmissionKey::try_from((incoming_viewing_key, diversifier))
.expect("should be a valid transmission key");

let _sapling_shielded_address = Address {
network: Network::Mainnet,
Expand Down
34 changes: 23 additions & 11 deletions zebra-chain/src/sapling/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,14 +837,20 @@ impl fmt::Debug for TransmissionKey {

impl Eq for TransmissionKey {}

impl From<[u8; 32]> for TransmissionKey {
/// Attempts to interpret a byte representation of an
/// affine point, failing if the element is not on
/// the curve or non-canonical.
impl TryFrom<[u8; 32]> for TransmissionKey {
type Error = &'static str;

/// Attempts to interpret a byte representation of an affine Jubjub point, failing if the
/// element is not on the curve, non-prime, the identity, or non-canonical.
///
/// https://github.com/zkcrypto/jubjub/blob/master/src/lib.rs#L411
fn from(bytes: [u8; 32]) -> Self {
Self(jubjub::AffinePoint::from_bytes(bytes).unwrap())
fn try_from(bytes: [u8; 32]) -> Result<Self, Self::Error> {
let affine_point = jubjub::AffinePoint::from_bytes(bytes).unwrap();
if affine_point.is_prime_order().into() {
Ok(Self(affine_point))
} else {
Err("derived an invalid Sapling transmission key")
}
}
}

Expand All @@ -854,16 +860,22 @@ impl From<TransmissionKey> for [u8; 32] {
}
}

impl From<(IncomingViewingKey, Diversifier)> for TransmissionKey {
impl TryFrom<(IncomingViewingKey, Diversifier)> for TransmissionKey {
type Error = &'static str;

/// This includes _KA^Sapling.DerivePublic(ivk, G_d)_, which is just a
/// scalar mult _\[ivk\]G_d_.
///
/// https://zips.z.cash/protocol/protocol.pdf#saplingkeycomponents
/// https://zips.z.cash/protocol/protocol.pdf#concretesaplingkeyagreement
fn from((ivk, d): (IncomingViewingKey, Diversifier)) -> Self {
Self(jubjub::AffinePoint::from(
diversify_hash(d.0).unwrap() * ivk.scalar,
))
fn try_from((ivk, d): (IncomingViewingKey, Diversifier)) -> Result<Self, Self::Error> {
let affine_point = jubjub::AffinePoint::from(diversify_hash(d.0).unwrap() * ivk.scalar);

if affine_point.is_prime_order().into() {
Ok(Self(affine_point))
} else {
Err("derived an invalid Sapling transmission key")
}
}
}

Expand Down
5 changes: 3 additions & 2 deletions zebra-chain/src/sapling/keys/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Arbitrary for TransmissionKey {

let diversifier = Diversifier::from(spending_key);

Self::from((incoming_viewing_key, diversifier))
Self::try_from((incoming_viewing_key, diversifier)).unwrap()
})
.boxed()
}
Expand Down Expand Up @@ -60,7 +60,8 @@ mod tests {
let diversifier = Diversifier::from(spending_key);
assert_eq!(diversifier, test_vector.default_d);

let transmission_key = TransmissionKey::from((incoming_viewing_key, diversifier));
let transmission_key = TransmissionKey::try_from((incoming_viewing_key, diversifier))
.expect("should be a valid transmission key");
assert_eq!(transmission_key, test_vector.default_pk_d);

let _full_viewing_key = FullViewingKey {
Expand Down

0 comments on commit 04203f8

Please sign in to comment.