From d3659e748124f014e78943c2d557af81dd25d13c Mon Sep 17 00:00:00 2001 From: Ryan Kim Date: Thu, 20 Jun 2024 21:07:06 +0900 Subject: [PATCH] feat(math): add quartic extension for `BabyBear` and `KoalaBear` See https://github.com/Plonky3/Plonky3/blob/11ac8745b21295d5699f47089dbd927836f53fff/baby-bear/src/extension.rs#L6-L11 and https://github.com/Plonky3/Plonky3/blob/11ac8745b21295d5699f47089dbd927836f53fff/koala-bear/src/extension.rs#L40-L45. --- tachyon/math/circle/stark/BUILD.bazel | 1 + tachyon/math/finite_fields/BUILD.bazel | 15 + .../math/finite_fields/baby_bear/BUILD.bazel | 12 + .../finite_fields/finite_field_forwards.h | 2 +- tachyon/math/finite_fields/fp4.h | 112 ++- .../ext_prime_field_generator/build_defs.bzl | 1 - .../ext_prime_field_generator.cc | 7 +- .../ext_prime_field_generator/fq.h.tpl | 7 + .../math/finite_fields/koala_bear/BUILD.bazel | 13 +- .../finite_fields/quartic_extension_field.h | 662 ++++++++++++++++++ .../quartic_extension_field_unittest.cc | 208 ++++++ 11 files changed, 1034 insertions(+), 6 deletions(-) create mode 100644 tachyon/math/finite_fields/quartic_extension_field.h create mode 100644 tachyon/math/finite_fields/quartic_extension_field_unittest.cc diff --git a/tachyon/math/circle/stark/BUILD.bazel b/tachyon/math/circle/stark/BUILD.bazel index 903d4b2665..f1387b30e9 100644 --- a/tachyon/math/circle/stark/BUILD.bazel +++ b/tachyon/math/circle/stark/BUILD.bazel @@ -48,6 +48,7 @@ generate_fp2s( generate_fp4s( name = "fq4", base_field = "Fq2", + base_field_degree = 2, base_field_hdr = "tachyon/math/circle/stark/fq2.h", class_name = "Fq4", namespace = "tachyon::math::stark", diff --git a/tachyon/math/finite_fields/BUILD.bazel b/tachyon/math/finite_fields/BUILD.bazel index 01ab0b112e..0d60514ffe 100644 --- a/tachyon/math/finite_fields/BUILD.bazel +++ b/tachyon/math/finite_fields/BUILD.bazel @@ -72,6 +72,7 @@ tachyon_cc_library( hdrs = ["fp4.h"], deps = [ ":quadratic_extension_field", + ":quartic_extension_field", "//tachyon/math/base/gmp:gmp_util", ], ) @@ -235,6 +236,18 @@ tachyon_cc_library( ], ) +tachyon_cc_library( + name = "quartic_extension_field", + hdrs = ["quartic_extension_field.h"], + deps = [ + ":cyclotomic_multiplicative_subgroup", + "//tachyon/base/buffer:copyable", + "//tachyon/base/json", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + tachyon_cc_library( name = "small_prime_field", hdrs = ["small_prime_field.h"], @@ -275,6 +288,7 @@ tachyon_cc_unittest( "prime_field_generator_unittest.cc", "prime_field_unittest.cc", "quadratic_extension_field_unittest.cc", + "quartic_extension_field_unittest.cc", ] + select({ "@platforms//cpu:x86_64": ["packed_prime_field_unittest.cc"], "@platforms//cpu:aarch64": ["packed_prime_field_unittest.cc"], @@ -301,6 +315,7 @@ tachyon_cc_unittest( "//tachyon/math/elliptic_curves/secp/secp256k1:fq", "//tachyon/math/elliptic_curves/secp/secp256k1:fr", "//tachyon/math/finite_fields/baby_bear", + "//tachyon/math/finite_fields/baby_bear:baby_bear4", "//tachyon/math/finite_fields/binary_fields", "//tachyon/math/finite_fields/goldilocks:goldilocks_prime_field", "//tachyon/math/finite_fields/koala_bear", diff --git a/tachyon/math/finite_fields/baby_bear/BUILD.bazel b/tachyon/math/finite_fields/baby_bear/BUILD.bazel index 0d11b0f9c0..93ae5338db 100644 --- a/tachyon/math/finite_fields/baby_bear/BUILD.bazel +++ b/tachyon/math/finite_fields/baby_bear/BUILD.bazel @@ -1,5 +1,6 @@ load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64") load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library") +load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp4s") load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields") package(default_visibility = ["//visibility:public"]) @@ -14,6 +15,17 @@ generate_prime_fields( use_montgomery = True, ) +generate_fp4s( + name = "baby_bear4", + base_field = "BabyBear", + base_field_degree = 1, + base_field_hdr = "tachyon/math/finite_fields/baby_bear/baby_bear.h", + class_name = "BabyBear4", + namespace = "tachyon::math", + non_residue = ["11"], + deps = [":baby_bear"], +) + tachyon_cc_library( name = "packed_baby_bear", hdrs = ["packed_baby_bear.h"], diff --git a/tachyon/math/finite_fields/finite_field_forwards.h b/tachyon/math/finite_fields/finite_field_forwards.h index fc7292f731..950c54fb67 100644 --- a/tachyon/math/finite_fields/finite_field_forwards.h +++ b/tachyon/math/finite_fields/finite_field_forwards.h @@ -21,7 +21,7 @@ class Fp2; template class Fp3; -template +template class Fp4; template diff --git a/tachyon/math/finite_fields/fp4.h b/tachyon/math/finite_fields/fp4.h index aa892073d1..a5c29634cf 100644 --- a/tachyon/math/finite_fields/fp4.h +++ b/tachyon/math/finite_fields/fp4.h @@ -8,11 +8,13 @@ #include "tachyon/math/base/gmp/gmp_util.h" #include "tachyon/math/finite_fields/quadratic_extension_field.h" +#include "tachyon/math/finite_fields/quartic_extension_field.h" namespace tachyon::math { template -class Fp4 final : public QuadraticExtensionField> { +class Fp4> final + : public QuadraticExtensionField> { public: using BaseField = typename Config::BaseField; using BasePrimeField = typename Config::BasePrimeField; @@ -89,6 +91,114 @@ class Fp4 final : public QuadraticExtensionField> { } }; +template +class Fp4> final + : public QuarticExtensionField> { + public: + using BaseField = typename Config::BaseField; + using BasePrimeField = typename Config::BasePrimeField; + using FrobeniusCoefficient = typename Config::FrobeniusCoefficient; + + using CpuField = Fp4; + // TODO(chokobole): Implement Fp4Gpu + using GpuField = Fp4; + + using QuarticExtensionField>::QuarticExtensionField; + + static_assert(Config::kDegreeOverBaseField == 4); + static_assert(BaseField::ExtensionDegree() == 1); + + constexpr static uint64_t kDegreeOverBasePrimeField = 4; + + static void Init() { + Config::Init(); + // x⁴ = q = |Config::kNonResidue| + + // αᴾ = (α₀ + α₁x + α₂x² + α₃x³)ᴾ + // = α₀ᴾ + α₁ᴾxᴾ + α₂ᴾx²ᴾ + α₃ᴾx³ᴾ + // = α₀ + α₁xᴾ + α₂x²ᴾ + α₃x³ᴾ <- Fermat's little theorem + // = α₀ + α₁xᴾ⁻¹x + α₂x²ᴾ⁻²x² + α₃x³ᴾ⁻³x³ + // = α₀ + α₁(x⁴)^((P - 1) / 4) * x + α₂(x⁴)^(2 * (P - 1) / 4) * x² + + // α₃(x⁴)^(3 * (P - 1) / 4) * x³ + // = α₀ + α₁ωx + α₂ω²x² + α₃ω³x³, where ω is a quartic root of unity. + + constexpr uint64_t N = BasePrimeField::kLimbNums; + // m₁ = P + mpz_class m1; + if constexpr (BasePrimeField::Config::kModulusBits <= 32) { + m1 = mpz_class(BasePrimeField::Config::kModulus); + } else { + gmp::WriteLimbs(BasePrimeField::Config::kModulus.limbs, N, &m1); + } + +#define SET_M(d, d_prev) mpz_class m##d = m##d_prev * m1 + + // m₂ = m₁ * P = P² + SET_M(2, 1); + // m₃ = m₂ * P = P³ + SET_M(3, 2); + +#undef SET_M + +#define SET_EXP_GMP(d) mpz_class exp##d##_gmp = (m##d - 1) / mpz_class(4) + + // exp₁ = (m₁ - 1) / 4 = (P¹ - 1) / 4 + SET_EXP_GMP(1); + // exp₂ = (m₂ - 1) / 4 = (P² - 1) / 4 + SET_EXP_GMP(2); + // exp₃ = (m₃ - 1) / 4 = (P³ - 1) / 4 + SET_EXP_GMP(3); + +#undef SET_EXP_GMP + + // |kFrobeniusCoeffs[0]| = q^((P⁰ - 1) / 4) = 1 + Config::kFrobeniusCoeffs[0] = FrobeniusCoefficient::One(); +#define SET_FROBENIUS_COEFF(d) \ + BigInt exp##d; \ + gmp::CopyLimbs(exp##d##_gmp, exp##d.limbs); \ + Config::kFrobeniusCoeffs[d] = Config::kNonResidue.Pow(exp##d) + + // |kFrobeniusCoeffs[1]| = q^(exp₁) = q^((P¹ - 1) / 4) = ω + SET_FROBENIUS_COEFF(1); + // |kFrobeniusCoeffs[2]| = q^(exp₂) = q^((P² - 1) / 4) + SET_FROBENIUS_COEFF(2); + // |kFrobeniusCoeffs[3]| = q^(exp₃) = q^((P³ - 1) / 4) + SET_FROBENIUS_COEFF(3); + +#undef SET_FROBENIUS_COEFF + + // |kFrobeniusCoeffs2[0]| = q^(2 * (P⁰ - 1) / 4) = 1 + Config::kFrobeniusCoeffs2[0] = FrobeniusCoefficient::One(); +#define SET_FROBENIUS_COEFF2(d) \ + gmp::CopyLimbs(mpz_class(2) * exp##d##_gmp, exp##d.limbs); \ + Config::kFrobeniusCoeffs2[d] = Config::kNonResidue.Pow(exp##d) + + // |kFrobeniusCoeffs2[1]| = q^(2 * exp₁) = q^(2 * (P¹ - 1) / 4) = ω² + SET_FROBENIUS_COEFF2(1); + // |kFrobeniusCoeffs2[2]| = q^(2 * exp₂) = q^(2 * (P² - 1) / 4) + SET_FROBENIUS_COEFF2(2); + // |kFrobeniusCoeffs2[3]| = q^(2 * exp₃) = q^(2 * (P³ - 1) / 4) + SET_FROBENIUS_COEFF2(3); + +#undef SET_FROBENIUS_COEFF2 + + // |kFrobeniusCoeffs3[0]| = q^(3 * (P⁰ - 1) / 4) = 1 + Config::kFrobeniusCoeffs3[0] = FrobeniusCoefficient::One(); +#define SET_FROBENIUS_COEFF3(d) \ + gmp::CopyLimbs(mpz_class(3) * exp##d##_gmp, exp##d.limbs); \ + Config::kFrobeniusCoeffs3[d] = Config::kNonResidue.Pow(exp##d) + + // |kFrobeniusCoeffs3[1]| = q^(3 * exp₁) = q^(3 * (P¹ - 1) / 4) = ω³ + SET_FROBENIUS_COEFF3(1); + // |kFrobeniusCoeffs3[2]| = q^(3 * exp₂) = q^(3 * (P² - 1) / 4) + SET_FROBENIUS_COEFF3(2); + // |kFrobeniusCoeffs3[3]| = q^(3 * exp₃) = q^(3 * (P³ - 1) / 4) + SET_FROBENIUS_COEFF3(3); + +#undef SET_FROBENIUS_COEFF3 + } +}; + } // namespace tachyon::math #endif // TACHYON_MATH_FINITE_FIELDS_FP4_H_ diff --git a/tachyon/math/finite_fields/generator/ext_prime_field_generator/build_defs.bzl b/tachyon/math/finite_fields/generator/ext_prime_field_generator/build_defs.bzl index bbe3d2f510..e7a70b186d 100644 --- a/tachyon/math/finite_fields/generator/ext_prime_field_generator/build_defs.bzl +++ b/tachyon/math/finite_fields/generator/ext_prime_field_generator/build_defs.bzl @@ -122,7 +122,6 @@ def generate_fp4s( _generate_ext_prime_fields( name = name, degree = 4, - base_field_degree = 2, ext_prime_field_deps = ["//tachyon/math/finite_fields:fp4"], **kwargs ) diff --git a/tachyon/math/finite_fields/generator/ext_prime_field_generator/ext_prime_field_generator.cc b/tachyon/math/finite_fields/generator/ext_prime_field_generator/ext_prime_field_generator.cc index c909e96735..0185c2fbdc 100644 --- a/tachyon/math/finite_fields/generator/ext_prime_field_generator/ext_prime_field_generator.cc +++ b/tachyon/math/finite_fields/generator/ext_prime_field_generator/ext_prime_field_generator.cc @@ -112,7 +112,8 @@ int GenerationConfig::GenerateConfigHdr() const { } replacements["%{frobenius_coefficient}"] = - (degree == 4 || (degree == 6 && base_field_degree == 3) || degree == 12) + ((degree == 4 && base_field_degree == 2) || + (degree == 6 && base_field_degree == 3) || degree == 12) ? "typename BaseField::BaseField" : "BaseField"; @@ -122,7 +123,9 @@ int GenerationConfig::GenerateConfigHdr() const { std::vector tpl_lines = absl::StrSplit(tpl_content, '\n'); RemoveOptionalLines(tpl_lines, "FrobeniusCoefficient2", - degree_over_base_field == 3); + degree_over_base_field >= 3); + RemoveOptionalLines(tpl_lines, "FrobeniusCoefficient3", + degree_over_base_field >= 4); tpl_content = absl::StrJoin(tpl_lines, "\n"); std::string content = absl::StrReplaceAll(tpl_content, replacements); diff --git a/tachyon/math/finite_fields/generator/ext_prime_field_generator/fq.h.tpl b/tachyon/math/finite_fields/generator/ext_prime_field_generator/fq.h.tpl index b49ace1895..1dceda156f 100644 --- a/tachyon/math/finite_fields/generator/ext_prime_field_generator/fq.h.tpl +++ b/tachyon/math/finite_fields/generator/ext_prime_field_generator/fq.h.tpl @@ -18,6 +18,9 @@ class %{class}Config { %{if FrobeniusCoefficient2} static FrobeniusCoefficient kFrobeniusCoeffs2[%{frobenius_coeffs_size}]; %{endif FrobeniusCoefficient2} +%{if FrobeniusCoefficient3} + static FrobeniusCoefficient kFrobeniusCoeffs3[%{frobenius_coeffs_size}]; +%{endif FrobeniusCoefficient3} constexpr static bool kNonResidueIsMinusOne = %{non_residue_is_minus_one}; constexpr static uint64_t kDegreeOverBaseField = %{degree_over_base_field}; @@ -42,6 +45,10 @@ typename %{class}Config::FrobeniusCoefficient %{class}Config typename %{class}Config::FrobeniusCoefficient %{class}Config::kFrobeniusCoeffs2[%{frobenius_coeffs_size}]; %{endif FrobeniusCoefficient2} +%{if FrobeniusCoefficient3} +template +typename %{class}Config::FrobeniusCoefficient %{class}Config::kFrobeniusCoeffs3[%{frobenius_coeffs_size}]; +%{endif FrobeniusCoefficient3} using %{class} = Fp%{degree}<%{class}Config<%{base_field}>>; } // namespace %{namespace} diff --git a/tachyon/math/finite_fields/koala_bear/BUILD.bazel b/tachyon/math/finite_fields/koala_bear/BUILD.bazel index 9a8a6a842e..38238dbef0 100644 --- a/tachyon/math/finite_fields/koala_bear/BUILD.bazel +++ b/tachyon/math/finite_fields/koala_bear/BUILD.bazel @@ -1,6 +1,6 @@ load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64") load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library") -load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp2s") +load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp2s", "generate_fp4s") load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields") package(default_visibility = ["//visibility:public"]) @@ -25,6 +25,17 @@ generate_fp2s( deps = [":koala_bear"], ) +generate_fp4s( + name = "koala_bear4", + base_field = "KoalaBear", + base_field_degree = 1, + base_field_hdr = "tachyon/math/finite_fields/koala_bear/koala_bear.h", + class_name = "KoalaBear4", + namespace = "tachyon::math", + non_residue = ["3"], + deps = [":koala_bear"], +) + tachyon_cc_library( name = "packed_koala_bear", hdrs = ["packed_koala_bear.h"], diff --git a/tachyon/math/finite_fields/quartic_extension_field.h b/tachyon/math/finite_fields/quartic_extension_field.h new file mode 100644 index 0000000000..0a0f32fd54 --- /dev/null +++ b/tachyon/math/finite_fields/quartic_extension_field.h @@ -0,0 +1,662 @@ +#ifndef TACHYON_MATH_FINITE_FIELDS_QUARTIC_EXTENSION_FIELD_H_ +#define TACHYON_MATH_FINITE_FIELDS_QUARTIC_EXTENSION_FIELD_H_ + +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "absl/types/span.h" + +#include "tachyon/base/buffer/copyable.h" +#include "tachyon/base/json/json.h" +#include "tachyon/math/finite_fields/cyclotomic_multiplicative_subgroup.h" + +namespace tachyon { +namespace math { + +template +class QuarticExtensionField : public CyclotomicMultiplicativeSubgroup { + public: + using Config = typename FiniteField::Config; + using BaseField = typename Config::BaseField; + using BasePrimeField = typename Config::BasePrimeField; + + constexpr QuarticExtensionField() = default; + constexpr QuarticExtensionField(const BaseField& c0, const BaseField& c1, + const BaseField& c2, const BaseField& c3) + : c0_(c0), c1_(c1), c2_(c2), c3_(c3) {} + constexpr QuarticExtensionField(BaseField&& c0, BaseField&& c1, + BaseField&& c2, BaseField&& c3) + : c0_(std::move(c0)), + c1_(std::move(c1)), + c2_(std::move(c2)), + c3_(std::move(c3)) {} + + constexpr static Derived Zero() { + return {BaseField::Zero(), BaseField::Zero(), BaseField::Zero(), + BaseField::Zero()}; + } + + constexpr static Derived One() { + return {BaseField::One(), BaseField::Zero(), BaseField::Zero(), + BaseField::Zero()}; + } + + static Derived Random() { + return {BaseField::Random(), BaseField::Random(), BaseField::Random(), + BaseField::Random()}; + } + + static Derived FromBasePrimeFields( + absl::Span prime_fields) { + CHECK_EQ(prime_fields.size(), ExtensionDegree()); + constexpr size_t kBaseFieldDegree = BaseField::ExtensionDegree(); + if constexpr (kBaseFieldDegree == 1) { + return Derived(prime_fields[0], prime_fields[1], prime_fields[2], + prime_fields[3]); + } else { + BaseField c0 = BaseField::FromBasePrimeFields( + prime_fields.subspan(0, kBaseFieldDegree)); + prime_fields.remove_prefix(kBaseFieldDegree); + BaseField c1 = BaseField::FromBasePrimeFields( + prime_fields.subspan(0, kBaseFieldDegree)); + prime_fields.remove_prefix(kBaseFieldDegree); + BaseField c2 = BaseField::FromBasePrimeFields( + prime_fields.subspan(0, kBaseFieldDegree)); + prime_fields.remove_prefix(kBaseFieldDegree); + BaseField c3 = BaseField::FromBasePrimeFields( + prime_fields.subspan(kBaseFieldDegree)); + return Derived(std::move(c0), std::move(c1), std::move(c2), + std::move(c3)); + } + } + + constexpr bool IsZero() const { + return c0_.IsZero() && c1_.IsZero() && c2_.IsZero() && c3_.IsZero(); + } + + constexpr bool IsOne() const { + return c0_.IsOne() && c1_.IsZero() && c2_.IsZero() && c3_.IsZero(); + } + + constexpr static uint64_t ExtensionDegree() { + return 4 * BaseField::ExtensionDegree(); + } + + // Calculate the norm of an element with respect to |BaseField|. + // The norm maps an element |a| in the extension field + // Fqᵐ to an element in the |BaseField| Fq. + // |a.Norm() = a * a^q * a^q² * a^q³| + constexpr BaseField Norm() const { + // w.r.t to |BaseField|, we need the 0th, 1st, 2nd & 3rd powers of q. + // Since Frobenius coefficients on the towered extensions are + // indexed w.r.t. to |BasePrimeField|, we need to calculate the correct + // index. + // NOTE(chokobole): This assumes that |BaseField::ExtensionDegree()| + // never overflows even on 32 bit machine. + size_t index_multiplier = size_t{BaseField::ExtensionDegree()}; + Derived self_to_p = static_cast(*this); + self_to_p.FrobeniusMapInPlace(index_multiplier); + Derived self_to_p2 = static_cast(*this); + self_to_p2.FrobeniusMapInPlace(2 * index_multiplier); + Derived self_to_p3 = static_cast(*this); + self_to_p3.FrobeniusMapInPlace(3 * index_multiplier); + self_to_p *= (self_to_p2 * self_to_p3 * static_cast(*this)); + // NOTE(chokobole): below CHECK() is not a device code. + // See https://github.com/kroma-network/tachyon/issues/76 + CHECK(self_to_p.c1().IsZero() && self_to_p.c2().IsZero() && + self_to_p.c3().IsZero()); + return self_to_p.c0(); + } + + constexpr Derived& FrobeniusMapInPlace(uint64_t exponent) { + c0_.FrobeniusMapInPlace(exponent); + c1_.FrobeniusMapInPlace(exponent); + c2_.FrobeniusMapInPlace(exponent); + c3_.FrobeniusMapInPlace(exponent); + c1_ *= + Config::kFrobeniusCoeffs[exponent % Config::kDegreeOverBasePrimeField]; + c2_ *= + Config::kFrobeniusCoeffs2[exponent % Config::kDegreeOverBasePrimeField]; + c3_ *= + Config::kFrobeniusCoeffs3[exponent % Config::kDegreeOverBasePrimeField]; + return *static_cast(this); + } + + std::string ToString() const { + return absl::Substitute("($0, $1, $2, $3)", c0_.ToString(), c1_.ToString(), + c2_.ToString(), c3_.ToString()); + } + + std::string ToHexString(bool pad_zero = false) const { + return absl::Substitute("($0, $1, $2, $3)", c0_.ToHexString(pad_zero), + c1_.ToHexString(pad_zero), + c2_.ToHexString(pad_zero), + c3_.ToHexString(pad_zero)); + } + + constexpr const BaseField& c0() const { return c0_; } + constexpr const BaseField& c1() const { return c1_; } + constexpr const BaseField& c2() const { return c2_; } + constexpr const BaseField& c3() const { return c3_; } + + constexpr bool operator==(const Derived& other) const { + return c0_ == other.c0_ && c1_ == other.c1_ && c2_ == other.c2_ && + c3_ == other.c3_; + } + + constexpr bool operator!=(const Derived& other) const { + return c0_ != other.c0_ || c1_ != other.c1_ || c2_ != other.c2_ || + c3_ != other.c3_; + } + + constexpr bool operator<(const Derived& other) const { + if (c3_ == other.c3_) { + if (c2_ == other.c2_) { + if (c1_ == other.c1_) return c0_ < other.c0_; + return c1_ < other.c1_; + } + return c2_ < other.c2_; + } + return c3_ < other.c3_; + } + + constexpr bool operator>(const Derived& other) const { + if (c3_ == other.c3_) { + if (c2_ == other.c2_) { + if (c1_ == other.c1_) return c0_ > other.c0_; + return c1_ > other.c1_; + } + return c2_ > other.c2_; + } + return c3_ > other.c3_; + } + + constexpr bool operator<=(const Derived& other) const { + if (c3_ == other.c3_) { + if (c2_ == other.c2_) { + if (c1_ == other.c1_) return c0_ <= other.c0_; + return c1_ <= other.c1_; + } + return c2_ <= other.c2_; + } + return c3_ <= other.c3_; + } + + constexpr bool operator>=(const Derived& other) const { + if (c3_ == other.c3_) { + if (c2_ == other.c2_) { + if (c1_ == other.c1_) return c0_ >= other.c0_; + return c1_ >= other.c1_; + } + return c2_ >= other.c2_; + } + return c3_ >= other.c3_; + } + + // AdditiveSemigroup methods + constexpr Derived Add(const Derived& other) const { + return { + c0_ + other.c0_, + c1_ + other.c1_, + c2_ + other.c2_, + c3_ + other.c3_, + }; + } + + constexpr Derived& AddInPlace(const Derived& other) { + c0_ += other.c0_; + c1_ += other.c1_; + c2_ += other.c2_; + c3_ += other.c3_; + return *static_cast(this); + } + + constexpr Derived DoubleImpl() const { + return { + c0_.Double(), + c1_.Double(), + c2_.Double(), + c3_.Double(), + }; + } + + constexpr Derived& DoubleImplInPlace() { + c0_.DoubleInPlace(); + c1_.DoubleInPlace(); + c2_.DoubleInPlace(); + c3_.DoubleInPlace(); + return *static_cast(this); + } + + // AdditiveGroup methods + constexpr Derived Sub(const Derived& other) const { + return { + c0_ - other.c0_, + c1_ - other.c1_, + c2_ - other.c2_, + c3_ - other.c3_, + }; + } + + constexpr Derived& SubInPlace(const Derived& other) { + c0_ -= other.c0_; + c1_ -= other.c1_; + c2_ -= other.c2_; + c3_ -= other.c3_; + return *static_cast(this); + } + + constexpr Derived Negate() const { + return { + -c0_, + -c1_, + -c2_, + -c3_, + }; + } + + constexpr Derived& NegateInPlace() { + c0_.NegateInPlace(); + c1_.NegateInPlace(); + c2_.NegateInPlace(); + c3_.NegateInPlace(); + return *static_cast(this); + } + + // MultiplicativeSemigroup methods + constexpr Derived Mul(const Derived& other) const { + Derived ret{}; + DoMul(*static_cast(this), other, ret); + return ret; + } + + constexpr Derived& MulInPlace(const Derived& other) { + DoMul(*static_cast(this), other, + *static_cast(this)); + return *static_cast(this); + } + + constexpr Derived Mul(const BaseField& element) const { + return { + c0_ * element, + c1_ * element, + c2_ * element, + c3_ * element, + }; + } + + constexpr Derived& MulInPlace(const BaseField& element) { + c0_ *= element; + c1_ *= element; + c2_ *= element; + c3_ *= element; + return *static_cast(this); + } + + constexpr Derived SquareImpl() const { + Derived ret{}; + DoSquareImpl(*static_cast(this), ret); + return ret; + } + + constexpr Derived& SquareImplInPlace() { + DoSquareImpl(*static_cast(this), + *static_cast(this)); + return *static_cast(this); + } + + // MultiplicativeGroup methods + constexpr std::optional Inverse() const { + Derived ret{}; + if (LIKELY(DoInverse(*static_cast(this), ret))) { + return ret; + } + LOG_IF_NOT_GPU(ERROR) << "Inverse of zero attempted"; + return std::nullopt; + } + + [[nodiscard]] constexpr std::optional InverseInPlace() { + if (LIKELY(DoInverse(*static_cast(this), + *static_cast(this)))) { + return static_cast(this); + } + LOG_IF_NOT_GPU(ERROR) << "Inverse of zero attempted"; + return std::nullopt; + } + + protected: + constexpr static void DoMul(const Derived& a, const Derived& b, Derived& c) { + // clang-format off + // (a.c0, a.c1, a.c2, a.c3) * (b.c0, b.c1, b.c2, b.c3) + // = (a.c0 + a.c1 * x + a.c2 * x² + a.c3 * x³) * (b.c0 + b.c1 * x + b.c2 * x² + b.c3 * x³) + // = a.c0 * b.c0 + (a.c0 * b.c1 + a.c1 * b.c0) * x + (a.c0 * b.c2 + a.c1 * b.c1 + a.c2 * b.c0) * x² + + // (a.c0 * b.c3 + a.c1 * b.c2 + a.c2 * b.c1 * a.c3 * b.c0) * x³ + (a.c1 * b.c3 + a.c2 * b.c2 + a.c3 * b.c1) * x⁴ + + // (a.c2 * b.c3 + a.c3 * b.c2) * x⁵ + a.c3 * b.c3 * x⁶ + // = a.c0 * b.c0 + (a.c1 * b.c3 + a.c2 * b.c2 + a.c3 * b.c1) * x⁴ + + // (a.c0 * b.c1 + a.c1 * b.c0) * x + (a.c2 * b.c3 + a.c3 * b.c2) * x⁵ + + // (a.c0 * b.c2 + a.c1 * b.c1 + a.c2 * b.c0) * x² + a.c3 * b.c3 * x⁶ + + // (a.c0 * b.c3 + a.c1 * b.c2 + a.c2 * b.c1 * a.c3 * b.c0) * x³ + // = a.c0 * b.c0 + (a.c1 * b.c3 + a.c2 * b.c2 + a.c3 * b.c1) * q + + // (a.c0 * b.c1 + a.c1 * b.c0) * x + (a.c2 * b.c3 + a.c3 * b.c2) * q * x + + // (a.c0 * b.c2 + a.c1 * b.c1 + a.c2 * b.c0) * x² + a.c3 * b.c3 * q * x² + + // (a.c0 * b.c3 + a.c1 * b.c2 + a.c2 * b.c1 * a.c3 * b.c0) * x³ + // = (a.c0 * b.c0 + (a.c1 * b.c3 + a.c2 * b.c2 + a.c3 * b.c1) * q, + // a.c0 * b.c1 + a.c1 * b.c0 + (a.c2 * b.c3 + a.c3 * b.c2) * q, + // a.c0 * b.c2 + a.c1 * b.c1 + a.c2 * b.c0 + a.c3 * b.c3 * q, + // a.c0 * b.c3 + a.c1 * b.c2 + a.c2 * b.c1 * a.c3 * b.c0) + // where q is |Config::kNonResidue|. + + // See https://eprint.iacr.org/2006/471.pdf + // Devegili OhEig Scott Dahab --- Multiplication and Squaring on AbstractPairing-Friendly Fields.pdf; Section 5.2 + // clang-format on + + constexpr BaseField kInv2 = *BaseField(2).Inverse(); + constexpr BaseField kInv3 = *BaseField(3).Inverse(); + constexpr BaseField kInv4 = *BaseField(4).Inverse(); + constexpr BaseField kInv6 = *BaseField(6).Inverse(); + constexpr BaseField kInv12 = *BaseField(12).Inverse(); + constexpr BaseField kInv20 = *BaseField(20).Inverse(); + constexpr BaseField kInv24 = *BaseField(24).Inverse(); + constexpr BaseField kInv30 = *BaseField(30).Inverse(); + constexpr BaseField kInv120 = *BaseField(120).Inverse(); + constexpr BaseField kNeg5 = -BaseField(5); + constexpr BaseField kNegInv2 = -kInv2; + constexpr BaseField kNegInv3 = -kInv3; + constexpr BaseField kNegInv4 = -kInv4; + constexpr BaseField kNegInv6 = -kInv6; + constexpr BaseField kNegInv12 = -kInv12; + constexpr BaseField kNegInv24 = -kInv24; + constexpr BaseField kNegInv120 = -kInv120; + + // h1 = 2 * a.c1 + BaseField h1 = a.c1_.Double(); + // h2 = 4 * a.c2 + BaseField h2 = a.c2_.Double(); + h2.DoubleInPlace(); + // h3 = 8 * a.c3 + BaseField h3 = a.c3_.Double(); + h3.DoubleInPlace().DoubleInPlace(); + // h4 = 2 * b.c1 + BaseField h4 = b.c1_.Double(); + // h5 = 4 * b.c2 + BaseField h5 = b.c2_.Double(); + h5.DoubleInPlace(); + // h6 = 8 * b.c3 + BaseField h6 = b.c3_.Double(); + h6.DoubleInPlace().DoubleInPlace(); + + // v0 = a.c0 * b.c0 + BaseField v0 = a.c0_ * b.c0_; + // v1 = (a.c0 + a.c1 + a.c2 + a.c3) * (b.c0 + b.c1 + b.c2 + b.c3) + BaseField v1 = + (a.c0_ + a.c1_ + a.c2_ + a.c3_) * (b.c0_ + b.c1_ + b.c2_ + b.c3_); + // v2 = (a.c0 - a.c1 + a.c2 - a.c3) * (b.c0 - b.c1 + b.c2 - b.c3) + BaseField v2 = + (a.c0_ - a.c1_ + a.c2_ - a.c3_) * (b.c0_ - b.c1_ + b.c2_ - b.c3_); + // v3 = (a.c0 + 2 * a.c1 + 4 * a.c2 + 8 * a.c3) * + // (b.c0 + 2 * b.c1 + 4 * b.c2 + 8 * b.c3) + BaseField v3 = (a.c0_ + h1 + h2 + h3) * (b.c0_ + h4 + h5 + h6); + // v4 = (a.c0 - 2 * a.c1 + 4 * a.c2 - 8 * a.c3) * + // (b.c0 - 2 * b.c1 + 4 * b.c2 - 8 * b.c3) + BaseField v4 = (a.c0_ - h1 + h2 - h3) * (b.c0_ - h4 + h5 - h6); + // h1 = 3 * a.c1 + h1 += a.c1_; + // h2 = 9 * a.c2 + h2.DoubleInPlace().AddInPlace(a.c2_); + // h3 = 27 * a.c3 + h3 += a.c3_; + h3 += h3.Double(); + // h4 = 3 * b.c1 + h4 += b.c1_; + // h5 = 9 * b.c2 + h5.DoubleInPlace().AddInPlace(b.c2_); + // h6 = 27 * b.c3 + h6 += b.c3_; + h6 += h6.Double(); + // v5 = (a.c0 + 3 * a.c1 + 9 * a.c2 + 27 * a.c3) * + // (b.c0 + 3 * b.c1 + 9 * b.c2 + 27 * b.c3) + BaseField v5 = (a.c0_ + h1 + h2 + h3) * (b.c0_ + h4 + h5 + h6); + // v6 = a.c3 * b.c3 + BaseField v6 = a.c3_ * b.c3_; + + // v0_5 = 5 * v0 + BaseField v0_5 = v0.Double(); + v0_5.DoubleInPlace().AddInPlace(v0); + // v6_3 = 3 * v6 + BaseField v6_3 = v6.Double(); + v6_3 += v6; + + // clang-format off + // c.c0 = v0 + + // q * ((1 / 4) * v0 - (1 / 6) * (v1 + v2) + (1 / 24) * (v3 + v4) - 5 * v6) + c.c0_ = v0 + + Config::MulByNonResidue(kInv4 * v0 + kNegInv6 * (v1 + v2) + kInv24 * (v3 + v4) + kNeg5 * v6); + // c.c1 = -(1 / 3) * v0 + v1 - (1 / 2) * v2 + (1 / 20) * v4 + (1 / 30) * v5 - 12 * v6 + + // q * (-(1 / 12) * (v0 - v1) + (1 / 24) * (v2 - v3) - (1 / 120) * (v4 - v5) - 3 * v6) + c.c1_ = kNegInv3 * v0 + v1 + kNegInv2 * v2 + kNegInv4 * v3 + kInv20 * v4 + kInv30 * v5 - v6_3.Double().Double() + + Config::MulByNonResidue(kNegInv12 * (v0 - v1) + kInv24 * (v2 - v3) + kNegInv120 * (v4 - v5) - v6_3); + // c.c2 = -(5 / 4) * v0 + (2 / 3) * (v1 + v2) - (1 / 24) * (v3 + v4) + 4 * v6 + + // q * v6 + c.c2_ = kNegInv4 * v0_5 + kInv3 * (v1 + v2).Double() + kNegInv24 * (v3 + v4) + v6.Double().Double() + + Config::MulByNonResidue(v6); + // c.c3 = (1 / 12) * (5 * v0 - 7 * v1) - (1 / 24) * (v2 - 7 * v3 + v4 + v5) + 15 * v6 + c.c3_ = kInv12 * (v0_5 - v1.Double().Double().Double() + v1) + kNegInv24 * (v2 - v3.Double().Double().Double() + v3 + v4 + v5) + v6_3.Double().Double() + v6_3; + // clang-format on + } + + constexpr static void DoSquareImpl(const Derived& a, Derived& b) { + // clang-format off + // (c0, c1, c2, c3)² + // = (c0 + c1 * x + c2 * x² + c3 * x³)² + // = c0² + 2 * c0 * c1 * x + (c1² + 2 * c0 * c2) * x² + 2 * (c0 * c3 + c1 * c2) * x³ + (c2² + 2 * c1 * c3) * x⁴ + 2 * c2 * c3 * x⁵ + c3 * x⁶ + // = c0² + (c2² + 2 * c1 * c3) * x⁴ + 2 * c0 * c1 * x + 2 * c2 * c3 * x⁵ + (c1² + 2 * c0 * c2) * x² + c3 * x⁶ + 2 * (c0 * c3 + c1 * c2) * x³ + // = c0² + (c2² + 2 * c1 * c3) * q + 2 * (c0 * c1 + c2 * c3 * q) * x + (c1² + 2 * c0 * c2 + c3 * q) * x² + 2 * (c0 * c3 + c1 * c2) * x³ + // = (c0² + (c2² + 2 * c1 * c3) * q, 2 * (c0 * c1 + c2 * c3 * q), c1² + 2 * c0 * c2 + c3 * q, 2 * (c0 * c3 + c1 * c2)) + // where q is |Config::kNonResidue|. + + // See https://eprint.iacr.org/2006/471.pdf + // Devegili OhEig Scott Dahab --- Multiplication and Squaring on AbstractPairing-Friendly Fields.pdf; Section 5 + // clang-format on + + constexpr BaseField kInv2 = *BaseField(2).Inverse(); + constexpr BaseField kInv3 = *BaseField(3).Inverse(); + constexpr BaseField kInv4 = *BaseField(4).Inverse(); + constexpr BaseField kInv6 = *BaseField(6).Inverse(); + constexpr BaseField kInv12 = *BaseField(12).Inverse(); + constexpr BaseField kInv20 = *BaseField(20).Inverse(); + constexpr BaseField kInv24 = *BaseField(24).Inverse(); + constexpr BaseField kInv30 = *BaseField(30).Inverse(); + constexpr BaseField kInv120 = *BaseField(120).Inverse(); + constexpr BaseField kNeg5 = -BaseField(5); + constexpr BaseField kNegInv2 = -kInv2; + constexpr BaseField kNegInv3 = -kInv3; + constexpr BaseField kNegInv4 = -kInv4; + constexpr BaseField kNegInv6 = -kInv6; + constexpr BaseField kNegInv12 = -kInv12; + constexpr BaseField kNegInv24 = -kInv24; + constexpr BaseField kNegInv120 = -kInv120; + + // h1 = 2 * c1 + BaseField h1 = a.c1_.Double(); + // h2 = 4 * c2 + BaseField h2 = a.c2_.Double(); + h2.DoubleInPlace(); + // h3 = 8 * c3 + BaseField h3 = a.c3_.Double(); + h3.DoubleInPlace().DoubleInPlace(); + + // v0 = c0² + BaseField v0 = a.c0_.Square(); + // v1 = (c0 + c1 + c2 + c3)² + BaseField v1 = (a.c0_ + a.c1_ + a.c2_ + a.c3_).Square(); + // v2 = (c0 - c1 + c2 - c3)² + BaseField v2 = (a.c0_ - a.c1_ + a.c2_ - a.c3_).Square(); + // v3 = (c0 + 2 * c1 + 4 * c2 + 8 * c3)² + BaseField v3 = (a.c0_ + h1 + h2 + h3).Square(); + // v4 = (c0 - 2 * c1 + 4 * c2 - 8 * c3)² + BaseField v4 = (a.c0_ - h1 + h2 - h3).Square(); + // h1 = 3 * c1 + h1 += a.c1_; + // h2 = 9 * c2 + h2.DoubleInPlace().AddInPlace(a.c2_); + // h3 = 27 * c3 + h3 += a.c3_; + h3 += h3.Double(); + // v5 = (c0 + 3 * c1 + 9 * c2 + 27 * c3)² + BaseField v5 = (a.c0_ + h1 + h2 + h3).Square(); + // v6 = c3² + BaseField v6 = a.c3_.Square(); + + // v0_5 = 5 * v0 + BaseField v0_5 = v0.Double(); + v0_5.DoubleInPlace().AddInPlace(v0); + // v6_3 = 3 * v6 + BaseField v6_3 = v6.Double(); + v6_3 += v6; + + // clang-format off + // b.c0 = v0 + + // q * ((1 / 4) * v0 - (1 / 6) * (v1 + v2) + (1 / 24) * (v3 + v4) - 5 * v6) + b.c0_ = v0 + + Config::MulByNonResidue(kInv4 * v0 + kNegInv6 * (v1 + v2) + kInv24 * (v3 + v4) + kNeg5 * v6); + // b.c1 = -(1 / 3) * v0 + v1 - (1 / 2) * v2 + - (1 / 3) * v3 + (1 / 20) * v4 + (1 / 30) * v5 - 12 * v6 + + // q * (-(1 / 12) * (v0 - v1) + (1 / 24) * (v2 - v3) - (1 / 120) * (v4 - v5) - 3 * v6) + b.c1_ = kNegInv3 * v0 + v1 + kNegInv2 * v2 + kNegInv4 * v3 + kInv20 * v4 + kInv30 * v5 - v6_3.Double().Double() + + Config::MulByNonResidue(kNegInv12 * (v0 - v1) + kInv24 * (v2 - v3) + kNegInv120 * (v4 - v5) - v6_3); + // b.c2 = -(5 / 4) * v0 + (2 / 3) * (v1 + v2) - (1 / 24) * (v3 + v4) + 4 * v6 + + // q * v6 + b.c2_ = kNegInv4 * v0_5 + kInv3 * (v1 + v2).Double() + kNegInv24 * (v3 + v4) + v6.Double().Double() + + Config::MulByNonResidue(v6); + // b.c3 = (1 / 12) * (5 * v0 - 7 * v1) - (1 / 24) * (v2 - 7 * v3 + v4 + v5) + 15 * v6 + b.c3_ = kInv12 * (v0_5 - v1.Double().Double().Double() + v1) + kNegInv24 * (v2 - v3.Double().Double().Double() + v3 + v4 + v5) + v6_3.Double().Double() + v6_3; + // clang-format on + } + + [[nodiscard]] constexpr static bool DoInverse(const Derived& a, Derived& b) { + if (UNLIKELY(a.IsZero())) { + LOG_IF_NOT_GPU(ERROR) << "Inverse of zero attempted"; + return false; + } + + // See Algorithm 11.3.4 in Handbook of Elliptic and Hyperelliptic Curve + // Cryptography. + // Compute aʳ⁻¹, where r = (p⁴ - 1) / (p - 1) = p³ + p² + p + 1 + size_t index_multiplier = size_t{BaseField::ExtensionDegree()}; + // f = a^{p³ + p² + p} + Derived a_to_r_minus_1 = a; + a_to_r_minus_1.FrobeniusMapInPlace(index_multiplier); + Derived a_to_p2 = a; + a_to_p2.FrobeniusMapInPlace(2 * index_multiplier); + Derived a_to_p3 = a; + a_to_p3.FrobeniusMapInPlace(3 * index_multiplier); + a_to_r_minus_1 *= (a_to_p2 * a_to_p3); + + // Since aʳ which is |Norm()| is in the base field, + // computing the constant part is enough. + BaseField a_to_r = BaseField::Zero(); + a_to_r += a.c1_ * a_to_r_minus_1.c3_; + a_to_r += a.c2_ * a_to_r_minus_1.c2_; + a_to_r += a.c3_ * a_to_r_minus_1.c1_; + a_to_r = Config::MulByNonResidue(a_to_r); + a_to_r += a.c0_ * a_to_r_minus_1.c0_; + + // a⁻¹ = aʳ⁻¹ * a⁻ʳ + b = a_to_r_minus_1 * *a_to_r.Inverse(); + return true; + } + + // c = c0_ + c1_ * X + c2_ * X² + c3_ * X³ + BaseField c0_; + BaseField c1_; + BaseField c2_; + BaseField c3_; +}; + +template < + typename BaseField, typename Derived, + std::enable_if_t>* = + nullptr> +Derived operator*(const BaseField& element, + const QuarticExtensionField& f) { + return static_cast(f) * element; +} + +} // namespace math + +namespace base { + +template +class Copyable, Derived>>> { + public: + static bool WriteTo( + const math::QuarticExtensionField& quadratic_extension_field, + Buffer* buffer) { + return buffer->WriteMany( + quadratic_extension_field.c0(), quadratic_extension_field.c1(), + quadratic_extension_field.c2(), quadratic_extension_field.c3()); + } + + static bool ReadFrom( + const ReadOnlyBuffer& buffer, + math::QuarticExtensionField* quadratic_extension_field) { + typename Derived::BaseField c0; + typename Derived::BaseField c1; + typename Derived::BaseField c2; + typename Derived::BaseField c3; + if (!buffer.ReadMany(&c0, &c1, &c2, &c3)) return false; + + *quadratic_extension_field = math::QuarticExtensionField( + std::move(c0), std::move(c1), std::move(c2), std::move(c3)); + return true; + } + + static size_t EstimateSize( + const math::QuarticExtensionField& quadratic_extension_field) { + return base::EstimateSize( + quadratic_extension_field.c0(), quadratic_extension_field.c1(), + quadratic_extension_field.c2(), quadratic_extension_field.c3()); + } +}; + +template +class RapidJsonValueConverter< + Derived, std::enable_if_t, Derived>>> { + public: + using BaseField = typename math::QuarticExtensionField::BaseField; + + template + static rapidjson::Value From( + const math::QuarticExtensionField& value, Allocator& allocator) { + rapidjson::Value object(rapidjson::kObjectType); + AddJsonElement(object, "c0", value.c0(), allocator); + AddJsonElement(object, "c1", value.c1(), allocator); + AddJsonElement(object, "c2", value.c2(), allocator); + AddJsonElement(object, "c3", value.c3(), allocator); + return object; + } + + static bool To(const rapidjson::Value& json_value, std::string_view key, + math::QuarticExtensionField* value, + std::string* error) { + BaseField c0; + BaseField c1; + BaseField c2; + BaseField c3; + if (!ParseJsonElement(json_value, "c0", &c0, error)) return false; + if (!ParseJsonElement(json_value, "c1", &c1, error)) return false; + if (!ParseJsonElement(json_value, "c2", &c2, error)) return false; + if (!ParseJsonElement(json_value, "c3", &c3, error)) return false; + *value = math::QuarticExtensionField(std::move(c0), std::move(c1), + std::move(c2), std::move(c3)); + return true; + } +}; + +} // namespace base +} // namespace tachyon + +#endif // TACHYON_MATH_FINITE_FIELDS_QUARTIC_EXTENSION_FIELD_H_ diff --git a/tachyon/math/finite_fields/quartic_extension_field_unittest.cc b/tachyon/math/finite_fields/quartic_extension_field_unittest.cc new file mode 100644 index 0000000000..cf8602cc11 --- /dev/null +++ b/tachyon/math/finite_fields/quartic_extension_field_unittest.cc @@ -0,0 +1,208 @@ +#include + +#include "gtest/gtest.h" + +#include "tachyon/math/finite_fields/baby_bear/baby_bear4.h" +#include "tachyon/math/finite_fields/test/finite_field_test.h" + +namespace tachyon::math { + +namespace { + +using F4 = BabyBear4; +using F = BabyBear; + +class QuaticExtensionFieldTest : public FiniteFieldTest {}; + +} // namespace + +TEST_F(QuaticExtensionFieldTest, Zero) { + EXPECT_TRUE(F4::Zero().IsZero()); + EXPECT_FALSE(F4::One().IsZero()); +} + +TEST_F(QuaticExtensionFieldTest, One) { + EXPECT_TRUE(F4::One().IsOne()); + EXPECT_FALSE(F4::Zero().IsOne()); +} + +TEST_F(QuaticExtensionFieldTest, Random) { + bool success = false; + F4 r = F4::Random(); + for (size_t i = 0; i < 100; ++i) { + if (r != F4::Random()) { + success = true; + break; + } + } + EXPECT_TRUE(success); +} + +TEST_F(QuaticExtensionFieldTest, Norm) { + constexpr static uint32_t kModulus = BabyBear::Config::kModulus; + F4 r = F4::Random(); + F4 r_to_p = r.Pow(kModulus); + F4 r_to_p2 = r_to_p.Pow(kModulus); + F4 r_to_p3 = r_to_p2.Pow(kModulus); + EXPECT_EQ(r.Norm(), (r * r_to_p * r_to_p2 * r_to_p3).c0()); +} + +TEST_F(QuaticExtensionFieldTest, EqualityOperators) { + F4 f(F(3), F(4), F(5), F(6)); + F4 f2(F(4), F(4), F(5), F(6)); + EXPECT_FALSE(f == f2); + EXPECT_TRUE(f != f2); + + F4 f3(F(4), F(3), F(5), F(6)); + EXPECT_FALSE(f2 == f3); + EXPECT_TRUE(f2 != f3); + + F4 f4(F(3), F(4), F(5), F(7)); + EXPECT_FALSE(f == f4); + EXPECT_TRUE(f != f4); + + F4 f5(F(3), F(4), F(5), F(6)); + EXPECT_TRUE(f == f5); +} + +TEST_F(QuaticExtensionFieldTest, ComparisonOperator) { + F4 f(F(3), F(4), F(5), F(6)); + F4 f2(F(4), F(4), F(5), F(6)); + EXPECT_TRUE(f < f2); + EXPECT_TRUE(f <= f2); + EXPECT_FALSE(f > f2); + EXPECT_FALSE(f >= f2); + + F4 f3(F(4), F(3), F(5), F(6)); + F4 f4(F(3), F(4), F(5), F(6)); + EXPECT_TRUE(f3 < f4); + EXPECT_TRUE(f3 <= f4); + EXPECT_FALSE(f3 > f4); + EXPECT_FALSE(f3 >= f4); + + F4 f5(F(4), F(5), F(6), F(3)); + F4 f6(F(3), F(2), F(6), F(5)); + EXPECT_TRUE(f5 < f6); + EXPECT_TRUE(f5 <= f6); + EXPECT_FALSE(f5 > f6); + EXPECT_FALSE(f5 >= f6); +} + +TEST_F(QuaticExtensionFieldTest, AdditiveOperators) { + struct { + F4 a; + F4 b; + F4 sum; + F4 amb; + F4 bma; + } tests[] = { + { + {F(1), F(2), F(3), F(4)}, + {F(3), F(5), F(6), F(8)}, + {F(4), F(7), F(9), F(12)}, + {-F(2), -F(3), -F(3), -F(4)}, + {F(2), F(3), F(3), F(4)}, + }, + }; + + for (const auto& test : tests) { + EXPECT_EQ(test.a + test.b, test.sum); + EXPECT_EQ(test.b + test.a, test.sum); + EXPECT_EQ(test.a - test.b, test.amb); + EXPECT_EQ(test.b - test.a, test.bma); + + F4 tmp = test.a; + tmp += test.b; + EXPECT_EQ(tmp, test.sum); + tmp -= test.b; + EXPECT_EQ(tmp, test.a); + } +} + +TEST_F(QuaticExtensionFieldTest, AdditiveGroupOperators) { + F4 f(F(3), F(4), F(5), F(6)); + F4 f_neg(-F(3), -F(4), -F(5), -F(6)); + EXPECT_EQ(-f, f_neg); + f.NegateInPlace(); + EXPECT_EQ(f, f_neg); + + f = F4(F(3), F(4), F(5), F(6)); + F4 f_dbl(F(6), F(8), F(10), F(12)); + EXPECT_EQ(f.Double(), f_dbl); + f.DoubleInPlace(); + EXPECT_EQ(f, f_dbl); +} + +TEST_F(QuaticExtensionFieldTest, MultiplicativeOperators) { + struct { + F4 a; + F4 b; + F4 mul; + F4 adb; + F4 bda; + } tests[] = { + { + {F(1), F(2), F(3), F(4)}, + {F(3), F(5), F(6), F(8)}, + {F(597), F(539), F(377), F(47)}, + {F(1144494179), F(1502926259), F(1509084158), F(151175067)}, + {F(653096429), F(494869942), F(67683040), F(1807436149)}, + }, + }; + + for (const auto& test : tests) { + EXPECT_EQ(test.a * test.b, test.mul); + EXPECT_EQ(test.b * test.a, test.mul); + EXPECT_EQ(test.a / test.b, test.adb); + EXPECT_EQ(test.b / test.a, test.bda); + + F4 tmp = test.a; + tmp *= test.b; + EXPECT_EQ(tmp, test.mul); + ASSERT_TRUE(tmp /= test.b); + EXPECT_EQ(tmp, test.a); + } +} + +TEST_F(QuaticExtensionFieldTest, MultiplicativeOperators2) { + F4 f(F(3), F(4), F(5), F(6)); + F4 f_mul(F(6), F(8), F(10), F(12)); + EXPECT_EQ(f * F(2), f_mul); + f *= F(2); + EXPECT_EQ(f, f_mul); +} + +TEST_F(QuaticExtensionFieldTest, MultiplicativeGroupOperators) { + F4 f = F4::Random(); + std::optional f_inv = f.Inverse(); + if (UNLIKELY(f.IsZero())) { + ASSERT_FALSE(f_inv); + ASSERT_FALSE(f.InverseInPlace()); + } else { + EXPECT_EQ(f * *f_inv, F4::One()); + F4 f_tmp = f; + EXPECT_EQ(**f.InverseInPlace() * f_tmp, F4::One()); + } + + f = F4(F(3), F(4), F(5), F(6)); + F4 f_sqr = F4(F(812), F(684), F(442), F(76)); + EXPECT_EQ(f.Square(), f_sqr); + f.SquareInPlace(); + EXPECT_EQ(f, f_sqr); +} + +TEST_F(QuaticExtensionFieldTest, JsonValueConverter) { + F4 expected_point(F(1), F(2), F(3), F(4)); + std::string expected_json = R"({"c0":1,"c1":2,"c2":3,"c3":4})"; + + F4 p; + std::string error; + ASSERT_TRUE(base::ParseJson(expected_json, &p, &error)); + ASSERT_TRUE(error.empty()); + EXPECT_EQ(p, expected_point); + + std::string json = base::WriteToJson(p); + EXPECT_EQ(json, expected_json); +} + +} // namespace tachyon::math