Skip to content

Commit

Permalink
Add generator wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Denis committed Jul 3, 2021
1 parent ef8d6f9 commit fe0083e
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 58 deletions.
6 changes: 3 additions & 3 deletions src/elliptic/curves/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub trait Curve {
///
/// Trait exposes various methods to manipulate scalars. Scalar can be zero. Scalar must zeroize its
/// value on drop.
pub trait ECScalar: Clone + PartialEq + fmt::Debug {
pub trait ECScalar: Clone + PartialEq + fmt::Debug + 'static {
/// Underlying scalar type that can be retrieved in case of missing methods in this trait
type Underlying;

Expand Down Expand Up @@ -101,8 +101,8 @@ pub trait ECScalar: Clone + PartialEq + fmt::Debug {
///
/// Trait exposes various methods that make elliptic curve arithmetic. The point can
/// be [zero](ECPoint::zero). Unlike [ECScalar], ECPoint isn't required to zeroize its value on drop,
/// but it implementы [Zeroize] trait so you can force zeroizing policy on your own.
pub trait ECPoint: Zeroize + Clone + PartialEq + fmt::Debug {
/// but it implements [Zeroize] trait so you can force zeroizing policy on your own.
pub trait ECPoint: Zeroize + Clone + PartialEq + fmt::Debug + 'static {
/// Scalar value the point can be multiplied at
type Scalar: ECScalar;
/// Underlying curve implementation that can be retrieved in case of missing methods in this trait
Expand Down
246 changes: 191 additions & 55 deletions src/elliptic/curves/wrappers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::borrow::Cow;
use std::convert::TryFrom;
use std::marker::PhantomData;
use std::{fmt, iter, ops};

use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -159,13 +160,24 @@ impl<E: Curve> fmt::Debug for PointZ<E> {
///
/// ```rust
/// # use curv::elliptic::curves::{PointZ, Point, Scalar, Secp256k1};
/// let s = Scalar::<Secp256k1>::random(); // Non-zero scalar
/// let g = Point::<Secp256k1>::generator(); // Non-zero point (curve generator)
/// let result: PointZ<Secp256k1> = s * g; // Multiplication of two non-zero points
/// // might produce zero-point
/// let p1: Point<Secp256k1> = Point::generator() * Scalar::random(); // Non-zero point
/// let p2: Point<Secp256k1> = Point::generator() * Scalar::random(); // Non-zero point
/// let result: PointZ<Secp256k1> = p1 + p2; // Addition of two (even non-zero)
/// // points might produce zero point
/// let nonzero_result: Option<Point<Secp256k1>> = result.ensure_nonzero();
/// ```
///
/// Exception is [curve generator](Self::generator) that can be multiplied at non-zero scalar, and
/// resulting point is guaranteed to be non-zero:
///
/// ```rust
/// # use curv::elliptic::curves::{PointZ, Point, Scalar, Secp256k1};
/// let s = Scalar::<Secp256k1>::random(); // Non-zero scalar
/// let g = Point::<Secp256k1>::generator(); // Curve generator
/// let result: Point<Secp256k1> = s * g; // Generator multiplied at non-zero scalar is
/// // always a non-zero point
/// ```
///
/// When evaluating complex expressions, you typically need to ensure that none of intermediate
/// results are zero-points:
///
Expand All @@ -184,9 +196,8 @@ impl<E: Curve> Point<E> {
///
/// Returns a static reference on actual value because in most cases referenced value is fine.
/// Use [`.to_point_owned()`](PointRef::to_point_owned) if you need to take it by value.
pub fn generator() -> PointRef<'static, E> {
let p = E::Point::generator();
PointRef::from_raw(p).expect("generator must be non-zero")
pub fn generator() -> Generator<E> {
Generator::default()
}

/// Curve second generator
Expand Down Expand Up @@ -308,6 +319,51 @@ impl<E: Curve> TryFrom<PointZ<E>> for Point<E> {
}
}

/// Elliptic curve generator
///
/// Holds internally a static reference on curve generator. Can be used in arithmetic interchangeably
/// as [`PointRef<E>`](PointRef).
///
/// Generator multiplied at non-zero scalar always produce non-zero point, thus output type of
/// the multiplication is [`Point<E>`](Point). This is the only difference compared to `Point<E>` and
/// `PointRef<E>`. Use [`to_point_owned`](Self::to_point_owned) and [`as_point_ref`](Self::as_point_ref)
/// methods to convert the generator into `Point<E>` and `PointRef<E>`.
pub struct Generator<E: Curve> {
_ph: PhantomData<&'static E::Point>,
}

impl<E: Curve> Default for Generator<E> {
fn default() -> Self {
Self { _ph: PhantomData }
}
}

impl<E: Curve> Generator<E> {
fn as_raw(self) -> &'static E::Point {
E::Point::generator()
}

/// Clones generator point, returns `Point<E>`
pub fn to_point_owned(self) -> Point<E> {
// Safety: curve generator must be non-zero point, otherwise nothing will work at all
unsafe { Point::from_raw_unchecked(self.as_raw().clone()) }
}

/// Converts generator into `PointRef<E>`
pub fn as_point_ref(self) -> PointRef<'static, E> {
// Safety: curve generator must be non-zero point, otherwise nothing will work at all
unsafe { PointRef::from_raw_unchecked(self.as_raw()) }
}
}

impl<E: Curve> Clone for Generator<E> {
fn clone(&self) -> Self {
Self { _ph: PhantomData }
}
}

impl<E: Curve> Copy for Generator<E> {}

/// Reference on elliptic point, _guaranteed_ to be non-zero
///
/// Holds internally a reference on [`Point<E>`](Point), refer to its documentation to learn
Expand Down Expand Up @@ -918,24 +974,29 @@ matrix! {
point_fn = add_point,
point_assign_fn = add_point_assign,
pairs = {
(o_<> Point<E>, Point<E>), (o_<> Point<E>, PointZ<E>),
(o_<> Point<E>, &Point<E>), (o_<> Point<E>, &PointZ<E>),
(r_<> &Point<E>, &Point<E>), (r_<> &Point<E>, &PointZ<E>),
(o_<'p> Point<E>, PointRef<'p, E>), (o_<> Point<E>, Generator<E>),

(o_<> PointZ<E>, Point<E>), (o_<> PointZ<E>, PointZ<E>),
(o_<> PointZ<E>, &Point<E>), (o_<> PointZ<E>, &PointZ<E>),
(r_<> &PointZ<E>, &Point<E>), (r_<> &PointZ<E>, &PointZ<E>),
(o_<> Point<E>, Point<E>), (o_<> PointZ<E>, PointZ<E>),
(o_<> Point<E>, PointZ<E>), (o_<> PointZ<E>, Point<E>),
(o_<'p> PointZ<E>, PointRef<'p, E>), (o_<> PointZ<E>, Generator<E>),

(_o<> &Point<E>, Point<E>), (_o<> &Point<E>, PointZ<E>),
(r_<> &Point<E>, &Point<E>), (r_<> &Point<E>, &PointZ<E>),
(r_<'p> &Point<E>, PointRef<'p, E>), (r_<> &Point<E>, Generator<E>),

(_o<> &PointZ<E>, Point<E>), (_o<> &PointZ<E>, PointZ<E>),
(r_<> &PointZ<E>, &Point<E>), (r_<> &PointZ<E>, &PointZ<E>),
(r_<'p> &PointZ<E>, PointRef<'p, E>), (r_<> &PointZ<E>, Generator<E>),

// The same as above, but replacing &Point<E> with PointRef<E>
(o_<'r> Point<E>, PointRef<'r, E>),
(r_<'a, 'b> PointRef<'a, E>, PointRef<'b, E>), (r_<'r> PointRef<'r, E>, &PointZ<E>),
(o_<'r> PointZ<E>, PointRef<'r, E>),
(r_<'r> &PointZ<E>, PointRef<'r, E>),
(_o<'r> PointRef<'r, E>, Point<E>), (_o<'r> PointRef<'r, E>, PointZ<E>),
(_o<'p> PointRef<'p, E>, Point<E>), (_o<'p> PointRef<'p, E>, PointZ<E>),
(r_<'p> PointRef<'p, E>, &Point<E>), (r_<'p> PointRef<'p, E>, &PointZ<E>),
(r_<'a, 'b> PointRef<'a, E>, PointRef<'b, E>), (r_<'p> PointRef<'p, E>, Generator<E>),

// And define trait between &Point<E> and PointRef<E>
(r_<'r> &Point<E>, PointRef<'r, E>), (r_<'r> PointRef<'r, E>, &Point<E>),
(_o<> Generator<E>, Point<E>), (_o<> Generator<E>, PointZ<E>),
(r_<> Generator<E>, &Point<E>), (r_<> Generator<E>, &PointZ<E>),
(r_<'p> Generator<E>, PointRef<'p, E>), (r_<> Generator<E>, Generator<E>),
}
}

Expand All @@ -947,24 +1008,29 @@ matrix! {
point_fn = sub_point,
point_assign_fn = sub_point_assign,
pairs = {
(o_<> Point<E>, Point<E>), (o_<> Point<E>, PointZ<E>),
(o_<> Point<E>, &Point<E>), (o_<> Point<E>, &PointZ<E>),
(r_<> &Point<E>, &Point<E>), (r_<> &Point<E>, &PointZ<E>),
(o_<'p> Point<E>, PointRef<'p, E>), (o_<> Point<E>, Generator<E>),

(o_<> PointZ<E>, Point<E>), (o_<> PointZ<E>, PointZ<E>),
(o_<> PointZ<E>, &Point<E>), (o_<> PointZ<E>, &PointZ<E>),
(r_<> &PointZ<E>, &Point<E>), (r_<> &PointZ<E>, &PointZ<E>),
(o_<> Point<E>, Point<E>), (o_<> PointZ<E>, PointZ<E>),
(o_<> Point<E>, PointZ<E>), (o_<> PointZ<E>, Point<E>),
(o_<'p> PointZ<E>, PointRef<'p, E>), (o_<> PointZ<E>, Generator<E>),

(_o<> &Point<E>, Point<E>), (_o<> &Point<E>, PointZ<E>),
(r_<> &Point<E>, &Point<E>), (r_<> &Point<E>, &PointZ<E>),
(r_<'p> &Point<E>, PointRef<'p, E>), (r_<> &Point<E>, Generator<E>),

(_o<> &PointZ<E>, Point<E>), (_o<> &PointZ<E>, PointZ<E>),
(r_<> &PointZ<E>, &Point<E>), (r_<> &PointZ<E>, &PointZ<E>),
(r_<'p> &PointZ<E>, PointRef<'p, E>), (r_<> &PointZ<E>, Generator<E>),

// The same as above, but replacing &Point<E> with PointRef<E>
(o_<'r> Point<E>, PointRef<'r, E>),
(r_<'a, 'b> PointRef<'a, E>, PointRef<'b, E>), (r_<'r> PointRef<'r, E>, &PointZ<E>),
(o_<'r> PointZ<E>, PointRef<'r, E>),
(r_<'r> &PointZ<E>, PointRef<'r, E>),
(_o<'r> PointRef<'r, E>, Point<E>), (_o<'r> PointRef<'r, E>, PointZ<E>),
(_o<'p> PointRef<'p, E>, Point<E>), (_o<'p> PointRef<'p, E>, PointZ<E>),
(r_<'p> PointRef<'p, E>, &Point<E>), (r_<'p> PointRef<'p, E>, &PointZ<E>),
(r_<'a, 'b> PointRef<'a, E>, PointRef<'b, E>), (r_<'p> PointRef<'p, E>, Generator<E>),

// And define trait between &Point<E> and PointRef<E>
(r_<'r> &Point<E>, PointRef<'r, E>), (r_<'r> PointRef<'r, E>, &Point<E>),
(_o<> Generator<E>, Point<E>), (_o<> Generator<E>, PointZ<E>),
(r_<> Generator<E>, &Point<E>), (r_<> Generator<E>, &PointZ<E>),
(r_<'p> Generator<E>, PointRef<'p, E>), (r_<> Generator<E>, Generator<E>),
}
}

Expand All @@ -976,33 +1042,41 @@ matrix! {
point_fn = scalar_mul,
point_assign_fn = scalar_mul_assign,
pairs = {
(_o<> Scalar<E>, Point<E>), (_o<> Scalar<E>, PointZ<E>),
(_r<> Scalar<E>, &Point<E>), (_r<> Scalar<E>, &PointZ<E>),
(_r<'p> Scalar<E>, PointRef<'p, E>), /*(_r<> Scalar<E>, Generator<E>),*/

(_o<> ScalarZ<E>, Point<E>), (_o<> ScalarZ<E>, PointZ<E>),
(_r<> ScalarZ<E>, &Point<E>), (_r<> ScalarZ<E>, &PointZ<E>),
(_r<'p> ScalarZ<E>, PointRef<'p, E>), (_r<> ScalarZ<E>, Generator<E>),

(_o<> &Scalar<E>, Point<E>), (_o<> &Scalar<E>, PointZ<E>),
(_r<> &Scalar<E>, &Point<E>), (_r<> &Scalar<E>, &PointZ<E>),
(_r<'p> &Scalar<E>, PointRef<'p, E>), /*(_r<> &Scalar<E>, Generator<E>),*/

(_o<> &ScalarZ<E>, Point<E>), (_o<> &ScalarZ<E>, PointZ<E>),
(_r<> &ScalarZ<E>, &Point<E>), (_r<> &ScalarZ<E>, &PointZ<E>),
(_r<'p> &ScalarZ<E>, PointRef<'p, E>), (_r<> &ScalarZ<E>, Generator<E>),

// --- and vice-versa ---

(o_<> Point<E>, Scalar<E>), (o_<> Point<E>, ScalarZ<E>),
(o_<> Point<E>, &Scalar<E>), (o_<> Point<E>, &ScalarZ<E>),
(r_<> &Point<E>, &Scalar<E>), (r_<> &Point<E>, &ScalarZ<E>),

(o_<> PointZ<E>, Scalar<E>), (o_<> PointZ<E>, ScalarZ<E>),
(o_<> PointZ<E>, &Scalar<E>), (o_<> PointZ<E>, &ScalarZ<E>),
(r_<> &PointZ<E>, &Scalar<E>), (r_<> &PointZ<E>, &ScalarZ<E>),
(o_<> Point<E>, Scalar<E>), (o_<> Point<E>, ScalarZ<E>),

(r_<> &Point<E>, Scalar<E>), (r_<> &Point<E>, ScalarZ<E>),
(o_<> PointZ<E>, Scalar<E>), (o_<> PointZ<E>, ScalarZ<E>),
(r_<> &Point<E>, &Scalar<E>), (r_<> &Point<E>, &ScalarZ<E>),

(r_<> &PointZ<E>, Scalar<E>), (r_<> &PointZ<E>, ScalarZ<E>),
(r_<> &PointZ<E>, &Scalar<E>), (r_<> &PointZ<E>, &ScalarZ<E>),

// The same as above but replacing &Point with PointRef
(r_<'p> PointRef<'p, E>, &Scalar<E>), (r_<'p> PointRef<'p, E>, &ScalarZ<E>),
(r_<'p> PointRef<'p, E>, Scalar<E>), (r_<'p> PointRef<'p, E>, ScalarZ<E>),
(r_<'p> PointRef<'p, E>, &Scalar<E>), (r_<'p> PointRef<'p, E>, &ScalarZ<E>),

// --- And vice-versa ---

(_o<> &Scalar<E>, Point<E>), (_o<> &ScalarZ<E>, Point<E>),
(_r<> &Scalar<E>, &Point<E>), (_r<> &ScalarZ<E>, &Point<E>),
(_o<> &Scalar<E>, PointZ<E>), (_o<> &ScalarZ<E>, PointZ<E>),
(_r<> &Scalar<E>, &PointZ<E>), (_r<> &ScalarZ<E>, &PointZ<E>),
(_o<> Scalar<E>, Point<E>), (_o<> ScalarZ<E>, Point<E>),
(_r<> Scalar<E>, &Point<E>), (_r<> ScalarZ<E>, &Point<E>),
(_o<> Scalar<E>, PointZ<E>), (_o<> ScalarZ<E>, PointZ<E>),
(_r<> Scalar<E>, &PointZ<E>), (_r<> ScalarZ<E>, &PointZ<E>),

// The same as above but replacing &Point with PointRef
(_r<'p> &Scalar<E>, PointRef<'p, E>), (_r<'p> &ScalarZ<E>, PointRef<'p, E>),
(_r<'p> Scalar<E>, PointRef<'p, E>), (_r<'p> ScalarZ<E>, PointRef<'p, E>),
/*(r_<> Generator<E>, Scalar<E>),*/ (r_<> Generator<E>, ScalarZ<E>),
/*(r_<> Generator<E>, &Scalar<E>),*/ (r_<> Generator<E>, &ScalarZ<E>),
}
}

Expand Down Expand Up @@ -1063,6 +1137,35 @@ matrix! {
}
}

impl<E: Curve> ops::Mul<&Scalar<E>> for Generator<E> {
type Output = Point<E>;
fn mul(self, rhs: &Scalar<E>) -> Self::Output {
Point::from_raw(self.as_raw().scalar_mul(&rhs.as_raw()))
.expect("generator multiplied by non-zero scalar is always non-zero point")
}
}

impl<E: Curve> ops::Mul<Scalar<E>> for Generator<E> {
type Output = Point<E>;
fn mul(self, rhs: Scalar<E>) -> Self::Output {
self.mul(&rhs)
}
}

impl<E: Curve> ops::Mul<Generator<E>> for &Scalar<E> {
type Output = Point<E>;
fn mul(self, rhs: Generator<E>) -> Self::Output {
rhs.mul(self)
}
}

impl<E: Curve> ops::Mul<Generator<E>> for Scalar<E> {
type Output = Point<E>;
fn mul(self, rhs: Generator<E>) -> Self::Output {
rhs.mul(self)
}
}

impl<E: Curve> ops::Neg for Scalar<E> {
type Output = Scalar<E>;

Expand Down Expand Up @@ -1279,8 +1382,8 @@ mod test {
fn _curve<E: Curve>() {
assert_operator_defined_for! {
assert_fn = assert_point_addition_defined,
lhs = {Point<E>, PointZ<E>, &Point<E>, &PointZ<E>, PointRef<E>},
rhs = {Point<E>, PointZ<E>, &Point<E>, &PointZ<E>, PointRef<E>},
lhs = {Point<E>, PointZ<E>, &Point<E>, &PointZ<E>, PointRef<E>, Generator<E>},
rhs = {Point<E>, PointZ<E>, &Point<E>, &PointZ<E>, PointRef<E>, Generator<E>},
}
}
}
Expand All @@ -1301,8 +1404,8 @@ mod test {
fn _curve<E: Curve>() {
assert_operator_defined_for! {
assert_fn = assert_point_subtraction_defined,
lhs = {Point<E>, PointZ<E>, &Point<E>, &PointZ<E>, PointRef<E>},
rhs = {Point<E>, PointZ<E>, &Point<E>, &PointZ<E>, PointRef<E>},
lhs = {Point<E>, PointZ<E>, &Point<E>, &PointZ<E>, PointRef<E>, Generator<E>},
rhs = {Point<E>, PointZ<E>, &Point<E>, &PointZ<E>, PointRef<E>, Generator<E>},
}
}
}
Expand All @@ -1318,6 +1421,17 @@ mod test {
// no-op
}

/// Function asserts that M can be multiplied by N (ie. M * N) and result is **non-zero** Point.
/// If any condition doesn't meet, function won't compile.
#[allow(dead_code)]
fn assert_point_nonzero_multiplication_defined<E, M, N>()
where
M: ops::Mul<N, Output = Point<E>>,
E: Curve,
{
// no-op
}

#[test]
fn test_point_multiplication_defined() {
fn _curve<E: Curve>() {
Expand All @@ -1331,6 +1445,28 @@ mod test {
lhs = {Scalar<E>, ScalarZ<E>, &Scalar<E>, &ScalarZ<E>},
rhs = {Point<E>, PointZ<E>, &Point<E>, &PointZ<E>, PointRef<E>},
}

// Checking generator's arithmetic
assert_operator_defined_for! {
assert_fn = assert_point_multiplication_defined,
lhs = {Generator<E>},
rhs = {ScalarZ<E>, &ScalarZ<E>},
}
assert_operator_defined_for! {
assert_fn = assert_point_nonzero_multiplication_defined,
lhs = {Generator<E>},
rhs = {Scalar<E>, &Scalar<E>},
}
assert_operator_defined_for! {
assert_fn = assert_point_multiplication_defined,
lhs = {ScalarZ<E>, &ScalarZ<E>},
rhs = {Generator<E>},
}
assert_operator_defined_for! {
assert_fn = assert_point_nonzero_multiplication_defined,
lhs = {Scalar<E>, &Scalar<E>},
rhs = {Generator<E>},
}
}
}

Expand Down

0 comments on commit fe0083e

Please sign in to comment.