Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FSDP for Models with Frozen Weights #5484

Merged
merged 5 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/test_torch_distributed_fsdp_frozen_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sys
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP


def _mp_fn(index):
dev = xm.xla_device()
if xm.xla_device_hw(dev) not in ('TPU', 'GPU'):
print(
'Default device {} is not a TPU or GPU device'.format(dev),
file=sys.stderr)
return

model = nn.Linear(1024, 1024)
model.weight.requires_grad = False # the weight param is frozen

model = FSDP(model) # wrapping the linear module with FSDP

input = torch.rand((2, 1024), device=xm.xla_device())

output = model(input)
loss = torch.sum(output)
loss.backward()
assert not any(p._has_full_param for p in model.full_params), \
'Expecting all the full params to be freed at this moment.'


if __name__ == "__main__":
xmp.spawn(_mp_fn, args=())
Original file line number Diff line number Diff line change
Expand Up @@ -1299,13 +1299,19 @@ def _wait_for_post_backward(self) -> None:
# A backward pass is done, clean up below.
def _finalize_parameters(fsdp_module: XlaFullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
frozen_params = []
Liyang90 marked this conversation as resolved.
Show resolved Hide resolved
for p in fsdp_module.full_params:
if not p.requires_grad:
continue
frozen_params.append(p)
if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")
# Free the full params with `requires_grad==False`
if frozen_params:
fsdp_module._free_full_params(
frozen_params,
apply_opt_barrier=self.optimization_barrier_in_backward)

# Update root and nested FSDP's hooks and flags.
for m in self.modules(): # includes self
Expand Down