Skip to content

Commit

Permalink
[float8nocompile] Add alternate Triton kernels for FP8 conversion whi…
Browse files Browse the repository at this point in the history
…ch use atomic_max-based algo instead of reduction-based algo (#1455)

* refactor float8nocompile kernel so autotune is easily usable

* refactor to make kernel algo configurable; refactor unit tests to test both algos

* address comments
  • Loading branch information
danielvegamyhre authored Dec 23, 2024
1 parent eab345c commit 567cb46
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 48 deletions.
10 changes: 6 additions & 4 deletions torchao/prototype/float8nocompile/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def get_configs() -> List[ExperimentConfig]:
layer_sizes = [[4096, 4096]]
input_shapes = [(2**4, 4096), (2**8, 4096), (2**12, 4096), (2**16, 4096)]
high_precision_dtypes = [torch.float32, torch.bfloat16]
high_precision_dtypes = [torch.bfloat16]
configs = []
for layer_size, input_shape, high_precision_dtype in itertools.product(
layer_sizes, input_shapes, high_precision_dtypes
Expand Down Expand Up @@ -133,18 +133,20 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:

def print_results(experiments: List[Experiment]):
headers = [
"input_size",
"input_shape",
"high_precision_dtype",
"eager_time",
"compiled_time",
"float8nocompile",
]
rows = []
for experiment in experiments:
input_size = experiment.config.input_shape[0] * experiment.config.input_shape[1]
input_shape = (
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
)
rows.append(
[
f"{input_size:.2e}",
input_shape,
experiment.config.high_precision_dtype,
experiment.result.eager_time,
experiment.result.compiled_time,
Expand Down
212 changes: 170 additions & 42 deletions torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"""
Triton kernels for scaling high precision tensors to float8.
"""
from enum import Enum

import torch

import triton
import triton.language as tl

Expand All @@ -31,8 +31,99 @@
}


class KernelAlgorithm(Enum):
"""Enum for FP8 conversion strategy."""

# use atomic max to compute global amax between blocks
ATOMIC_MAX = "atomic_max"

# reduce shared buffer containing local block amaxes to find global amax
REDUCTION = "reduction"


kernel_configs = [
triton.Config({"BLOCK_SIZE": 128}, num_warps=1),
triton.Config({"BLOCK_SIZE": 256}, num_warps=2),
triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
]


# --- atomic max version of kernel ---
@triton.autotune(configs=kernel_configs, key=["input_size"])
@triton.jit
def _block_amax_atomic(
input_ptr,
amax_ptr,
num_elements,
input_dtype: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
# compute local amax for each block
block_id = tl.program_id(axis=0)
block_start = block_id * BLOCK_SIZE
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
block_mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
block_amax = tl.max(tl.abs(vals))
tl.atomic_max(amax_ptr, block_amax)


@triton.jit
def _fp8_scale_atomic(
amax_ptr,
scale_out_ptr,
fp8_dtype_max,
EPS: tl.constexpr,
):
# load previously computed global amax
global_amax = tl.load(amax_ptr)

# compute scale, must be fp32
scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to(
tl.float32
)

# store scale for use in Float8Tensor constructor
scale_off = tl.arange(0, 1)
tl.store(scale_out_ptr + scale_off, scale)


@triton.autotune(configs=kernel_configs, key=["input_size"])
@triton.jit
def _block_amax(
def _to_fp8_atomic(
input_ptr,
scale_ptr,
amax_ptr,
out_ptr,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype: tl.constexpr,
output_dtype: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
EPS: tl.constexpr,
):
block_id = tl.program_id(axis=0)

# load scale
scale = tl.load(scale_ptr)

# load block of input tensor
block_start = block_id * BLOCK_SIZE
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=mask).to(input_dtype)

# perform conversion
vals = vals * scale
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype)
tl.store(out_ptr + block_offs, fp8_vals, mask=mask)


# --- reduction version of kernel ---
@triton.jit
def _block_amax_reduction(
input_ptr,
block_amaxes_ptr,
num_elements,
Expand All @@ -46,12 +137,12 @@ def _block_amax(
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
block_mask = block_offs < num_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
block_amax = tl.max(tl.abs(vals), axis=0)
block_amax = tl.max(tl.abs(vals))
tl.store(block_amaxes_ptr + block_id, block_amax)


@triton.jit
def _fp8_scale(
def _fp8_scale_reduction(
block_amaxes_ptr,
scale_out_ptr,
num_elements,
Expand All @@ -75,7 +166,7 @@ def _fp8_scale(


@triton.jit
def _to_fp8(
def _to_fp8_reduction(
input_ptr,
scale_ptr,
out_ptr,
Expand Down Expand Up @@ -108,12 +199,10 @@ def triton_hp_tensor_to_float8_dynamic(
fp8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
) -> Float8Tensor:

assert hp_tensor.is_contiguous(), "tensor must be contiguous"

BLOCK_SIZE = 8 # TODO(danielvegamyhre): tune this for perf

num_elements = hp_tensor.numel()
orig_shape = hp_tensor.shape
flattened_input = hp_tensor.flatten()
Expand All @@ -126,47 +215,86 @@ def triton_hp_tensor_to_float8_dynamic(

# allocate memory for computed scale, local block maxes, and output fp8 tensor
scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device)
block_amaxes = torch.zeros(
(num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device
)

fp8_output = torch.empty_like(
flattened_input, dtype=fp8_dtype, device=hp_tensor.device
)

# compute local amax for each block
grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
_block_amax[grid](
flattened_input,
block_amaxes,
num_elements,
input_dtype=tl_input_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# calculate global amax across all blocks and use it to compute scale
_fp8_scale[(1, 1, 1)](
block_amaxes,
scale_out,
num_elements,
fp8_dtype_max,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)
if algo == KernelAlgorithm.ATOMIC_MAX:
global_amax = torch.zeros((1,), dtype=torch.float32, device=hp_tensor.device)
# compute global amax to be used for scaling
_block_amax_atomic[grid](
flattened_input,
global_amax,
num_elements,
input_dtype=tl_input_dtype,
EPS=EPS,
)

# perform conversion
_to_fp8[grid](
flattened_input,
scale_out,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)
# compute scale for fp8 conversion
_fp8_scale_atomic[1, 1, 1](
global_amax,
scale_out,
fp8_dtype_max,
EPS=EPS,
)

# perform conversion and store scale for use in Float8Tensor
_to_fp8_atomic[grid](
flattened_input,
scale_out,
global_amax,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
EPS=EPS,
)
elif algo == KernelAlgorithm.REDUCTION:
max_block_size = 512
BLOCK_SIZE = min(max_block_size, num_elements)
block_amaxes = torch.zeros(
(num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device
)
# compute local amax for each block
_block_amax_reduction[grid](
flattened_input,
block_amaxes,
num_elements,
input_dtype=tl_input_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# calculate global amax across all blocks and use it to compute scale
_fp8_scale_reduction[(1, 1, 1)](
block_amaxes,
scale_out,
num_elements,
fp8_dtype_max,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)

# perform conversion
_to_fp8_reduction[grid](
flattened_input,
scale_out,
fp8_output,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
BLOCK_SIZE=BLOCK_SIZE,
EPS=EPS,
)
else:
raise ValueError(f"Unsupported kernel algorithm: {algo}")

return Float8Tensor(
fp8_output.reshape(orig_shape),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,24 @@
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
KernelAlgorithm,
triton_hp_tensor_to_float8_dynamic,
)


def test_fp8_triton_hp_tensor_to_float8_dynamic():
@pytest.mark.parametrize(
"algo", [KernelAlgorithm.ATOMIC_MAX, KernelAlgorithm.REDUCTION]
)
@pytest.mark.parametrize(
"input_shape",
[(32, 32), (512, 512), (4096, 4096)],
)
def test_fp8_triton_hp_tensor_to_float8_dynamic(
algo: KernelAlgorithm, input_shape: tuple[int, int]
):
assert torch.cuda.is_available()
device = "cuda"
input_bf16 = torch.randn((4, 4), dtype=torch.bfloat16, device=device)
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
x_bf16 = input_bf16.clone().detach().to(device)
y_bf16 = input_bf16.clone().detach().to(device)

Expand All @@ -26,6 +36,7 @@ def test_fp8_triton_hp_tensor_to_float8_dynamic():
y_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
algo=algo,
)

def allclose_fp8(tensor1, tensor2, atol=1e-3, rtol=1e-3):
Expand Down

0 comments on commit 567cb46

Please sign in to comment.