You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the context of VISSL, we are training a SWAV model with FSDP and then want to evaluate this model using linear evaluation, again using FSDP to wrap our model.
In the context of linear evaluation, only the linear layer on top of the "trunk" will be learned: the representations of the trunk, fed to the linear layer, will stay constant, and so we set requires_grad=False on the parameters of the trunk.
While this works with DDP, this leads to an assert in the case of FSDP:
expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.IDLE
To help you dig into this issue, I have reproduced a small test case which demonstrate this issue:
run this script with -f to enable FSDP (otherwise, we use DDP)
run this script with -l (for linear) to freeze the trunk
You will get the assert if you enable both -f and -l and everything will work fine otherwise.
import argparse
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel
class Model(nn.Module):
def __init__(self):
super().__init__()
self.trunk = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
)
self.head = nn.Linear(64, 10)
def forward(self, x):
return self.head(self.trunk(x))
def create_model(with_fsdp: bool):
model = Model()
if with_fsdp:
model.trunk = FSDP(model.trunk)
model.head = FSDP(model.head)
return model
def _distributed_worker(gpu_id: int, with_fsdp: bool, with_linear: bool):
torch.cuda.set_device(gpu_id)
dist.init_process_group(
backend="nccl", init_method="tcp://127.0.0.1:9099", world_size=2, rank=gpu_id
)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
model = create_model(with_fsdp)
model = model.cuda()
if with_linear:
for name, param in model.named_parameters():
if "trunk" in name:
param.requires_grad = False
if with_fsdp:
model = FSDP(model)
else:
model = DistributedDataParallel(model, device_ids=[gpu_id])
if gpu_id == 0:
print(model)
target = torch.LongTensor([0, 1]).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for iteration in range(3):
out = model(batch)
fake_loss = criterion(out, target)
print("Loss", iteration, ":", fake_loss.item())
optimizer.zero_grad()
fake_loss.backward()
optimizer.step()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--fsdp", action="store_const", const=True, default=False)
parser.add_argument("-l", "--linear", action="store_const", const=True, default=False)
args = parser.parse_args()
mp.spawn(_distributed_worker, (args.fsdp, args.linear), nprocs=2)
I use pytorch 1.6.0+cu102, with the following modules enabled: module load cuda/11.0 cudnn/v8.0.3.33-cuda.11.0 NCCL/2.8.3-1-cuda.11.0
- an assert is changed to catch this case correctly
- unit test added (based on Quentin's test code) for this case and
compare DDP and FSDP
fixes: #610
* FSDP: fixing training with freezing weights
- an assert is changed to catch this case correctly
- unit test added (based on Quentin's test code) for this case and
compare DDP and FSDP
fixes: #610
* added test file to list 1
* Use better and simpler code as suggested by Myle
* testing both methods of freezing as well
Co-authored-by: Min Xu <min.xu@acm.org>
❓ Support for frozen parameters
In the context of VISSL, we are training a SWAV model with FSDP and then want to evaluate this model using linear evaluation, again using FSDP to wrap our model.
In the context of linear evaluation, only the linear layer on top of the "trunk" will be learned: the representations of the trunk, fed to the linear layer, will stay constant, and so we set
requires_grad=False
on the parameters of the trunk.While this works with DDP, this leads to an assert in the case of FSDP:
To help you dig into this issue, I have reproduced a small test case which demonstrate this issue:
-f
to enable FSDP (otherwise, we use DDP)-l
(for linear) to freeze the trunkYou will get the assert if you enable both
-f
and-l
and everything will work fine otherwise.I use pytorch 1.6.0+cu102, with the following modules enabled:
module load cuda/11.0 cudnn/v8.0.3.33-cuda.11.0 NCCL/2.8.3-1-cuda.11.0
Don't hesitate to reach me for more details.
CC: @prigoyal @min-xu-ai @myleott
The text was updated successfully, but these errors were encountered: