Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for frozen parameters (Linear classification benchmarks) #610

Closed
QuentinDuval opened this issue Apr 15, 2021 · 0 comments · Fixed by #614
Closed

Support for frozen parameters (Linear classification benchmarks) #610

QuentinDuval opened this issue Apr 15, 2021 · 0 comments · Fixed by #614
Assignees
Labels
FSDP FullyShardedDataParallel (zero-3)

Comments

@QuentinDuval
Copy link
Contributor

QuentinDuval commented Apr 15, 2021

❓ Support for frozen parameters

In the context of VISSL, we are training a SWAV model with FSDP and then want to evaluate this model using linear evaluation, again using FSDP to wrap our model.

In the context of linear evaluation, only the linear layer on top of the "trunk" will be learned: the representations of the trunk, fed to the linear layer, will stay constant, and so we set requires_grad=False on the parameters of the trunk.

While this works with DDP, this leads to an assert in the case of FSDP:

expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.IDLE

To help you dig into this issue, I have reproduced a small test case which demonstrate this issue:

  • run this script with -f to enable FSDP (otherwise, we use DDP)
  • run this script with -l (for linear) to freeze the trunk

You will get the assert if you enable both -f and -l and everything will work fine otherwise.

import argparse

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
        )
        self.head = nn.Linear(64, 10)

    def forward(self, x):
        return self.head(self.trunk(x))


def create_model(with_fsdp: bool):
    model = Model()
    if with_fsdp:
        model.trunk = FSDP(model.trunk)
        model.head = FSDP(model.head)
    return model


def _distributed_worker(gpu_id: int, with_fsdp: bool, with_linear: bool):
    torch.cuda.set_device(gpu_id)
    dist.init_process_group(
        backend="nccl", init_method="tcp://127.0.0.1:9099", world_size=2, rank=gpu_id
    )

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = create_model(with_fsdp)
    model = model.cuda()

    if with_linear:
        for name, param in model.named_parameters():
            if "trunk" in name:
                param.requires_grad = False

    if with_fsdp:
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

    target = torch.LongTensor([0, 1]).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        optimizer.step()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--fsdp", action="store_const", const=True, default=False)
    parser.add_argument("-l", "--linear", action="store_const", const=True, default=False)
    args = parser.parse_args()
    mp.spawn(_distributed_worker, (args.fsdp, args.linear), nprocs=2)

I use pytorch 1.6.0+cu102, with the following modules enabled:
module load cuda/11.0 cudnn/v8.0.3.33-cuda.11.0 NCCL/2.8.3-1-cuda.11.0

Don't hesitate to reach me for more details.

CC: @prigoyal @min-xu-ai @myleott

@myleott myleott added the FSDP FullyShardedDataParallel (zero-3) label Apr 15, 2021
@min-xu-ai min-xu-ai self-assigned this Apr 16, 2021
min-xu-ai added a commit that referenced this issue Apr 16, 2021
- an assert is changed to catch this case correctly
- unit test added (based on Quentin's test code) for this case and
  compare DDP and FSDP

fixes: #610
min-xu-ai added a commit that referenced this issue Apr 19, 2021
* FSDP: fixing training with freezing weights

- an assert is changed to catch this case correctly
- unit test added (based on Quentin's test code) for this case and
  compare DDP and FSDP

fixes: #610

* added test file to list 1

* Use better and simpler code as suggested by Myle

* testing both methods of freezing as well

Co-authored-by: Min Xu <min.xu@acm.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FSDP FullyShardedDataParallel (zero-3)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants