Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8nocompile] add e2e fsdp test #1523

Merged
merged 93 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
f85618f
Update
danielvegamyhre Jan 3, 2025
c603139
Update
danielvegamyhre Jan 3, 2025
9b42e69
Update
danielvegamyhre Jan 3, 2025
fc301fd
Update
danielvegamyhre Jan 3, 2025
a69fc66
Update
danielvegamyhre Jan 3, 2025
1d2ee55
Update
danielvegamyhre Jan 3, 2025
5870160
Update
danielvegamyhre Jan 3, 2025
36d8d17
Update
danielvegamyhre Jan 3, 2025
7e526fd
Update
danielvegamyhre Jan 7, 2025
2bbdf88
Update
danielvegamyhre Jan 7, 2025
58be437
Update
danielvegamyhre Jan 7, 2025
25298cb
Update
danielvegamyhre Jan 7, 2025
89c6b53
Update
danielvegamyhre Jan 7, 2025
0808acf
Update
danielvegamyhre Jan 7, 2025
3cc35df
Update
danielvegamyhre Jan 7, 2025
ff6dad0
Update
danielvegamyhre Jan 7, 2025
ddf1efc
Update
danielvegamyhre Jan 7, 2025
0536cb8
Update
danielvegamyhre Jan 7, 2025
10830d8
Update
danielvegamyhre Jan 7, 2025
5a47687
Update
danielvegamyhre Jan 7, 2025
8d52227
Update
danielvegamyhre Jan 7, 2025
f485529
Update
danielvegamyhre Jan 7, 2025
ad6d97b
Update
danielvegamyhre Jan 7, 2025
a6becf8
Update
danielvegamyhre Jan 7, 2025
5714e99
Update
danielvegamyhre Jan 7, 2025
7ccdd26
Update
danielvegamyhre Jan 7, 2025
6c97b63
Update
danielvegamyhre Jan 7, 2025
1cb1fec
Update
danielvegamyhre Jan 7, 2025
a8a8f3c
Update
danielvegamyhre Jan 7, 2025
23266fb
Update
danielvegamyhre Jan 7, 2025
99fab5a
Update
danielvegamyhre Jan 7, 2025
879d61f
Update
danielvegamyhre Jan 7, 2025
8860f93
Update
danielvegamyhre Jan 7, 2025
d8b2451
Update
danielvegamyhre Jan 7, 2025
6666e54
Update
danielvegamyhre Jan 7, 2025
4aadedf
Update
danielvegamyhre Jan 8, 2025
1e9a150
Update
danielvegamyhre Jan 8, 2025
6db778a
Update
danielvegamyhre Jan 8, 2025
f585e44
Update
danielvegamyhre Jan 8, 2025
e030671
Update
danielvegamyhre Jan 8, 2025
2145e47
Update
danielvegamyhre Jan 8, 2025
e11918d
Update
danielvegamyhre Jan 8, 2025
f65e981
Update
danielvegamyhre Jan 8, 2025
c0da780
Update
danielvegamyhre Jan 8, 2025
754c6bf
Update
danielvegamyhre Jan 8, 2025
1982ac0
Update
danielvegamyhre Jan 8, 2025
88797b3
Update
danielvegamyhre Jan 8, 2025
49373f1
Update
danielvegamyhre Jan 8, 2025
e459d25
Update
danielvegamyhre Jan 8, 2025
ff6b91e
Update
danielvegamyhre Jan 8, 2025
3eb406f
Update
danielvegamyhre Jan 8, 2025
01aa756
Update
danielvegamyhre Jan 8, 2025
89c0d5a
Update
danielvegamyhre Jan 8, 2025
c78a574
Update
danielvegamyhre Jan 8, 2025
e5c69e7
Update
danielvegamyhre Jan 8, 2025
778901d
Update
danielvegamyhre Jan 8, 2025
d29176e
Update
danielvegamyhre Jan 8, 2025
01eedbf
Update
danielvegamyhre Jan 8, 2025
74286fe
Update
danielvegamyhre Jan 8, 2025
a356ac5
Update
danielvegamyhre Jan 8, 2025
54a3213
Update
danielvegamyhre Jan 8, 2025
f0bca8c
Update
danielvegamyhre Jan 8, 2025
7ee060a
Update
danielvegamyhre Jan 8, 2025
84cc74b
Update
danielvegamyhre Jan 8, 2025
72fdc56
Update
danielvegamyhre Jan 8, 2025
aa50a54
Update
danielvegamyhre Jan 8, 2025
2600ee4
Update
danielvegamyhre Jan 8, 2025
d5666b2
Update
danielvegamyhre Jan 8, 2025
3bf5ade
Update
danielvegamyhre Jan 8, 2025
2db5deb
Update
danielvegamyhre Jan 8, 2025
7a44bd9
Update
danielvegamyhre Jan 8, 2025
7184b5b
Update
danielvegamyhre Jan 8, 2025
077e8bd
Update
danielvegamyhre Jan 8, 2025
2e13197
Update
danielvegamyhre Jan 8, 2025
e665139
Update
danielvegamyhre Jan 8, 2025
693159f
Update
danielvegamyhre Jan 8, 2025
5081694
Update
danielvegamyhre Jan 8, 2025
96ee5ee
Update
danielvegamyhre Jan 8, 2025
ec6aa9e
Update
danielvegamyhre Jan 8, 2025
2a139aa
Update
danielvegamyhre Jan 8, 2025
79e25fd
Update
danielvegamyhre Jan 9, 2025
a0544db
Update
danielvegamyhre Jan 9, 2025
26b55cf
Update
danielvegamyhre Jan 13, 2025
f90d16b
Update
danielvegamyhre Jan 13, 2025
2f84f3d
Update
danielvegamyhre Jan 13, 2025
44201b1
Update
danielvegamyhre Jan 13, 2025
dfa3911
Update
danielvegamyhre Jan 14, 2025
e015b03
Update
danielvegamyhre Jan 14, 2025
89e096f
Update
danielvegamyhre Jan 16, 2025
85a911c
Update
danielvegamyhre Jan 16, 2025
79a4d02
Update
danielvegamyhre Jan 16, 2025
38c8932
Update
danielvegamyhre Jan 16, 2025
32fdee5
Update
danielvegamyhre Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions torchao/prototype/float8nocompile/.gitignore

This file was deleted.

97 changes: 97 additions & 0 deletions torchao/prototype/float8nocompile/test/fsdp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
######################################################################
#
# To run these unit tests, use the following command:
#
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test/fsdp_test.py
#
#######################################################################
import os

import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard

from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
convert_to_float8_nocompile_training,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")


class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(2048, 4096, bias=False),
nn.Linear(4096, 16, bias=False),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)


def setup_distributed():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)


@pytest.fixture
def model1():
torch.manual_seed(0)
return TestModel()


@pytest.fixture
def model2():
torch.manual_seed(0)
return TestModel()


def test_model_weights_and_gradients(model1, model2):
assert torch.cuda.is_available()
device = torch.device("cuda")

setup_distributed()

model1 = model1.to(torch.bfloat16).to(device)
model2 = model2.to(torch.bfloat16).to(device)

# compare production float8 linear conversion with no-compile version
convert_to_float8_training(model2)
convert_to_float8_nocompile_training(model1)

# distributed training with FSDP2
fully_shard(model1)
fully_shard(model2)

input_tensor = torch.randn(
16, 2048, requires_grad=True, dtype=torch.bfloat16, device=device
)
input_copy1 = input_tensor.clone().detach().requires_grad_(True)
input_copy2 = input_tensor.clone().detach().requires_grad_(True)

loss_fn = nn.MSELoss()

output1 = model1(input_copy1)
output2 = model2(input_copy2)

loss1 = loss_fn(output1, torch.zeros_like(output1))
loss2 = loss_fn(output2, torch.zeros_like(output2))

loss1.backward()
loss2.backward()

# compare the outputs, weight gradients, and input gradients
assert torch.allclose(output1, output2, atol=0, rtol=0)
assert torch.allclose(input_copy1.grad, input_copy2.grad, atol=0, rtol=0)
for param1, param2 in zip(model1.parameters(), model2.parameters()):
assert torch.equal(param1.grad, param2.grad)

dist.destroy_process_group()
Loading