Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(math): implement field extension for stark fields #443

Merged
merged 15 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tachyon/math/base/arithmetics.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ ALWAYS_INLINE SubResult<uint32_t> SubWithBorrow(uint32_t a, uint32_t b,
ALWAYS_INLINE constexpr MulResult<uint32_t> MulAddWithCarry(
uint32_t a, uint32_t b, uint32_t c, uint32_t carry = 0) {
uint64_t tmp = uint64_t{a} + uint64_t{b} * uint64_t{c} + uint64_t{carry};
MulResult<uint32_t> result;
MulResult<uint32_t> result{};
fakedev9999 marked this conversation as resolved.
Show resolved Hide resolved
result.lo = static_cast<uint32_t>(tmp);
result.hi = static_cast<uint32_t>(tmp >> 32);
return result;
Expand Down
28 changes: 14 additions & 14 deletions tachyon/math/base/big_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {

// Returns the maximum representable value for BigInt.
constexpr static BigInt Max() {
BigInt ret;
BigInt ret{};
for (uint64_t& limb : ret.limbs) {
limb = std::numeric_limits<uint64_t>::max();
}
Expand All @@ -112,7 +112,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {

// Generate a random BigInt between [0, |max|).
constexpr static BigInt Random(const BigInt& max = Max()) {
BigInt ret;
BigInt ret{};
for (size_t i = 0; i < N; ++i) {
ret[i] = base::Uniform(base::Range<uint64_t>::All());
}
Expand All @@ -123,14 +123,14 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

// Convert a decimal string to a BigInt.
constexpr static std::optional<BigInt> FromDecString(std::string_view str) {
static std::optional<BigInt> FromDecString(std::string_view str) {
BigInt ret;
if (!internal::StringToLimbs(str, ret.limbs, N)) return std::nullopt;
return ret;
}

// Convert a hexadecimal string to a BigInt.
constexpr static std::optional<BigInt> FromHexString(std::string_view str) {
static std::optional<BigInt> FromHexString(std::string_view str) {
BigInt ret;
if (!(internal::HexStringToLimbs(str, ret.limbs, N))) return std::nullopt;
return ret;
Expand All @@ -141,7 +141,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
template <size_t BitNums = kBitNums>
constexpr static BigInt FromBitsLE(const std::bitset<BitNums>& bits) {
static_assert(BitNums <= kBitNums);
BigInt ret;
BigInt ret{};
size_t bit_idx = 0;
size_t limb_idx = 0;
std::bitset<kLimbBitNums> limb_bits;
Expand All @@ -167,7 +167,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
template <size_t BitNums = kBitNums>
constexpr static BigInt FromBitsBE(const std::bitset<BitNums>& bits) {
static_assert(BitNums <= kBitNums);
BigInt ret;
BigInt ret{};
std::bitset<kLimbBitNums> limb_bits;
size_t bit_idx = 0;
size_t limb_idx = 0;
Expand Down Expand Up @@ -196,7 +196,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
// ordering.
template <typename ByteContainer>
constexpr static BigInt FromBytesLE(const ByteContainer& bytes) {
BigInt ret;
BigInt ret{};
size_t byte_idx = 0;
size_t limb_idx = 0;
uint64_t limb = 0;
Expand Down Expand Up @@ -224,7 +224,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
// ordering.
template <typename ByteContainer>
constexpr static BigInt FromBytesBE(const ByteContainer& bytes) {
BigInt ret;
BigInt ret{};
size_t byte_idx = 0;
size_t limb_idx = 0;
uint64_t limb = 0;
Expand Down Expand Up @@ -446,7 +446,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt operator&(const BigInt& other) const {
BigInt ret;
BigInt ret{};
if constexpr (N == 1) {
ret[0] = limbs[0] & other[0];
} else if constexpr (N == 2) {
Expand All @@ -471,7 +471,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt operator|(const BigInt& other) const {
BigInt ret;
BigInt ret{};
if constexpr (N == 1) {
ret[0] = limbs[0] | other[0];
} else if constexpr (N == 2) {
Expand All @@ -496,7 +496,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt operator^(const BigInt& other) const {
BigInt ret;
BigInt ret{};
if constexpr (N == 1) {
ret[0] = limbs[0] ^ other[0];
} else if constexpr (N == 2) {
Expand Down Expand Up @@ -548,7 +548,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt Add(const BigInt& other, uint64_t& carry) const {
BigInt ret;
BigInt ret{};
DoAdd(*this, other, carry, ret);
return ret;
}
Expand All @@ -569,7 +569,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt Sub(const BigInt& other, uint64_t& borrow) const {
BigInt ret;
BigInt ret{};
DoSub(*this, other, borrow, ret);
return ret;
}
Expand All @@ -590,7 +590,7 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
}

constexpr BigInt MulBy2(uint64_t& carry) const {
BigInt ret;
BigInt ret{};
DoMulBy2(*this, carry, ret);
return ret;
}
Expand Down
1 change: 1 addition & 0 deletions tachyon/math/circle/stark/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ generate_fp2s(
generate_fp4s(
name = "fq4",
base_field = "Fq2",
base_field_degree = 2,
base_field_hdr = "tachyon/math/circle/stark/fq2.h",
class_name = "Fq4",
namespace = "tachyon::math::stark",
Expand Down
2 changes: 1 addition & 1 deletion tachyon/math/elliptic_curves/msm/msm_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct TACHYON_EXPORT MSMCtx {

template <typename ScalarField>
constexpr static MSMCtx CreateDefault(size_t size) {
MSMCtx ctx;
MSMCtx ctx{};
ctx.window_bits = ComputeWindowsBits(size);
ctx.window_count = ComputeWindowsCount<ScalarField>(ctx.window_bits);
ctx.size = size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class AffinePoint<

constexpr static std::optional<AffinePoint> CreateFromX(const BaseField& x,
bool pick_odd) {
AffinePoint point;
AffinePoint point{};
if (!Curve::GetPointFromX(x, pick_odd, &point)) return std::nullopt;
return point;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class JacobianPoint<

constexpr static std::optional<JacobianPoint> CreateFromX(const BaseField& x,
bool pick_odd) {
JacobianPoint point;
JacobianPoint point{};
if (!Curve::GetPointFromX(x, pick_odd, &point)) return std::nullopt;
return point;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class PointXYZZ<_Curve,

constexpr static std::optional<PointXYZZ> CreateFromX(const BaseField& x,
bool pick_odd) {
PointXYZZ point;
PointXYZZ point{};
if (!Curve::GetPointFromX(x, pick_odd, &point)) return std::nullopt;
return point;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ProjectivePoint<

constexpr static std::optional<ProjectivePoint> CreateFromX(
const BaseField& x, bool pick_odd) {
ProjectivePoint point;
ProjectivePoint point{};
if (!Curve::GetPointFromX(x, pick_odd, &point)) return std::nullopt;
return point;
}
Expand Down
4 changes: 2 additions & 2 deletions tachyon/math/elliptic_curves/short_weierstrass/sw_curve.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class SWCurve {
template <typename Point>
constexpr static bool GetPointFromX(const BaseField& x, bool pick_odd,
Point* point) {
BaseField even_y;
BaseField odd_y;
BaseField even_y{};
BaseField odd_y{};
if (!GetYsFromX(x, &even_y, &odd_y)) return false;
if constexpr (std::is_same_v<Point, AffinePoint>) {
*point = AffinePoint(x, pick_odd ? odd_y : even_y);
Expand Down
17 changes: 15 additions & 2 deletions tachyon/math/finite_fields/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ tachyon_cc_library(
":cyclotomic_multiplicative_subgroup",
"//tachyon/base/buffer:copyable",
"//tachyon/base/json",
"//tachyon/math/geometry:point3",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
Expand Down Expand Up @@ -73,6 +72,7 @@ tachyon_cc_library(
hdrs = ["fp4.h"],
deps = [
":quadratic_extension_field",
":quartic_extension_field",
"//tachyon/math/base/gmp:gmp_util",
],
)
Expand Down Expand Up @@ -231,7 +231,18 @@ tachyon_cc_library(
":cyclotomic_multiplicative_subgroup",
"//tachyon/base/buffer:copyable",
"//tachyon/base/json",
"//tachyon/math/geometry:point2",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

tachyon_cc_library(
name = "quartic_extension_field",
hdrs = ["quartic_extension_field.h"],
deps = [
":cyclotomic_multiplicative_subgroup",
"//tachyon/base/buffer:copyable",
"//tachyon/base/json",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
Expand Down Expand Up @@ -277,6 +288,7 @@ tachyon_cc_unittest(
"prime_field_generator_unittest.cc",
"prime_field_unittest.cc",
"quadratic_extension_field_unittest.cc",
"quartic_extension_field_unittest.cc",
] + select({
"@platforms//cpu:x86_64": ["packed_prime_field_unittest.cc"],
"@platforms//cpu:aarch64": ["packed_prime_field_unittest.cc"],
Expand All @@ -303,6 +315,7 @@ tachyon_cc_unittest(
"//tachyon/math/elliptic_curves/secp/secp256k1:fq",
"//tachyon/math/elliptic_curves/secp/secp256k1:fr",
"//tachyon/math/finite_fields/baby_bear",
"//tachyon/math/finite_fields/baby_bear:baby_bear4",
"//tachyon/math/finite_fields/binary_fields",
"//tachyon/math/finite_fields/goldilocks:goldilocks_prime_field",
"//tachyon/math/finite_fields/koala_bear",
Expand Down
12 changes: 12 additions & 0 deletions tachyon/math/finite_fields/baby_bear/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library")
load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp4s")
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields")

package(default_visibility = ["//visibility:public"])
Expand All @@ -14,6 +15,17 @@ generate_prime_fields(
use_montgomery = True,
)

generate_fp4s(
name = "baby_bear4",
base_field = "BabyBear",
base_field_degree = 1,
base_field_hdr = "tachyon/math/finite_fields/baby_bear/baby_bear.h",
class_name = "BabyBear4",
namespace = "tachyon::math",
non_residue = ["11"],
deps = [":baby_bear"],
)

tachyon_cc_library(
name = "packed_baby_bear",
hdrs = ["packed_baby_bear.h"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ TYPED_TEST(BinaryFieldsTest, ComparisonOperator) {
if constexpr (BinaryField::Config::kModulusBits > 3) {
BinaryField f(3);
BinaryField f2(4);
EXPECT_TRUE(f < f2);
EXPECT_TRUE(f <= f2);
EXPECT_FALSE(f > f2);
EXPECT_FALSE(f >= f2);
EXPECT_LT(f, f2);
EXPECT_LE(f, f2);
EXPECT_GT(f2, f);
EXPECT_GE(f2, f);
} else {
GTEST_SKIP() << "Modulus is too small";
}
Expand Down
Loading
Loading