Skip to content

Commit

Permalink
[fix] better assert and better test for frozen weights (#657)
Browse files Browse the repository at this point in the history
* [fix] better assert and better test for frozen weights

- the precise condition should have been check m.parameters(), not
  m.params.
- fixes #643

* add changelog

* use enum is so much better

Co-authored-by: Min Xu <min.xu@acm.org>
  • Loading branch information
min-xu-ai and Min Xu authored May 5, 2021
1 parent 1ae7778 commit b54eed1
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## NEXT - TBD
### Fixed
- FSDP: improved frozen weight support
- FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647))
Expand Down
16 changes: 10 additions & 6 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,15 +1301,19 @@ def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None:
if isinstance(m, FullyShardedDataParallel):
_remove_shard_bwd_hook(m)
m._pre_backward_hook_has_run = False
if m._has_params:
if any(p.requires_grad for p in m.params):
# Note: m.parameters() should not be an empty list. FSDP
# wrapping modules without weights is not tested at the moment.
if any(p.requires_grad for p in m.parameters()):
if m._has_params:
m.assert_state(TrainingState.BACKWARD_POST)
else:
# Unlikely case, should only happens if `m` has params but none of the
# params has `requires_grad==True`.
m.assert_state(TrainingState.IDLE)
m.assert_state(TrainingState.BACKWARD_PRE)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
# Unlikely case. When m and its children has no params
# with `requires_grad==True`, then m's pre-backward and
# post-backward hooks aren't called by autograd. Therefore,
# it is in IDLE state.
m.assert_state(TrainingState.IDLE)
m.training_state = TrainingState.IDLE

@torch.no_grad()
Expand Down
78 changes: 61 additions & 17 deletions tests/nn/data_parallel/test_fsdp_freezing_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
""" Test FSDP with some params frozen. """


from enum import Enum
import tempfile

import pytest
Expand Down Expand Up @@ -38,16 +39,51 @@ def forward(self, x):
return self.head(self.trunk(x))


def _create_model(with_fsdp):
model = Model()
if with_fsdp:
model.trunk = FSDP(model.trunk)
model.head = FSDP(model.head)
class NestedTrunkModel(nn.Module):
def __init__(self, with_fsdp):
super().__init__()
self.trunk = nn.Sequential(self._create_block(3, 64, with_fsdp), self._create_block(64, 64, with_fsdp),)
self.head = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(64, 10),)
if with_fsdp:
self.trunk = FSDP(self.trunk)
self.head = FSDP(self.head)

def forward(self, x):
return self.head(self.trunk(x))

def _create_block(self, in_channels, out_channels, with_fsdp):
block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3), nn.ReLU(inplace=True),)
if with_fsdp:
block = FSDP(block)
return block


def _create_model(with_fsdp, with_nested_trunk):
if with_nested_trunk:
model = NestedTrunkModel(with_fsdp)
else:
model = Model()
if with_fsdp:
model.trunk = FSDP(model.trunk)
model.head = FSDP(model.head)
return model


class FreezingMethod(str, Enum):
GradToNone = "grad_to_none"
RequiresGrad = "requires_grad"


def _distributed_worker(
gpu_id, world_size, with_fsdp, freezing_method, tempfile_name, unused, rank_0_output, expected_state
gpu_id,
world_size,
with_fsdp,
with_nested_trunk,
freezing_method,
tempfile_name,
unused,
rank_0_output,
expected_state,
):
torch.cuda.set_device(gpu_id)

Expand All @@ -59,12 +95,11 @@ def _distributed_worker(
torch.backends.cudnn.deterministic = True
batch = torch.randn(size=(2, 3, 224, 224)).cuda()

model = _create_model(with_fsdp)
model = _create_model(with_fsdp, with_nested_trunk)
model = model.cuda()

# freezing the trunk using requires_grad.
assert freezing_method in ["requires_grad", "grad_to_none"]
if freezing_method == "requires_grad":
if freezing_method == FreezingMethod.RequiresGrad:
for param in model.trunk.parameters():
param.requires_grad = False

Expand All @@ -86,7 +121,7 @@ def _distributed_worker(
print("Loss", iteration, ":", fake_loss.item())
optimizer.zero_grad()
fake_loss.backward()
if freezing_method == "grad_to_none":
if freezing_method == FreezingMethod.GradToNone:
for param in model.trunk.parameters():
param.grad = None
optimizer.step()
Expand Down Expand Up @@ -118,21 +153,30 @@ def temp_files():


@skip_if_single_gpu
def test_freezing_weights(temp_files):
@pytest.mark.parametrize("nested_trunk", ["nested_trunk", "simple_trunk"])
def test_freezing_weights(temp_files, nested_trunk):
with_nested_trunk = nested_trunk == "nested_trunk"

world_size = 2
# DDP
fsdp = False
freezing_method = "requires_grad"
mp.spawn(_distributed_worker, (world_size, fsdp, freezing_method) + temp_files[0:3] + (None,), nprocs=world_size)
with_fsdp = False
freezing_method = FreezingMethod.RequiresGrad
mp.spawn(
_distributed_worker,
(world_size, with_fsdp, with_nested_trunk, freezing_method) + temp_files[0:3] + (None,),
nprocs=world_size,
)
# FSDP, case 1 and 2.
fsdp = True
with_fsdp = True
expected_state = torch.load(temp_files[2])
temp_file_idx = 3
for freezing_method in ["requires_grad", "grad_to_none"]:
for freezing_method in [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone]:
print(f"Testing FSDP with freezing method {freezing_method}")
mp.spawn(
_distributed_worker,
(world_size, fsdp, freezing_method) + temp_files[temp_file_idx : temp_file_idx + 3] + (expected_state,),
(world_size, with_fsdp, with_nested_trunk, freezing_method)
+ temp_files[temp_file_idx : temp_file_idx + 3]
+ (expected_state,),
nprocs=world_size,
)
temp_file_idx += 3

0 comments on commit b54eed1

Please sign in to comment.