diff --git a/src/key.rs b/src/key.rs index 8166673e2..8a0dfaf69 100644 --- a/src/key.rs +++ b/src/key.rs @@ -215,54 +215,14 @@ impl ::serde::Serialize for SecretKey { impl<'de> ::serde::Deserialize<'de> for SecretKey { fn deserialize>(d: D) -> Result { if d.is_human_readable() { - struct HexVisitor; - - impl<'de> ::serde::de::Visitor<'de> for HexVisitor { - type Value = SecretKey; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a hex string representing 32 byte SecretKey") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: ::serde::de::Error, - { - if let Ok(hex) = str::from_utf8(v) { - str::FromStr::from_str(hex).map_err(E::custom) - } else { - Err(E::invalid_value(::serde::de::Unexpected::Bytes(v), &self)) - } - } - - fn visit_str(self, v: &str) -> Result - where - E: ::serde::de::Error, - { - str::FromStr::from_str(v).map_err(E::custom) - } - } - - d.deserialize_str(HexVisitor) + d.deserialize_str(crate::serde_util::HexVisitor::new( + "a hex string representing 32 byte SecretKey" + )) } else { - struct BytesVisitor; - - impl<'de> ::serde::de::Visitor<'de> for BytesVisitor { - type Value = SecretKey; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("raw 32 bytes SecretKey") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: ::serde::de::Error, - { - SecretKey::from_slice(v).map_err(E::custom) - } - } - - d.deserialize_bytes(BytesVisitor) + d.deserialize_bytes(crate::serde_util::BytesVisitor::new( + "raw 32 bytes SecretKey", + SecretKey::from_slice + )) } } } @@ -459,53 +419,14 @@ impl ::serde::Serialize for PublicKey { impl<'de> ::serde::Deserialize<'de> for PublicKey { fn deserialize>(d: D) -> Result { if d.is_human_readable() { - struct HexVisitor; - - impl<'de> ::serde::de::Visitor<'de> for HexVisitor { - type Value = PublicKey; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("an ASCII hex string") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: ::serde::de::Error, - { - if let Ok(hex) = str::from_utf8(v) { - str::FromStr::from_str(hex).map_err(E::custom) - } else { - Err(E::invalid_value(::serde::de::Unexpected::Bytes(v), &self)) - } - } - - fn visit_str(self, v: &str) -> Result - where - E: ::serde::de::Error, - { - str::FromStr::from_str(v).map_err(E::custom) - } - } - d.deserialize_str(HexVisitor) + d.deserialize_str(crate::serde_util::HexVisitor::new( + "an ASCII hex string representing a public key" + )) } else { - struct BytesVisitor; - - impl<'de> ::serde::de::Visitor<'de> for BytesVisitor { - type Value = PublicKey; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a bytestring") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: ::serde::de::Error, - { - PublicKey::from_slice(v).map_err(E::custom) - } - } - - d.deserialize_bytes(BytesVisitor) + d.deserialize_bytes(crate::serde_util::BytesVisitor::new( + "a bytestring representing a public key", + PublicKey::from_slice + )) } } } diff --git a/src/lib.rs b/src/lib.rs index 44a1b3e6d..f37c33a55 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -165,6 +165,8 @@ pub mod ecdh; pub mod key; #[cfg(feature = "recovery")] pub mod recovery; +#[cfg(feature = "serde")] +mod serde_util; pub use key::SecretKey; pub use key::PublicKey; @@ -434,54 +436,14 @@ impl ::serde::Serialize for Signature { impl<'de> ::serde::Deserialize<'de> for Signature { fn deserialize>(d: D) -> Result { if d.is_human_readable() { - struct HexVisitor; - - impl<'de> ::serde::de::Visitor<'de> for HexVisitor { - type Value = Signature; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a hex string representing a DER encoded Signature") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: ::serde::de::Error, - { - if let Ok(hex) = str::from_utf8(v) { - str::FromStr::from_str(hex).map_err(E::custom) - } else { - Err(E::invalid_value(::serde::de::Unexpected::Bytes(v), &self)) - } - } - - fn visit_str(self, v: &str) -> Result - where - E: ::serde::de::Error, - { - str::FromStr::from_str(v).map_err(E::custom) - } - } - - d.deserialize_str(HexVisitor) + d.deserialize_str(crate::serde_util::HexVisitor::new( + "a hex string representing a DER encoded Signature" + )) } else { - struct BytesVisitor; - - impl<'de> ::serde::de::Visitor<'de> for BytesVisitor { - type Value = Signature; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("raw byte stream, that represents a DER encoded Signature") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: ::serde::de::Error, - { - Signature::from_der(v).map_err(E::custom) - } - } - - d.deserialize_bytes(BytesVisitor) + d.deserialize_bytes(crate::serde_util::BytesVisitor::new( + "raw byte stream, that represents a DER encoded Signature", + Signature::from_der + )) } } } diff --git a/src/serde_util.rs b/src/serde_util.rs new file mode 100644 index 000000000..2fbae46a4 --- /dev/null +++ b/src/serde_util.rs @@ -0,0 +1,78 @@ +use core::fmt; +use core::marker::PhantomData; +use core::str; + +pub struct HexVisitor { + expectation: &'static str, + _pd: PhantomData, +} + +impl HexVisitor { + pub fn new(expectation: &'static str) -> Self { + HexVisitor { + expectation, + _pd: PhantomData + } + } +} + +impl<'de, T: str::FromStr> ::serde::de::Visitor<'de> for HexVisitor + where ::Err: fmt::Display +{ + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(self.expectation) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: ::serde::de::Error, + { + if let Ok(hex) = str::from_utf8(v) { + str::FromStr::from_str(hex).map_err(E::custom) + } else { + Err(E::invalid_value(::serde::de::Unexpected::Bytes(v), &self)) + } + } + + fn visit_str(self, v: &str) -> Result + where + E: ::serde::de::Error, + { + str::FromStr::from_str(v).map_err(E::custom) + } +} + +pub struct BytesVisitor { + expectation: &'static str, + parse_fn: F, +} + +impl BytesVisitor { + pub fn new(expectation: &'static str, parse_fn: F) -> Self { + BytesVisitor { + expectation, + parse_fn + } + } +} + +impl<'de, F, T, Err> ::serde::de::Visitor<'de> for BytesVisitor + where F: Fn(&[u8]) -> Result, + Err: fmt::Display +{ + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(self.expectation) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: ::serde::de::Error, + { + (self.parse_fn)(v).map_err(E::custom) + } +} +