Skip to content

Commit

Permalink
Move serde-related stuff to dedicated module
Browse files Browse the repository at this point in the history
  • Loading branch information
Denis committed Jul 26, 2021
1 parent fc6e894 commit 5abd70a
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 329 deletions.
1 change: 1 addition & 0 deletions src/elliptic/curves/wrappers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod generator;
mod point;
mod point_ref;
mod scalar;
mod serde_support;

pub use self::{
encoded_point::EncodedPoint, encoded_scalar::EncodedScalar, generator::Generator, point::Point,
Expand Down
179 changes: 0 additions & 179 deletions src/elliptic/curves/wrappers/point.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
use std::marker::PhantomData;
use std::{fmt, iter};

use serde::de::{Deserializer, Error, MapAccess, Visitor};
use serde::ser::Serializer;
use serde::{Deserialize, Serialize};

use crate::elliptic::curves::traits::*;
use crate::BigInt;

Expand Down Expand Up @@ -283,177 +278,3 @@ impl<'p, E: Curve> iter::Sum<PointRef<'p, E>> for Point<E> {
iter.fold(Point::zero(), |acc, p| acc + p)
}
}

impl<E: Curve> Serialize for Point<E> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.as_point().serialize(serializer)
}
}

impl<'de, E: Curve> Deserialize<'de> for Point<E> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct PointVisitor<E: Curve>(PhantomData<E>);

impl<'de, E: Curve> Visitor<'de> for PointVisitor<E> {
type Value = Point<E>;

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "point of {} curve", E::CURVE_NAME)
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut curve_name: Option<CurveNameGuard<E>> = None;
let mut point: Option<PointFromBytes<E>> = None;

while let Some(key) = map.next_key()? {
match key {
Field::Curve => {
if curve_name.is_some() {
return Err(A::Error::duplicate_field("curve_name"));
}
curve_name = Some(map.next_value()?)
}
Field::Point => {
if point.is_some() {
return Err(A::Error::duplicate_field("point"));
}
point = Some(map.next_value()?)
}
}
}
let _curve_name =
curve_name.ok_or_else(|| A::Error::missing_field("curve_name"))?;
let point = point.ok_or_else(|| A::Error::missing_field("point"))?;
Ok(point.0)
}
}

deserializer.deserialize_struct("Point", &["curve", "point"], PointVisitor(PhantomData))
}
}

#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
Curve,
Point,
}

/// Efficient guard for asserting that deserialized `&str`/`String` is `E::CURVE_NAME`
pub(super) struct CurveNameGuard<E: Curve>(PhantomData<E>);

impl<'de, E: Curve> Deserialize<'de> for CurveNameGuard<E> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct CurveNameVisitor<E: Curve>(PhantomData<E>);

impl<'de, E: Curve> Visitor<'de> for CurveNameVisitor<E> {
type Value = ();

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "curve name (constrained to be '{}')", E::CURVE_NAME)
}

fn visit_str<Err>(self, v: &str) -> Result<Self::Value, Err>
where
Err: Error,
{
if v == E::CURVE_NAME {
Ok(())
} else {
Err(Err::invalid_value(
serde::de::Unexpected::Str(v),
&E::CURVE_NAME,
))
}
}
}

deserializer
.deserialize_str(CurveNameVisitor(PhantomData::<E>))
.map(|_| CurveNameGuard(PhantomData))
}
}

struct PointFromBytes<E: Curve>(Point<E>);

impl<'de, E: Curve> Deserialize<'de> for PointFromBytes<E> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct PointBytesVisitor<E: Curve>(PhantomData<E>);

impl<'de, E: Curve> Visitor<'de> for PointBytesVisitor<E> {
type Value = Point<E>;

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "point of {} curve", E::CURVE_NAME)
}

fn visit_bytes<Err>(self, v: &[u8]) -> Result<Self::Value, Err>
where
Err: Error,
{
Point::from_bytes(v).map_err(|e| Err::custom(format!("invalid point: {}", e)))
}
}

deserializer
.deserialize_bytes(PointBytesVisitor(PhantomData))
.map(PointFromBytes)
}
}

#[cfg(test)]
mod serde_tests {
use serde_test::{assert_tokens, Token::*};

use crate::elliptic::curves::*;

#[test]
fn test_serde_point() {
fn generic<E: Curve>(point: Point<E>) {
let bytes = point.to_bytes(true).to_vec();
let tokens = vec![
Struct {
name: "Point",
len: 2,
},
Str("curve"),
Str(E::CURVE_NAME),
Str("point"),
Bytes(bytes.leak()),
StructEnd,
];
assert_tokens(&point, &tokens);
}

// Test **zero points** (de)serializing
generic::<Secp256k1>(Point::zero());
generic::<Secp256r1>(Point::zero());
generic::<Ed25519>(Point::zero());
generic::<Ristretto>(Point::zero());
generic::<Bls12_381_1>(Point::zero());
generic::<Bls12_381_2>(Point::zero());

// Test **random point** (de)serializing
generic::<Secp256k1>(Point::generator() * Scalar::random());
generic::<Secp256r1>(Point::generator() * Scalar::random());
generic::<Ed25519>(Point::generator() * Scalar::random());
generic::<Ristretto>(Point::generator() * Scalar::random());
generic::<Bls12_381_1>(Point::generator() * Scalar::random());
generic::<Bls12_381_2>(Point::generator() * Scalar::random());
}
}
150 changes: 0 additions & 150 deletions src/elliptic/curves/wrappers/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
use std::marker::PhantomData;
use std::{fmt, iter};

use serde::de::{Error, MapAccess, Visitor};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize};
use serde::{Deserializer, Serializer};

use crate::elliptic::curves::traits::{Curve, ECScalar};
use crate::BigInt;

use super::point::CurveNameGuard;
use crate::elliptic::curves::wrappers::encoded_scalar::EncodedScalar;
use crate::elliptic::curves::{DeserializationError, ZeroScalarError};

Expand Down Expand Up @@ -213,146 +206,3 @@ impl<'s, E: Curve> iter::Product<&'s Scalar<E>> for Scalar<E> {
iter.fold(Scalar::from(1), |acc, s| acc * s)
}
}

impl<E: Curve> Serialize for Scalar<E> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut s = serializer.serialize_struct("Scalar", 2)?;
s.serialize_field("curve", E::CURVE_NAME)?;
s.serialize_field(
"scalar",
// Serializes bytes efficiently
serde_bytes::Bytes::new(&self.to_bytes()),
)?;
s.end()
}
}

impl<'de, E: Curve> Deserialize<'de> for Scalar<E> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ScalarVisitor<E: Curve>(PhantomData<E>);

impl<'de, E: Curve> Visitor<'de> for ScalarVisitor<E> {
type Value = Scalar<E>;

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "scalar of {} curve", E::CURVE_NAME)
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut curve_name: Option<CurveNameGuard<E>> = None;
let mut scalar: Option<ScalarFromBytes<E>> = None;

while let Some(key) = map.next_key()? {
match key {
Field::Curve => {
if curve_name.is_some() {
return Err(A::Error::duplicate_field("curve_name"));
}
curve_name = Some(map.next_value()?)
}
Field::Scalar => {
if scalar.is_some() {
return Err(A::Error::duplicate_field("scalar"));
}
scalar = Some(map.next_value()?)
}
}
}
let _curve_name =
curve_name.ok_or_else(|| A::Error::missing_field("curve_name"))?;
let scalar = scalar.ok_or_else(|| A::Error::missing_field("scalar"))?;
Ok(scalar.0)
}
}

deserializer.deserialize_struct("Scalar", &["curve", "scalar"], ScalarVisitor(PhantomData))
}
}

struct ScalarFromBytes<E: Curve>(Scalar<E>);

impl<'de, E: Curve> Deserialize<'de> for ScalarFromBytes<E> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ScalarBytesVisitor<E: Curve>(PhantomData<E>);

impl<'de, E: Curve> Visitor<'de> for ScalarBytesVisitor<E> {
type Value = Scalar<E>;

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "scalar value of {} curve", E::CURVE_NAME)
}

fn visit_bytes<Err>(self, v: &[u8]) -> Result<Self::Value, Err>
where
Err: Error,
{
Scalar::from_bytes(v).map_err(|_| Err::custom("invalid scalar"))
}
}

deserializer
.deserialize_bytes(ScalarBytesVisitor(PhantomData))
.map(ScalarFromBytes)
}
}

#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
Curve,
Scalar,
}

#[cfg(test)]
mod serde_tests {
use serde_test::{assert_tokens, Token::*};

use crate::elliptic::curves::*;

#[test]
fn test_serde_scalar() {
fn generic<E: Curve>(scalar: Scalar<E>) {
let bytes = scalar.to_bytes().to_vec();
let tokens = vec![
Struct {
name: "Scalar",
len: 2,
},
Str("curve"),
Str(E::CURVE_NAME),
Str("scalar"),
Bytes(bytes.leak()),
StructEnd,
];
assert_tokens(&scalar, &tokens);
}

// Test **zero scalars** (de)serializing
generic::<Secp256k1>(Scalar::zero());
generic::<Secp256r1>(Scalar::zero());
generic::<Ed25519>(Scalar::zero());
generic::<Ristretto>(Scalar::zero());
generic::<Bls12_381_1>(Scalar::zero());
generic::<Bls12_381_2>(Scalar::zero());

// Test **random scalars** (de)serializing
generic::<Secp256k1>(Scalar::random());
generic::<Secp256r1>(Scalar::random());
generic::<Ed25519>(Scalar::random());
generic::<Ristretto>(Scalar::random());
generic::<Bls12_381_1>(Scalar::random());
generic::<Bls12_381_2>(Scalar::random());
}
}
Loading

0 comments on commit 5abd70a

Please sign in to comment.