diff --git a/third_party/tsl/tsl/framework/type_traits.h b/third_party/tsl/tsl/framework/type_traits.h index 9012a4e598cd2..566ec6d9309fa 100644 --- a/third_party/tsl/tsl/framework/type_traits.h +++ b/third_party/tsl/tsl/framework/type_traits.h @@ -72,7 +72,9 @@ struct is_simple_type { is_quantized::value || std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value; + std::is_same::value || + std::is_same::value || + std::is_same::value; }; } // namespace tsl diff --git a/third_party/tsl/tsl/platform/float8.h b/third_party/tsl/tsl/platform/float8.h index 52ba7a9ad272b..4ea4ff1fc9173 100644 --- a/third_party/tsl/tsl/platform/float8.h +++ b/third_party/tsl/tsl/platform/float8.h @@ -25,7 +25,7 @@ using float8_e4m3b11fnuz = ml_dtypes::float8_e4m3b11fnuz; // Deprecated: old name for backward-compatibility only. using float8_e4m3b11 = float8_e4m3b11fnuz; using float8_e5m2 = ml_dtypes::float8_e5m2; -using float8_e5m2funz = ml_dtypes::float8_e5m2fnuz; +using float8_e5m2fnuz = ml_dtypes::float8_e5m2fnuz; } // namespace tsl #endif // TENSORFLOW_TSL_PLATFORM_FLOAT8_H_ diff --git a/third_party/tsl/tsl/protobuf/dnn.proto b/third_party/tsl/tsl/protobuf/dnn.proto index daad67f448b67..6dbce61f14b38 100644 --- a/third_party/tsl/tsl/protobuf/dnn.proto +++ b/third_party/tsl/tsl/protobuf/dnn.proto @@ -19,6 +19,8 @@ enum DataType { kBF16 = 7; kF8E5M2 = 8; kF8E4M3FN = 9; + kF8E5M2FNUZ = 10; + kF8E4M3FNUZ = 11; } // Describes how a convolution input or output layer's data is formatted. diff --git a/xla/client/lib/math.cc b/xla/client/lib/math.cc index 826f7f6759791..0c6f6d38f1143 100644 --- a/xla/client/lib/math.cc +++ b/xla/client/lib/math.cc @@ -161,6 +161,8 @@ XlaOp IsNegZero(XlaOp operand) { case F8E5M2: case F8E4M3FN: case F8E4M3B11FNUZ: + case F8E5M2FNUZ: + case F8E4M3FNUZ: case F16: case BF16: // Not all XLA backends handle U16 well, so we convert to F32/U32. @@ -302,7 +304,8 @@ XlaOp Erfc(XlaOp x) { // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. return DoWithUpcastToF32( - x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, [](XlaOp x) { + x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, + [](XlaOp x) { return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x), ScalarLike(x, 1) - ErfImpl32Cephes(x)); }); @@ -347,8 +350,9 @@ XlaOp Erf(XlaOp x) { } // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32(x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, - [](XlaOp x) { return ErfImpl32(x); }); + return DoWithUpcastToF32( + x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, + [](XlaOp x) { return ErfImpl32(x); }); }); } @@ -496,8 +500,9 @@ XlaOp ErfInv(XlaOp x) { if (shape.element_type() == F64) { return ErfInv64(x); } - return DoWithUpcastToF32(x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, - [](XlaOp x) { return ErfInv32(x); }); + return DoWithUpcastToF32( + x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, + [](XlaOp x) { return ErfInv32(x); }); }); } @@ -626,7 +631,9 @@ XlaOp Lgamma(XlaOp input) { // here (although it's better than you might expect!), so do the // computations in F32. return DoWithUpcastToF32( - input, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, do_it); + input, + {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, + do_it); }); } @@ -722,7 +729,9 @@ XlaOp Digamma(XlaOp input) { return b.ReportErrorOrReturn([&]() -> StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input)); return DoWithUpcastToF32( - input, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, do_it); + input, + {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, + do_it); }); } @@ -977,7 +986,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); bool needs_upcast = false; - for (PrimitiveType type : {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}) { + for (PrimitiveType type : + {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1029,7 +1039,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); bool needs_upcast = false; - for (PrimitiveType type : {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}) { + for (PrimitiveType type : + {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; diff --git a/xla/hlo/evaluator/hlo_evaluator.cc b/xla/hlo/evaluator/hlo_evaluator.cc index 89c22c1fcdaa6..12dc8d0c5f738 100644 --- a/xla/hlo/evaluator/hlo_evaluator.cc +++ b/xla/hlo/evaluator/hlo_evaluator.cc @@ -896,7 +896,13 @@ HloEvaluator::HloEvaluator(int64_t max_loop_iterations) this); typed_visitors_[F8E4M3B11FNUZ] = std::make_unique>( - this); + this); + typed_visitors_[F8E5M2FNUZ] = + std::make_unique>( + this); + typed_visitors_[F8E4M3FNUZ] = + std::make_unique>( + this); typed_visitors_[TUPLE] = std::make_unique([](const HloInstruction*) { diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 3f129b7724125..d3f74eaa99b63 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1653,6 +1653,8 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc index 36710e571d6f1..18a7b2c576000 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc @@ -20,4 +20,6 @@ namespace xla { template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/literal.cc b/xla/literal.cc index 6774dae91382e..2b0e78d9847ad 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -88,7 +88,8 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { !proto.f16s().empty() || !proto.bf16s().empty() || !proto.u16s().empty() || !proto.s16s().empty() || !proto.f8e5m2s().empty() || !proto.f8e4m3fns().empty() || - !proto.f8e4m3b11fnuzs().empty(); + !proto.f8e4m3b11fnuzs().empty() || !proto.f8e5m2fnuzs().empty() || + !proto.f8e4m3fnuzs().empty(); } // Lazy getter for the interned scalar shape in static storage. We reuse this @@ -2243,6 +2244,16 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E5M2FNUZ: + *proto->mutable_f8e5m2fnuzs() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; + case F8E4M3FNUZ: + *proto->mutable_f8e4m3fnuzs() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F32: CopyToRepeatedField(proto->mutable_f32s(), data()); break; @@ -2386,6 +2397,20 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { s.size()); memcpy(untyped_data(), s.data(), s.size()); } break; + case F8E5M2FNUZ: { + const std::string& s(proto.f8e5m2fnuzs()); + TF_RET_CHECK(data().size() * + sizeof(tsl::float8_e5m2fnuz) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + } break; + case F8E4M3FNUZ: { + const std::string& s(proto.f8e4m3fnuzs()); + TF_RET_CHECK(data().size() * + sizeof(tsl::float8_e4m3fnuz) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + } break; case F16: { const std::string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/xla/literal_comparison.cc b/xla/literal_comparison.cc index aaae073023c5c..e5d06c5234804 100644 --- a/xla/literal_comparison.cc +++ b/xla/literal_comparison.cc @@ -94,6 +94,23 @@ bool CompareEqual(tsl::float8_e4m3b11 lhs, return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } + +template<> +bool CompareEqual(tsl::float8_e5m2fnuz lhs, + tsl::float8_e5m2fnuz rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, + multi_index); +} + +template <> +bool CompareEqual(tsl::float8_e4m3fnuz lhs, + tsl::float8_e4m3fnuz rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, + multi_index); +} + template <> bool CompareEqual(bfloat16 lhs, bfloat16 rhs, absl::Span multi_index) { @@ -188,6 +205,18 @@ Status MakeErrorStatus(tsl::float8_e4m3b11 lhs, tsl::float8_e4m3b11 rhs, return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } +template<> +Status MakeErrorStatus(tsl::float8_e5m2fnuz lhs, tsl::float8_e5m2fnuz rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, + multi_index); +} +template <> +Status MakeErrorStatus(tsl::float8_e4m3fnuz lhs, tsl::float8_e4m3fnuz rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, + multi_index); +} template <> Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs, absl::Span multi_index) { @@ -312,6 +341,14 @@ std::string FpValueToString(tsl::float8_e4m3b11 value) { return absl::StrFormat("%5.3g", static_cast(value)); } +std::string FpValueToString(tsl::float8_e5m2fnuz value) { + return absl::StrFormat("%5.3g", static_cast(value)); +} + +std::string FpValueToString(tsl::float8_e4m3fnuz value) { + return absl::StrFormat("%5.3g", static_cast(value)); +} + std::string FpValueToString(bfloat16 value) { return absl::StrFormat("%10.4g", static_cast(value)); } @@ -370,6 +407,16 @@ double FpAbsoluteValue(tsl::float8_e4m3b11 value) { return FpAbsoluteValue(static_cast(value)); } +template<> +double FpAbsoluteValue(tsl::float8_e5m2fnuz value) { + return FpAbsoluteValue(static_cast(value)); +} + +template <> +double FpAbsoluteValue(tsl::float8_e4m3fnuz value) { + return FpAbsoluteValue(static_cast(value)); +} + // Helper class for comparing floating-point literals within an error bound. template class NearComparator { @@ -914,6 +961,16 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, expected, actual, shape_index, error, use_detailed_message, miscompare_callback); break; + case F8E5M2FNUZ: + return NearComparator::Compare( + expected, actual, shape_index, error, use_detailed_message, + miscompare_callback); + break; + case F8E4M3FNUZ: + return NearComparator::Compare( + expected, actual, shape_index, error, use_detailed_message, + miscompare_callback); + break; case BF16: return NearComparator::Compare(expected, actual, shape_index, error, use_detailed_message, diff --git a/xla/literal_test.cc b/xla/literal_test.cc index 2843db869fe45..42ea4ecff087c 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -158,6 +158,14 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f8e4m3b11fnuz_lit = LiteralUtil::CreateR0(tsl::float8_e4m3b11(0.5)); EXPECT_EQ("f8e4m3b11fnuz[] 0.5", f8e4m3b11fnuz_lit.ToString()); + + auto f8e4m3fnuz_lit = + LiteralUtil::CreateR0(tsl::float8_e4m3fnuz(0.5)); + EXPECT_EQ("f8e4m3fnuz[] 0.5", f8e4m3fnuz_lit.ToString()); + + auto f8e5m2fnuz_lit = + LiteralUtil::CreateR0(tsl::float8_e5m2fnuz(0.5)); + EXPECT_EQ("f8e5m2fnuz[] 0.5", f8e5m2fnuz_lit.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -593,6 +601,15 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(LiteralUtil::CreateR1({s16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({s16}).IsAll(9)); + tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3 + EXPECT_FALSE(LiteralUtil::CreateR1({t16}).IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR1({t16}).IsAll(9)); + + tsl::float8_e5m2fnuz u16(8); + EXPECT_TRUE(LiteralUtil::CreateR1({u16}).IsAll(8)); + // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false + EXPECT_FALSE(LiteralUtil::CreateR1({u16}).IsAll(9)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); @@ -1099,6 +1116,22 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3b11) { EXPECT_EQ(output, expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3fnuz) { + Literal output(ShapeUtil::MakeShape(F8E4M3FNUZ, {3})); + tsl::float8_e4m3fnuz x(0.5f); + output.PopulateWithValue(x); + auto expected = LiteralUtil::CreateR1({x, x, x}); + EXPECT_EQ(output, expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1F8e5m2fnuz) { + Literal output(ShapeUtil::MakeShape(F8E5M2FNUZ, {3})); + tsl::float8_e5m2fnuz x(0.5f); + output.PopulateWithValue(x); + auto expected = LiteralUtil::CreateR1({x, x, x}); + EXPECT_EQ(output, expected); +} + TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); @@ -1595,6 +1628,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatchF8) { using b11 = tsl::float8_e4m3b11; auto f8e4m3b11 = LiteralUtil::CreateR2WithLayout( {{b11{0.}, b11{1.}}, {b11{2.}, b11{3.}}}, layout_r2_dim0major_); + using e5f = tsl::float8_e5m2fnuz; + auto f8e5m2fnuz = LiteralUtil::CreateR2WithLayout( + {{e5f{0.}, e5f{1.}}, {e5f{2.}, e5f{3.}}}, layout_r2_dim0major_); + using e4f = tsl::float8_e4m3fnuz; + auto f8e4m3fnuz = LiteralUtil::CreateR2WithLayout( + {{e4f{0.}, e4f{1.}}, {e4f{2.}, e4f{3.}}}, layout_r2_dim0major_); Literal conv; conv = s8.Convert(F8E5M2).value(); @@ -1641,6 +1680,24 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatchF8) { conv = f8e4m3b11.Convert(C128).value(); EXPECT_EQ(conv, c128); + + conv = f8e5m2fnuz.Convert(S8).value(); + EXPECT_EQ(conv, s8); + + conv = f8e5m2fnuz.Convert(F32).value(); + EXPECT_EQ(conv, f32); + + conv = f8e5m2fnuz.Convert(C128).value(); + EXPECT_EQ(conv, c128); + + conv = f8e4m3fnuz.Convert(S8).value(); + EXPECT_EQ(conv, s8); + + conv = f8e4m3fnuz.Convert(F32).value(); + EXPECT_EQ(conv, f32); + + conv = f8e4m3fnuz.Convert(C128).value(); + EXPECT_EQ(conv, c128); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -2040,6 +2097,12 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { using b11 = tsl::float8_e4m3b11; auto vector_f8e4m3b11 = LiteralUtil::CreateR1({b11{10.0}, b11{20.0}, b11{-30.0}}); + using e5f = tsl::float8_e5m2fnuz; + auto vector_f8e5m2fnuz = + LiteralUtil::CreateR1({e5f{10.0}, e5f{20.0}, e5f{-30.0}}); + using e4f = tsl::float8_e4m3fnuz; + auto vector_f8e4m3fnuz = + LiteralUtil::CreateR1({e4f{10.0}, e4f{20.0}, e4f{-30.0}}); auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto vector_s4 = LiteralUtil::CreateR1({s4{-1}, s4{3}, s4{7}}); @@ -2063,6 +2126,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); + EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); + EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(vector_s4, to_from_proto(vector_s4)); EXPECT_EQ(vector_u4, to_from_proto(vector_u4)); @@ -2328,6 +2393,14 @@ TEST_F(LiteralUtilTest, IsEqualAt) { Literal c5 = LiteralUtil::CreateR0(val_true_complex); EXPECT_TRUE(c5.IsEqualAt({}, val_true_complex)); EXPECT_TRUE(c5.IsEqualAt({}, val_smaller_complex)); + Literal c6 = + LiteralUtil::CreateR0(tsl::float8_e5m2fnuz{val_double}); + EXPECT_TRUE(c6.IsEqualAt({}, val_double)); + EXPECT_TRUE(c6.IsEqualAt({}, val_integral)); + Literal c7 = + LiteralUtil::CreateR0(tsl::float8_e4m3fnuz{val_double}); + EXPECT_TRUE(c6.IsEqualAt({}, val_double)); + EXPECT_TRUE(c6.IsEqualAt({}, val_integral)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { diff --git a/xla/literal_util.cc b/xla/literal_util.cc index a4ff48d31a953..c906f0634c11f 100644 --- a/xla/literal_util.cc +++ b/xla/literal_util.cc @@ -126,8 +126,10 @@ struct IsReal { std::is_integral::value || std::is_floating_point::value || std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || - std::is_same::value; + std::is_same::value || + std::is_same::value; }; template @@ -246,6 +248,16 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, return ConvertType(bf16_literal); } +/* static */ Literal LiteralUtil::ConvertF32ToF8E4M3FNUZ( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + +/* static */ Literal LiteralUtil::ConvertF32ToF8E5M2FNUZ( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + /* static */ Literal LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); diff --git a/xla/literal_util.h b/xla/literal_util.h index 45e0db31a9b03..960a29ea9d153 100644 --- a/xla/literal_util.h +++ b/xla/literal_util.h @@ -226,6 +226,8 @@ class LiteralUtil { // recursively converts its elements. static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal); + static Literal ConvertF32ToF8E4M3FNUZ(const LiteralSlice& f32_literal); + static Literal ConvertF32ToF8E5M2FNUZ(const LiteralSlice& f32_literal); static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); static Literal ConvertF32ToS8(const LiteralSlice& f32_literal); static Literal ConvertF32ToF64(const LiteralSlice& f32_literal); diff --git a/xla/mlir/runtime/transforms/custom_call_encoding.cc b/xla/mlir/runtime/transforms/custom_call_encoding.cc index 9f1526a7f745d..9482370efe758 100644 --- a/xla/mlir/runtime/transforms/custom_call_encoding.cc +++ b/xla/mlir/runtime/transforms/custom_call_encoding.cc @@ -636,6 +636,8 @@ static PrimitiveType ScalarPrimitiveType(Type type) { if (type.isFloat8E4M3FN()) return PrimitiveType::F8E4M3FN; if (type.isFloat8E4M3B11FNUZ()) return PrimitiveType::F8E4M3B11FNUZ; if (type.isFloat8E5M2()) return PrimitiveType::F8E5M2; + if (type.isFloat8E4M3FNUZ()) return PrimitiveType::F8E4M3FNUZ; + if (type.isFloat8E5M2FNUZ()) return PrimitiveType::F8E5M2FNUZ; if (type.isF16()) return PrimitiveType::F16; if (type.isF32()) return PrimitiveType::F32; if (type.isF64()) return PrimitiveType::F64; diff --git a/xla/mlir/runtime/transforms/type_converter.cc b/xla/mlir/runtime/transforms/type_converter.cc index da5b0ca10819d..4967bcf2640f5 100644 --- a/xla/mlir/runtime/transforms/type_converter.cc +++ b/xla/mlir/runtime/transforms/type_converter.cc @@ -111,7 +111,9 @@ static std::unique_ptr ConvertCanonicalType( mlir::Type type) { if (type.isFloat8E4M3FN()) return PrimitiveType::F8E4M3FN; if (type.isFloat8E4M3B11FNUZ()) return PrimitiveType::F8E4M3B11FNUZ; + if (type.isFloat8E4M3FNUZ()) return PrimitiveType::F8E4M3FNUZ; if (type.isFloat8E5M2()) return PrimitiveType::F8E5M2; + if (type.isFloat8E5M2FNUZ()) return PrimitiveType::F8E5M2FNUZ; if (type.isIndex()) return PrimitiveType::S64; if (type.isBF16()) return PrimitiveType::BF16; if (type.isF16()) return PrimitiveType::F16; diff --git a/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc b/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc index 5bea10558afba..7659bdcf02975 100644 --- a/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc +++ b/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc @@ -278,6 +278,10 @@ tsl::StatusOr LiteralToValue(const xla::Literal& literal) { return tsl::errors::Unimplemented("F8E4M3FN not implemented"); case xla::F8E4M3B11FNUZ: return tsl::errors::Unimplemented("F8E4M3B11FNUZ not implemented"); + case xla::F8E5M2FNUZ: + return tsl::errors::Unimplemented("F8E5M2FNUZ not implemented"); + case xla::F8E4M3FNUZ: + return tsl::errors::Unimplemented("F8E4M3FNUZ not implemented"); case xla::C64: return {{ArrayLiteralToTensor>(literal)}}; case xla::C128: diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index a811c556ca7ab..13f008972db1f 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -547,6 +547,8 @@ typedef enum { PJRT_Buffer_Type_F8E5M2, PJRT_Buffer_Type_F8E4M3FN, PJRT_Buffer_Type_F8E4M3B11FNUZ, + PJRT_Buffer_Type_F8E5M2FNUZ, + PJRT_Buffer_Type_F8E4M3FNUZ, // 4-bit integer types PJRT_Buffer_Type_S4, diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index 0af53eaf8d582..570b7d53000b0 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -249,6 +249,10 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FN; case xla::PrimitiveType::F8E4M3B11FNUZ: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3B11FNUZ; + case xla::PrimitiveType::F8E5M2FNUZ: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2FNUZ; + case xla::PrimitiveType::F8E4M3FNUZ: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ; case xla::PrimitiveType::C64: return PJRT_Buffer_Type::PJRT_Buffer_Type_C64; case xla::PrimitiveType::C128: @@ -302,6 +306,10 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::F8E4M3FN; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3B11FNUZ: return xla::PrimitiveType::F8E4M3B11FNUZ; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2FNUZ: + return xla::PrimitiveType::F8E5M2FNUZ; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ: + return xla::PrimitiveType::F8E4M3FNUZ; case PJRT_Buffer_Type::PJRT_Buffer_Type_INVALID: CHECK(false) << "Buffer type is not supported in C API layer."; } diff --git a/xla/primitive_util.cc b/xla/primitive_util.cc index 740583efac92c..6d52d17b4e50b 100644 --- a/xla/primitive_util.cc +++ b/xla/primitive_util.cc @@ -44,6 +44,10 @@ int SignificandWidth(PrimitiveType type) { return std::numeric_limits::digits; case F8E4M3B11FNUZ: return std::numeric_limits::digits; + case F8E5M2FNUZ: + return std::numeric_limits::digits; + case F8E4M3FNUZ: + return std::numeric_limits::digits; default: LOG(FATAL) << "Not a floating data type " << type; } @@ -83,6 +87,10 @@ int UnderflowExponent(PrimitiveType type) { return std::numeric_limits::min_exponent; case F8E4M3B11FNUZ: return std::numeric_limits::min_exponent; + case F8E5M2FNUZ: + return std::numeric_limits::min_exponent; + case F8E4M3FNUZ: + return std::numeric_limits::min_exponent; default: LOG(FATAL) << "Not a floating data type " << type; } @@ -109,11 +117,62 @@ int OverflowExponent(PrimitiveType type) { return std::numeric_limits::max_exponent; case F8E4M3B11FNUZ: return std::numeric_limits::max_exponent; + case F8E5M2FNUZ: + return std::numeric_limits::max_exponent; + case F8E4M3FNUZ: + return std::numeric_limits::max_exponent; default: LOG(FATAL) << "Not a floating data type " << type; } } +int ExponentBias(PrimitiveType type) { + switch (type) { + case F32: + case BF16: + case F16: + case F64: + case F8E5M2: + case F8E4M3FN: + return (1 << (ExponentWidth(type) - 1)) - 1; + case F8E4M3B11FNUZ: + return 11; + case F8E4M3FNUZ: + return 8; + case F8E5M2FNUZ: + return 16; + default: + LOG(FATAL) << "Not a floating data type " << type; + } +} + +bool HasInfinity(PrimitiveType type) { + switch (type) { + case F32: + return std::numeric_limits::has_infinity; + case F64: + return std::numeric_limits::has_infinity; + case BF16: + return std::numeric_limits::has_infinity; + case F16: + return std::numeric_limits::has_infinity; + case F8E5M2: + return std::numeric_limits::has_infinity; + case F8E4M3FN: + return std::numeric_limits::has_infinity; + case F8E4M3B11FNUZ: + return std::numeric_limits::has_infinity; + case F8E5M2FNUZ: + return std::numeric_limits::has_infinity; + case F8E4M3FNUZ: + return std::numeric_limits::has_infinity; + // Assumes types not enumerated are non-floating point types without an + // infinity. + default: + return false; + } +} + xla::PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth) { switch (src_bitwidth) { case 4: diff --git a/xla/primitive_util.h b/xla/primitive_util.h index dd516ddbc58f0..f0de21cd7109e 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -54,6 +54,13 @@ int UnderflowExponent(PrimitiveType type); // results in a LOG(FATAL). int OverflowExponent(PrimitiveType type); +// Returns the exponent bias of the given floating point type. +// For non-float datatypes, results in a LOG(FATAL). +int ExponentBias(PrimitiveType type); + +// Returns whether the type has a value for infinity. +bool HasInfinity(PrimitiveType type); + // Returns the XLA primitive type (eg, F32) corresponding to the given // template parameter native type (eg, float). template @@ -163,6 +170,16 @@ inline PrimitiveType NativeToPrimitiveType() { return F8E4M3B11FNUZ; } +template <> +inline PrimitiveType NativeToPrimitiveType() { + return F8E5M2FNUZ; +} + +template <> +inline PrimitiveType NativeToPrimitiveType() { + return F8E4M3FNUZ; +} + // Complex template <> inline PrimitiveType NativeToPrimitiveType() { @@ -174,9 +191,14 @@ inline PrimitiveType NativeToPrimitiveType() { return C128; } +constexpr bool IsF8Type(PrimitiveType type) { + return type == F8E5M2 || type == F8E4M3FN || type == F8E4M3B11FNUZ || + type == F8E5M2FNUZ || type == F8E4M3FNUZ; +} + constexpr bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16 || - type == F8E5M2 || type == F8E4M3FN || type == F8E4M3B11FNUZ; + IsF8Type(type); } constexpr bool IsComplexType(PrimitiveType type) { @@ -195,10 +217,6 @@ constexpr bool IsIntegralType(PrimitiveType type) { return IsUnsignedIntegralType(type) || IsSignedIntegralType(type); } -inline bool IsF8Type(PrimitiveType type) { - return type == F8E5M2 || type == F8E4M3FN || type == F8E4M3B11FNUZ; -} - // Returns true if values of the given primitive type are held in array shapes. inline constexpr bool IsArrayType(PrimitiveType primitive_type) { return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && @@ -220,6 +238,8 @@ constexpr ABSL_ATTRIBUTE_ALWAYS_INLINE inline int BitWidth(PrimitiveType type) { case F8E5M2: case F8E4M3FN: case F8E4M3B11FNUZ: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return 8; case S16: @@ -268,6 +288,8 @@ ABSL_ATTRIBUTE_ALWAYS_INLINE inline int ByteWidth(PrimitiveType type) { case F8E5M2: case F8E4M3FN: case F8E4M3B11FNUZ: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return 1; case S16: @@ -413,15 +435,20 @@ inline bool CastPreservesValues(PrimitiveType from_type, if (primitive_util::IsComplexType(from_type)) { return false; } - // F -> F is safe if the exponent and significand are preserved. + // F -> F is safe if the exponent/significand are preserved and `to_type` + // preserves infinities in `from_type. if (primitive_util::IsFloatingPointType(from_type) && primitive_util::IsFloatingPointType(to_type)) { - return primitive_util::SignificandWidth(from_type) <= + return (!primitive_util::HasInfinity(from_type) || + primitive_util::HasInfinity(to_type)) && + primitive_util::SignificandWidth(from_type) <= primitive_util::SignificandWidth(to_type) && primitive_util::ExponentWidth(from_type) <= primitive_util::ExponentWidth(to_type) && - primitive_util::UnderflowExponent(from_type) >= - primitive_util::UnderflowExponent(to_type) && + (primitive_util::UnderflowExponent(from_type) - + primitive_util::SignificandWidth(from_type)) >= + (primitive_util::UnderflowExponent(to_type) - + primitive_util::SignificandWidth(to_type)) && primitive_util::OverflowExponent(from_type) <= primitive_util::OverflowExponent(to_type); } @@ -557,6 +584,16 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e4m3b11; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e5m2fnuz; +}; + +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e4m3fnuz; +}; + // Complex template <> struct PrimitiveTypeToNative { @@ -594,6 +631,8 @@ bool IsCanonicalRepresentation(PrimitiveType type) { case F8E5M2: case F8E4M3FN: case F8E4M3B11FNUZ: + case F8E5M2FNUZ: + case F8E4M3FNUZ: case C64: case C128: return NativeToPrimitiveType() == type; @@ -666,8 +705,12 @@ R PrimitiveTypeSwitch(F&& f, PrimitiveType type) { case F8E4M3B11FNUZ: return std::invoke(f, PrimitiveTypeConstant()); + case F8E4M3FNUZ: + return std::invoke(f, PrimitiveTypeConstant()); case F8E5M2: return std::invoke(f, PrimitiveTypeConstant()); + case F8E5M2FNUZ: + return std::invoke(f, PrimitiveTypeConstant()); case F16: return std::invoke(f, PrimitiveTypeConstant()); case BF16: diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index b65628c9826f1..8f9a67ff37556 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -78,6 +78,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[PRED][F8E5M2] = true; expecteds[PRED][F8E4M3FN] = true; expecteds[PRED][F8E4M3B11FNUZ] = true; + expecteds[PRED][F8E5M2FNUZ] = true; + expecteds[PRED][F8E4M3FNUZ] = true; expecteds[S4][PRED] = false; expecteds[S4][S4] = true; expecteds[S4][S8] = true; @@ -98,6 +100,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][F8E5M2] = true; expecteds[S4][F8E4M3FN] = true; expecteds[S4][F8E4M3B11FNUZ] = true; + expecteds[S4][F8E5M2FNUZ] = true; + expecteds[S4][F8E4M3FNUZ] = true; expecteds[S8][PRED] = false; expecteds[S8][S4] = false; expecteds[S8][S8] = true; @@ -118,6 +122,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][F8E5M2] = false; expecteds[S8][F8E4M3FN] = false; expecteds[S8][F8E4M3B11FNUZ] = false; + expecteds[S8][F8E5M2FNUZ] = false; + expecteds[S8][F8E4M3FNUZ] = false; expecteds[S16][PRED] = false; expecteds[S16][S4] = false; expecteds[S16][S8] = false; @@ -138,6 +144,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][F8E5M2] = false; expecteds[S16][F8E4M3FN] = false; expecteds[S16][F8E4M3B11FNUZ] = false; + expecteds[S16][F8E5M2FNUZ] = false; + expecteds[S16][F8E4M3FNUZ] = false; expecteds[S32][PRED] = false; expecteds[S32][S4] = false; expecteds[S32][S8] = false; @@ -158,6 +166,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][F8E5M2] = false; expecteds[S32][F8E4M3FN] = false; expecteds[S32][F8E4M3B11FNUZ] = false; + expecteds[S32][F8E5M2FNUZ] = false; + expecteds[S32][F8E4M3FNUZ] = false; expecteds[S64][PRED] = false; expecteds[S64][S4] = false; expecteds[S64][S8] = false; @@ -178,6 +188,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][F8E5M2] = false; expecteds[S64][F8E4M3FN] = false; expecteds[S64][F8E4M3B11FNUZ] = false; + expecteds[S64][F8E5M2FNUZ] = false; + expecteds[S64][F8E4M3FNUZ] = false; expecteds[U4][PRED] = false; expecteds[U4][S4] = false; expecteds[U4][S8] = true; @@ -200,6 +212,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][F8E5M2] = false; expecteds[U4][F8E4M3FN] = true; expecteds[U4][F8E4M3B11FNUZ] = true; + expecteds[U4][F8E5M2FNUZ] = false; + expecteds[U4][F8E4M3FNUZ] = true; expecteds[U8][PRED] = false; expecteds[U8][S4] = false; expecteds[U8][S8] = false; @@ -222,6 +236,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][F8E5M2] = false; expecteds[U8][F8E4M3FN] = false; expecteds[U8][F8E4M3B11FNUZ] = false; + expecteds[U8][F8E5M2FNUZ] = false; + expecteds[U8][F8E4M3FNUZ] = false; expecteds[U16][PRED] = false; expecteds[U16][S4] = false; expecteds[U16][S8] = false; @@ -242,6 +258,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][F8E5M2] = false; expecteds[U16][F8E4M3FN] = false; expecteds[U16][F8E4M3B11FNUZ] = false; + expecteds[U16][F8E5M2FNUZ] = false; + expecteds[U16][F8E4M3FNUZ] = false; expecteds[U32][PRED] = false; expecteds[U32][S4] = false; expecteds[U32][S8] = false; @@ -262,6 +280,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][F8E5M2] = false; expecteds[U32][F8E4M3FN] = false; expecteds[U32][F8E4M3B11FNUZ] = false; + expecteds[U32][F8E5M2FNUZ] = false; + expecteds[U32][F8E4M3FNUZ] = false; expecteds[U64][PRED] = false; expecteds[U64][S4] = false; expecteds[U64][S8] = false; @@ -282,6 +302,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][F8E5M2] = false; expecteds[U64][F8E4M3FN] = false; expecteds[U64][F8E4M3B11FNUZ] = false; + expecteds[U64][F8E5M2FNUZ] = false; + expecteds[U64][F8E4M3FNUZ] = false; expecteds[F16][PRED] = false; expecteds[F16][S4] = false; expecteds[F16][S8] = false; @@ -302,6 +324,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][F8E5M2] = false; expecteds[F16][F8E4M3FN] = false; expecteds[F16][F8E4M3B11FNUZ] = false; + expecteds[F16][F8E5M2FNUZ] = false; + expecteds[F16][F8E4M3FNUZ] = false; expecteds[F32][PRED] = false; expecteds[F32][S4] = false; expecteds[F32][S8] = false; @@ -322,6 +346,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][F8E5M2] = false; expecteds[F32][F8E4M3FN] = false; expecteds[F32][F8E4M3B11FNUZ] = false; + expecteds[F32][F8E5M2FNUZ] = false; + expecteds[F32][F8E4M3FNUZ] = false; expecteds[F64][PRED] = false; expecteds[F64][S4] = false; expecteds[F64][S8] = false; @@ -342,6 +368,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][F8E5M2] = false; expecteds[F64][F8E4M3FN] = false; expecteds[F64][F8E4M3B11FNUZ] = false; + expecteds[F64][F8E5M2FNUZ] = false; + expecteds[F64][F8E4M3FNUZ] = false; expecteds[C64][PRED] = false; expecteds[C64][S4] = false; expecteds[C64][S8] = false; @@ -362,6 +390,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][F8E5M2] = false; expecteds[C64][F8E4M3FN] = false; expecteds[C64][F8E4M3B11FNUZ] = false; + expecteds[C64][F8E5M2FNUZ] = false; + expecteds[C64][F8E4M3FNUZ] = false; expecteds[BF16][PRED] = false; expecteds[BF16][S4] = false; expecteds[BF16][S8] = false; @@ -382,6 +412,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][F8E5M2] = false; expecteds[BF16][F8E4M3FN] = false; expecteds[BF16][F8E4M3B11FNUZ] = false; + expecteds[BF16][F8E5M2FNUZ] = false; + expecteds[BF16][F8E4M3FNUZ] = false; expecteds[C128][PRED] = false; expecteds[C128][S4] = false; expecteds[C128][S8] = false; @@ -402,6 +434,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][F8E5M2] = false; expecteds[C128][F8E4M3FN] = false; expecteds[C128][F8E4M3B11FNUZ] = false; + expecteds[C128][F8E5M2FNUZ] = false; + expecteds[C128][F8E4M3FNUZ] = false; expecteds[F8E5M2][PRED] = false; expecteds[F8E5M2][S4] = false; expecteds[F8E5M2][S8] = false; @@ -422,6 +456,8 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][F8E5M2] = true; expecteds[F8E5M2][F8E4M3FN] = false; expecteds[F8E5M2][F8E4M3B11FNUZ] = false; + expecteds[F8E5M2][F8E5M2FNUZ] = false; + expecteds[F8E5M2][F8E4M3FNUZ] = false; expecteds[F8E4M3FN][PRED] = false; expecteds[F8E4M3FN][S4] = false; expecteds[F8E4M3FN][S8] = false; @@ -462,6 +498,54 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][F8E5M2] = false; expecteds[F8E4M3B11FNUZ][F8E4M3FN] = false; expecteds[F8E4M3B11FNUZ][F8E4M3B11FNUZ] = true; + expecteds[F8E4M3B11FNUZ][F8E4M3FNUZ] = false; + expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; + expecteds[F8E4M3FN][F8E5M2FNUZ] = false; + expecteds[F8E4M3FN][F8E4M3FNUZ] = false; + expecteds[F8E5M2FNUZ][PRED] = false; + expecteds[F8E5M2FNUZ][S4] = false; + expecteds[F8E5M2FNUZ][S8] = false; + expecteds[F8E5M2FNUZ][S16] = false; + expecteds[F8E5M2FNUZ][S32] = false; + expecteds[F8E5M2FNUZ][S64] = false; + expecteds[F8E5M2FNUZ][U4] = false; + expecteds[F8E5M2FNUZ][U8] = false; + expecteds[F8E5M2FNUZ][U16] = false; + expecteds[F8E5M2FNUZ][U32] = false; + expecteds[F8E5M2FNUZ][U64] = false; + expecteds[F8E5M2FNUZ][F16] = true; + expecteds[F8E5M2FNUZ][F32] = true; + expecteds[F8E5M2FNUZ][F64] = true; + expecteds[F8E5M2FNUZ][C64] = true; + expecteds[F8E5M2FNUZ][BF16] = true; + expecteds[F8E5M2FNUZ][C128] = true; + expecteds[F8E5M2FNUZ][F8E5M2] = false; + expecteds[F8E5M2FNUZ][F8E4M3FN] = false; + expecteds[F8E5M2FNUZ][F8E4M3B11FNUZ] = false; + expecteds[F8E5M2FNUZ][F8E5M2FNUZ] = true; + expecteds[F8E5M2FNUZ][F8E4M3FNUZ] = false; + expecteds[F8E4M3FNUZ][PRED] = false; + expecteds[F8E4M3FNUZ][S4] = false; + expecteds[F8E4M3FNUZ][S8] = false; + expecteds[F8E4M3FNUZ][S16] = false; + expecteds[F8E4M3FNUZ][S32] = false; + expecteds[F8E4M3FNUZ][S64] = false; + expecteds[F8E4M3FNUZ][U4] = false; + expecteds[F8E4M3FNUZ][U8] = false; + expecteds[F8E4M3FNUZ][U16] = false; + expecteds[F8E4M3FNUZ][U32] = false; + expecteds[F8E4M3FNUZ][U64] = false; + expecteds[F8E4M3FNUZ][F16] = true; + expecteds[F8E4M3FNUZ][F32] = true; + expecteds[F8E4M3FNUZ][F64] = true; + expecteds[F8E4M3FNUZ][C64] = true; + expecteds[F8E4M3FNUZ][BF16] = true; + expecteds[F8E4M3FNUZ][C128] = true; + expecteds[F8E4M3FNUZ][F8E5M2] = false; + expecteds[F8E4M3FNUZ][F8E4M3FN] = false; + expecteds[F8E4M3FNUZ][F8E4M3B11FNUZ] = false; + expecteds[F8E4M3FNUZ][F8E5M2FNUZ] = false; + expecteds[F8E4M3FNUZ][F8E4M3FNUZ] = true; for (int from_type_int = PrimitiveType_MIN; from_type_int < PrimitiveType_ARRAYSIZE; ++from_type_int) { diff --git a/xla/python/ifrt/dtype.h b/xla/python/ifrt/dtype.h index 485b4293608b0..b5073965b406a 100644 --- a/xla/python/ifrt/dtype.h +++ b/xla/python/ifrt/dtype.h @@ -74,9 +74,11 @@ class DType { kF8E4M3FN = 19, kF8E4M3B11FNUZ = 23, + kF8E4M3FNUZ = 24, kF8E5M2 = 20, + kF8E5M2FNUZ = 25, - // Next = 24 + // Next = 26 // String is not support in XLA. DType.Kind needs to match xla.PrimitiveType // enum, so choose a large enum to avoid collision. diff --git a/xla/python/pjrt_ifrt/pjrt_array.cc b/xla/python/pjrt_ifrt/pjrt_array.cc index a215cf8976f13..bcc2613fe5770 100644 --- a/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/xla/python/pjrt_ifrt/pjrt_array.cc @@ -53,7 +53,9 @@ StatusOr ToPrimitiveType(DType dtype) { case DType::kU64: case DType::kF8E4M3FN: case DType::kF8E4M3B11FNUZ: + case DType::kF8E4M3FNUZ: case DType::kF8E5M2: + case DType::kF8E5M2FNUZ: case DType::kF16: case DType::kF32: case DType::kBF16: @@ -85,7 +87,9 @@ StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::U64: case xla::PrimitiveType::F8E4M3FN: case xla::PrimitiveType::F8E4M3B11FNUZ: + case xla::PrimitiveType::F8E4M3FNUZ: case xla::PrimitiveType::F8E5M2: + case xla::PrimitiveType::F8E5M2FNUZ: case xla::PrimitiveType::F16: case xla::PrimitiveType::F32: case xla::PrimitiveType::BF16: diff --git a/xla/python/py_buffer.cc b/xla/python/py_buffer.cc index 545d44c73e01f..1cbde5e31ffbc 100644 --- a/xla/python/py_buffer.cc +++ b/xla/python/py_buffer.cc @@ -279,6 +279,14 @@ StatusOr IfrtHelpers::CudaArrayInterface( return InvalidArgument( "__cuda_array_interface__ is not supported for F8E5M2 buffers."); } + if (pjrt_buffer->on_device_shape().element_type() == F8E4M3FNUZ) { + return InvalidArgument( + "__cuda_array_interface__ is not supported for F8E4M3FNUZ buffers."); + } + if (pjrt_buffer->on_device_shape().element_type() == F8E5M2FNUZ) { + return InvalidArgument( + "__cuda_array_interface__ is not supported for F8E5M2FNUZ buffers."); + } TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( pjrt_buffer->on_device_shape().layout())); diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index e9a46344ac9ed..8e6068b81e874 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -172,6 +172,12 @@ StatusOr HandleNumpyScalar(py::handle h, ifrt::Client* client, } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &ptr); type = F8E5M2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &ptr); + type = F8E4M3FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &ptr); + type = F8E5M2FNUZ; } else if (std::is_same() || !options.squash_64bit_types) { PyArray_ScalarAsCtype(h.ptr(), &data); ptr = &data; @@ -332,6 +338,10 @@ StatusOr DevicePut(py::handle arg, ifrt::Client* client, HandleNumpyScalar; } (*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = + HandleNumpyScalar; (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; @@ -511,6 +521,8 @@ StatusOr PyArgSignatureOfValue(py::handle arg, (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz->ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler; (*p)[dtypes.np_float16.ptr()] = numpy_array_handler; (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler; (*p)[dtypes.np_float32.ptr()] = numpy_array_handler; diff --git a/xla/python/types.cc b/xla/python/types.cc index 4f6f42b34d465..c585f3818abe5 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -40,7 +40,9 @@ struct CustomDtypes { py::dtype bfloat16; py::dtype float8_e4m3fn; std::optional float8_e4m3b11fnuz; + py::dtype float8_e4m3fnuz; py::dtype float8_e5m2; + py::dtype float8_e5m2fnuz; std::optional int4; std::optional uint4; }; @@ -57,6 +59,10 @@ const CustomDtypes& GetCustomDtypes() { dtypes->float8_e4m3b11fnuz = py::dtype::from_args(ml_dtypes.attr("float8_e4m3b11fnuz")); } + dtypes->float8_e4m3fnuz = + py::dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")); + dtypes->float8_e5m2fnuz = + py::dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")); if (py::hasattr(ml_dtypes, "int4")) { dtypes->int4 = py::dtype::from_args(ml_dtypes.attr("int4")); } @@ -115,7 +121,9 @@ xla::StatusOr DtypeToPrimitiveType(const py::dtype& np_type) { if (custom_dtypes.float8_e4m3b11fnuz) { map->emplace(*custom_dtypes.float8_e4m3b11fnuz, F8E4M3B11FNUZ); } + map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ); map->emplace(custom_dtypes.float8_e5m2, F8E5M2); + map->emplace(custom_dtypes.float8_e5m2fnuz, F8E5M2FNUZ); if (custom_dtypes.int4) { map->emplace(*custom_dtypes.int4, S4); } @@ -172,8 +180,12 @@ xla::StatusOr PrimitiveTypeToDtype(PrimitiveType type) { return *custom_dtypes.float8_e4m3b11fnuz; } return InvalidArgument("ml_dtypes.float8_e4m3b11fnuz not found"); + case F8E4M3FNUZ: + return custom_dtypes.float8_e4m3fnuz; case F8E5M2: return custom_dtypes.float8_e5m2; + case F8E5M2FNUZ: + return custom_dtypes.float8_e5m2fnuz; case BF16: return custom_dtypes.bfloat16; case F16: @@ -219,6 +231,8 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { py::object(ml_dtypes.attr("float8_e4m3b11fnuz")); } dtypes->np_float8_e5m2 = py::object(ml_dtypes.attr("float8_e5m2")); + dtypes->np_float8_e4m3fnuz = py::object(ml_dtypes.attr("float8_e4m3fnuz")); + dtypes->np_float8_e5m2fnuz = py::object(ml_dtypes.attr("float8_e5m2fnuz")); dtypes->np_float16 = py::object(numpy.attr("float16")); dtypes->np_float32 = py::object(numpy.attr("float32")); dtypes->np_float64 = py::object(numpy.attr("float64")); diff --git a/xla/python/types.h b/xla/python/types.h index 1c6257fe75532..8ea823ad6486f 100644 --- a/xla/python/types.h +++ b/xla/python/types.h @@ -66,7 +66,9 @@ struct NumpyScalarTypes { pybind11::object np_bfloat16; pybind11::object np_float8_e4m3fn; std::optional np_float8_e4m3b11fnuz; + pybind11::object np_float8_e4m3fnuz; pybind11::object np_float8_e5m2; + pybind11::object np_float8_e5m2fnuz; pybind11::object np_float16; pybind11::object np_float32; pybind11::object np_float64; diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 364eae41c5e82..67032777f92b9 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -157,7 +157,9 @@ PYBIND11_MODULE(xla_extension, m) { .value("F16", F16) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) + .value("F8E4M3FNUZ", F8E4M3FNUZ) .value("F8E5M2", F8E5M2) + .value("F8E5M2FNUZ", F8E5M2FNUZ) .value("BF16", BF16) .value("F32", F32) .value("F64", F64) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 61fd9624e7f42..8b1deb8abf737 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -214,7 +214,9 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): if hasattr(ml_dtypes, 'float8_e4m3b11fnuz') else ml_dtypes.float8_e4m3fn ) +float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz float8_e5m2 = ml_dtypes.float8_e5m2 +float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz XLA_ELEMENT_TYPE_TO_DTYPE = { PrimitiveType.PRED: np.dtype('bool'), @@ -229,6 +231,8 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), PrimitiveType.F8E5M2: np.dtype(float8_e5m2), + PrimitiveType.F8E4M3FNUZ: np.dtype(float8_e4m3fnuz), + PrimitiveType.F8E5M2FNUZ: np.dtype(float8_e5m2fnuz), PrimitiveType.BF16: np.dtype(bfloat16), PrimitiveType.F16: np.dtype('float16'), PrimitiveType.F32: np.dtype('float32'), diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index 0788957f7a5bd..3a06eb1d8c576 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -55,7 +55,9 @@ mlir_api_version: int bfloat16: Type[numpy.generic] float8_e4m3fn: Type[numpy.generic] float8_e4m3b11fnuz: Type[numpy.generic] +float8_e4m3fnuz: Type[numpy.generic] float8_e5m2: Type[numpy.generic] +float8_e5m2fnuz: Type[numpy.generic] XLA_ELEMENT_TYPE_TO_DTYPE: Dict[PrimitiveType, numpy.dtype] _NameValueMapping = Mapping[str, Union[str, int, List[int], float]] diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index e87fa90085bce..f5f2d57b627fc 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -39,6 +39,8 @@ bfloat16 = xla_client.bfloat16 float8_e4m3fn = xla_client.float8_e4m3fn float8_e4m3b11fnuz = xla_client.float8_e4m3b11fnuz +float8_e5m2fnuz = xla_client.float8_e5m2fnuz +float8_e4m3fnuz = xla_client.float8_e4m3fnuz float8_e5m2 = xla_client.float8_e5m2 ops = xla_client.ops xla_computation_to_mlir_module = ( @@ -91,7 +93,8 @@ def TestFactory(xla_backend, # TODO(zhangqiaorjc): test fp8 types when XLA support is complete. # standard_dtypes is only used for BufferProtocolTest so we only test fp8 # round trip tests. - standard_dtypes += [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + standard_dtypes += [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2, + float8_e4m3fnuz, float8_e5m2fnuz] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index cb7df18aca94d..38cef1047a4d4 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -55,7 +55,9 @@ class PrimitiveType(enum.IntEnum): U64: PrimitiveType F8_E4M3FN: PrimitiveType F8_E4M3B11FNUZ: PrimitiveType + F8_E4M3FNUZ: PrimitiveType F8_E5M2: PrimitiveType + F8_E5M2FNUZ: PrimitiveType BF16: PrimitiveType F16: PrimitiveType F32: PrimitiveType diff --git a/xla/service/BUILD b/xla/service/BUILD index 337ba8a541e4f..d774a1ddfb079 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -5013,11 +5013,29 @@ cc_library( ], ) +cc_library( + name = "float8_fnuz_ir_emitter", + srcs = [ + "float8_fnuz_ir_emitter.cc", + ], + hdrs = [ + "float8_fnuz_ir_emitter.h", + ], + deps = [ + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "@llvm-project//llvm:Core", + ], +) + cc_library( name = "elemental_ir_emitter", srcs = ["elemental_ir_emitter.cc"], hdrs = ["elemental_ir_emitter.h"], deps = [ + ":float8_fnuz_ir_emitter", "//xla:permutation_util", "//xla:shape_util", "//xla:status_macros", diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index cb635bb304540..522e8c289fe73 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -663,6 +663,10 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(&f8e4m3fn_support); FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); pipeline.AddPass(&f8e4m3b11fnuz_support); + FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); + pipeline.AddPass(&f8e5m2fnuz_support); + FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); + pipeline.AddPass(&f8e4m3fnuz_support); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index 8dfb50afdb65f..51b2d688119bb 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_loop.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/float8_fnuz_ir_emitter.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/statusor.h" @@ -57,6 +58,8 @@ using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; +using xla::float8_fnuz_ir_emitter::EmitFloatingToF8fnuz; +using xla::float8_fnuz_ir_emitter::EmitF8fnuzToFloating; namespace { @@ -625,6 +628,13 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } + if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { + return EmitFloatingToF8fnuz( + F16, + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + to_type, b_); + } return EmitIntegralToFloating(operand_value, from_type, to_type, module_, b_); } @@ -759,6 +769,18 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F8E5M2FNUZ || from_type == F8E4M3FNUZ) { + TF_RET_CHECK(to_type != from_type); + PrimitiveType cast_type = + primitive_util::IsFloatingPointType(to_type) ? to_type : F16; + TF_ASSIGN_OR_RETURN(operand_value, + EmitF8fnuzToFloating(from_type, operand_value, + cast_type, b_, module_)); + from_type = cast_type; + if (from_type == to_type) { + return operand_value; + } + } if (primitive_util::IsComplexType(to_type)) { PrimitiveType to_component_type = primitive_util::ComplexComponentType(to_type); @@ -804,6 +826,9 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF8e4m3b11fnuz(operand_value, b_); } + if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { + return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); + } if (to_type == PRED) { return b_->CreateZExt( FCmpUNE(operand_value, @@ -1292,6 +1317,13 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } else if (operand_type == F8E4M3FN) { lhs_value = EmitF8e4m3fnToF16(lhs_value, b_); rhs_value = EmitF8e4m3fnToF16(rhs_value, b_); + } else if (operand_type == F8E5M2FNUZ || operand_type == F8E4M3FNUZ) { + TF_ASSIGN_OR_RETURN( + lhs_value, + EmitF8fnuzToFloating(operand_type, lhs_value, F16, b_, module_)); + TF_ASSIGN_OR_RETURN( + rhs_value, + EmitF8fnuzToFloating(operand_type, rhs_value, F16, b_, module_)); } switch (op->comparison_direction()) { case ComparisonDirection::kEq: @@ -3032,6 +3064,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( llvm::Type* float_ir_type; if (component_element_type == BF16) { float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_); + } else if (component_element_type == F8E4M3FNUZ) { + float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_); + } else if (component_element_type == F8E5M2FNUZ) { + float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_); } else { float_ir_type = llvm_ir::PrimitiveTypeToIrType(component_element_type, module_); @@ -3040,6 +3076,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( b_->CreateUIToFP(elem_index_linear, float_ir_type); if (component_element_type == BF16) { TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val)); + } else if (component_element_type == F8E4M3FNUZ || + component_element_type == F8E5M2FNUZ) { + TF_ASSIGN_OR_RETURN( + iota_result, + EmitFloatingToF8fnuz(F16, float_val, component_element_type, b_)); } else { iota_result = float_val; } diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index ed82a52f58431..9ca6f0a918ba6 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -367,5 +367,307 @@ XLA_TEST_F(ElementalIrEmitterExecutionTest, BatchDotBF16) { EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); } +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToF8E4FNUZ) { + RunTypeConversionTest(R"( + HloModule convertToF8E4FNUZ + ENTRY ConvertToF8E4FNUZ + (f16_ f16[], f32_ f32[], f64_ f64[], bf16_ bf16[]) -> (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { + f16_ = f16[] parameter(0) + f32_ = f32[] parameter(1) + f64_ = f64[] parameter(2) + bf16_ = bf16[] parameter(3) + converted_f16 = f8e4m3fnuz[] convert(f16[] f16_) + converted_f32 = f8e4m3fnuz[] convert(f32[] f32_) + converted_f64 = f8e4m3fnuz[] convert(f64[] f64_) + converted_bf16 = f8e4m3fnuz[] convert(bf16[] bf16_) + ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( + converted_f16, converted_f32, converted_f64, converted_bf16) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToF8E4FNUZ) { + RunTypeConversionTest(R"( + HloModule convertToF8E4FNUZ + ENTRY ConvertToF8E4FNUZ (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> + (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { + s8_ = s8[] parameter(0) + s16_ = s16[] parameter(1) + s32_ = s32[] parameter(2) + s64_ = s64[] parameter(3) + converted_s8 = f8e4m3fnuz[] convert(s8[] s8_) + converted_s16 = f8e4m3fnuz[] convert(s16[] s16_) + converted_s32 = f8e4m3fnuz[] convert(s32[] s32_) + converted_s64 = f8e4m3fnuz[] convert(s64[] s64_) + ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( + converted_s8, converted_s16, converted_s32, converted_s64) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToF8E4FNUZ) { + RunTypeConversionTest(R"( + HloModule convertToF8E4FNUZ + ENTRY ConvertToF8E4FNUZ (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> + (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { + u8_ = u8[] parameter(0) + u16_ = u16[] parameter(1) + u32_ = u32[] parameter(2) + u64_ = u64[] parameter(3) + converted_u8 = f8e4m3fnuz[] convert(u8[] u8_) + converted_u16 = f8e4m3fnuz[] convert(u16[] u16_) + converted_u32 = f8e4m3fnuz[] convert(u32[] u32_) + converted_u64 = f8e4m3fnuz[] convert(u64[] u64_) + ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( + converted_u8, converted_u16, converted_u32, converted_u64) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToFloat) { + RunTypeConversionTest(R"( + HloModule convertFromF8E4FNUZ + ENTRY ConvertFromF8E4FNUZ + (to_f16 f8e4m3fnuz[], to_f32 f8e4m3fnuz[], to_f64 f8e4m3fnuz[], to_bf16 f8e4m3fnuz[]) -> (f16[], f32[], f64[], bf16[]) { + to_f16 = f8e4m3fnuz[] parameter(0) + to_f32 = f8e4m3fnuz[] parameter(1) + to_f64 = f8e4m3fnuz[] parameter(2) + to_bf16 = f8e4m3fnuz[] parameter(3) + f16_ = f16[] convert(f8e4m3fnuz[] to_f16) + f32_ = f32[] convert(f8e4m3fnuz[] to_f32) + f64_ = f64[] convert(f8e4m3fnuz[] to_f64) + bf16_ = bf16[] convert(f8e4m3fnuz[] to_f64) + ROOT tuple = (f16[], f32[], f64[], bf16[]) tuple(f16_, f32_, f64_, bf16_) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToSigned) { + RunTypeConversionTest(R"( + HloModule convertFromF8E4FNUZ + ENTRY ConvertFromF8E4FNUZ(to_s8 f8e4m3fnuz[], to_s16 f8e4m3fnuz[], to_s32 f8e4m3fnuz[], + to_s64 f8e4m3fnuz[]) -> (s8[], s16[], s32[], s64[]) { + to_s8 = f8e4m3fnuz[] parameter(0) + to_s16 = f8e4m3fnuz[] parameter(1) + to_s32 = f8e4m3fnuz[] parameter(2) + to_s64 = f8e4m3fnuz[] parameter(3) + s8_ = s8[] convert(f8e4m3fnuz[] to_s8) + s16_ = s16[] convert(f8e4m3fnuz[] to_s16) + s32_ = s32[] convert(f8e4m3fnuz[] to_s32) + s64_ = s64[] convert(f8e4m3fnuz[] to_s64) + ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToUnsigned) { + RunTypeConversionTest(R"( + HloModule convertFromF8E4FNUZ + ENTRY ConvertFromF8E4FNUZ(to_u8 f8e4m3fnuz[], to_u16 f8e4m3fnuz[], to_u32 f8e4m3fnuz[], + to_u64 f8e4m3fnuz[]) -> (u8[], u16[], u32[], u64[]) { + to_u8 = f8e4m3fnuz[] parameter(0) + to_u16 = f8e4m3fnuz[] parameter(1) + to_u32 = f8e4m3fnuz[] parameter(2) + to_u64 = f8e4m3fnuz[] parameter(3) + u8_ = u8[] convert(f8e4m3fnuz[] to_u8) + u16_ = u16[] convert(f8e4m3fnuz[] to_u16) + u32_ = u32[] convert(f8e4m3fnuz[] to_u32) + u64_ = u64[] convert(f8e4m3fnuz[] to_u64) + ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToComplex) { + RunTypeConversionTest(R"( + HloModule convertFromF8E4FNUZ + ENTRY ConvertFromF8E4FNUZ + (to_c64 f8e4m3fnuz[], to_c128 f8e4m3fnuz[]) -> (c64[], c128[]) { + to_c64 = f8e4m3fnuz[] parameter(0) + to_c128 = f8e4m3fnuz[] parameter(1) + c64_ = c64[] convert(f8e4m3fnuz[] to_c64) + c128_ = c128[] convert(f8e4m3fnuz[] to_c128) + ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareF8E4FNUZ) { + constexpr char hlo_text[] = R"( + HloModule compareF8E4FNUZ + ENTRY main { + p0 = f8e4m3fnuz[4] parameter(0) + p1 = f8e4m3fnuz[4] parameter(1) + ROOT cmp = pred[4] compare(p0, p1), direction=LT +})"; + + Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); + lhs = LiteralUtil::ConvertF32ToF8E4M3FNUZ(lhs); + rhs = LiteralUtil::ConvertF32ToF8E4M3FNUZ(rhs); + RunTest(hlo_text, {&lhs, &rhs}); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaF8E4FNUZ) { + constexpr char hlo_text[] = R"( + HloModule IotaF8E4FNUZ + ENTRY main { + ROOT iota_ = f8e4m3fnuz[4] iota(), iota_dimension=0 + } + )"; + + RunTest(hlo_text, {}); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToF8E5FNUZ) { + RunTypeConversionTest(R"( + HloModule convertToF8E5FNUZ + ENTRY ConvertToF8E5FNUZ + (f16_ f16[], f32_ f32[], f64_ f64[], bf16_ bf16[]) -> (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { + f16_ = f16[] parameter(0) + f32_ = f32[] parameter(1) + f64_ = f64[] parameter(2) + bf16_ = bf16[] parameter(3) + converted_f16 = f8e5m2fnuz[] convert(f16[] f16_) + converted_f32 = f8e5m2fnuz[] convert(f32[] f32_) + converted_f64 = f8e5m2fnuz[] convert(f64[] f64_) + converted_bf16 = f8e5m2fnuz[] convert(bf16[] bf16_) + ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( + converted_f16, converted_f32, converted_f64, converted_bf16) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToF8E5FNUZ) { + RunTypeConversionTest(R"( + HloModule convertToF8E5FNUZ + ENTRY ConvertToF8E5FNUZ (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> + (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { + s8_ = s8[] parameter(0) + s16_ = s16[] parameter(1) + s32_ = s32[] parameter(2) + s64_ = s64[] parameter(3) + converted_s8 = f8e5m2fnuz[] convert(s8[] s8_) + converted_s16 = f8e5m2fnuz[] convert(s16[] s16_) + converted_s32 = f8e5m2fnuz[] convert(s32[] s32_) + converted_s64 = f8e5m2fnuz[] convert(s64[] s64_) + ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( + converted_s8, converted_s16, converted_s32, converted_s64) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToF8E5FNUZ) { + RunTypeConversionTest(R"( + HloModule convertToF8E5FNUZ + ENTRY ConvertToF8E5FNUZ (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> + (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { + u8_ = u8[] parameter(0) + u16_ = u16[] parameter(1) + u32_ = u32[] parameter(2) + u64_ = u64[] parameter(3) + converted_u8 = f8e5m2fnuz[] convert(u8[] u8_) + converted_u16 = f8e5m2fnuz[] convert(u16[] u16_) + converted_u32 = f8e5m2fnuz[] convert(u32[] u32_) + converted_u64 = f8e5m2fnuz[] convert(u64[] u64_) + ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( + converted_u8, converted_u16, converted_u32, converted_u64) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToFloat) { + RunTypeConversionTest(R"( + HloModule convertFromF8E5FNUZ + ENTRY ConvertFromF8E5FNUZ + (to_f16 f8e5m2fnuz[], to_f32 f8e5m2fnuz[], to_f64 f8e5m2fnuz[]) -> (f16[], f32[], f64[]) { + to_f16 = f8e5m2fnuz[] parameter(0) + to_f32 = f8e5m2fnuz[] parameter(1) + to_f64 = f8e5m2fnuz[] parameter(2) + f16_ = f16[] convert(f8e5m2fnuz[] to_f16) + f32_ = f32[] convert(f8e5m2fnuz[] to_f32) + f64_ = f64[] convert(f8e5m2fnuz[] to_f64) + ROOT tuple = (f16[], f32[], f64[]) tuple(f16_, f32_, f64_) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToSigned) { + RunTypeConversionTest(R"( + HloModule convertFromF8E5FNUZ + ENTRY ConvertFromF8E5FNUZ(to_s8 f8e5m2fnuz[], to_s16 f8e5m2fnuz[], to_s32 f8e5m2fnuz[], + to_s64 f8e5m2fnuz[]) -> (s8[], s16[], s32[], s64[]) { + to_s8 = f8e5m2fnuz[] parameter(0) + to_s16 = f8e5m2fnuz[] parameter(1) + to_s32 = f8e5m2fnuz[] parameter(2) + to_s64 = f8e5m2fnuz[] parameter(3) + s8_ = s8[] convert(f8e5m2fnuz[] to_s8) + s16_ = s16[] convert(f8e5m2fnuz[] to_s16) + s32_ = s32[] convert(f8e5m2fnuz[] to_s32) + s64_ = s64[] convert(f8e5m2fnuz[] to_s64) + ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToUnsigned) { + RunTypeConversionTest(R"( + HloModule convertFromF8E5FNUZ + ENTRY ConvertFromF8E5FNUZ(to_u8 f8e5m2fnuz[], to_u16 f8e5m2fnuz[], to_u32 f8e5m2fnuz[], + to_u64 f8e5m2fnuz[]) -> (u8[], u16[], u32[], u64[]) { + to_u8 = f8e5m2fnuz[] parameter(0) + to_u16 = f8e5m2fnuz[] parameter(1) + to_u32 = f8e5m2fnuz[] parameter(2) + to_u64 = f8e5m2fnuz[] parameter(3) + u8_ = u8[] convert(f8e5m2fnuz[] to_u8) + u16_ = u16[] convert(f8e5m2fnuz[] to_u16) + u32_ = u32[] convert(f8e5m2fnuz[] to_u32) + u64_ = u64[] convert(f8e5m2fnuz[] to_u64) + ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToComplex) { + RunTypeConversionTest(R"( + HloModule convertFromF8E5FNUZ + ENTRY ConvertFromF8E5FNUZ + (to_c64 f8e5m2fnuz[], to_c128 f8e5m2fnuz[]) -> (c64[], c128[]) { + to_c64 = f8e5m2fnuz[] parameter(0) + to_c128 = f8e5m2fnuz[] parameter(1) + c64_ = c64[] convert(f8e5m2fnuz[] to_c64) + c128_ = c128[] convert(f8e5m2fnuz[] to_c128) + ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) + } + )"); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareF8E5FNUZ) { + constexpr char hlo_text[] = R"( + HloModule compareF8E5FNUZ + ENTRY main { + p0 = f8e5m2fnuz[4] parameter(0) + p1 = f8e5m2fnuz[4] parameter(1) + ROOT cmp = pred[4] compare(p0, p1), direction=LT +})"; + + Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); + lhs = LiteralUtil::ConvertF32ToF8E5M2FNUZ(lhs); + rhs = LiteralUtil::ConvertF32ToF8E5M2FNUZ(rhs); + RunTest(hlo_text, {&lhs, &rhs}); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaF8E5FNUZ) { + constexpr char hlo_text[] = R"( + HloModule IotaF8E5FNUZ + ENTRY main { + ROOT iota_ = f8e5m2fnuz[4] iota(), iota_dimension=0 + } + )"; + + RunTest(hlo_text, {}); +} + } // namespace } // namespace xla diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc new file mode 100644 index 0000000000000..d2d81b5a8ffe7 --- /dev/null +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -0,0 +1,665 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/float8_fnuz_ir_emitter.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/Intrinsics.h" +#include "xla/primitive_util.h" +#include "xla/status_macros.h" +#include "xla/util.h" + +namespace xla { +namespace float8_fnuz_ir_emitter { + +using primitive_util::BitWidth; +using primitive_util::ExponentBias; +using primitive_util::ExponentWidth; +using primitive_util::OverflowExponent; +using primitive_util::SignificandWidth; +using primitive_util::UnderflowExponent; + +namespace { + +StatusOr PrimitiveTypeToAPFloatSemantics( + PrimitiveType type) { + switch (type) { + case F8E4M3B11FNUZ: + return &llvm::APFloat::Float8E4M3B11FNUZ(); + case F8E4M3FN: + return &llvm::APFloat::Float8E4M3FN(); + case F8E4M3FNUZ: + return &llvm::APFloat::Float8E4M3FNUZ(); + case F8E5M2: + return &llvm::APFloat::Float8E5M2(); + case F8E5M2FNUZ: + return &llvm::APFloat::Float8E5M2FNUZ(); + case BF16: + return &llvm::APFloat::BFloat(); + case F16: + return &llvm::APFloat::IEEEhalf(); + case F32: + return &llvm::APFloat::IEEEsingle(); + case F64: + return &llvm::APFloat::IEEEdouble(); + default: + return Unimplemented( + "PrimitiveTypeToAPFloatSemantics has no semantics for %s.", + PrimitiveType_Name(type)); + } +} + +StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilder<>* b, + PrimitiveType type) { + switch (type) { + case F8E4M3B11FNUZ: + case F8E4M3FN: + case F8E4M3FNUZ: + case F8E5M2: + case F8E5M2FNUZ: + return b->getInt8Ty(); + case BF16: + return b->getInt16Ty(); + case F16: + return b->getHalfTy(); + case F32: + return b->getFloatTy(); + case F64: + return b->getDoubleTy(); + default: + return Unimplemented("PrimitiveTypeToLLVMType has no LLVM type for %s.", + PrimitiveType_Name(type)); + } +} + +// Compute the maximum value in the input type that is a finite value when +// converted to the output type. This takes into account rounding. This +// supports floating point types, and assumes the input type is wider than +// the output type. +// +// The result is provided as a uint64_t containing the bit encoding of the +// maximum value. +StatusOr ComputeMaximumValue(PrimitiveType input_type, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { + // Sanity check inputs. + TF_RET_CHECK(primitive_util::IsFloatingPointType(input_type)); + TF_RET_CHECK(primitive_util::IsFloatingPointType(output_type)); + TF_RET_CHECK(BitWidth(input_type) > BitWidth(output_type)); + + llvm::Type* uint_type = b->getIntNTy(BitWidth(input_type)); + + TF_ASSIGN_OR_RETURN(auto output_semantics, + PrimitiveTypeToAPFloatSemantics(output_type)); + + TF_ASSIGN_OR_RETURN(auto input_semantics, + PrimitiveTypeToAPFloatSemantics(input_type)); + + // Compute the largest number of the output type and convert it to the input + // type. + bool losesInfo; + llvm::APFloat largest_output_value = + llvm::APFloat::getLargest(*output_semantics); + largest_output_value.convert( + *input_semantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo); + + llvm::APInt maximum_value = largest_output_value.bitcastToAPInt(); + + // The maximum value in the input type that converts to a finite value in the + // output type has the suffix 0b0111... after the last 1 in the encoding. + // This is the maximum input value that will round down to the maximum finite + // output value. + // + // To find where to put that suffix, count the trailing zeros. Subtract 1 + // from the trailing zero count to ensure there is a 0 between the current + // encoding and the new suffix. + const int trailing_zeros = maximum_value.countTrailingZeros() - 1; + + // Create the 1s that will go in the suffix. + const uint64_t lower_bits = (0x1ull << trailing_zeros) - 1; + + // Or the suffix into the maximum value. + return maximum_value.getZExtValue() | lower_bits; +} + +// Tests whether the input value can be represented in the output type as a +// finite value. This takes into account rounding. +StatusOr IsInputOutsideOutputRange(PrimitiveType input_type, + llvm::Value* value, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { + const uint64_t shift = BitWidth(input_type) - 1; + const uint64_t bit_mask = (0x1ull << shift) - 1; + + // Ignore the sign bit. + llvm::Value* non_sign_bits = b->CreateAnd(value, bit_mask); + + TF_ASSIGN_OR_RETURN(uint64_t maximum_value, + ComputeMaximumValue(input_type, output_type, b)); + + // Compare against the maximum value. + llvm::Type* uint_type = b->getIntNTy(BitWidth(input_type)); + return b->CreateICmpUGT(non_sign_bits, + llvm::ConstantInt::get(uint_type, maximum_value)); +} + +llvm::Value* IsZero(PrimitiveType type, llvm::Value* value, + llvm::IRBuilder<>* b) { + const uint64_t shift = BitWidth(type) - 1; + const uint64_t bit_mask = (0x1ull << shift) - 1; + + // Assuming the input is finite, so we can ignore the sign bit. + llvm::Value* non_sign_bits = b->CreateAnd(value, bit_mask); + + llvm::Type* uint_type = b->getIntNTy(BitWidth(type)); + return b->CreateICmpEQ(non_sign_bits, + llvm::ConstantInt::get(uint_type, 0x0u)); +} + +llvm::Value* IsNormalNumber(PrimitiveType type, llvm::Value* value, + llvm::IRBuilder<>* b) { + const uint64_t width = ExponentWidth(type); + const uint64_t position = SignificandWidth(type) - 1; + const uint64_t exponent_bit_mask = ((0x1ull << width) - 0x1ull) << position; + + llvm::Value* exponent_bits = b->CreateAnd(value, exponent_bit_mask); + llvm::Type* uint_type = b->getIntNTy(BitWidth(type)); + return b->CreateICmpNE(exponent_bits, + llvm::ConstantInt::get(uint_type, 0x0u)); +} + +llvm::Value* IsOutputNormal(PrimitiveType input_type, llvm::Value* exponent, + PrimitiveType output_type, llvm::IRBuilder<>* b) { + const uint64_t denorm_exponent = UnderflowExponent(output_type) - 1; + + llvm::Type* uint_type = b->getIntNTy(BitWidth(input_type)); + return b->CreateICmpSGE(exponent, + llvm::ConstantInt::get(uint_type, denorm_exponent)); +} + +llvm::Value* Max(llvm::Type* type, llvm::Value* x, uint64_t y, + llvm::IRBuilder<>* b) { + return b->CreateBinaryIntrinsic(llvm::Intrinsic::smax, x, + llvm::ConstantInt::get(type, y)); +} + +llvm::Value* Min(llvm::Type* type, llvm::Value* x, uint64_t y, + llvm::IRBuilder<>* b) { + return b->CreateBinaryIntrinsic(llvm::Intrinsic::smin, x, + llvm::ConstantInt::get(type, y)); +} + +llvm::Value* Clamp(llvm::Type* type, llvm::Value* value, uint64_t min, + uint64_t max, llvm::IRBuilder<>* b) { + return Min(type, Max(type, value, min, b), max, b); +} + +// Returns the sign bit of the input value shifted down to the least +// significant bit. +llvm::Value* ExtractSign(PrimitiveType type, llvm::Value* value, + bool preserve_signed_zero, llvm::IRBuilder<>* b) { + const uint64_t shift = BitWidth(type) - 1; + const uint64_t sign_bit_mask = 0x1ull << shift; + + llvm::Value* sign = b->CreateAnd(value, sign_bit_mask); + + llvm::Type* uint_type = b->getIntNTy(BitWidth(type)); + sign = b->CreateLShr(sign, llvm::ConstantInt::get(uint_type, shift)); + + if (preserve_signed_zero) { + return sign; + } + + llvm::Value* is_zero_pred = IsZero(type, value, b); + return b->CreateSelect(is_zero_pred, llvm::ConstantInt::get(uint_type, 0x0u), + sign); +} + +// Returns the exponent of the input value shifted down to the least +// significant bits and without any bias. +llvm::Value* ExtractExponent(PrimitiveType type, llvm::Value* value, + llvm::IRBuilder<>* b) { + const uint64_t shift = BitWidth(type) - 1; + const uint64_t bit_mask = (0x1ull << shift) - 0x1ull; + llvm::Type* uint_type = b->getIntNTy(BitWidth(type)); + + // Mask out sign bit. + llvm::Value* exponent = b->CreateAnd(value, bit_mask); + + // Shift the mantissa bits away, leaving the exponent. + exponent = b->CreateLShr( + exponent, llvm::ConstantInt::get(uint_type, SignificandWidth(type) - 1)); + + // Subtract the exponent bias. + exponent = b->CreateSub( + exponent, llvm::ConstantInt::get(uint_type, ExponentBias(type))); + + // If the input number is not a normal number, return the subnormal exponent. + llvm::Value* input_normal_pred = IsNormalNumber(type, value, b); + return b->CreateSelect( + input_normal_pred, exponent, + llvm::ConstantInt::get(uint_type, UnderflowExponent(type) - 1)); +} + +// Returns the mantissa of the input value with all bits explicitly +// represented. For normal numbers, the implicit leading 1 is in the +// returned value. +llvm::Value* ExtractMantissa(PrimitiveType type, llvm::Value* value, + llvm::IRBuilder<>* b) { + const uint64_t shift = SignificandWidth(type) - 1; + const uint64_t mantissa_bit_mask = (0x1ull << shift) - 0x1ull; + + llvm::Value* mantissa = b->CreateAnd(value, mantissa_bit_mask); + + llvm::Value* input_normal_pred = IsNormalNumber(type, value, b); + llvm::Value* mantissa_normal = b->CreateOr(mantissa, (0x1ull << shift)); + + return b->CreateSelect(input_normal_pred, mantissa_normal, mantissa); +} + +// Identifies the index of the last bit of the input that can be represented +// in the output type. The index starts from the least significant bit +// (index 0) to the most significant bit (bit n-1). This takes into account +// whether the input value is a normal number. +// +// Example 1: +// input_type = F16 +// output_type = F8E5M2FNUZ +// value = 0.00002664 +// = 0x1.BFp-16 +// = 0x01BF +// = 0b0|00000|0110111111 +// = 0b0.0110111111 * 2^(-14) +// Given the input and output is a denorm, we are looking for the bit the +// corresponds to the smallest non-zero value. For F8E5M2FNUZ that's 2^(-17). +// ExtractMantissa(value) = 0b0000000110111111 +// ^- 2^(-17) is here at bit 7. +// result = LastMantissaBit(F16, 0.00002664, F8E5M2FNUZ, b) = 7 +// +// Example 2: +// input_type = BF16 +// output_type = F8E4M3FNUZ +// value = 247.0 +// = 0x1.EEp7 +// = 0x4377 +// = 0b0|10000110|1110111 +// = 0b1.1110111 * 2^7 +// Given the input and output is a normal number, we are looking for the bit +// the corresponds to the third bit of the mantissa. +// ExtractMantissa(value) = 0b0000000011110111 +// ^- third mantissa bit is at bit 4. +// result = LastMantissaBit(BF16, 247.0, F8E4M3FNUZ, b) = 4 +StatusOr LastMantissaBit(PrimitiveType input_type, + llvm::Value* value, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { + const int src_mantissa_bits = SignificandWidth(input_type) - 1; + const int dest_mantissa_bits = SignificandWidth(output_type) - 1; + llvm::Type* int_type = b->getIntNTy(BitWidth(input_type)); + + llvm::Value* exponent = ExtractExponent(input_type, value, b); + + // The index when the input/output is normal. + llvm::Value* last_bit_index = + llvm::ConstantInt::get(int_type, src_mantissa_bits - dest_mantissa_bits); + + // Increase the index if the output will be denormal given the exponent. + llvm::Value* denormal_shift = b->CreateSub( + llvm::ConstantInt::get(int_type, UnderflowExponent(output_type) - 1), + exponent); + denormal_shift = Max(int_type, denormal_shift, 0, b); + last_bit_index = b->CreateAdd(last_bit_index, denormal_shift); + + // Check the output type exponent bias is not greater than the input type by + // more than 1. + TF_RET_CHECK(ExponentBias(input_type) >= (ExponentBias(output_type) - 1)); + + // The log_2(x) of the smallest denorm value. This gives us the exponent n + // that produces that number. This corresponds to the encoding with a single + // bit set in the last significant bit (0b0000...01). + const int input_log_minimum = + UnderflowExponent(input_type) - SignificandWidth(input_type); + const int output_log_minimum = + UnderflowExponent(output_type) - SignificandWidth(output_type); + + // Alternatively, the input might be a denorm. This directly computes the + // last mantissa bit when the input is a denorm. Suppose we have 2.664E-5 + // (0x1.BFp-16) as the input number. The bit encoding for this in F16 is: + // S|EEEEE|MMMMMMMMMM + // 0b0|00000|0110111111 = 2^(-14) * 0b0.0110111111 + // + // To cast this to F8E5M2FNUZ, we would find the smallest denorm encoding and + // find the corresponding bit in the input. + // S|MMMMM|MM + // 0b0|00000|01 = 2^(-15) * 0b0.01 + // We can see the "last" bit of an output denorm represents 2^(-17), so we + // find that corresponding bit in the input. In this F16 example it is + // highlighted here: + // S|EEEEE|MMMMMMMMMM + // 0b0|00000|0110111111 + // ^-- last mantissa bit + llvm::Value* denorm_last_mantissa_bit = + llvm::ConstantInt::get(int_type, output_log_minimum - input_log_minimum); + + // Select the last mantissa bit based on whether the input is a normal number. + llvm::Value* normal_pred = IsNormalNumber(input_type, value, b); + + // For the purposes of this function, consider zero a normal number. + normal_pred = b->CreateOr(normal_pred, IsZero(input_type, value, b)); + + // Select the normal or denorm case. + llvm::Value* last_mantissa_bit = + b->CreateSelect(normal_pred, last_bit_index, denorm_last_mantissa_bit); + + // Ensure the last_mantissa_bit is a valid bit in the input mantissa. + // last_mantissa_bit is allowed to be the "2s" bit of the exponent. + // This means the maximum value is 0b10.000...0. This corresponds to the + // maximum possible rounding. + return Min(int_type, last_mantissa_bit, src_mantissa_bits + 1, b); +} + +// Compute the rounding bias for round-to-nearest-even for the input value. +// This takes into account whether the input value is a normal number and +// whether it will map to a normal number in the output type. +StatusOr DynamicRoundingBias(PrimitiveType input_type, + llvm::Value* value, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { + llvm::Type* int_type = b->getIntNTy(BitWidth(input_type)); + + // Find the bit position of the last mantissa bit. + TF_ASSIGN_OR_RETURN(llvm::Value * shift, + LastMantissaBit(input_type, value, output_type, b)); + + // Compute the mask to select that bit. + llvm::Value* last_mantissa_bit_mask = + b->CreateShl(llvm::ConstantInt::get(int_type, 0x1u), shift); + + // Given the mantissa bit mask, compute the rounding bias bits. + llvm::Value* base_rounding_bias = b->CreateLShr(last_mantissa_bit_mask, 0x1u); + base_rounding_bias = + b->CreateSub(base_rounding_bias, llvm::ConstantInt::get(int_type, 0x1u)); + + // Select the last mantissa bit, and shift it down to the lsb. + llvm::Value* mantissa = ExtractMantissa(input_type, value, b); + llvm::Value* x_last_mantissa_bit = + b->CreateLShr(b->CreateAnd(mantissa, last_mantissa_bit_mask), shift); + + // Add the last mantissa lsb into the rounding bias. + return b->CreateAdd(x_last_mantissa_bit, base_rounding_bias); +} + +// Given an unbiased exponent and mantissa with no implicit bits, returns the +// mantissa for the output type. The exponent is expected to be unbiased and +// the mantissa should have all bits explicitly represented, including the +// normally implicit leading 1 for a normal number. +llvm::Value* BuildOutputMantissa(PrimitiveType input_type, + llvm::Value* exponent, llvm::Value* mantissa, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { + llvm::Type* input_int_type = b->getIntNTy(BitWidth(input_type)); + + // Count the number of leading zeros, excluding the bits that would contain + // the exponent. + llvm::Value* zero_count = + b->CreateBinaryIntrinsic(llvm::Intrinsic::ctlz, mantissa, + llvm::ConstantInt::get(b->getInt1Ty(), 0x0u)); + zero_count = b->CreateSub( + zero_count, + llvm::ConstantInt::get(input_int_type, ExponentWidth(input_type))); + + // The amount to shift the normal mantissa down. + llvm::Value* shift = b->CreateSub( + llvm::ConstantInt::get(input_int_type, SignificandWidth(input_type) - + SignificandWidth(output_type)), + zero_count); + + // Shift the mantissa into its "normal" position. + mantissa = b->CreateLShr(mantissa, shift); + exponent = b->CreateSub(exponent, zero_count); + + // Additional shifting required to account for the denorm exponent. + shift = b->CreateSub(llvm::ConstantInt::get( + input_int_type, UnderflowExponent(output_type) - 1), + exponent); + + // Avoid shift > BitWidth(input_type) which is UB for lshr. This can happen + // with a large negative input exponent. This will shift all the bits out, + // which is equivalent. + shift = Min(input_int_type, shift, BitWidth(input_type) - 1, b); + llvm::Value* mantissa_denorm = b->CreateLShr(mantissa, shift); + + // Test whether the output will be a normal number. + llvm::Value* output_normal_pred = + IsOutputNormal(input_type, exponent, output_type, b); + + // Select the normal or subnormal mantissa. + mantissa = b->CreateSelect(output_normal_pred, mantissa, mantissa_denorm); + + // Mask out any additional bits. This includes the now implicit leading 1. + const uint64_t mantissa_bit_mask = + (0x1ull << (SignificandWidth(output_type) - 1)) - 1; + return b->CreateAnd(mantissa, mantissa_bit_mask); +} + +// Given an unbiased exponent and mantissa with no implicit bits, returns the +// exponent for the output type. The result is shifted into the correct +// position. +llvm::Value* BuildOutputExponent(PrimitiveType input_type, + llvm::Value* exponent, llvm::Value* mantissa, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { + llvm::Type* input_int_type = b->getIntNTy(BitWidth(input_type)); + + // Count the number of leading zeros, excluding the bits that would contain + // the exponent. + llvm::Value* zero_count = + b->CreateBinaryIntrinsic(llvm::Intrinsic::ctlz, mantissa, + llvm::ConstantInt::get(b->getInt1Ty(), 0x0u)); + zero_count = b->CreateSub( + zero_count, + llvm::ConstantInt::get(input_int_type, ExponentWidth(input_type))); + + // Lower the exponent value by the number of additional leading zeros. + exponent = b->CreateSub(exponent, zero_count); + + // Check whether this would lead to a normal number output. + llvm::Value* output_normal_pred = + IsOutputNormal(input_type, exponent, output_type, b); + + // If this would lead to a subnormal output, use the subnormal exponent. + exponent = b->CreateSelect( + output_normal_pred, exponent, + llvm::ConstantInt::get(input_int_type, -OverflowExponent(output_type))); + + // Bias the exponent. + exponent = b->CreateAdd( + exponent, + llvm::ConstantInt::get(input_int_type, OverflowExponent(output_type))); + + // Shift the exponent into the appropriate position. + return b->CreateShl(exponent, SignificandWidth(output_type) - 1); +} + +// Returns the sign for the output type. The result is shifted into the correct +// position. +llvm::Value* BuildOutputSign(llvm::Value* sign, PrimitiveType output_type, + llvm::IRBuilder<>* b) { + // Shift the sign bit into the msb. + return b->CreateShl(sign, BitWidth(output_type) - 1); +} + +StatusOr GetQNaN(PrimitiveType type) { + TF_ASSIGN_OR_RETURN(auto semantics, PrimitiveTypeToAPFloatSemantics(type)); + + return llvm::APFloat::getQNaN(*semantics).bitcastToAPInt().getZExtValue(); +} +} // namespace + +StatusOr EmitFloatingToF8fnuz(PrimitiveType input_type, + llvm::Value* input_value, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { + // Sanity check for supported types. + TF_RET_CHECK(input_type == BF16 || input_type == F16 || input_type == F32 || + input_type == F64); + TF_RET_CHECK(output_type == F8E4M3FNUZ || output_type == F8E5M2FNUZ); + + llvm::IntegerType* input_int_type = b->getIntNTy(BitWidth(input_type)); + llvm::Value* input_uint = b->CreateBitCast(input_value, input_int_type); + + TF_ASSIGN_OR_RETURN( + llvm::Value * out_of_range_pred, + IsInputOutsideOutputRange(input_type, input_uint, output_type, b)); + // We may now assume there won't be any further overflow issues. They will be + // handled in the final select. + + // Compute rounding bias for round-to-nearest with ties to even. + TF_ASSIGN_OR_RETURN( + llvm::Value * input_rounding_bias, + DynamicRoundingBias(input_type, input_uint, output_type, b)); + + // Apply the rounding bias to the input. This won't carry into the sign bit. + llvm::Value* input_uint_rounded = + b->CreateAdd(input_uint, input_rounding_bias); + + // The input value is broken down and in a canonical form. Appropriate + // rounding has been applied, exponent is not biased, and there are no + // implicit bits in the mantissa. + llvm::Value* sign = + ExtractSign(input_type, input_uint, /*preserve_signed_zero=*/false, b); + llvm::Value* exponent = ExtractExponent(input_type, input_uint_rounded, b); + llvm::Value* mantissa = ExtractMantissa(input_type, input_uint_rounded, b); + + // The component parts of the output value. + llvm::Value* output_sign = BuildOutputSign(sign, output_type, b); + llvm::Value* output_exponent = + BuildOutputExponent(input_type, exponent, mantissa, output_type, b); + llvm::Value* output_mantissa = + BuildOutputMantissa(input_type, exponent, mantissa, output_type, b); + + // Bitwise or the output components together. + llvm::Value* result = b->CreateOr(output_exponent, output_mantissa); + + // Check for output underflow before adding a sign bit. There's no -0 in + // fnuz types. + llvm::Value* is_zero_pred = IsZero(input_type, result, b); + output_sign = b->CreateSelect( + is_zero_pred, llvm::ConstantInt::get(input_int_type, 0x0u), output_sign); + + // Bitwise or the sign bit into the result. + result = b->CreateOr(result, output_sign); + + // Truncate down to int8. + result = b->CreateTrunc(result, b->getInt8Ty()); + + // Select based on whether the value was in range. + TF_ASSIGN_OR_RETURN(const uint64_t output_qnan, GetQNaN(output_type)); + return b->CreateSelect(out_of_range_pred, + llvm::ConstantInt::get(b->getInt8Ty(), output_qnan), + result); +} + +StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, + llvm::Value* f8_value, + PrimitiveType output_type, + llvm::IRBuilder<>* b, + llvm::Module* module) { + // Sanity check for supported types. + TF_RET_CHECK(input_type == F8E4M3FNUZ || input_type == F8E5M2FNUZ); + TF_RET_CHECK(primitive_util::IsFloatingPointType(output_type)); + + const int output_type_bit_width = BitWidth(output_type); + llvm::IntegerType* output_int_type = b->getIntNTy(output_type_bit_width); + + llvm::ArrayType* result_lut_array_type = + llvm::ArrayType::get(output_int_type, 128); + + const std::string lut_name = PrimitiveType_Name(input_type) + "To" + + PrimitiveType_Name(output_type) + "LUT"; + TF_ASSIGN_OR_RETURN(auto input_semantics, + PrimitiveTypeToAPFloatSemantics(input_type)); + TF_ASSIGN_OR_RETURN(auto output_semantics, + PrimitiveTypeToAPFloatSemantics(output_type)); + + llvm::Constant* global_result_lut_array = module->getOrInsertGlobal( + lut_name, result_lut_array_type, [&]() -> llvm::GlobalVariable* { + // Since the function range is only 2^8 and symmetric on the sign bit, + // this is implemented as a table lookup. + llvm::Constant* result_lut[128]; + + // Populate the table with values computed using llvm APFloat. + for (uint8_t i = 0; i < 128; ++i) { + llvm::APFloat value(*input_semantics, llvm::APInt(8, i)); + + bool losesInfo; + value.convert(*output_semantics, llvm::APFloat::rmNearestTiesToEven, + &losesInfo); + + result_lut[i] = llvm::ConstantInt::get( + output_int_type, value.bitcastToAPInt().getZExtValue()); + } + + llvm::Constant* result_lut_array = + llvm::ConstantArray::get(result_lut_array_type, result_lut); + + return new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/result_lut_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/result_lut_array, + /*Name=*/lut_name); + }); + + // Check for NaN, since it's a special case. + TF_ASSIGN_OR_RETURN(const uint64_t input_qnan, GetQNaN(input_type)); + llvm::Value* nan_pred = b->CreateICmpEQ( + f8_value, llvm::ConstantInt::get(b->getInt8Ty(), input_qnan)); + + // Extract the sign, which will be added back to the result of the table + // lookup. + llvm::Value* sign = b->CreateAnd(f8_value, 0x80); + + // The lower 7 bits used s the index for the table lookup. + llvm::Value* f8_abs = b->CreateAnd(f8_value, 0x7F); + + // Fetch the value from the lookup table. + llvm::Value* result_abs = + b->CreateGEP(output_int_type, global_result_lut_array, f8_abs); + result_abs = b->CreateLoad(output_int_type, result_abs); + + // Never output a negative zero. + llvm::Value* is_output_zero_pred = IsZero(output_type, result_abs, b); + sign = b->CreateSelect(is_output_zero_pred, + llvm::ConstantInt::get(b->getInt8Ty(), 0x0u), sign); + + // Bitwise or the sign bit back in. + sign = b->CreateZExt(sign, output_int_type); + sign = b->CreateShl(sign, output_type_bit_width - BitWidth(input_type)); + llvm::Value* result = b->CreateOr(sign, result_abs); + + // Bitcast to the output type. + TF_ASSIGN_OR_RETURN(auto type, PrimitiveTypeToLLVMType(b, output_type)); + TF_ASSIGN_OR_RETURN(const uint64_t output_qnan, GetQNaN(output_type)); + return b->CreateBitCast( + b->CreateSelect(nan_pred, + llvm::ConstantInt::get(output_int_type, output_qnan), + result), + type); +} + +} // namespace float8_fnuz_ir_emitter +} // namespace xla diff --git a/xla/service/float8_fnuz_ir_emitter.h b/xla/service/float8_fnuz_ir_emitter.h new file mode 100644 index 0000000000000..6aaad33767233 --- /dev/null +++ b/xla/service/float8_fnuz_ir_emitter.h @@ -0,0 +1,47 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_FLOAT8_FNUZ_IR_EMITTER_H_ +#define XLA_SERVICE_FLOAT8_FNUZ_IR_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "xla/statusor.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace float8_fnuz_ir_emitter { + +// Convert the given floating point input to the output type. input_type must +// be one of BF16, F16, F32, and F64. output_type must be one of F8E4M3FNUZ and +// F8E5M2FNUZ. +StatusOr EmitFloatingToF8fnuz(PrimitiveType input_type, + llvm::Value* input_value, + PrimitiveType output_type, + llvm::IRBuilder<>* b); + +// Convert the given floating point input to the output type. input_type must +// be one of F8E4M3FNUZ and F8E5M2FNUZ. output_type must be one of BF16, F16, +// F32, and F64. +StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, + llvm::Value* f8_value, + PrimitiveType output_type, + llvm::IRBuilder<>* b, + llvm::Module* module); +} // namespace float8_fnuz_ir_emitter +} // namespace xla + +#endif // XLA_SERVICE_FLOAT8_FNUZ_IR_EMITTER_H_ diff --git a/xla/service/float_support.h b/xla/service/float_support.h index 0b8b1e482c5f3..9a2691be1f27f 100644 --- a/xla/service/float_support.h +++ b/xla/service/float_support.h @@ -40,7 +40,9 @@ class FloatSupport { // instruction. PrimitiveType HighPrecisionType() const { if (low_precision_type_ == F8E5M2 || low_precision_type_ == F8E4M3FN || - low_precision_type_ == F8E4M3B11FNUZ) { + low_precision_type_ == F8E4M3B11FNUZ || + low_precision_type_ == F8E5M2FNUZ || + low_precision_type_ == F8E4M3FNUZ) { return F16; } DCHECK_EQ(low_precision_type_, BF16); diff --git a/xla/service/gpu/compile_module_to_llvm_ir.cc b/xla/service/gpu/compile_module_to_llvm_ir.cc index 2e0573ec45b86..f9853d811d385 100644 --- a/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -92,8 +92,10 @@ static bool HasFp8(const HloModule& hlo_module) { for (const HloComputation* computation : hlo_module.computations()) { for (const HloInstruction* instruction : computation->instructions()) { if (ShapeUtil::HasPrimitiveType(instruction->shape(), F8E5M2) || + ShapeUtil::HasPrimitiveType(instruction->shape(), F8E5M2FNUZ) || ShapeUtil::HasPrimitiveType(instruction->shape(), F8E4M3FN) || - ShapeUtil::HasPrimitiveType(instruction->shape(), F8E4M3B11FNUZ)) { + ShapeUtil::HasPrimitiveType(instruction->shape(), F8E4M3B11FNUZ) || + ShapeUtil::HasPrimitiveType(instruction->shape(), F8E4M3FNUZ)) { return true; } } diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 05f0f872fd540..452f79c8db5bb 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1022,6 +1022,8 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( GpuFloatSupport f8e5m2_support(F8E5M2); GpuFloatSupport f8e4m3fn_support(F8E4M3FN); FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); + FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); + FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = diff --git a/xla/service/hlo_parser.cc b/xla/service/hlo_parser.cc index 867815ec24754..9c51715511aa4 100644 --- a/xla/service/hlo_parser.cc +++ b/xla/service/hlo_parser.cc @@ -3980,6 +3980,14 @@ template <> struct MinMaxFiniteValue : MinMaxFiniteValueCustomFloat {}; +template<> +struct MinMaxFiniteValue + : MinMaxFiniteValueCustomFloat {}; + +template <> +struct MinMaxFiniteValue + : MinMaxFiniteValueCustomFloat {}; + // MSVC's standard C++ library does not define isnan/isfinite for integer types. // To work around that we will need to provide our own. template diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index 97e6b63e3bbb6..5cd37e0c18192 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -181,8 +181,10 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // used for storage. return llvm::Type::getInt16Ty(module->getContext()); case F8E5M2: + case F8E5M2FNUZ: case F8E4M3FN: case F8E4M3B11FNUZ: + case F8E4M3FNUZ: // Similarly as with BF16, we represent F8 as an int since there is no // LLVM F8 dtype. return llvm::Type::getInt8Ty(module->getContext()); diff --git a/xla/shape_util.cc b/xla/shape_util.cc index 5570285ff26cc..148362aac5637 100644 --- a/xla/shape_util.cc +++ b/xla/shape_util.cc @@ -73,6 +73,8 @@ constexpr uint8_t primitive_byte_size[PrimitiveType_ARRAYSIZE] = { sizeof(int8_t), // S4 = 21 sizeof(int8_t), // U4 = 22 sizeof(float) / 4, // F8E4M3B11FNUZ = 23 + sizeof(float) / 4, // F8E4M3FNUZ = 24 + sizeof(float) / 4, // F8E5M2FNUZ = 25 }; constexpr int64_t kAnnotationPrintInterval = 5; diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index 894b5db95dee1..c5a5858e08cfe 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -46,6 +46,14 @@ struct ToDataType { static constexpr DataType value = DataType::kF8E5M2; }; template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E4M3FNUZ; +}; +template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E5M2FNUZ; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kFloat; }; diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 4dcd3012b49f6..7004c9203f6ee 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -51,7 +51,9 @@ bool ProtoMapsEqual(const google::protobuf::Map& x, } // namespace constexpr DataType ToDataType::value; +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; diff --git a/xla/tests/client_library_test_base.h b/xla/tests/client_library_test_base.h index 0679d0ce44c3d..765fd10340ade 100644 --- a/xla/tests/client_library_test_base.h +++ b/xla/tests/client_library_test_base.h @@ -464,8 +464,10 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); @@ -492,8 +494,10 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); @@ -521,8 +525,10 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); @@ -551,8 +557,10 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); @@ -581,8 +589,10 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); @@ -610,8 +620,10 @@ void ClientLibraryTestBase::ComputeAndCompare( std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); diff --git a/xla/tests/constants_test.cc b/xla/tests/constants_test.cc index bc99fa5a3c94c..6c4a35410cfa3 100644 --- a/xla/tests/constants_test.cc +++ b/xla/tests/constants_test.cc @@ -132,6 +132,26 @@ TEST_F(ConstantsTest, OneCellF8e4m3b11fnuz) { ComputeAndCompareR1(&builder, {2.0f}, {}, error_spec_); } +TEST_F(ConstantsTest, OneCellF8e5m2fnuz) { + std::vector constant = {tsl::float8_e5m2fnuz{2.0}}; + + XlaBuilder builder(TestName()); + auto c = ConstantR1(&builder, constant); + + ComputeAndCompareR1(&builder, constant, {}, + error_spec_); +} + +TEST_F(ConstantsTest, OneCellF8e4m3fnuz) { + std::vector constant = {tsl::float8_e4m3fnuz{2.0}}; + + XlaBuilder builder(TestName()); + auto c = ConstantR1(&builder, constant); + + ComputeAndCompareR1(&builder, constant, {}, + error_spec_); +} + TEST_F(ConstantsTest, EightCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index 29bb21d8988ab..0a93487c33e89 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -45,6 +45,17 @@ class ConvertTest : public ClientLibraryTestBase { } }; +template +class ConvertTestT : public ConvertTest { + public: + using ConvertTest::ConvertTest; +}; +using FloatingPointTypeList = + ::testing::Types; +TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); + TEST_F(ConvertTest, ConvertR1S32ToR1S32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42, 64}); @@ -630,6 +641,36 @@ XLA_TEST_F(ConvertTest, ConvertF8e5m2F16RoundtripExhaustive) { ComputeAndCompareR1(&builder, all_f8, {}, ErrorSpec(0.)); } +XLA_TEST_F(ConvertTest, ConvertF8e5m2F16RoundtripExhaustive2) { + // Convert from F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, ConvertF8e5m2BF16RoundtripExhaustive3) { + // Convert from BF16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_bf16_to_f8, F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnRoundtrip) { // Convert from FP16 to FP8, then back to FP16 XlaBuilder builder(TestName()); @@ -732,6 +773,36 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3fnF16RoundtripExhaustive3) { ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +XLA_TEST_F(ConvertTest, ConvertF8e4m3fnF16RoundtripExhaustive4) { + // Convert from F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E4M3FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, ConvertF8e4m3fnBF16RoundtripExhaustive5) { + // Convert from BF16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_bf16_to_f8, F8E4M3FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + XLA_TEST_F(ConvertTest, ConvertF16F8e4m3b11fnuzRoundtrip) { // Convert from FP16 to FP8, then back to FP16 XlaBuilder builder(TestName()); @@ -835,6 +906,408 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive3) { ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +XLA_TEST_F(ConvertTest, ConvertF16F8e5m2fnuzRoundtrip) { + // Convert from FP16 to FP8, then back to FP16 + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, 0.0}, // No signed zero in F8E5M2FNUZ + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {inf, nan}, // No Inf in F8E4M3FNUZ + // clang-format on + {0x1.2p0, 0x1p0}, // Round-to-even down + {0x1.6p0, 0x1.8p0}, // Round-to-even up + {0x1.Cp15, 0x1.Cp15}, // Max value + {0x1.DFCp15, 0x1.Cp15}, // Largest number that doesn't overflow + {0x1.Ep15, nan}, // Smallest number that overflows + {0x1p16, nan}, // Overflow + {0x1p-15, 0x1p-15}, // Smallest F8 normal + {0x1.Cp-16, 0x1p-15}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-16, 0x1.0p-16}, // Denormal without rounding + {0x1.4p-16, 0x1.0p-16}, // Round-to-even down + {0x1.Cp-16, 0x1.0p-15}, // Round-to-even up + {0x1.3p-16, 0x1.0p-16}, // Round-to-nearest down + {0x1.5p-16, 0x1.8p-16}, // Round-to-nearest up + {0x1p-18, 0}, // Largest number that underflows + {0x1.04p-18, 0x1p-17}, // Smallest number that doesn't underflow + {0x1.BFp-16, 0x1.8p-16}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E5M2FNUZ); + ConvertElementType(f8, F16); + const bool saved = + execution_options_.debug_options().xla_allow_excess_precision(); + execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( + false); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); + execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( + saved); +} + +XLA_TEST_F(ConvertTest, ConvertF32F8e5m2fnuzRoundtrip) { + // Convert from FP32 to FP8, then back to FP32 + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, 0.0}, // No signed zero in F8E5M2FNUZ + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {inf, nan}, // No Inf in F8E4M3FNUZ + // clang-format on + {0x1.2p0, 0x1p0}, // Round-to-even down + {0x1.6p0, 0x1.8p0}, // Round-to-even up + {0x1.Cp15, 0x1.Cp15}, // Max value + {0x1.DFFFFEp15, 0x1.Cp15}, // Largest number that doesn't overflow + {0x1.Ep15, nan}, // Smallest number that overflows + {0x1p16, nan}, // Overflow + {0x1p-15, 0x1p-15}, // Smallest F8 normal + {0x1.Cp-16, 0x1p-15}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-16, 0x1.0p-16}, // Denormal without rounding + {0x1.4p-16, 0x1.0p-16}, // Round-to-even down + {0x1.Cp-16, 0x1.0p-15}, // Round-to-even up + {0x1.3FFFFEp-16, 0x1.0p-16}, // Round-to-nearest down + {0x1.5FFFFEp-16, 0x1.8p-16}, // Round-to-nearest up + {0x1p-18, 0}, // Largest number that underflows + {0x1.000002p-18, 0x1p-17}, // Smallest number that doesn't underflow + {0x1.BFFFFEp-16, 0x1.8p-16}, // Largest number that rounds to denormal + {0x1.FFFFFEp-50, 0}, // A very small input that should underflow + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E5M2FNUZ); + ConvertElementType(f8, F32); + const bool saved = + execution_options_.debug_options().xla_allow_excess_precision(); + execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( + false); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); + execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( + saved); +} + +XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzRoundtripExhaustive) { + // Convert from FP8 to each supported floating type, then back to FP8. + XlaBuilder builder(TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + for (auto type : {F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, + F16, BF16, F32, F64}) { + xla::XlaOp all_f8_as_f8 = + ConstantR1(&builder, all_f8); + xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, type); + ConvertElementType(all_f8_as_f16, F8E5M2FNUZ); + ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + } +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + static_cast(Eigen::numext::bit_cast( + static_cast(i)))); + } + + xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); + ConvertElementType(all_f8_as_f32, F8E5M2FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzRoundtripExhaustive3) { + // Convert from FP8 to supported floating point types. + XlaBuilder builder(TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, + F16, BF16, F32, F64}) { + xla::XlaOp all_f8_as_f8 = + ConstantR1(&builder, all_f8); + ConvertElementType(all_f8_as_f8, type); + ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + } +} + +XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzF16RoundtripExhaustive4) { + // Convert from F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E5M2FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzBF16RoundtripExhaustive5) { + // Convert from BF16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_bf16_to_f8, F8E5M2FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnuzRoundtrip) { + // Convert from FP16 to FP8, then back to FP16 + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, 0.0}, // No signed zero in F8E4M3FNUZ + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, nan}, // No Inf in F8E4M3FNUZ + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep7, 0x1.Ep7}, // Max value + {0x1.EFCp7, 0x1.Ep7}, // Largest number that doesn't overflow + {0x1.Fp7, nan}, // Smallest number that overflows + {0x1p8, nan}, // Overflow + {0x1p-7, 0x1p-7}, // Smallest F8 normal + {0x1.Ep-8, 0x1p-7}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-9, 0x1.0p-9}, // Denormal without rounding + {0x1.4p-9, 0x1.0p-9}, // Round-to-even down + {0x1.Cp-9, 0x1.0p-8}, // Round-to-even up + {0x1.5p-8, 0x1.4p-8}, // Round-to-nearest down + {0x1.3p-8, 0x1.4p-8}, // Round-to-nearest up + {0x1p-11, 0}, // Largest number that underflows + {0x1.004p-11, 0x1p-10}, // Smallest number that doesn't underflow + {0x1.DFCp-8, 0x1.Cp-8}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E4M3FNUZ); + ConvertElementType(f8, F16); + const bool saved = + execution_options_.debug_options().xla_allow_excess_precision(); + execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( + false); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); + execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( + saved); +} + +XLA_TEST_F(ConvertTest, ConvertF32F8e4m3fnuzRoundtrip) { + // Convert from FP16 to FP8, then back to FP16 + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, 0.0}, // No signed zero in F8E4M3FNUZ + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, nan}, // No Inf in F8E4M3FNUZ + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep7, 0x1.Ep7}, // Max value + {0x1.EFFFFEp7, 0x1.Ep7}, // Largest number that doesn't overflow + {0x1.Fp7, nan}, // Smallest number that overflows + {0x1p8, nan}, // Overflow + {0x1p-7, 0x1p-7}, // Smallest F8 normal + {0x1.Ep-8, 0x1p-7}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-9, 0x1.0p-9}, // Denormal without rounding + {0x1.4p-9, 0x1.0p-9}, // Round-to-even down + {0x1.Cp-9, 0x1.0p-8}, // Round-to-even up + {0x1.5p-8, 0x1.4p-8}, // Round-to-nearest down + {0x1.3p-8, 0x1.4p-8}, // Round-to-nearest up + {0x1p-11, 0}, // Largest number that underflows + {0x1.000002p-11, 0x1p-10}, // Smallest number that doesn't underflow + {0x1.DFFFFEp-8, 0x1.Cp-8}, // Largest number that rounds to denormal + {0x1.FFFFFEp-50, 0}, // A very small input that should underflow + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E4M3FNUZ); + ConvertElementType(f8, F32); + const bool saved = + execution_options_.debug_options().xla_allow_excess_precision(); + execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( + false); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); + execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( + saved); +} + +XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzRoundtripExhaustive) { + // Convert from FP8 to each supported floating type, then back to FP8. + XlaBuilder builder(TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, + F16, BF16, F32, F64}) { + xla::XlaOp all_f8_as_f8 = + ConstantR1(&builder, all_f8); + xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, type); + ConvertElementType(all_f8_as_f16, F8E4M3FNUZ); + ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + } +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnuzRoundtripExhaustive2) { + // Convert from support floating types to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + static_cast(Eigen::numext::bit_cast( + static_cast(i)))); + } + + xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); + ConvertElementType(all_f8_as_f32, F8E4M3FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzRoundtripExhaustive3) { + // Convert from FP8 to supported floating point types. + XlaBuilder builder(TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, + F16, BF16, F32, F64}) { + xla::XlaOp all_f8_as_f8 = + ConstantR1(&builder, all_f8); + ConvertElementType(all_f8_as_f8, type); + ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + } +} + +XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzF16RoundtripExhaustive4) { + // Convert from F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E4M3FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzBF16RoundtripExhaustive5) { + // Convert from BF16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_bf16_to_f8, F8E4M3FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + XLA_TEST_F(ConvertTest, ConvertF8e5m2ToPred) { XlaBuilder builder(TestName()); using F8 = tsl::float8_e5m2; @@ -855,5 +1328,25 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3fnToPred) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzToPred) { + XlaBuilder builder(TestName()); + using F8 = tsl::float8_e5m2fnuz; + auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); + ConvertElementType(a, PRED); + + std::array expected = {false, true, true}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzToPred) { + XlaBuilder builder(TestName()); + using F8 = tsl::float8_e4m3fnuz; + auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); + ConvertElementType(a, PRED); + + std::array expected = {false, true, true}; + ComputeAndCompareR1(&builder, expected, {}); +} + } // namespace } // namespace xla diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index 26517ba945a56..1c31712a95248 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -119,13 +119,15 @@ enum PrimitiveType { S4, U4, F8E4M3B11FNUZ, + F8E5M2FNUZ, + F8E4M3FNUZ, }; const std::vector& primitive_strings() { static auto vec = new std::vector( {"s16", "s32", "s64", "u8", "u16", "u32", "u64", "f16", "bf16", "f32", "f64", "c64", "c128", "f8e5m2", "f8e4m3fn", "s4", "u4", - "f8e4m3b11fnuz"}); + "f8e4m3b11fnuz", "f8e5m2fnuz", "f8e4m3fnuz"}); return *vec; } @@ -403,6 +405,8 @@ void Fill(void* buffer, const ArrayShape& shape) { case F8E5M2: case F8E4M3FN: case F8E4M3B11FNUZ: + case F8E5M2FNUZ: + case F8E4M3FNUZ: case F16: case BF16: case C64: @@ -453,6 +457,8 @@ void Display(const void* buffer, const ArrayShape& shape) { case F8E5M2: case F8E4M3FN: case F8E4M3B11FNUZ: + case F8E5M2FNUZ: + case F8E4M3FNUZ: case F16: case BF16: case C64: diff --git a/xla/translate/hlo_to_mhlo/hlo_utils.cc b/xla/translate/hlo_to_mhlo/hlo_utils.cc index 5bdd0fc8e132d..3c1decc42f319 100644 --- a/xla/translate/hlo_to_mhlo/hlo_utils.cc +++ b/xla/translate/hlo_to_mhlo/hlo_utils.cc @@ -160,6 +160,14 @@ Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data, CopyDenseElementsBy(data, output); return OkStatus(); } + if (element_type.isFloat8E5M2FNUZ()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isFloat8E4M3FNUZ()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } if (element_type.isBF16()) { CopyDenseElementsBy(data, output); return OkStatus(); @@ -223,6 +231,10 @@ StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, return builder.getFloat8E4M3FNType(); case PrimitiveType::F8E4M3B11FNUZ: return builder.getFloat8E4M3B11FNUZType(); + case PrimitiveType::F8E5M2FNUZ: + return builder.getFloat8E5M2FNUZType(); + case PrimitiveType::F8E4M3FNUZ: + return builder.getFloat8E4M3FNUZType(); case PrimitiveType::F16: return builder.getF16Type(); case PrimitiveType::BF16: diff --git a/xla/translate/hlo_to_mhlo/tests/import.hlotxt b/xla/translate/hlo_to_mhlo/tests/import.hlotxt index dbfb75bebd796..ee0e40d983002 100644 --- a/xla/translate/hlo_to_mhlo/tests/import.hlotxt +++ b/xla/translate/hlo_to_mhlo/tests/import.hlotxt @@ -327,6 +327,12 @@ add { // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3B11FNUZ> %constant.9 = f8e4m3b11fnuz[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ> + %constant.10 = f8e4m3fnuz[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> + %constant.11 = f8e5m2fnuz[4] constant({1, 2, 3, 4}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -421,7 +427,19 @@ add { %convert.7 = f8e4m3fn[4] convert(f32[4] %convert.6) // CHECK-NEXT: %5 = mhlo.convert %4 : (tensor<4xf8E4M3FN>) -> tensor<4xf32> - ROOT %convert.8 = f32[4] convert(f8e4m3fn[4] %convert.7) + %convert.8 = f32[4] convert(f8e4m3fn[4] %convert.7) + + // CHECK-NEXT: %6 = mhlo.convert %5 : (tensor<4xf32>) -> tensor<4xf8E4M3FNUZ> + %convert.9 = f8e4m3fnuz[4] convert(f32[4] %convert.8) + + // CHECK-NEXT: %7 = mhlo.convert %6 : (tensor<4xf8E4M3FNUZ>) -> tensor<4xf32> + %convert.10 = f32[4] convert(f8e4m3fnuz[4] %convert.9) + + // CHECK-NEXT: %8 = mhlo.convert %7 : (tensor<4xf32>) -> tensor<4xf8E5M2FNUZ> + %convert.11 = f8e5m2fnuz[4] convert(f32[4] %convert.10) + + // CHECK-NEXT: %9 = mhlo.convert %8 : (tensor<4xf8E5M2FNUZ>) -> tensor<4xf32> + ROOT %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11) } // CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8> diff --git a/xla/translate/mhlo_to_hlo/tests/export.mlir b/xla/translate/mhlo_to_hlo/tests/export.mlir index 9187b1c0c7f48..ecb8df2f29087 100644 --- a/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/translate/mhlo_to_hlo/tests/export.mlir @@ -653,6 +653,12 @@ func.func @main() { // CHECK: f8e4m3b11fnuz[4] constant({1, 2, 3, 4}) %cst_13 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3B11FNUZ> + // CHECK: f8e4m3fnuz[4] constant({1, 2, 3, 4}) + %cst_14 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ> + + // CHECK: f8e5m2fnuz[4] constant({1, 2, 3, 4}) + %cst_15 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> + func.return } @@ -778,15 +784,23 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { %1 = "mhlo.convert"(%0) : (tensor<2xf8E5M2>) -> tensor<2xf32> %2 = "mhlo.convert"(%1) : (tensor<2xf32>) -> tensor<2xf8E4M3FN> %3 = "mhlo.convert"(%2) : (tensor<2xf8E4M3FN>) -> tensor<2xf32> - func.return %3 : tensor<2xf32> + %4 = "mhlo.convert"(%3) : (tensor<2xf32>) -> tensor<2xf8E4M3FNUZ> + %5 = "mhlo.convert"(%4) : (tensor<2xf8E4M3FNUZ>) -> tensor<2xf32> + %6 = "mhlo.convert"(%5) : (tensor<2xf32>) -> tensor<2xf8E5M2FNUZ> + %7 = "mhlo.convert"(%6) : (tensor<2xf8E5M2FNUZ>) -> tensor<2xf32> + func.return %7 : tensor<2xf32> } // CHECK: ENTRY // CHECK: %[[ARG:.*]] = f32[2] parameter(0) // CHECK: %[[E5M2_VAL:.*]] = f8e5m2[2] convert(f32[2] %[[ARG]]) // CHECK: %[[F32_VAL:.*]] = f32[2] convert(f8e5m2[2] %[[E5M2_VAL]]) -// CHECK: %[[E4M3_VAL:.*]] = f8e4m3fn[2] convert(f32[2] %[[F32_VAL]]) -// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(f8e4m3fn[2] %[[E4M3_VAL]]) +// CHECK: %[[E4M3FN_VAL:.*]] = f8e4m3fn[2] convert(f32[2] %[[F32_VAL]]) +// CHECK: %[[F32_VAL2:.*]] = f32[2] convert(f8e4m3fn[2] %[[E4M3FN_VAL]]) +// CHECK: %[[E4M3FNUZ_VAL:.*]] = f8e4m3fnuz[2] convert(f32[2] %[[F32_VAL2]]) +// CHECK: %[[F32_VAL3:.*]] = f32[2] convert(f8e4m3fnuz[2] %[[E4M3FNUZ_VAL]]) +// CHECK: %[[E5M2FNUZ_VAL:.*]] = f8e5m2fnuz[2] convert(f32[2] %[[F32_VAL3]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]]) // ----- diff --git a/xla/translate/mhlo_to_hlo/type_to_shape.cc b/xla/translate/mhlo_to_hlo/type_to_shape.cc index 97c690eafb283..244d794307f9d 100644 --- a/xla/translate/mhlo_to_hlo/type_to_shape.cc +++ b/xla/translate/mhlo_to_hlo/type_to_shape.cc @@ -53,6 +53,10 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) { return PrimitiveType::F8E4M3FN; } else if (type.isFloat8E4M3B11FNUZ()) { return PrimitiveType::F8E4M3B11FNUZ; + } else if (type.isFloat8E4M3FNUZ()) { + return PrimitiveType::F8E4M3FNUZ; + } else if (type.isFloat8E5M2FNUZ()) { + return PrimitiveType::F8E5M2FNUZ; } else if (type.isBF16()) { return PrimitiveType::BF16; } else if (type.isF16()) { diff --git a/xla/util.cc b/xla/util.cc index 93f30c1c1da93..5d1eae734c213 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -134,6 +134,10 @@ template static void RoundTripNanPayload(FloatT value, std::string* result) { static_assert(!std::is_same::value, "RoundTripNanPayload does not support E4M3"); + static_assert(!std::is_same::value, + "RoundTripNanPayload does not support E4M3FNUZ"); + static_assert(!std::is_same::value, + "RoundTripNanPayload does not support E5M2FNUZ"); const int kPayloadBits = NanPayloadBits(); if (Eigen::numext::isnan(value) && kPayloadBits > 0) { auto rep = absl::bit_cast< @@ -160,6 +164,16 @@ std::string RoundTripFpToString(tsl::float8_e5m2 value) { return result; } +std::string RoundTripFpToString(tsl::float8_e4m3fnuz value) { + std::string result = GenericRoundTripFpToString(value); + return result; +} + +std::string RoundTripFpToString(tsl::float8_e5m2fnuz value) { + std::string result = GenericRoundTripFpToString(value); + return result; +} + std::string RoundTripFpToString(tsl::float8_e4m3fn value) { std::string result = GenericRoundTripFpToString(value); return result; diff --git a/xla/util.h b/xla/util.h index ff67caf34b307..04c306ec2f1d8 100644 --- a/xla/util.h +++ b/xla/util.h @@ -339,6 +339,12 @@ std::string RoundTripFpToString(tsl::float8_e4m3fn value); // Returns a string which can losslessly round trip to a float8 E4M3B11. std::string RoundTripFpToString(tsl::float8_e4m3b11 value); +// Returns a string which can losslessly round trip to a float8 E5M2FNUZ. +std::string RoundTripFpToString(tsl::float8_e5m2fnuz value); + +// Returns a string which can losslessly round trip to a float8 E4M3FNUZ. +std::string RoundTripFpToString(tsl::float8_e4m3fnuz value); + // Returns a string which can losslessly round trip to a bfloat. std::string RoundTripFpToString(tsl::bfloat16 value); @@ -527,9 +533,11 @@ auto SignAndMagnitude(T x) { BitType x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); const BitType x_bits = Eigen::numext::bit_cast(x); const BitType x_sign = x_bits ^ x_abs_bits; - if constexpr (std::is_same_v) { - // f8e4m3b11 does not support -0, adjust negative numbers to fill in the - // gap. + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + // f8e4m3b11, f8e4m3fnuz, and f8e5m2fnuz don't support -0, adjust negative + // numbers to fill in the gap. if (x_sign) { x_abs_bits -= 1; } diff --git a/xla/util_test.cc b/xla/util_test.cc index b066fb3062830..3dfb2e8682af7 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -128,11 +128,17 @@ TEST(UtilTest, RoundTripFpToString) { EXPECT_EQ( RoundTripFpToString(std::numeric_limits::quiet_NaN()), "nan"); + EXPECT_EQ(RoundTripFpToString( + -std::numeric_limits::quiet_NaN()), + "-nan"); EXPECT_EQ(RoundTripFpToString( std::numeric_limits::quiet_NaN()), "-nan"); EXPECT_EQ(RoundTripFpToString( - -std::numeric_limits::quiet_NaN()), + std::numeric_limits::quiet_NaN()), + "-nan"); + EXPECT_EQ(RoundTripFpToString( + std::numeric_limits::quiet_NaN()), "-nan"); EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( false, QuietNanWithoutPayload())), @@ -258,5 +264,29 @@ TEST(UtilTest, TotalOrder_F8E4M3B11) { } } +TEST(UtilTest, TotalOrder_F8E4M3FNUZ) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e4m3fnuz x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e4m3fnuz y = Eigen::numext::bit_cast( + static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + +TEST(UtilTest, TotalOrder_F8E5M2FNUZ) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e5m2fnuz x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e5m2fnuz y = Eigen::numext::bit_cast( + static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + } // namespace } // namespace xla diff --git a/xla/xla_data.proto b/xla/xla_data.proto index c4c0f97d24967..e6a872b899db6 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -82,6 +82,23 @@ enum PrimitiveType { F8E4M3FN = 20; F8E4M3B11FNUZ = 23; + // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915 + // + // F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits. + // F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits. + // + // The "FNUZ" means only Finite and NaN values are supported; zero is + // unsigned. Unlike IEEE types, infinities are not supported. NaN is + // represented when the exponent and mantissa bits are all 0s with a sign bit + // of 1. All other values are finite. + // + // These differences mean there's an additional exponent value available. To + // keep the same dynamic range as an IEEE-like FP8 type, the exponent is + // biased one more than would be expected given the number of exponent bits + // (8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ). + F8E5M2FNUZ = 24; + F8E4M3FNUZ = 25; + // Complex values of fixed width. C64 = 15; // Paired F32 (real, imag), as in std::complex. C128 = 18; // Paired F64 (real, imag), as in std::complex. @@ -107,7 +124,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 24 + // Next = 26 } // LINT.ThenChange( // https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, @@ -522,8 +539,10 @@ message LiteralProto { bytes f8e5m2s = 19; bytes f8e4m3fns = 20; bytes f8e4m3b11fnuzs = 23; + bytes f8e5m2fnuzs = 24; + bytes f8e4m3fnuzs = 25; repeated int64 sparse_indices = 14; - // Next = 24 + // Next = 26 } message WindowDimension {