Skip to content

Commit

Permalink
Improve scalar serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Denis committed Jul 26, 2021
1 parent 4a33058 commit 1357bd5
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 24 deletions.
8 changes: 5 additions & 3 deletions src/elliptic/curves/bls12_381/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ pub struct FieldScalar {
impl ECScalar for FieldScalar {
type Underlying = SK;

type ScalarBytes = [u8; 32];

fn random() -> FieldScalar {
FieldScalar {
purpose: "random",
Expand Down Expand Up @@ -73,10 +75,10 @@ impl ECScalar for FieldScalar {
BigInt::from_bytes(&bytes)
}

fn serialize(&self) -> Vec<u8> {
fn serialize(&self) -> Self::ScalarBytes {
let repr = self.fe.into_repr();
let mut bytes = Vec::with_capacity(SECRET_KEY_SIZE);
repr.write_be(&mut bytes).unwrap();
let mut bytes = [0u8; SECRET_KEY_SIZE];
repr.write_be(&mut bytes[..]).unwrap();
bytes
}

Expand Down
6 changes: 4 additions & 2 deletions src/elliptic/curves/curve_ristretto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ impl Curve for Ristretto {
impl ECScalar for RistrettoScalar {
type Underlying = SK;

type ScalarBytes = [u8; 32];

fn random() -> RistrettoScalar {
RistrettoScalar {
purpose: "random",
Expand Down Expand Up @@ -116,8 +118,8 @@ impl ECScalar for RistrettoScalar {
BigInt::from_bytes(&t)
}

fn serialize(&self) -> Vec<u8> {
self.fe.to_bytes().to_vec()
fn serialize(&self) -> Self::ScalarBytes {
self.fe.to_bytes()
}

fn deserialize(bytes: &[u8]) -> Result<Self, DeserializationError> {
Expand Down
6 changes: 4 additions & 2 deletions src/elliptic/curves/ed25519.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ impl Curve for Ed25519 {
impl ECScalar for Ed25519Scalar {
type Underlying = SK;

type ScalarBytes = [u8; 32];

// we chose to multiply by 8 (co-factor) all group elements to work in the prime order sub group.
// each random fe is having its 3 first bits zeroed
fn random() -> Ed25519Scalar {
Expand Down Expand Up @@ -171,8 +173,8 @@ impl ECScalar for Ed25519Scalar {
BigInt::from_bytes(&t)
}

fn serialize(&self) -> Vec<u8> {
self.fe.to_bytes().to_vec()
fn serialize(&self) -> Self::ScalarBytes {
self.fe.to_bytes()
}

fn deserialize(bytes: &[u8]) -> Result<Self, DeserializationError> {
Expand Down
6 changes: 4 additions & 2 deletions src/elliptic/curves/p256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ impl Curve for Secp256r1 {
impl ECScalar for Secp256r1Scalar {
type Underlying = SK;

type ScalarBytes = FieldBytes;

fn random() -> Secp256r1Scalar {
let mut rng = thread_rng();
let scalar = loop {
Expand Down Expand Up @@ -127,8 +129,8 @@ impl ECScalar for Secp256r1Scalar {
BigInt::from_bytes(self.fe.to_bytes().as_slice())
}

fn serialize(&self) -> Vec<u8> {
self.fe.to_bytes().to_vec()
fn serialize(&self) -> Self::ScalarBytes {
self.fe.to_bytes()
}

fn deserialize(bytes: &[u8]) -> Result<Self, DeserializationError> {
Expand Down
17 changes: 9 additions & 8 deletions src/elliptic/curves/secp256_k1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ type FE = Secp256k1Scalar;
impl ECScalar for Secp256k1Scalar {
type Underlying = Option<SK>;

type ScalarBytes = [u8; 32];

fn random() -> Secp256k1Scalar {
let sk = SK(SecretKey::new(&mut rand_legacy::thread_rng()));
Secp256k1Scalar {
Expand Down Expand Up @@ -201,16 +203,15 @@ impl ECScalar for Secp256k1Scalar {
}
}

fn serialize(&self) -> Vec<u8> {
let scalar = match &*self.fe {
Some(s) => s,
None => return vec![0u8],
};
scalar.0[..].to_vec()
fn serialize(&self) -> Self::ScalarBytes {
match &*self.fe {
Some(s) => *s.as_ref(),
None => [0u8; 32],
}
}

fn deserialize(bytes: &[u8]) -> Result<Self, DeserializationError> {
let pk = if bytes == [0] {
let sk = if bytes == [0; 32] {
None
} else {
Some(SK(
Expand All @@ -219,7 +220,7 @@ impl ECScalar for Secp256k1Scalar {
};
Ok(Secp256k1Scalar {
purpose: "deserialize",
fe: pk.into(),
fe: sk.into(),
})
}

Expand Down
2 changes: 1 addition & 1 deletion src/elliptic/curves/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ fn serialize_deserialize_scalar<E: Curve>() {
let zero = E::Scalar::zero();
for scalar in [rand_point, zero] {
let bytes = scalar.serialize();
let deserialized = <E::Scalar as ECScalar>::deserialize(&bytes).unwrap();
let deserialized = <E::Scalar as ECScalar>::deserialize(bytes.as_ref()).unwrap();
assert_eq!(scalar, deserialized);
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/elliptic/curves/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ pub trait ECScalar: Clone + PartialEq + fmt::Debug + 'static {
/// Underlying scalar type that can be retrieved in case of missing methods in this trait
type Underlying;

/// Serialized scalar
type ScalarBytes: AsRef<[u8]>;

/// Samples a random scalar
fn random() -> Self;

Expand All @@ -52,7 +55,7 @@ pub trait ECScalar: Clone + PartialEq + fmt::Debug + 'static {
/// Converts a scalar to BigInt
fn to_bigint(&self) -> BigInt;
/// Serializes scalar into bytes
fn serialize(&self) -> Vec<u8>;
fn serialize(&self) -> Self::ScalarBytes;
/// Deserializes scalar from bytes
fn deserialize(bytes: &[u8]) -> Result<Self, DeserializationError>;

Expand Down
29 changes: 29 additions & 0 deletions src/elliptic/curves/wrappers/encoded_scalar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use std::ops::Deref;

use crate::elliptic::curves::{Curve, ECScalar, Scalar};

/// Encoded scalar
pub struct EncodedScalar<E: Curve> {
bytes: <E::Scalar as ECScalar>::ScalarBytes,
}

impl<E: Curve> Deref for EncodedScalar<E> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.bytes.as_ref()
}
}

impl<'s, E: Curve> From<&'s Scalar<E>> for EncodedScalar<E> {
fn from(s: &'s Scalar<E>) -> Self {
Self {
bytes: s.as_raw().serialize(),
}
}
}

impl<E: Curve> From<Scalar<E>> for EncodedScalar<E> {
fn from(s: Scalar<E>) -> Self {
Self::from(&s)
}
}
1 change: 1 addition & 0 deletions src/elliptic/curves/wrappers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod arithmetic;
mod encoded_point;
mod encoded_scalar;
pub mod error;
mod format;
mod generator;
Expand Down
2 changes: 1 addition & 1 deletion src/elliptic/curves/wrappers/point_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl<'p, E: Curve> Serialize for PointRef<'p, E> {
s.serialize_field(
"compressed_point",
// Serializes bytes efficiently
&Bytes::new(self.to_bytes(true).as_ref()),
Bytes::new(self.to_bytes(true).as_ref()),
)?;
s.end()
}
Expand Down
164 changes: 160 additions & 4 deletions src/elliptic/curves/wrappers/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
use std::marker::PhantomData;
use std::{fmt, iter};

use serde::de::{Error, MapAccess, Visitor};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize};
use serde::{Deserializer, Serializer};

use crate::elliptic::curves::traits::{Curve, ECScalar};
use crate::BigInt;

use super::format::ScalarFormat;
use crate::elliptic::curves::ZeroScalarError;
use super::point::CurveNameGuard;
use crate::elliptic::curves::wrappers::encoded_scalar::EncodedScalar;
use crate::elliptic::curves::{DeserializationError, ZeroScalarError};

/// Scalar value in a prime field
///
Expand Down Expand Up @@ -35,8 +40,6 @@ use crate::elliptic::curves::ZeroScalarError;
/// a + b * c
/// }
/// ```
#[derive(Serialize, Deserialize)]
#[serde(try_from = "ScalarFormat<E>", into = "ScalarFormat<E>", bound = "")]
pub struct Scalar<E: Curve> {
raw_scalar: E::Scalar,
}
Expand Down Expand Up @@ -81,6 +84,16 @@ impl<E: Curve> Scalar<E> {
Self::from_raw(E::Scalar::from_bigint(n))
}

/// Serializes a scalar to bytes
pub fn to_bytes(&self) -> EncodedScalar<E> {
EncodedScalar::from(self)
}

/// Constructs a scalar from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, DeserializationError> {
ECScalar::deserialize(bytes).map(Self::from_raw)
}

/// Returns an order of generator point
pub fn group_order() -> &'static BigInt {
E::Scalar::group_order()
Expand Down Expand Up @@ -200,3 +213,146 @@ impl<'s, E: Curve> iter::Product<&'s Scalar<E>> for Scalar<E> {
iter.fold(Scalar::from(1), |acc, s| acc * s)
}
}

impl<E: Curve> Serialize for Scalar<E> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut s = serializer.serialize_struct("Scalar", 2)?;
s.serialize_field("curve", E::CURVE_NAME)?;
s.serialize_field(
"scalar",
// Serializes bytes efficiently
serde_bytes::Bytes::new(&self.to_bytes()),
)?;
s.end()
}
}

impl<'de, E: Curve> Deserialize<'de> for Scalar<E> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ScalarVisitor<E: Curve>(PhantomData<E>);

impl<'de, E: Curve> Visitor<'de> for ScalarVisitor<E> {
type Value = Scalar<E>;

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "scalar of {} curve", E::CURVE_NAME)
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut curve_name: Option<CurveNameGuard<E>> = None;
let mut scalar: Option<ScalarFromBytes<E>> = None;

while let Some(key) = map.next_key()? {
match key {
Field::Curve => {
if curve_name.is_some() {
return Err(A::Error::duplicate_field("curve_name"));
}
curve_name = Some(map.next_value()?)
}
Field::Scalar => {
if scalar.is_some() {
return Err(A::Error::duplicate_field("scalar"));
}
scalar = Some(map.next_value()?)
}
}
}
let _curve_name =
curve_name.ok_or_else(|| A::Error::missing_field("curve_name"))?;
let scalar = scalar.ok_or_else(|| A::Error::missing_field("scalar"))?;
Ok(scalar.0)
}
}

deserializer.deserialize_struct("Scalar", &["curve", "scalar"], ScalarVisitor(PhantomData))
}
}

struct ScalarFromBytes<E: Curve>(Scalar<E>);

impl<'de, E: Curve> Deserialize<'de> for ScalarFromBytes<E> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ScalarBytesVisitor<E: Curve>(PhantomData<E>);

impl<'de, E: Curve> Visitor<'de> for ScalarBytesVisitor<E> {
type Value = Scalar<E>;

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "scalar value of {} curve", E::CURVE_NAME)
}

fn visit_bytes<Err>(self, v: &[u8]) -> Result<Self::Value, Err>
where
Err: Error,
{
Scalar::from_bytes(v).map_err(|_| Err::custom(format!("invalid scalar")))
}
}

deserializer
.deserialize_bytes(ScalarBytesVisitor(PhantomData))
.map(ScalarFromBytes)
}
}

#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
Curve,
Scalar,
}

#[cfg(test)]
mod serde_tests {
use serde_test::{assert_tokens, Token::*};

use crate::elliptic::curves::*;

#[test]
fn test_serde_scalar() {
fn generic<E: Curve>(scalar: Scalar<E>) {
let bytes = scalar.to_bytes().to_vec();
let tokens = vec![
Struct {
name: "Scalar",
len: 2,
},
Str("curve"),
Str(E::CURVE_NAME),
Str("scalar"),
Bytes(bytes.leak()),
StructEnd,
];
assert_tokens(&scalar, &tokens);
}

// Test **zero scalars** (de)serializing
generic::<Secp256k1>(Scalar::zero());
generic::<Secp256r1>(Scalar::zero());
generic::<Ed25519>(Scalar::zero());
generic::<Ristretto>(Scalar::zero());
generic::<Bls12_381_1>(Scalar::zero());
generic::<Bls12_381_2>(Scalar::zero());

// Test **random scalars** (de)serializing
generic::<Secp256k1>(Scalar::random());
generic::<Secp256r1>(Scalar::random());
generic::<Ed25519>(Scalar::random());
generic::<Ristretto>(Scalar::random());
generic::<Bls12_381_1>(Scalar::random());
generic::<Bls12_381_2>(Scalar::random());
}
}

0 comments on commit 1357bd5

Please sign in to comment.