Skip to content

Commit

Permalink
fix: review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ludamad0 committed Feb 12, 2024
1 parent 3347344 commit 2c2e5d3
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,12 @@ template <class Params_> struct alignas(32) field {
if constexpr (Params::modulus_3 >= 0x4000000000000000ULL) {
split_into_endomorphism_scalars_384(k, k1, k2);
} else {
std::pair<std::array<uint64_t, 2>, std::array<uint64_t, 2>> ret =
split_into_endomorphism_scalars_no_shift(k);
std::pair<std::array<uint64_t, 2>, std::array<uint64_t, 2>> ret = split_into_endomorphism_scalars(k);
k1.data[0] = ret.first[0];
k1.data[1] = ret.first[1];

// TODO(AD): We should move away from this hack by adapting split_into_endomorphism_scalars_no_shift
// TODO(https://github.com/AztecProtocol/barretenberg/issues/851): We should move away from this hack by
// returning pair of uint64_t[2] instead of a half-set field
#if !defined(__clang__) && defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds"
Expand All @@ -348,8 +348,10 @@ template <class Params_> struct alignas(32) field {
}
}

static std::pair<std::array<uint64_t, 2>, std::array<uint64_t, 2>> split_into_endomorphism_scalars_no_shift(
const field& k)
// NOTE: this form is only usable if the modulus is not a 256-bit integer, otherwise see

This comment has been minimized.

Copy link
@Rumata888

Rumata888 Feb 12, 2024

Contributor

You should probably specify that it has to be 254 bits or less. "Not" allows 255, for example

// split_into_endomorphism_scalars_384.
// TODO(https://github.com/AztecProtocol/barretenberg/issues/851): Unify these APIs.
static std::pair<std::array<uint64_t, 2>, std::array<uint64_t, 2>> split_into_endomorphism_scalars(const field& k)
{
static_assert(Params::modulus_3 < 0x4000000000000000ULL);
field input = k.reduce_once();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ TEST(AffineElement, Msgpack)
}

namespace bb::group_elements {
// Kludge to access mul_without_endomorphism;
// mul_with_endomorphism and mul_without_endomorphism are private in affine_element.
// We could make those public to test or create other public utilities, but to keep the API intact we
// instead mark TestElementPrivate as a friend class so that our test functions can have access.
class TestElementPrivate {
public:
template <typename Element, typename Scalar>
Expand All @@ -148,7 +150,8 @@ class TestElementPrivate {
};
} // namespace bb::group_elements

TEST(AffineElement, EndoMulMatchesNonEndo)
// Our endomorphism-specialized multiplication should match our generic multiplication
TEST(AffineElement, MulWithEndomorphismMatchesMulWithoutEndomorphism)
{
for (int i = 0; i < 100; i++) {
auto x1 = bb::group_elements::element(grumpkin::g1::affine_element::random_element());
Expand All @@ -159,19 +162,23 @@ TEST(AffineElement, EndoMulMatchesNonEndo)
}
}

TEST(AffineElement, InfinityMul)
// Multiplication of a point at infinity by a scalar should be a point at infinity
TEST(AffineElement, InfinityMulByScalarIsInfinity)
{
auto result = grumpkin::g1::affine_element::infinity() * grumpkin::fr::random_element();
EXPECT_TRUE(result.is_point_at_infinity());
}

TEST(AffineElement, BatchMulMatchesMul)
// Batched multiplication of points should match
TEST(AffineElement, BatchMulMatchesNonBatchMul)
{
constexpr size_t num_points = 1024;
constexpr size_t num_points = 512;
std::vector<grumpkin::g1::affine_element> affine_points;
for (size_t i = 0; i < num_points; ++i) {
for (size_t i = 0; i < num_points - 1; ++i) {
affine_points.emplace_back(grumpkin::g1::affine_element::random_element());
}
// Include a point at infinity to test the mixed infinity + non-infinity case
affine_points.emplace_back(grumpkin::g1::affine_element::infinity());
grumpkin::fr exponent = grumpkin::fr::random_element();
std::vector<grumpkin::g1::affine_element> result =
grumpkin::g1::element::batch_mul_with_endomorphism(affine_points, exponent);
Expand All @@ -182,7 +189,8 @@ TEST(AffineElement, BatchMulMatchesMul)
}
}

TEST(AffineElement, InfinityBatchMul)
// Batched multiplication of a point at infinity by a scalar should result in points at infinity
TEST(AffineElement, InfinityBatchMulByScalarIsInfinity)
{
constexpr size_t num_points = 1024;
std::vector<grumpkin::g1::affine_element> affine_points;
Expand Down
8 changes: 4 additions & 4 deletions barretenberg/cpp/src/barretenberg/ecc/groups/element.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,17 @@ template <class Fq, class Fr, class Params> class alignas(32) element {
const std::span<affine_element<Fq, Fr, Params>>& second_group,
const std::span<affine_element<Fq, Fr, Params>>& results) noexcept;
static std::vector<affine_element<Fq, Fr, Params>> batch_mul_with_endomorphism(
const std::span<affine_element<Fq, Fr, Params>>& points, const Fr& exponent) noexcept;
const std::span<affine_element<Fq, Fr, Params>>& points, const Fr& scalar) noexcept;

Fq x;
Fq y;
Fq z;

private:
// For access to mul_without_endomorphism
// For test access to mul_without_endomorphism
friend class TestElementPrivate;
element mul_without_endomorphism(const Fr& exponent) const noexcept;
element mul_with_endomorphism(const Fr& exponent) const noexcept;
element mul_without_endomorphism(const Fr& scalar) const noexcept;
element mul_with_endomorphism(const Fr& scalar) const noexcept;

template <typename = typename std::enable_if<Params::can_hash_to_curve>>
static element random_coordinates_on_curve(numeric::RNG* engine = nullptr) noexcept;
Expand Down
47 changes: 33 additions & 14 deletions barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,9 @@ element<Fq, Fr, T> element<Fq, Fr, T>::random_element(numeric::RNG* engine) noex
}

template <class Fq, class Fr, class T>
element<Fq, Fr, T> element<Fq, Fr, T>::mul_without_endomorphism(const Fr& exponent) const noexcept
element<Fq, Fr, T> element<Fq, Fr, T>::mul_without_endomorphism(const Fr& scalar) const noexcept
{
const uint256_t converted_scalar(exponent);
const uint256_t converted_scalar(scalar);

if (converted_scalar == 0) {
return element::infinity();
Expand All @@ -617,30 +617,49 @@ element<Fq, Fr, T> element<Fq, Fr, T>::mul_without_endomorphism(const Fr& expone
}

namespace detail {
// Represents the result of
using EndoScalars = std::pair<std::array<uint64_t, 2>, std::array<uint64_t, 2>>;
template <typename Element, std::size_t NUM_ROUNDS> struct EndomorphismWnaf {

/**
* @brief Handles the WNAF computation for scalars that are split using an endomorphism,
* achieved through `split_into_endomorphism_scalars`. It facilitates efficient computation of elliptic curve
* point multiplication by optimizing the representation of these scalars.
*
* @tparam Element The data type of elements in the elliptic curve.
* @tparam NUM_ROUNDS The number of computation rounds for WNAF.
*/
template <typename Element, std::size_t NUM_ROUNDS> struct EndomorphismWnaf {
// NUM_WNAF_BITS: Number of bits per window in the WNAF representation.
static constexpr size_t NUM_WNAF_BITS = 4;

// table: Stores the WNAF representation of the scalars.
std::array<uint64_t, NUM_ROUNDS * 2> table;
// skew and endo_skew: Indicate if our original scalar is even or odd.
bool skew = false;
bool endo_skew = false;

/**
* @param scalars A pair of 128-bit scalars (as two uint64_t arrays), split using an endomorphism.
*/
EndomorphismWnaf(const EndoScalars& scalars)
{
wnaf::fixed_wnaf(&scalars.first[0], &table[0], skew, 0, 2, NUM_WNAF_BITS);
wnaf::fixed_wnaf(&scalars.second[0], &table[1], endo_skew, 0, 2, NUM_WNAF_BITS);
}
};

} // namespace detail

template <class Fq, class Fr, class T>
element<Fq, Fr, T> element<Fq, Fr, T>::mul_with_endomorphism(const Fr& exponent) const noexcept
element<Fq, Fr, T> element<Fq, Fr, T>::mul_with_endomorphism(const Fr& scalar) const noexcept
{
// Consider the infinity flag, return infinity if set
if (is_point_at_infinity()) {
return element::infinity();
}
constexpr size_t NUM_ROUNDS = 32;
const Fr converted_scalar = exponent.from_montgomery_form();
const Fr converted_scalar = scalar.from_montgomery_form();

if (converted_scalar.is_zero() || is_point_at_infinity()) {
if (converted_scalar.is_zero()) {
return element::infinity();
}
static constexpr size_t LOOKUP_SIZE = 8;
Expand All @@ -652,7 +671,7 @@ element<Fq, Fr, T> element<Fq, Fr, T>::mul_with_endomorphism(const Fr& exponent)
lookup_table[i] = lookup_table[i - 1] + d2;
}

detail::EndoScalars endo_scalars = Fr::split_into_endomorphism_scalars_no_shift(converted_scalar);
detail::EndoScalars endo_scalars = Fr::split_into_endomorphism_scalars(converted_scalar);
detail::EndomorphismWnaf<element, NUM_ROUNDS> wnaf{ endo_scalars };
element accumulator{ T::one_x, T::one_y, Fq::one() };
accumulator.self_set_infinity();
Expand Down Expand Up @@ -771,18 +790,18 @@ void element<Fq, Fr, T>::batch_affine_add(const std::span<affine_element<Fq, Fr,
}

/**
* @brief Multiply each point by the same exponent
* @brief Multiply each point by the same scalar
*
* @details We use the fact that all points are being multiplied by the same exponent to batch the operations (perform
* @details We use the fact that all points are being multiplied by the same scalar to batch the operations (perform
* batch affine additions and doublings with batch inversion trick)
*
* @param points The span of individual points that need to be scaled
* @param exponent The scalar we multiply all the points by
* @param scalar The scalar we multiply all the points by
* @return std::vector<affine_element<Fq, Fr, T>> Vector of new points where each point is exponent⋅points[i]
*/
template <class Fq, class Fr, class T>
std::vector<affine_element<Fq, Fr, T>> element<Fq, Fr, T>::batch_mul_with_endomorphism(
const std::span<affine_element<Fq, Fr, T>>& points, const Fr& exponent) noexcept
const std::span<affine_element<Fq, Fr, T>>& points, const Fr& scalar) noexcept
{
BB_OP_COUNT_TIME();
typedef affine_element<Fq, Fr, T> affine_element;
Expand Down Expand Up @@ -883,7 +902,7 @@ std::vector<affine_element<Fq, Fr, T>> element<Fq, Fr, T>::batch_mul_with_endomo
/*finite_field_multiplications_per_iteration=*/6);
};
// Compute wnaf for scalar
const Fr converted_scalar = exponent.from_montgomery_form();
const Fr converted_scalar = scalar.from_montgomery_form();

// If the scalar is zero, just set results to the point at infinity
if (converted_scalar.is_zero()) {
Expand Down Expand Up @@ -953,7 +972,7 @@ std::vector<affine_element<Fq, Fr, T>> element<Fq, Fr, T>::batch_mul_with_endomo
batch_affine_add_internal(&temp_point_vector[0], &lookup_table[j][0]);
}

detail::EndoScalars endo_scalars = Fr::split_into_endomorphism_scalars_no_shift(converted_scalar);
detail::EndoScalars endo_scalars = Fr::split_into_endomorphism_scalars(converted_scalar);
detail::EndomorphismWnaf<element, NUM_ROUNDS> wnaf{ endo_scalars };

std::vector<affine_element> work_elements(num_points);
Expand Down
13 changes: 13 additions & 0 deletions barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ inline void fixed_wnaf_packed(
wnaf[0] = ((slice + predicate) >> 1UL) | (point_index);
}

/**
* @brief Performs fixed-window non-adjacent form (WNAF) computation for scalar multiplication.
*
* WNAF is a method for representing integers which optimizes the number of non-zero terms, which in turn optimizes
* the number of point doublings in scalar multiplication, in turn aiding efficiency.
*
* @param scalar Pointer to 128-bit scalar for which WNAF is to be computed.
* @param wnaf Pointer to num_points+1 size array where the computed WNAF will be stored.
* @param skew_map Reference to a boolean variable which will be set based on the least significant bit of the scalar.
* @param point_index The index of the point being computed in the context of multiple point multiplication.
* @param num_points The number of points being computed in parallel.
* @param wnaf_bits The number of bits to use in each window of the WNAF representation.
*/
inline void fixed_wnaf(const uint64_t* scalar,
uint64_t* wnaf,
bool& skew_map,
Expand Down

0 comments on commit 2c2e5d3

Please sign in to comment.