Skip to content

Commit

Permalink
Merge pull request kroma-network#443 from kroma-network/feat/implemen…
Browse files Browse the repository at this point in the history
…t-field-extension-for-stark-fields

feat(math): implement field extension for stark fields
  • Loading branch information
Ryan Kim authored Jun 26, 2024
2 parents d11a0f4 + b178601 commit 58b3f55
Show file tree
Hide file tree
Showing 42 changed files with 1,247 additions and 219 deletions.
2 changes: 1 addition & 1 deletion tachyon/math/base/arithmetics.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ ALWAYS_INLINE SubResult<uint32_t> SubWithBorrow(uint32_t a, uint32_t b,
ALWAYS_INLINE constexpr MulResult<uint32_t> MulAddWithCarry(
uint32_t a, uint32_t b, uint32_t c, uint32_t carry = 0) {
uint64_t tmp = uint64_t{a} + uint64_t{b} * uint64_t{c} + uint64_t{carry};
MulResult<uint32_t> result;
MulResult<uint32_t> result{};
result.lo = static_cast<uint32_t>(tmp);
result.hi = static_cast<uint32_t>(tmp >> 32);
return result;
Expand Down
28 changes: 14 additions & 14 deletions tachyon/math/base/big_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {

// Returns the maximum representable value for BigInt.
constexpr static BigInt Max() {
BigInt ret;
BigInt ret{};
for (uint64_t& limb : ret.limbs) {
limb = std::numeric_limits<uint64_t>::max();
}
Expand All @@ -112,7 +112,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {

// Generate a random BigInt between [0, |max|).
constexpr static BigInt Random(const BigInt& max = Max()) {
BigInt ret;
BigInt ret{};
for (size_t i = 0; i < N; ++i) {
ret[i] = base::Uniform(base::Range<uint64_t>::All());
}
Expand All @@ -123,14 +123,14 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

// Convert a decimal string to a BigInt.
constexpr static std::optional<BigInt> FromDecString(std::string_view str) {
static std::optional<BigInt> FromDecString(std::string_view str) {
BigInt ret;
if (!internal::StringToLimbs(str, ret.limbs, N)) return std::nullopt;
return ret;
}

// Convert a hexadecimal string to a BigInt.
constexpr static std::optional<BigInt> FromHexString(std::string_view str) {
static std::optional<BigInt> FromHexString(std::string_view str) {
BigInt ret;
if (!(internal::HexStringToLimbs(str, ret.limbs, N))) return std::nullopt;
return ret;
Expand All @@ -141,7 +141,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
template <size_t BitNums = kBitNums>
constexpr static BigInt FromBitsLE(const std::bitset<BitNums>& bits) {
static_assert(BitNums <= kBitNums);
BigInt ret;
BigInt ret{};
size_t bit_idx = 0;
size_t limb_idx = 0;
std::bitset<kLimbBitNums> limb_bits;
Expand All @@ -167,7 +167,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
template <size_t BitNums = kBitNums>
constexpr static BigInt FromBitsBE(const std::bitset<BitNums>& bits) {
static_assert(BitNums <= kBitNums);
BigInt ret;
BigInt ret{};
std::bitset<kLimbBitNums> limb_bits;
size_t bit_idx = 0;
size_t limb_idx = 0;
Expand Down Expand Up @@ -196,7 +196,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
// ordering.
template <typename ByteContainer>
constexpr static BigInt FromBytesLE(const ByteContainer& bytes) {
BigInt ret;
BigInt ret{};
size_t byte_idx = 0;
size_t limb_idx = 0;
uint64_t limb = 0;
Expand Down Expand Up @@ -224,7 +224,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
// ordering.
template <typename ByteContainer>
constexpr static BigInt FromBytesBE(const ByteContainer& bytes) {
BigInt ret;
BigInt ret{};
size_t byte_idx = 0;
size_t limb_idx = 0;
uint64_t limb = 0;
Expand Down Expand Up @@ -446,7 +446,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt operator&(const BigInt& other) const {
BigInt ret;
BigInt ret{};
if constexpr (N == 1) {
ret[0] = limbs[0] & other[0];
} else if constexpr (N == 2) {
Expand All @@ -471,7 +471,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt operator|(const BigInt& other) const {
BigInt ret;
BigInt ret{};
if constexpr (N == 1) {
ret[0] = limbs[0] | other[0];
} else if constexpr (N == 2) {
Expand All @@ -496,7 +496,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt operator^(const BigInt& other) const {
BigInt ret;
BigInt ret{};
if constexpr (N == 1) {
ret[0] = limbs[0] ^ other[0];
} else if constexpr (N == 2) {
Expand Down Expand Up @@ -548,7 +548,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt Add(const BigInt& other, uint64_t& carry) const {
BigInt ret;
BigInt ret{};
DoAdd(*this, other, carry, ret);
return ret;
}
Expand All @@ -569,7 +569,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt Sub(const BigInt& other, uint64_t& borrow) const {
BigInt ret;
BigInt ret{};
DoSub(*this, other, borrow, ret);
return ret;
}
Expand All @@ -590,7 +590,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt MulBy2(uint64_t& carry) const {
BigInt ret;
BigInt ret{};
DoMulBy2(*this, carry, ret);
return ret;
}
Expand Down
1 change: 1 addition & 0 deletions tachyon/math/circle/stark/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ generate_fp2s(
generate_fp4s(
name = "fq4",
base_field = "Fq2",
base_field_degree = 2,
base_field_hdr = "tachyon/math/circle/stark/fq2.h",
class_name = "Fq4",
namespace = "tachyon::math::stark",
Expand Down
2 changes: 1 addition & 1 deletion tachyon/math/elliptic_curves/msm/msm_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct TACHYON_EXPORT MSMCtx {

template <typename ScalarField>
constexpr static MSMCtx CreateDefault(size_t size) {
MSMCtx ctx;
MSMCtx ctx{};
ctx.window_bits = ComputeWindowsBits(size);
ctx.window_count = ComputeWindowsCount<ScalarField>(ctx.window_bits);
ctx.size = size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class AffinePoint<

constexpr static std::optional<AffinePoint> CreateFromX(const BaseField& x,
bool pick_odd) {
AffinePoint point;
AffinePoint point{};
if (!Curve::GetPointFromX(x, pick_odd, &point)) return std::nullopt;
return point;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class JacobianPoint<

constexpr static std::optional<JacobianPoint> CreateFromX(const BaseField& x,
bool pick_odd) {
JacobianPoint point;
JacobianPoint point{};
if (!Curve::GetPointFromX(x, pick_odd, &point)) return std::nullopt;
return point;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class PointXYZZ<_Curve,

constexpr static std::optional<PointXYZZ> CreateFromX(const BaseField& x,
bool pick_odd) {
PointXYZZ point;
PointXYZZ point{};
if (!Curve::GetPointFromX(x, pick_odd, &point)) return std::nullopt;
return point;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ProjectivePoint<

constexpr static std::optional<ProjectivePoint> CreateFromX(
const BaseField& x, bool pick_odd) {
ProjectivePoint point;
ProjectivePoint point{};
if (!Curve::GetPointFromX(x, pick_odd, &point)) return std::nullopt;
return point;
}
Expand Down
4 changes: 2 additions & 2 deletions tachyon/math/elliptic_curves/short_weierstrass/sw_curve.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class SWCurve {
template <typename Point>
constexpr static bool GetPointFromX(const BaseField& x, bool pick_odd,
Point* point) {
BaseField even_y;
BaseField odd_y;
BaseField even_y{};
BaseField odd_y{};
if (!GetYsFromX(x, &even_y, &odd_y)) return false;
if constexpr (std::is_same_v<Point, AffinePoint>) {
*point = AffinePoint(x, pick_odd ? odd_y : even_y);
Expand Down
17 changes: 15 additions & 2 deletions tachyon/math/finite_fields/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ tachyon_cc_library(
":cyclotomic_multiplicative_subgroup",
"//tachyon/base/buffer:copyable",
"//tachyon/base/json",
"//tachyon/math/geometry:point3",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
Expand Down Expand Up @@ -73,6 +72,7 @@ tachyon_cc_library(
hdrs = ["fp4.h"],
deps = [
":quadratic_extension_field",
":quartic_extension_field",
"//tachyon/math/base/gmp:gmp_util",
],
)
Expand Down Expand Up @@ -231,7 +231,18 @@ tachyon_cc_library(
":cyclotomic_multiplicative_subgroup",
"//tachyon/base/buffer:copyable",
"//tachyon/base/json",
"//tachyon/math/geometry:point2",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

tachyon_cc_library(
name = "quartic_extension_field",
hdrs = ["quartic_extension_field.h"],
deps = [
":cyclotomic_multiplicative_subgroup",
"//tachyon/base/buffer:copyable",
"//tachyon/base/json",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
Expand Down Expand Up @@ -277,6 +288,7 @@ tachyon_cc_unittest(
"prime_field_generator_unittest.cc",
"prime_field_unittest.cc",
"quadratic_extension_field_unittest.cc",
"quartic_extension_field_unittest.cc",
] + select({
"@platforms//cpu:x86_64": ["packed_prime_field_unittest.cc"],
"@platforms//cpu:aarch64": ["packed_prime_field_unittest.cc"],
Expand All @@ -303,6 +315,7 @@ tachyon_cc_unittest(
"//tachyon/math/elliptic_curves/secp/secp256k1:fq",
"//tachyon/math/elliptic_curves/secp/secp256k1:fr",
"//tachyon/math/finite_fields/baby_bear",
"//tachyon/math/finite_fields/baby_bear:baby_bear4",
"//tachyon/math/finite_fields/binary_fields",
"//tachyon/math/finite_fields/goldilocks:goldilocks_prime_field",
"//tachyon/math/finite_fields/koala_bear",
Expand Down
12 changes: 12 additions & 0 deletions tachyon/math/finite_fields/baby_bear/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library")
load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp4s")
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields")

package(default_visibility = ["//visibility:public"])
Expand All @@ -14,6 +15,17 @@ generate_prime_fields(
use_montgomery = True,
)

generate_fp4s(
name = "baby_bear4",
base_field = "BabyBear",
base_field_degree = 1,
base_field_hdr = "tachyon/math/finite_fields/baby_bear/baby_bear.h",
class_name = "BabyBear4",
namespace = "tachyon::math",
non_residue = ["11"],
deps = [":baby_bear"],
)

tachyon_cc_library(
name = "packed_baby_bear",
hdrs = ["packed_baby_bear.h"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ TYPED_TEST(BinaryFieldsTest, ComparisonOperator) {
if constexpr (BinaryField::Config::kModulusBits > 3) {
BinaryField f(3);
BinaryField f2(4);
EXPECT_TRUE(f < f2);
EXPECT_TRUE(f <= f2);
EXPECT_FALSE(f > f2);
EXPECT_FALSE(f >= f2);
EXPECT_LT(f, f2);
EXPECT_LE(f, f2);
EXPECT_GT(f2, f);
EXPECT_GE(f2, f);
} else {
GTEST_SKIP() << "Modulus is too small";
}
Expand Down
Loading

0 comments on commit 58b3f55

Please sign in to comment.