Skip to content

Commit

Permalink
add handling for batch dim in float8nocompile (#1512)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre authored Jan 8, 2025
1 parent 457c5b1 commit f86fda9
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 6 deletions.
33 changes: 30 additions & 3 deletions torchao/prototype/float8nocompile/float8nocompile_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,20 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return output

@classmethod
def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX):
def from_float(
cls,
mod,
config: Float8LinearConfig, # only default config is supported, non-defaults silently ignored
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Args:
mod (torch.nn.Linear): nn.Linear to convert
config (Optional[Float8LinearConfig]): configuration for conversion to float8
config (Optional[Float8LinearConfig]): configuration for conversion to float8 (note: only
default config is supported, non-defaults silently ignored)
"""
config = Float8LinearConfig()
with torch.device("meta"):
new_mod = cls(
mod.in_features,
Expand All @@ -107,6 +112,10 @@ def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_M
class matmul_with_args_in_hp(torch.autograd.Function):
@staticmethod
def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
# reshape to be 2D for triton kernels
orig_input_shape = input_hp.shape
input_hp = input_hp.reshape(-1, input_hp.shape[-1])

# output = input @ weight_t
input_fp8_row_major, input_fp8_col_major = ToFP8RowAndColumnMajor.apply(
input_hp,
Expand All @@ -130,12 +139,24 @@ def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
ctx.linear_mm_config = linear_mm_config
ctx.kernel_algo = kernel_algo

# reshape back to expected dims
output = output.reshape(*orig_input_shape[:-1], output.shape[-1])
return output

@staticmethod
def backward(ctx, grad_output):
# grad_output may not be contiguous in cases like:
# output.sum().backward() where grad is all 1s, so the (M,N) view of the scalar "1"
# results in a non-contiguous tensor with stride (0,0).
if not grad_output.is_contiguous():
grad_output = grad_output.contiguous()

input_fp8_col_major, weight_hp = ctx.saved_tensors

# reshsape to be 2D for triton kernels
orig_grad_output_shape = grad_output.shape
grad_output = grad_output.reshape(-1, grad_output.shape[-1])

# cast grad output to float8_e5m2 for backward
grad_output_fp8_row_major, grad_output_t_row_major = (
ToFP8RowMajorTAndNonT.apply(
Expand All @@ -162,4 +183,10 @@ def backward(ctx, grad_output):
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)

# reshape grad input to match original shape
grad_input = grad_input.reshape(
*orig_grad_output_shape[:-1], grad_input.shape[-1]
)

# grad input shape
return grad_input, grad_weight, None, None, None
97 changes: 97 additions & 0 deletions torchao/prototype/float8nocompile/float8nocompile_linear_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pytest
import torch

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp
from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig
from torchao.prototype.float8nocompile.float8nocompile_linear import (
matmul_with_args_in_hp,
)
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
KernelAlgorithm,
)


# unit test comparing the two implementations
@pytest.mark.parametrize(
"input_shape",
[(32, 16), (1, 32, 16), (2, 32, 16)],
)
def test_matmul_with_args_in_hp(input_shape: tuple[int, int]):
assert torch.cuda.is_available()
device = "cuda"

# high precision inputs
input_bf16 = torch.randn(
input_shape, dtype=torch.bfloat16, device=device, requires_grad=True
)
prod_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)
prototype_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)

# high precision weights
# nn.Linear stores weights in transposed form
weight_bf16 = torch.randn(
(32, input_bf16.shape[-1]),
dtype=torch.bfloat16,
device=device,
requires_grad=True,
)
prod_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)
prototype_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)

# default configs
config = Float8LinearConfig()
emulate = False
linear_mm_config = linear_mm_config = LinearMMConfig(
# output
ScaledMMConfig(
emulate,
config.gemm_config_output.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
emulate,
config.gemm_config_grad_input.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
config.gemm_config_grad_weight.use_fast_accum,
False,
config.pad_inner_dim,
),
)

# prod forward. expects transposed weight.
out_prod = manual_float8_matmul_with_args_in_hp.apply(
prod_input_bf16, prod_weight_bf16.t(), linear_mm_config, config
)

# prototype forward. expects non-transposed weight
out_prototype = matmul_with_args_in_hp.apply(
prototype_input_bf16,
prototype_weight_bf16,
config,
linear_mm_config,
KernelAlgorithm.ATOMIC_MAX,
)

# compare model outputs
assert torch.allclose(out_prod, out_prototype, atol=0, rtol=0)

out_prod.sum().backward()
out_prototype.sum().backward()

# compare input gradients
assert torch.allclose(
prod_input_bf16.grad, prototype_input_bf16.grad, atol=0, rtol=0
)

# compare weight gradients
assert torch.allclose(
prod_weight_bf16.grad, prototype_weight_bf16.grad, atol=0, rtol=0
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch.nn as nn

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear_utils import swap_linear_layers
from torchao.prototype.float8nocompile.float8nocompile_linear import (
Float8LinearNoCompile,
Expand All @@ -23,6 +24,7 @@
def convert_to_float8_nocompile_training(
module: nn.Module,
*,
config: Float8LinearConfig = None,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
) -> nn.Module:
Expand All @@ -39,7 +41,12 @@ def convert_to_float8_nocompile_training(
Returns:
nn.Module: The modified module with swapped linear layers.
"""
from_float = lambda m: Float8LinearNoCompile.from_float(m, kernel_algo=kernel_algo)
if config is None:
config = Float8LinearConfig()

from_float = lambda m: Float8LinearNoCompile.from_float(
m, config=config, kernel_algo=kernel_algo
)
return swap_linear_layers(
module,
from_float,
Expand Down
5 changes: 3 additions & 2 deletions torchao/prototype/float8nocompile/test/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def model2():
return TestModel()


def test_model_weights_and_gradients(model1, model2):
@pytest.mark.parametrize("input_shape", [(16, 32), (1, 16, 32), (2, 16, 32)])
def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int]):
assert torch.cuda.is_available()
device = torch.device("cuda")

Expand All @@ -48,7 +49,7 @@ def test_model_weights_and_gradients(model1, model2):
convert_to_float8_nocompile_training(model1)

input_tensor = torch.randn(
16, 32, requires_grad=True, dtype=torch.bfloat16, device=device
*input_shape, requires_grad=True, dtype=torch.bfloat16, device=device
)
input_copy1 = input_tensor.clone().detach().requires_grad_(True)
input_copy2 = input_tensor.clone().detach().requires_grad_(True)
Expand Down

0 comments on commit f86fda9

Please sign in to comment.