Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Frozen weights support #643

Closed
QuentinDuval opened this issue Apr 30, 2021 · 0 comments · Fixed by #657
Closed

[FSDP] Frozen weights support #643

QuentinDuval opened this issue Apr 30, 2021 · 0 comments · Fixed by #657
Assignees
Labels
FSDP FullyShardedDataParallel (zero-3)

Comments

@QuentinDuval
Copy link
Contributor

QuentinDuval commented Apr 30, 2021

🐛 Bug: Frozen weight support in FSDP

This bug is a follow up to issue #610. I encountered a new test case in which a scenario of linear evaluation fails.

Context (same as previous issue)

In the context of VISSL, we are training a SWAV model with FSDP and then want to evaluate this model using linear evaluation, again using FSDP to wrap our model.

In the context of linear evaluation, only the linear layer on top of the "trunk" will be learned: the representations of the trunk, fed to the linear layer, will stay constant, and so we set requires_grad=False on the parameters of the trunk.

New symptom

The initial test case raised in #610 is green, but I discovered that nesting FSDP blocks in the trunk (the part that is frozen) is leading to a new assertion being raised:

ERROR: expected to be in states [<TrainingState.BACKWARD_PRE: 3>] but current state is TrainingState.IDLE

Command

To Reproduce

This script should help you reproduce the issue: simply run the following script with the options -f -l where:

  • -f means "enable FSDP"
  • -l means "freeze the trunk"

Important notes:

  • The issue only appears if both -f and -l are used together
  • The issue disappear if I comment the line self.trunk = FSDP(self.trunk) and so it seems that the issue is linked to the nesting of FSDP blocks when layers are frozen
import argparse

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel


class Model(nn.Module):
    def __init__(self, with_fsdp: bool):
        super().__init__()
        self.trunk = nn.Sequential(
            self._create_block(3, 64, with_fsdp),
            self._create_block(64, 64, with_fsdp),
        )
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
            nn.Linear(64, 10),
        )
        if with_fsdp:
            self.trunk = FSDP(self.trunk)
            self.head = FSDP(self.head)

    def forward(self, x):
        return self.head(self.trunk(x))

    def _create_block(self, in_channels: int, out_channels: int, with_fsdp: bool):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True),
        )
        if with_fsdp:
            block = FSDP(block)
        return block


def create_model(with_fsdp: bool):
    return Model(with_fsdp)


def _distributed_worker(gpu_id: int, with_fsdp: bool, with_linear: bool):
    torch.cuda.set_device(gpu_id)
    dist.init_process_group(
        backend="nccl", init_method="tcp://127.0.0.1:9099", world_size=2, rank=gpu_id
    )

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = create_model(with_fsdp)
    model = model.cuda()

    if with_linear:
        for name, param in model.named_parameters():
            if "trunk" in name:
                param.requires_grad = False

    if with_fsdp:
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

    target = torch.LongTensor([0, 1]).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        if gpu_id == 0:
            print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        optimizer.step()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--fsdp", action="store_const", const=True, default=False)
    parser.add_argument(
        "-l", "--linear", action="store_const", const=True, default=False
    )
    args = parser.parse_args()
    mp.spawn(_distributed_worker, (args.fsdp, args.linear), nprocs=2)

Environment

I use pytorch 1.6.0+cu102, with the following modules enabled: module load cuda/11.0 cudnn/v8.0.3.33-cuda.11.0 NCCL/2.8.3-1-cuda.11.0.

My version of fairscale is 0.3.6.

Additional context

CC: @prigoyal @min-xu-ai

facebook-github-bot pushed a commit to facebookresearch/vissl that referenced this issue Apr 30, 2021
Summary:
Corrections:
- upgrade to fairscale 0.3.6 to remove assertion errors corrected since 0.3.5
- make sure that integration test errors are raised in the CI
- disable the failing integration tests due to issue facebookresearch/fairscale#643
- minor refactor of the regnet_fsdp to use the fsdp_wrapper directly

Pull Request resolved: fairinternal/ssl_scaling#129

Reviewed By: prigoyal

Differential Revision: D28120296

Pulled By: QuentinDuval

fbshipit-source-id: 0c23b639a9d7103aafa5f25b14c05042269df061
@min-xu-ai min-xu-ai added the FSDP FullyShardedDataParallel (zero-3) label Apr 30, 2021
@min-xu-ai min-xu-ai self-assigned this Apr 30, 2021
min-xu-ai pushed a commit that referenced this issue May 5, 2021
- the precise condition should have been check m.parameters(), not
  m.params.
- fixes #643
min-xu-ai added a commit that referenced this issue May 5, 2021
* [fix] better assert and better test for frozen weights

- the precise condition should have been check m.parameters(), not
  m.params.
- fixes #643

* add changelog

* use enum is so much better

Co-authored-by: Min Xu <min.xu@acm.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FSDP FullyShardedDataParallel (zero-3)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants