You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This bug is a follow up to issue #610. I encountered a new test case in which a scenario of linear evaluation fails.
Context (same as previous issue)
In the context of VISSL, we are training a SWAV model with FSDP and then want to evaluate this model using linear evaluation, again using FSDP to wrap our model.
In the context of linear evaluation, only the linear layer on top of the "trunk" will be learned: the representations of the trunk, fed to the linear layer, will stay constant, and so we set requires_grad=False on the parameters of the trunk.
New symptom
The initial test case raised in #610 is green, but I discovered that nesting FSDP blocks in the trunk (the part that is frozen) is leading to a new assertion being raised:
ERROR: expected to be in states [<TrainingState.BACKWARD_PRE: 3>] but current state is TrainingState.IDLE
Command
To Reproduce
This script should help you reproduce the issue: simply run the following script with the options -f -l where:
-f means "enable FSDP"
-l means "freeze the trunk"
Important notes:
The issue only appears if both -f and -l are used together
The issue disappear if I comment the line self.trunk = FSDP(self.trunk) and so it seems that the issue is linked to the nesting of FSDP blocks when layers are frozen
import argparse
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel
class Model(nn.Module):
def __init__(self, with_fsdp: bool):
super().__init__()
self.trunk = nn.Sequential(
self._create_block(3, 64, with_fsdp),
self._create_block(64, 64, with_fsdp),
)
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(64, 10),
)
if with_fsdp:
self.trunk = FSDP(self.trunk)
self.head = FSDP(self.head)
def forward(self, x):
return self.head(self.trunk(x))
def _create_block(self, in_channels: int, out_channels: int, with_fsdp: bool):
block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3),
nn.ReLU(inplace=True),
)
if with_fsdp:
block = FSDP(block)
return block
def create_model(with_fsdp: bool):
return Model(with_fsdp)
def _distributed_worker(gpu_id: int, with_fsdp: bool, with_linear: bool):
torch.cuda.set_device(gpu_id)
dist.init_process_group(
backend="nccl", init_method="tcp://127.0.0.1:9099", world_size=2, rank=gpu_id
)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
model = create_model(with_fsdp)
model = model.cuda()
if with_linear:
for name, param in model.named_parameters():
if "trunk" in name:
param.requires_grad = False
if with_fsdp:
model = FSDP(model)
else:
model = DistributedDataParallel(model, device_ids=[gpu_id])
if gpu_id == 0:
print(model)
target = torch.LongTensor([0, 1]).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for iteration in range(3):
out = model(batch)
fake_loss = criterion(out, target)
if gpu_id == 0:
print("Loss", iteration, ":", fake_loss.item())
optimizer.zero_grad()
fake_loss.backward()
optimizer.step()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--fsdp", action="store_const", const=True, default=False)
parser.add_argument(
"-l", "--linear", action="store_const", const=True, default=False
)
args = parser.parse_args()
mp.spawn(_distributed_worker, (args.fsdp, args.linear), nprocs=2)
Environment
I use pytorch 1.6.0+cu102, with the following modules enabled: module load cuda/11.0 cudnn/v8.0.3.33-cuda.11.0 NCCL/2.8.3-1-cuda.11.0.
Summary:
Corrections:
- upgrade to fairscale 0.3.6 to remove assertion errors corrected since 0.3.5
- make sure that integration test errors are raised in the CI
- disable the failing integration tests due to issue facebookresearch/fairscale#643
- minor refactor of the regnet_fsdp to use the fsdp_wrapper directly
Pull Request resolved: fairinternal/ssl_scaling#129
Reviewed By: prigoyal
Differential Revision: D28120296
Pulled By: QuentinDuval
fbshipit-source-id: 0c23b639a9d7103aafa5f25b14c05042269df061
* [fix] better assert and better test for frozen weights
- the precise condition should have been check m.parameters(), not
m.params.
- fixes#643
* add changelog
* use enum is so much better
Co-authored-by: Min Xu <min.xu@acm.org>
🐛 Bug: Frozen weight support in FSDP
This bug is a follow up to issue #610. I encountered a new test case in which a scenario of linear evaluation fails.
Context (same as previous issue)
In the context of VISSL, we are training a SWAV model with FSDP and then want to evaluate this model using linear evaluation, again using FSDP to wrap our model.
In the context of linear evaluation, only the linear layer on top of the "trunk" will be learned: the representations of the trunk, fed to the linear layer, will stay constant, and so we set
requires_grad=False
on the parameters of the trunk.New symptom
The initial test case raised in #610 is green, but I discovered that nesting FSDP blocks in the trunk (the part that is frozen) is leading to a new assertion being raised:
ERROR: expected to be in states [<TrainingState.BACKWARD_PRE: 3>] but current state is TrainingState.IDLE
Command
To Reproduce
This script should help you reproduce the issue: simply run the following script with the options
-f -l
where:-f
means "enable FSDP"-l
means "freeze the trunk"Important notes:
-f
and-l
are used togetherself.trunk = FSDP(self.trunk)
and so it seems that the issue is linked to the nesting of FSDP blocks when layers are frozenEnvironment
I use pytorch 1.6.0+cu102, with the following modules enabled:
module load cuda/11.0 cudnn/v8.0.3.33-cuda.11.0 NCCL/2.8.3-1-cuda.11.0
.My version of fairscale is 0.3.6.
Additional context
CC: @prigoyal @min-xu-ai
The text was updated successfully, but these errors were encountered: