-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add handling for batch dim in float8nocompile (#1512)
- Loading branch information
1 parent
457c5b1
commit f86fda9
Showing
4 changed files
with
138 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
torchao/prototype/float8nocompile/float8nocompile_linear_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import pytest | ||
import torch | ||
|
||
from torchao.float8.config import Float8LinearConfig | ||
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp | ||
from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig | ||
from torchao.prototype.float8nocompile.float8nocompile_linear import ( | ||
matmul_with_args_in_hp, | ||
) | ||
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import ( | ||
KernelAlgorithm, | ||
) | ||
|
||
|
||
# unit test comparing the two implementations | ||
@pytest.mark.parametrize( | ||
"input_shape", | ||
[(32, 16), (1, 32, 16), (2, 32, 16)], | ||
) | ||
def test_matmul_with_args_in_hp(input_shape: tuple[int, int]): | ||
assert torch.cuda.is_available() | ||
device = "cuda" | ||
|
||
# high precision inputs | ||
input_bf16 = torch.randn( | ||
input_shape, dtype=torch.bfloat16, device=device, requires_grad=True | ||
) | ||
prod_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True) | ||
prototype_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True) | ||
|
||
# high precision weights | ||
# nn.Linear stores weights in transposed form | ||
weight_bf16 = torch.randn( | ||
(32, input_bf16.shape[-1]), | ||
dtype=torch.bfloat16, | ||
device=device, | ||
requires_grad=True, | ||
) | ||
prod_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True) | ||
prototype_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True) | ||
|
||
# default configs | ||
config = Float8LinearConfig() | ||
emulate = False | ||
linear_mm_config = linear_mm_config = LinearMMConfig( | ||
# output | ||
ScaledMMConfig( | ||
emulate, | ||
config.gemm_config_output.use_fast_accum, | ||
False, | ||
config.pad_inner_dim, | ||
), | ||
# grad_input | ||
ScaledMMConfig( | ||
emulate, | ||
config.gemm_config_grad_input.use_fast_accum, | ||
False, | ||
config.pad_inner_dim, | ||
), | ||
# grad_weight | ||
ScaledMMConfig( | ||
emulate, | ||
config.gemm_config_grad_weight.use_fast_accum, | ||
False, | ||
config.pad_inner_dim, | ||
), | ||
) | ||
|
||
# prod forward. expects transposed weight. | ||
out_prod = manual_float8_matmul_with_args_in_hp.apply( | ||
prod_input_bf16, prod_weight_bf16.t(), linear_mm_config, config | ||
) | ||
|
||
# prototype forward. expects non-transposed weight | ||
out_prototype = matmul_with_args_in_hp.apply( | ||
prototype_input_bf16, | ||
prototype_weight_bf16, | ||
config, | ||
linear_mm_config, | ||
KernelAlgorithm.ATOMIC_MAX, | ||
) | ||
|
||
# compare model outputs | ||
assert torch.allclose(out_prod, out_prototype, atol=0, rtol=0) | ||
|
||
out_prod.sum().backward() | ||
out_prototype.sum().backward() | ||
|
||
# compare input gradients | ||
assert torch.allclose( | ||
prod_input_bf16.grad, prototype_input_bf16.grad, atol=0, rtol=0 | ||
) | ||
|
||
# compare weight gradients | ||
assert torch.allclose( | ||
prod_weight_bf16.grad, prototype_weight_bf16.grad, atol=0, rtol=0 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters