From 886caa8adf4b6cc9a32b2bd431d7ca8d24918372 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 7 Jan 2025 17:55:29 -0800 Subject: [PATCH] float8nocompile: add e2e fsdp test ghstack-source-id: da38b4a141de7dfee4cec9132967ad76d7d6dc20 ghstack-comment-id: 2576459235 Pull Request resolved: https://github.com/pytorch/ao/pull/1523 --- torchao/prototype/float8nocompile/.gitignore | 3 - .../float8nocompile/test/fsdp_test.py | 97 +++++++++++++++++++ 2 files changed, 97 insertions(+), 3 deletions(-) delete mode 100644 torchao/prototype/float8nocompile/.gitignore create mode 100644 torchao/prototype/float8nocompile/test/fsdp_test.py diff --git a/torchao/prototype/float8nocompile/.gitignore b/torchao/prototype/float8nocompile/.gitignore deleted file mode 100644 index 38e0f6f87e..0000000000 --- a/torchao/prototype/float8nocompile/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -kernels/autogen/ -test/activation_checkpoint_test.py -test/distributed_test.py diff --git a/torchao/prototype/float8nocompile/test/fsdp_test.py b/torchao/prototype/float8nocompile/test/fsdp_test.py new file mode 100644 index 0000000000..20e422e3da --- /dev/null +++ b/torchao/prototype/float8nocompile/test/fsdp_test.py @@ -0,0 +1,97 @@ +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test/fsdp_test.py +# +####################################################################### +import os + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( + convert_to_float8_nocompile_training, +) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + +if not TORCH_VERSION_AT_LEAST_2_5: + raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") + + +class TestModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(2048, 4096, bias=False), + nn.Linear(4096, 16, bias=False), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + +def setup_distributed(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +@pytest.fixture +def model1(): + torch.manual_seed(0) + return TestModel() + + +@pytest.fixture +def model2(): + torch.manual_seed(0) + return TestModel() + + +def test_model_weights_and_gradients(model1, model2): + assert torch.cuda.is_available() + device = torch.device("cuda") + + setup_distributed() + + model1 = model1.to(torch.bfloat16).to(device) + model2 = model2.to(torch.bfloat16).to(device) + + # compare production float8 linear conversion with no-compile version + convert_to_float8_training(model2) + convert_to_float8_nocompile_training(model1) + + # distributed training with FSDP + model1 = FSDP(model1) + model2 = FSDP(model2) + + input_tensor = torch.randn( + 16, 2048, requires_grad=True, dtype=torch.bfloat16, device=device + ) + input_copy1 = input_tensor.clone().detach().requires_grad_(True) + input_copy2 = input_tensor.clone().detach().requires_grad_(True) + + loss_fn = nn.MSELoss() + + output1 = model1(input_copy1) + output2 = model2(input_copy2) + + loss1 = loss_fn(output1, torch.zeros_like(output1)) + loss2 = loss_fn(output2, torch.zeros_like(output2)) + + loss1.backward() + loss2.backward() + + dist.destroy_process_group() + + # compare the outputs, weight gradients, and input gradients + assert torch.allclose(output1, output2, atol=0, rtol=0) + assert torch.allclose(input_copy1.grad, input_copy2.grad, atol=0, rtol=0) + for param1, param2 in zip(model1.parameters(), model2.parameters()): + assert torch.allclose(param1.grad, param2.grad, atol=0, rtol=0)