Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for float8_e4m3fnuz and float8_e5m2fnuz. #3200

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion third_party/tsl/tsl/framework/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ struct is_simple_type {
is_quantized<T>::value || std::is_same<T, bfloat16>::value ||
std::is_same<T, float8_e4m3fn>::value ||
std::is_same<T, float8_e4m3b11>::value ||
std::is_same<T, float8_e5m2>::value;
std::is_same<T, float8_e5m2>::value ||
std::is_same<T, float8_e4m3fnuz>::value ||
std::is_same<T, float8_e5m2fnuz>::value;
};

} // namespace tsl
Expand Down
2 changes: 1 addition & 1 deletion third_party/tsl/tsl/platform/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using float8_e4m3b11fnuz = ml_dtypes::float8_e4m3b11fnuz;
// Deprecated: old name for backward-compatibility only.
using float8_e4m3b11 = float8_e4m3b11fnuz;
using float8_e5m2 = ml_dtypes::float8_e5m2;
using float8_e5m2funz = ml_dtypes::float8_e5m2fnuz;
using float8_e5m2fnuz = ml_dtypes::float8_e5m2fnuz;
} // namespace tsl

#endif // TENSORFLOW_TSL_PLATFORM_FLOAT8_H_
2 changes: 2 additions & 0 deletions third_party/tsl/tsl/protobuf/dnn.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ enum DataType {
kBF16 = 7;
kF8E5M2 = 8;
kF8E4M3FN = 9;
kF8E5M2FNUZ = 10;
kF8E4M3FNUZ = 11;
}

// Describes how a convolution input or output layer's data is formatted.
Expand Down
29 changes: 20 additions & 9 deletions xla/client/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ XlaOp IsNegZero(XlaOp operand) {
case F8E5M2:
case F8E4M3FN:
case F8E4M3B11FNUZ:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
case F16:
case BF16:
// Not all XLA backends handle U16 well, so we convert to F32/U32.
Expand Down Expand Up @@ -302,7 +304,8 @@ XlaOp Erfc(XlaOp x) {
// Erf(c)Impl don't have enough precision when run with bf16 intermediates
// (not surprising!), so upcast to f32 in this case.
return DoWithUpcastToF32(
x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, [](XlaOp x) {
x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
[](XlaOp x) {
return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
ScalarLike(x, 1) - ErfImpl32Cephes(x));
});
Expand Down Expand Up @@ -347,8 +350,9 @@ XlaOp Erf(XlaOp x) {
}
// Erf(c)Impl don't have enough precision when run with bf16 intermediates
// (not surprising!), so upcast to f32 in this case.
return DoWithUpcastToF32(x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ},
[](XlaOp x) { return ErfImpl32(x); });
return DoWithUpcastToF32(
x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
[](XlaOp x) { return ErfImpl32(x); });
});
}

Expand Down Expand Up @@ -496,8 +500,9 @@ XlaOp ErfInv(XlaOp x) {
if (shape.element_type() == F64) {
return ErfInv64(x);
}
return DoWithUpcastToF32(x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ},
[](XlaOp x) { return ErfInv32(x); });
return DoWithUpcastToF32(
x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
[](XlaOp x) { return ErfInv32(x); });
});
}

Expand Down Expand Up @@ -626,7 +631,9 @@ XlaOp Lgamma(XlaOp input) {
// here (although it's better than you might expect!), so do the
// computations in F32.
return DoWithUpcastToF32(
input, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, do_it);
input,
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
do_it);
});
}

Expand Down Expand Up @@ -722,7 +729,9 @@ XlaOp Digamma(XlaOp input) {
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
return DoWithUpcastToF32(
input, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, do_it);
input,
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
do_it);
});
}

Expand Down Expand Up @@ -977,7 +986,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
PrimitiveType a_x_type = a_shape.element_type();
bool needs_upcast = false;
for (PrimitiveType type : {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}) {
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down Expand Up @@ -1029,7 +1039,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
bool needs_upcast = false;
for (PrimitiveType type : {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}) {
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down
8 changes: 7 additions & 1 deletion xla/hlo/evaluator/hlo_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,13 @@ HloEvaluator::HloEvaluator(int64_t max_loop_iterations)
this);
typed_visitors_[F8E4M3B11FNUZ] =
std::make_unique<HloEvaluatorTypedVisitor<tsl::float8_e4m3b11, float>>(
this);
this);
typed_visitors_[F8E5M2FNUZ] =
std::make_unique<HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>>(
this);
typed_visitors_[F8E4M3FNUZ] =
std::make_unique<HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>>(
this);

typed_visitors_[TUPLE] =
std::make_unique<ConstFunctionVisitor>([](const HloInstruction*) {
Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,8 @@ extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;

} // namespace xla

Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ namespace xla {
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
} // namespace xla
27 changes: 26 additions & 1 deletion xla/literal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ bool LiteralProtoHasValues(const LiteralProto& proto) {
!proto.f16s().empty() || !proto.bf16s().empty() ||
!proto.u16s().empty() || !proto.s16s().empty() ||
!proto.f8e5m2s().empty() || !proto.f8e4m3fns().empty() ||
!proto.f8e4m3b11fnuzs().empty();
!proto.f8e4m3b11fnuzs().empty() || !proto.f8e5m2fnuzs().empty() ||
!proto.f8e4m3fnuzs().empty();
}

// Lazy getter for the interned scalar shape in static storage. We reuse this
Expand Down Expand Up @@ -2243,6 +2244,16 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
reinterpret_cast<const char*>(data<tsl::float8_e4m3b11>().data()),
size_bytes_dense());
break;
case F8E5M2FNUZ:
*proto->mutable_f8e5m2fnuzs() = std::string(
reinterpret_cast<const char*>(data<tsl::float8_e5m2fnuz>().data()),
size_bytes_dense());
break;
case F8E4M3FNUZ:
*proto->mutable_f8e4m3fnuzs() = std::string(
reinterpret_cast<const char*>(data<tsl::float8_e4m3fnuz>().data()),
size_bytes_dense());
break;
case F32:
CopyToRepeatedField(proto->mutable_f32s(), data<float>());
break;
Expand Down Expand Up @@ -2386,6 +2397,20 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
s.size());
memcpy(untyped_data(), s.data(), s.size());
} break;
case F8E5M2FNUZ: {
const std::string& s(proto.f8e5m2fnuzs());
TF_RET_CHECK(data<tsl::float8_e5m2fnuz>().size() *
sizeof(tsl::float8_e5m2fnuz) ==
s.size());
memcpy(untyped_data(), s.data(), s.size());
} break;
case F8E4M3FNUZ: {
const std::string& s(proto.f8e4m3fnuzs());
TF_RET_CHECK(data<tsl::float8_e4m3fnuz>().size() *
sizeof(tsl::float8_e4m3fnuz) ==
s.size());
memcpy(untyped_data(), s.data(), s.size());
} break;
case F16: {
const std::string& s(proto.f16s());
TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
Expand Down
57 changes: 57 additions & 0 deletions xla/literal_comparison.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,23 @@ bool CompareEqual<tsl::float8_e4m3b11>(tsl::float8_e4m3b11 lhs,
return CompareFloatsBitwiseEqual<tsl::float8_e4m3b11, uint8_t>(lhs, rhs,
multi_index);
}

template<>
bool CompareEqual<tsl::float8_e5m2fnuz>(tsl::float8_e5m2fnuz lhs,
tsl::float8_e5m2fnuz rhs,
absl::Span<const int64_t> multi_index) {
return CompareFloatsBitwiseEqual<tsl::float8_e5m2fnuz, uint8_t>(lhs, rhs,
multi_index);
}

template <>
bool CompareEqual<tsl::float8_e4m3fnuz>(tsl::float8_e4m3fnuz lhs,
tsl::float8_e4m3fnuz rhs,
absl::Span<const int64_t> multi_index) {
return CompareFloatsBitwiseEqual<tsl::float8_e4m3fnuz, uint8_t>(lhs, rhs,
multi_index);
}

template <>
bool CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
absl::Span<const int64_t> multi_index) {
Expand Down Expand Up @@ -188,6 +205,18 @@ Status MakeErrorStatus(tsl::float8_e4m3b11 lhs, tsl::float8_e4m3b11 rhs,
return MakeBitwiseErrorStatus<tsl::float8_e4m3b11, uint8_t>(lhs, rhs,
multi_index);
}
template<>
Status MakeErrorStatus(tsl::float8_e5m2fnuz lhs, tsl::float8_e5m2fnuz rhs,
absl::Span<const int64_t> multi_index) {
return MakeBitwiseErrorStatus<tsl::float8_e5m2fnuz, uint8_t>(lhs, rhs,
multi_index);
}
template <>
Status MakeErrorStatus(tsl::float8_e4m3fnuz lhs, tsl::float8_e4m3fnuz rhs,
absl::Span<const int64_t> multi_index) {
return MakeBitwiseErrorStatus<tsl::float8_e4m3fnuz, uint8_t>(lhs, rhs,
multi_index);
}
template <>
Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs,
absl::Span<const int64_t> multi_index) {
Expand Down Expand Up @@ -312,6 +341,14 @@ std::string FpValueToString(tsl::float8_e4m3b11 value) {
return absl::StrFormat("%5.3g", static_cast<double>(value));
}

std::string FpValueToString(tsl::float8_e5m2fnuz value) {
return absl::StrFormat("%5.3g", static_cast<double>(value));
}

std::string FpValueToString(tsl::float8_e4m3fnuz value) {
return absl::StrFormat("%5.3g", static_cast<double>(value));
}

std::string FpValueToString(bfloat16 value) {
return absl::StrFormat("%10.4g", static_cast<double>(value));
}
Expand Down Expand Up @@ -370,6 +407,16 @@ double FpAbsoluteValue(tsl::float8_e4m3b11 value) {
return FpAbsoluteValue<float>(static_cast<float>(value));
}

template<>
double FpAbsoluteValue(tsl::float8_e5m2fnuz value) {
return FpAbsoluteValue<float>(static_cast<float>(value));
}

template <>
double FpAbsoluteValue(tsl::float8_e4m3fnuz value) {
return FpAbsoluteValue<float>(static_cast<float>(value));
}

// Helper class for comparing floating-point literals within an error bound.
template <typename NativeT>
class NearComparator {
Expand Down Expand Up @@ -914,6 +961,16 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
expected, actual, shape_index, error, use_detailed_message,
miscompare_callback);
break;
case F8E5M2FNUZ:
return NearComparator<tsl::float8_e5m2fnuz>::Compare(
expected, actual, shape_index, error, use_detailed_message,
miscompare_callback);
break;
case F8E4M3FNUZ:
return NearComparator<tsl::float8_e4m3fnuz>::Compare(
expected, actual, shape_index, error, use_detailed_message,
miscompare_callback);
break;
case BF16:
return NearComparator<bfloat16>::Compare(expected, actual, shape_index,
error, use_detailed_message,
Expand Down
Loading