Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix linter errors #1525

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 32 additions & 29 deletions torchao/prototype/float8nocompile/float8nocompile_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
from torch.utils.checkpoint import checkpoint

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_float8
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig

from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
ToFP8ColumnMajor,
ToFP8ColumnMajorT,
ToFP8RowAndColumnMajor,
ToFP8RowMajor,
ToFP8RowMajorT,
ToFP8RowMajorTAndNonT,
)
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
Expand All @@ -43,7 +40,9 @@ def __init__(self, *args, **kwargs):
"""
self.config = kwargs.pop("config")
self.kernel_algo = kwargs.pop("kernel_algo")
self.use_activation_checkpointing = kwargs.pop("use_activation_checkpointing", False)
self.use_activation_checkpointing = kwargs.pop(
"use_activation_checkpointing", False
)
super().__init__(*args, **kwargs)

self.linear_mm_config = LinearMMConfig(
Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(self, *args, **kwargs):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.use_activation_checkpointing:
output = checkpoint(
matmul_with_args_in_hp.apply,
matmul_with_args_in_hp.apply,
input,
self.weight,
self.config,
Expand All @@ -94,9 +93,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

@classmethod
def from_float(
cls,
mod,
config: Float8LinearConfig,
cls,
mod,
config: Float8LinearConfig,
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
use_activation_checkpointing: bool = False,
):
Expand Down Expand Up @@ -126,18 +125,22 @@ def from_float(
class matmul_with_args_in_hp(torch.autograd.Function):
@staticmethod
def forward(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
use_activation_checkpointing: bool,
):
if use_activation_checkpointing:
return matmul_with_args_in_hp._forward_with_ac(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo)
return matmul_with_args_in_hp._forward_with_ac(
ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo
)
else:
return matmul_with_args_in_hp._forward_no_ac(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo)
return matmul_with_args_in_hp._forward_no_ac(
ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo
)

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -148,12 +151,12 @@ def backward(ctx, grad_output):

@staticmethod
def _forward_no_ac(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
):
# reshape to be 2D for triton kernels
orig_input_shape = input_hp.shape
Expand Down Expand Up @@ -237,12 +240,12 @@ def _backward_no_ac(ctx, grad_output):

@staticmethod
def _forward_with_ac(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
config: Float8LinearConfig,
linear_mm_config: LinearMMConfig,
kernel_algo: KernelAlgorithm,
):
# reshape to be 2D for triton kernels
orig_input_shape = input_hp.shape
Expand Down Expand Up @@ -275,7 +278,7 @@ def _forward_with_ac(

# reshape back to expected dims
output = output.reshape(*orig_input_shape[:-1], output.shape[-1])
return output
return output

@staticmethod
def _backward_with_ac(ctx, grad_output):
Expand Down
83 changes: 50 additions & 33 deletions torchao/prototype/float8nocompile/float8nocompile_linear_test.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,84 @@
import pytest

import torch
from torch.autograd.function import FunctionCtx
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp
from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig
from torch.autograd import gradcheck
from torchao.prototype.float8nocompile.float8nocompile_linear import (
matmul_with_args_in_hp,
)
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
KernelAlgorithm,
)

from torchao.prototype.float8nocompile.float8nocompile_linear import matmul_with_args_in_hp
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import KernelAlgorithm

# unit test comparing the two implementations
@pytest.mark.parametrize(
"input_shape",
[(32, 16), (1,32,16), (2,32,16)],
[(32, 16), (1, 32, 16), (2, 32, 16)],
)
def test_matmul_with_args_in_hp(input_shape: tuple[int, int]):
assert torch.cuda.is_available()
device = "cuda"

# high precision inputs
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device, requires_grad=True)
input_bf16 = torch.randn(
input_shape, dtype=torch.bfloat16, device=device, requires_grad=True
)
x_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)
y_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)

# high precision weights
# nn.Linear stores weights in transposed form
weight_bf16 = torch.randn((32, input_bf16.shape[-1]), dtype=torch.bfloat16, device=device, requires_grad=True)
weight_bf16 = torch.randn(
(32, input_bf16.shape[-1]),
dtype=torch.bfloat16,
device=device,
requires_grad=True,
)
x_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)
y_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)

# default configs
config = Float8LinearConfig()
emulate = False
linear_mm_config = linear_mm_config = LinearMMConfig(
# output
ScaledMMConfig(
emulate,
config.gemm_config_output.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
emulate,
config.gemm_config_grad_input.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
config.gemm_config_grad_weight.use_fast_accum,
False,
config.pad_inner_dim,
),
)
# output
ScaledMMConfig(
emulate,
config.gemm_config_output.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
emulate,
config.gemm_config_grad_input.use_fast_accum,
False,
config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
config.gemm_config_grad_weight.use_fast_accum,
False,
config.pad_inner_dim,
),
)

# prod forward. expects transposed weight.
out_prod = manual_float8_matmul_with_args_in_hp.apply(x_input_bf16, x_weight_bf16.t(), linear_mm_config, config)
out_prod = manual_float8_matmul_with_args_in_hp.apply(
x_input_bf16, x_weight_bf16.t(), linear_mm_config, config
)

# prototype forward. expects non-transposed weight
out_prototype = matmul_with_args_in_hp.apply(y_input_bf16, y_weight_bf16, config, linear_mm_config, KernelAlgorithm.ATOMIC_MAX)
out_prototype = matmul_with_args_in_hp.apply(
y_input_bf16,
y_weight_bf16,
config,
linear_mm_config,
KernelAlgorithm.ATOMIC_MAX,
)

# compare
assert torch.allclose(out_prod, out_prototype, atol=1e-3, rtol=1e-3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch.nn as nn

from torchao.float8.config import Float8LinearConfig
from torchao.float8.float8_linear_utils import swap_linear_layers
from torchao.prototype.float8nocompile.float8nocompile_linear import (
Float8LinearNoCompile,
Expand Down Expand Up @@ -45,9 +46,9 @@ def convert_to_float8_nocompile_training(
config = Float8LinearConfig()

from_float = lambda m: Float8LinearNoCompile.from_float(
m,
config=config,
kernel_algo=kernel_algo,
m,
config=config,
kernel_algo=kernel_algo,
use_activation_checkpointing=use_activation_checkpointing,
)
return swap_linear_layers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,15 @@

import torch

from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
_ToFloat8ConstrFunc,
)

from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
KernelAlgorithm,
hp_to_fp8_col_major,
hp_to_fp8_col_major_t,
hp_to_fp8_row_and_col_major,
hp_to_fp8_row_major,
hp_to_fp8_row_major_t,
hp_to_fp8_row_major_t_and_non_t,
KernelAlgorithm,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
hp_to_fp8_row_major,
hp_to_fp8_row_major_t,
hp_to_fp8_row_major_t_and_non_t,
KernelAlgorithm,
)


Expand Down
16 changes: 8 additions & 8 deletions torchao/prototype/float8nocompile/test/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ def model2():
return TestModel()


@pytest.mark.parametrize(
"input_shape", [(16,32), (1,16,32), (2,16,32)]
)
@pytest.mark.parametrize(
"use_activation_checkpointing", [True, False]
)
def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int], use_activation_checkpointing: bool):
@pytest.mark.parametrize("input_shape", [(16, 32), (1, 16, 32), (2, 16, 32)])
@pytest.mark.parametrize("use_activation_checkpointing", [True, False])
def test_model_weights_and_gradients(
model1, model2, input_shape: tuple[int, int], use_activation_checkpointing: bool
):
assert torch.cuda.is_available()
device = torch.device("cuda")

Expand All @@ -51,7 +49,9 @@ def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int

# compare production float8 linear conversion with no-compile version
convert_to_float8_training(model2)
convert_to_float8_nocompile_training(model1, use_activation_checkpointing=use_activation_checkpointing)
convert_to_float8_nocompile_training(
model1, use_activation_checkpointing=use_activation_checkpointing
)

input_tensor = torch.randn(
*input_shape, requires_grad=True, dtype=torch.bfloat16, device=device
Expand Down
Loading