diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index cda946e285..8e3a1a4be1 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -12,20 +12,13 @@ _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - -# TODO(future): if needed, make the below work on previous PyTorch versions, -# just need to hunt down the previous location of `libdevice`. An assert -# at the callsite prevents usage of this on unsupported versions. -if TORCH_VERSION_AT_LEAST_2_4 and has_triton(): - from torch._inductor.runtime.triton_helpers import libdevice - from torchao.prototype.mx_formats.constants import ( E8M0_EXPONENT_BIAS, E8M0_EXPONENT_NAN_VAL, F4_E2M1_EXP_BIAS, F32_EXP_BIAS, ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 def get_bits(x: torch.Tensor) -> str: @@ -294,8 +287,8 @@ def triton_f4_to_scaled_bf16_kernel( s = tl.load(s_ptr + offsets_s, mask=mask_s) # create the scale in bf16 - s_offset = s.to(tl.int16) - e8m0_exponent_bias - s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16) + # S is already biased by 127, so we just have to shift it to align w/ bf16 + s_fp = (s.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True) s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan")) # multiply output by scale diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 2e67f5a4ac..a40df3e6f6 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -127,10 +127,7 @@ def to_mx( # For now, calculate the scale in floating point. # TODO(future) audit if there is a need to bit shift exponents instead. - scale_fp = torch.pow( - torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device), - scale_e8m0_unbiased, - ) + scale_fp = torch.exp2(scale_e8m0_unbiased).to(torch.float32) # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the # float32 denormal range. For now, manually adjust the fp scale. This is @@ -176,14 +173,10 @@ def to_mx( def get_fp_scale(scale_e8m0): - s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS - # TODO(later): it would be nice if there was a way to do the 2^x operation - # in PyTorch without creating a tensor of twos - two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device) - # pow(two, s_offset) can be out of range of floating point formats. # TODO(later): handle this for float16 if we decide to support float16 # scales. - s_fp = torch.pow(two, s_offset) + s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS + s_fp = torch.exp2(s_offset) # If a block exponent was 255, set values of that block to NaN s_fp = torch.where(scale_e8m0 != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))