Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use exp2 for mx scaling #1530

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions torchao/prototype/mx_formats/custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,13 @@
_f32_to_floatx_unpacked,
_floatx_unpacked_to_f32,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

# TODO(future): if needed, make the below work on previous PyTorch versions,
# just need to hunt down the previous location of `libdevice`. An assert
# at the callsite prevents usage of this on unsupported versions.
if TORCH_VERSION_AT_LEAST_2_4 and has_triton():
from torch._inductor.runtime.triton_helpers import libdevice

from torchao.prototype.mx_formats.constants import (
E8M0_EXPONENT_BIAS,
E8M0_EXPONENT_NAN_VAL,
F4_E2M1_EXP_BIAS,
F32_EXP_BIAS,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4


def get_bits(x: torch.Tensor) -> str:
Expand Down Expand Up @@ -294,8 +287,8 @@ def triton_f4_to_scaled_bf16_kernel(
s = tl.load(s_ptr + offsets_s, mask=mask_s)

# create the scale in bf16
s_offset = s.to(tl.int16) - e8m0_exponent_bias
s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16)
# S is already biased by 127, so we just have to shift it to align w/ bf16
s_fp = (s.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True)
s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan"))

# multiply output by scale
Expand Down
13 changes: 3 additions & 10 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@ def to_mx(

# For now, calculate the scale in floating point.
# TODO(future) audit if there is a need to bit shift exponents instead.
scale_fp = torch.pow(
torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device),
scale_e8m0_unbiased,
)
scale_fp = torch.exp2(scale_e8m0_unbiased).to(torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL about torch.exp2


# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
# float32 denormal range. For now, manually adjust the fp scale. This is
Expand Down Expand Up @@ -176,14 +173,10 @@ def to_mx(


def get_fp_scale(scale_e8m0):
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
# TODO(later): it would be nice if there was a way to do the 2^x operation
# in PyTorch without creating a tensor of twos
two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device)
# pow(two, s_offset) can be out of range of floating point formats.
# TODO(later): handle this for float16 if we decide to support float16
# scales.
s_fp = torch.pow(two, s_offset)
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
s_fp = torch.exp2(s_offset)

# If a block exponent was 255, set values of that block to NaN
s_fp = torch.where(scale_e8m0 != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))
Expand Down
Loading