From 070345d9676f3dd6cdd325e987764fc3c71ccaf5 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 8 Jan 2025 06:39:20 -0800 Subject: [PATCH] add fused transpose and non-transpose kernel and use it for grad output (#1497) --- torchao/prototype/float8nocompile/__init__.py | 0 .../float8nocompile/float8nocompile_linear.py | 25 ++- .../float8nocompile_scaling_utils.py | 36 +++- .../float8nocompile/kernels/__init__.py | 0 .../kernels/fp8_dynamic_tensorwise.py | 158 ++++++++++++++++++ .../kernels/fp8_dynamic_tensorwise_test.py | 75 +++++++++ 6 files changed, 275 insertions(+), 19 deletions(-) create mode 100644 torchao/prototype/float8nocompile/__init__.py create mode 100644 torchao/prototype/float8nocompile/kernels/__init__.py diff --git a/torchao/prototype/float8nocompile/__init__.py b/torchao/prototype/float8nocompile/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear.py b/torchao/prototype/float8nocompile/float8nocompile_linear.py index 1e50fe0bdc..37de7b852c 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear.py @@ -16,8 +16,7 @@ ToFP8ColumnMajor, ToFP8ColumnMajorT, ToFP8RowAndColumnMajor, - ToFP8RowMajor, - ToFP8RowMajorT, + ToFP8RowMajorTAndNonT, ) from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( KernelAlgorithm, @@ -138,12 +137,14 @@ def backward(ctx, grad_output): input_fp8_col_major, weight_hp = ctx.saved_tensors # cast grad output to float8_e5m2 for backward - grad_output_fp8_row_major = ToFP8RowMajor.apply( - grad_output, - ctx.config.cast_config_grad_output.target_dtype, - ctx.linear_mm_config, - GemmInputRole.GRAD_OUTPUT, - ctx.kernel_algo, + grad_output_fp8_row_major, grad_output_t_row_major = ( + ToFP8RowMajorTAndNonT.apply( + grad_output, + ctx.config.cast_config_grad_output.target_dtype, + ctx.linear_mm_config, + GemmInputRole.GRAD_OUTPUT, + ctx.kernel_algo, + ) ) # grad_input = grad_output @ weight @@ -159,12 +160,6 @@ def backward(ctx, grad_output): # grad_weight = grad_output_t @ input # apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output` # source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85 - grad_output_t_row_major = ToFP8RowMajorT.apply( - grad_output, - ctx.config.cast_config_grad_output.target_dtype, - ctx.linear_mm_config, - GemmInputRole.GRAD_OUTPUT, - ctx.kernel_algo, - ) grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major) + return grad_input, grad_weight, None, None, None diff --git a/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py b/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py index bade82c616..7b6a25e3f9 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py +++ b/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py @@ -10,10 +10,7 @@ import torch -from torchao.float8.float8_tensor import ( - GemmInputRole, - LinearMMConfig, -) +from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( KernelAlgorithm, hp_to_fp8_col_major, @@ -21,6 +18,7 @@ hp_to_fp8_row_and_col_major, hp_to_fp8_row_major, hp_to_fp8_row_major_t, + hp_to_fp8_row_major_t_and_non_t, ) @@ -172,3 +170,33 @@ def forward( @staticmethod def backward(ctx, g): return g, None, None, None, None + + +class ToFP8RowMajorTAndNonT(torch.autograd.Function): + """ + A differentiable conversion to fp8. + * forward: convert from high precision to float8 and produces both row-major (transposed) and row-major (non-transposed) outputs + * backward: pass the gradient without changes + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + float8_dtype: torch.dtype, + linear_mm_config: LinearMMConfig, + gemm_input_role: GemmInputRole, + kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, + ): + fp8_row_major, fp8_row_major_t = hp_to_fp8_row_major_t_and_non_t( + tensor, + float8_dtype, + linear_mm_config, + gemm_input_role, + algo=kernel_algo, + ) + return fp8_row_major, fp8_row_major_t + + @staticmethod + def backward(ctx, g): + return g, None, None, None, None diff --git a/torchao/prototype/float8nocompile/kernels/__init__.py b/torchao/prototype/float8nocompile/kernels/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index 7da49e20dd..4400a587c1 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -305,6 +305,82 @@ def _to_fp8_row_and_col_major( tl.store(col_major_out_ptr + col_major_offs, fp8_vals, mask=mask) +@triton.autotune( + configs=kernel_configs_2D, + key=["num_elements"], +) +@triton.jit +def _to_fp8_row_major_t_and_non_t( + input_ptr, + row_major_out_ptr, + row_major_t_out_ptr, + scale_ptr, + num_elements: int, + fp8_dtype_min: float, + fp8_dtype_max: float, + input_num_rows: int, + input_num_cols: int, + input_stride_row: int, + input_stride_col: int, + row_major_out_stride_row: int, + row_major_out_stride_col: int, + row_major_t_out_stride_row: int, + row_major_t_out_stride_col: int, + input_dtype: tl.constexpr, + output_dtype: tl.constexpr, + BLOCK_SIZE_ROWS: tl.constexpr, + BLOCK_SIZE_COLS: tl.constexpr, + EPS: tl.constexpr, +): + """ + Reads a row-major, high precision input tensor and writes 2 output tensors: + 1) fp8 row major tensor (transposed) + 2) fp8 row major tensor + """ + block_row_id = tl.program_id(axis=0) + block_col_id = tl.program_id(axis=1) + + # load scaling factor + scale = tl.load(scale_ptr).to(tl.float32) + + # load block of input tensor + block_row_start = block_row_id * BLOCK_SIZE_ROWS + block_col_start = block_col_id * BLOCK_SIZE_COLS + block_row_offs = block_row_start + tl.arange(0, BLOCK_SIZE_ROWS) + block_col_offs = block_col_start + tl.arange(0, BLOCK_SIZE_COLS) + input_offs = ( + block_row_offs[:, None] * input_stride_row + + block_col_offs[None, :] * input_stride_col + ) + mask = (block_row_offs[:, None] < input_num_rows) & ( + block_col_offs[None, :] < input_num_cols + ) + vals = tl.load(input_ptr + input_offs, mask=mask).to(input_dtype) + + # perform conversion + vals = vals * scale + fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype) + + # write row-major output + row_major_offs = ( + block_row_offs[:, None] * row_major_out_stride_row + + block_col_offs[None, :] * row_major_out_stride_col + ) + tl.store(row_major_out_ptr + row_major_offs, fp8_vals, mask=mask) + + # write tranposed row-major output + row_major_t_num_rows = input_num_cols + row_major_t_num_cols = input_num_rows + row_major_t_offs = ( + block_col_offs[:, None] * row_major_t_out_stride_row + + block_row_offs[None, :] * row_major_t_out_stride_col + ) + mask = (block_row_offs[:, None] < row_major_t_num_rows) & ( + block_col_offs[None, :] < row_major_t_num_cols + ) + tl.store(row_major_t_out_ptr + row_major_t_offs, fp8_vals.trans(1, 0), mask=mask) + + @triton.autotune(configs=kernel_configs_1D, key=["num_elements"]) @triton.jit def _amax_atomic( @@ -701,6 +777,88 @@ def hp_to_fp8_row_and_col_major( return fp8_tensor_row_major, fp8_tensor_col_major +def hp_to_fp8_row_major_t_and_non_t( + hp_tensor: torch.Tensor, + fp8_dtype: torch.dtype, + linear_mm_config: LinearMMConfig, + gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, +) -> Float8Tensor: + assert hp_tensor.is_contiguous(), "input tensor must be contiguous" + + tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] + tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype] + + fp8_dtype_min = torch.finfo(fp8_dtype).min + fp8_dtype_max = torch.finfo(fp8_dtype).max + + # compute scaling factor for tensor + scale = _hp_tensor_to_scale( + hp_tensor, + tl_input_dtype, + fp8_dtype_max, + algo, + ) + + # perform fp8 conversion + input_num_rows, input_num_cols = hp_tensor.shape + transposed_num_rows, transposed_num_cols = input_num_cols, input_num_rows + num_elements = hp_tensor.numel() + + # preallocate necessary output tensors + fp8_output_row_major = torch.empty( + (input_num_rows, input_num_cols), dtype=fp8_dtype, device=hp_tensor.device + ) + fp8_output_row_major_t = torch.empty( + (transposed_num_rows, transposed_num_cols), + dtype=fp8_dtype, + device=hp_tensor.device, + ) + + # launch triton kernel to perform conversion + grid = lambda meta: ( + triton.cdiv(input_num_rows, meta["BLOCK_SIZE_ROWS"]), + triton.cdiv(input_num_cols, meta["BLOCK_SIZE_COLS"]), + ) + _to_fp8_row_major_t_and_non_t[grid]( + hp_tensor, + fp8_output_row_major, + fp8_output_row_major_t, + scale, + num_elements, + fp8_dtype_min, + fp8_dtype_max, + input_num_rows, + input_num_cols, + hp_tensor.stride(0), + hp_tensor.stride(1), + fp8_output_row_major.stride(0), + fp8_output_row_major.stride(1), + fp8_output_row_major_t.stride(0), + fp8_output_row_major_t.stride(1), + input_dtype=tl_input_dtype, + output_dtype=tl_output_dtype, + EPS=EPS, + ) + + # wrap outputs in Float8Tensors + fp8_tensor_row_major = Float8Tensor( + fp8_output_row_major, + scale, + orig_dtype=hp_tensor.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) + fp8_tensor_row_major_t = Float8Tensor( + fp8_output_row_major_t, + scale, + orig_dtype=hp_tensor.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) + return fp8_tensor_row_major, fp8_tensor_row_major_t + + def _hp_tensor_to_scale( hp_tensor: torch.Tensor, tl_input_dtype: tl.core.dtype, diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py index df09728ab3..f0dd78bc01 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py @@ -11,6 +11,7 @@ hp_to_fp8_row_and_col_major, hp_to_fp8_row_major, hp_to_fp8_row_major_t, + hp_to_fp8_row_major_t_and_non_t, ) @@ -335,3 +336,77 @@ def test_fp8_hp_to_fp8_row_and_col_major( torch.float8_e4m3fn, LinearMMConfig(), ) + + +@pytest.mark.parametrize( + "algo", + [KernelAlgorithm.REDUCTION, KernelAlgorithm.ATOMIC_MAX], +) +@pytest.mark.parametrize( + "input_shape", + [(2, 4), (32, 16), (512, 512)], +) +def test_fp8_hp_to_fp8_row_major_t_and_non_t( + input_shape: tuple[int, int], algo: KernelAlgorithm +): + assert torch.cuda.is_available() + device = "cuda" + input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + x_bf16 = input_bf16.clone().detach().to(device) + y_bf16 = input_bf16.clone().detach().to(device) + + # production implementation + x_fp8_row_major = hp_tensor_to_float8_dynamic( + x_bf16, + torch.float8_e4m3fn, + LinearMMConfig(), + ) + x_fp8_row_major_t = x_fp8_row_major.t().contiguous() + + # float8nocompile triton implementation + y_fp8_row_major, y_fp8_row_major_t = hp_to_fp8_row_major_t_and_non_t( + y_bf16, + torch.float8_e4m3fn, + LinearMMConfig(), + algo=algo, + ) + + # check scales + assert torch.eq(x_fp8_row_major._scale, y_fp8_row_major._scale) + assert torch.eq(x_fp8_row_major_t._scale, y_fp8_row_major_t._scale) + + # check data + assert torch.all(torch.eq(x_fp8_row_major._data, y_fp8_row_major._data)) + assert torch.all(torch.eq(x_fp8_row_major_t._data, y_fp8_row_major_t._data)) + + # check shapes + assert x_fp8_row_major.shape == y_fp8_row_major.shape + assert x_fp8_row_major_t.shape == y_fp8_row_major_t.shape + + # check strides + assert x_fp8_row_major.stride() == y_fp8_row_major.stride() + assert x_fp8_row_major_t.stride() == y_fp8_row_major_t.stride() + + # check memory layout + assert is_row_major(x_fp8_row_major.stride()) + assert is_row_major(y_fp8_row_major.stride()) + assert is_row_major(x_fp8_row_major_t.stride()) + assert is_row_major(y_fp8_row_major_t.stride()) + + # check underlying memory layout + assert ( + x_fp8_row_major._data.storage().tolist() + == y_fp8_row_major._data.storage().tolist() + ) + assert ( + x_fp8_row_major_t._data.storage().tolist() + == y_fp8_row_major_t._data.storage().tolist() + ) + + # assert that error is raised when input tensor is not contiguous + with pytest.raises(AssertionError, match="tensor must be contiguous"): + hp_to_fp8_row_major_t_and_non_t( + y_bf16.t(), # transpose so tensor memory layout is no longer contiguous + torch.float8_e4m3fn, + LinearMMConfig(), + )