diff --git a/zebra-chain/src/sapling/address.rs b/zebra-chain/src/sapling/address.rs index 2bc1736b96b..7b5121ee933 100644 --- a/zebra-chain/src/sapling/address.rs +++ b/zebra-chain/src/sapling/address.rs @@ -1,6 +1,7 @@ //! Shielded addresses. use std::{ + convert::TryFrom, fmt, io::{self, Read, Write}, }; @@ -112,7 +113,8 @@ impl std::str::FromStr for Address { _ => Network::Testnet, }, diversifier: keys::Diversifier::from(diversifier_bytes), - transmission_key: keys::TransmissionKey::from(transmission_key_bytes), + transmission_key: keys::TransmissionKey::try_from(transmission_key_bytes) + .unwrap(), }) } _ => Err(SerializationError::Parse("bech32 decoding error")), @@ -178,7 +180,8 @@ mod tests { keys::IncomingViewingKey::from((authorizing_key, nullifier_deriving_key)); let diversifier = keys::Diversifier::new(&mut OsRng); - let transmission_key = keys::TransmissionKey::from((incoming_viewing_key, diversifier)); + let transmission_key = keys::TransmissionKey::try_from((incoming_viewing_key, diversifier)) + .expect("should be a valid transmission key"); let _sapling_shielded_address = Address { network: Network::Mainnet, diff --git a/zebra-chain/src/sapling/keys.rs b/zebra-chain/src/sapling/keys.rs index cb6806f10b6..617bc203d83 100644 --- a/zebra-chain/src/sapling/keys.rs +++ b/zebra-chain/src/sapling/keys.rs @@ -837,14 +837,20 @@ impl fmt::Debug for TransmissionKey { impl Eq for TransmissionKey {} -impl From<[u8; 32]> for TransmissionKey { - /// Attempts to interpret a byte representation of an - /// affine point, failing if the element is not on - /// the curve or non-canonical. +impl TryFrom<[u8; 32]> for TransmissionKey { + type Error = &'static str; + + /// Attempts to interpret a byte representation of an affine Jubjub point, failing if the + /// element is not on the curve, non-prime, the identity, or non-canonical. /// /// https://github.com/zkcrypto/jubjub/blob/master/src/lib.rs#L411 - fn from(bytes: [u8; 32]) -> Self { - Self(jubjub::AffinePoint::from_bytes(bytes).unwrap()) + fn try_from(bytes: [u8; 32]) -> Result { + let affine_point = jubjub::AffinePoint::from_bytes(bytes).unwrap(); + if affine_point.is_prime_order().into() { + Ok(Self(affine_point)) + } else { + Err("derived an invalid Sapling transmission key") + } } } @@ -854,16 +860,22 @@ impl From for [u8; 32] { } } -impl From<(IncomingViewingKey, Diversifier)> for TransmissionKey { +impl TryFrom<(IncomingViewingKey, Diversifier)> for TransmissionKey { + type Error = &'static str; + /// This includes _KA^Sapling.DerivePublic(ivk, G_d)_, which is just a /// scalar mult _\[ivk\]G_d_. /// /// https://zips.z.cash/protocol/protocol.pdf#saplingkeycomponents /// https://zips.z.cash/protocol/protocol.pdf#concretesaplingkeyagreement - fn from((ivk, d): (IncomingViewingKey, Diversifier)) -> Self { - Self(jubjub::AffinePoint::from( - diversify_hash(d.0).unwrap() * ivk.scalar, - )) + fn try_from((ivk, d): (IncomingViewingKey, Diversifier)) -> Result { + let affine_point = jubjub::AffinePoint::from(diversify_hash(d.0).unwrap() * ivk.scalar); + + if affine_point.is_prime_order().into() { + Ok(Self(affine_point)) + } else { + Err("derived an invalid Sapling transmission key") + } } } diff --git a/zebra-chain/src/sapling/keys/tests.rs b/zebra-chain/src/sapling/keys/tests.rs index 8b8f43228ff..9b4ba3170f6 100644 --- a/zebra-chain/src/sapling/keys/tests.rs +++ b/zebra-chain/src/sapling/keys/tests.rs @@ -22,7 +22,7 @@ impl Arbitrary for TransmissionKey { let diversifier = Diversifier::from(spending_key); - Self::from((incoming_viewing_key, diversifier)) + Self::try_from((incoming_viewing_key, diversifier)).unwrap() }) .boxed() } @@ -60,7 +60,8 @@ mod tests { let diversifier = Diversifier::from(spending_key); assert_eq!(diversifier, test_vector.default_d); - let transmission_key = TransmissionKey::from((incoming_viewing_key, diversifier)); + let transmission_key = TransmissionKey::try_from((incoming_viewing_key, diversifier)) + .expect("should be a valid transmission key"); assert_eq!(transmission_key, test_vector.default_pk_d); let _full_viewing_key = FullViewingKey {