From c8453b21755dafcece8f3744c3133fcc3fc28686 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Wed, 23 Aug 2023 16:58:31 +0000 Subject: [PATCH 1/4] Fix fsdp not freeing forzen full params --- .../distributed/fsdp/xla_fully_sharded_data_parallel.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index e51a31e0b4d..f1b62d1700b 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -1299,13 +1299,19 @@ def _wait_for_post_backward(self) -> None: # A backward pass is done, clean up below. def _finalize_parameters(fsdp_module: XlaFullyShardedDataParallel) -> None: """Helper used below on all fsdp modules.""" + frozen_params = [] for p in fsdp_module.full_params: if not p.requires_grad: - continue + frozen_params.append(p) if hasattr(p, "_shard_bwd_hook"): assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook) p._shard_bwd_hook[1].remove() delattr(p, "_shard_bwd_hook") + # Free the full params with `requires_grad==False` + if frozen_params: + fsdp_module._free_full_params( + frozen_params, + apply_opt_barrier=self.optimization_barrier_in_backward) # Update root and nested FSDP's hooks and flags. for m in self.modules(): # includes self From b32ec87cbf015e6ac2f8d896ef813540cea26933 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Wed, 23 Aug 2023 17:45:24 +0000 Subject: [PATCH 2/4] add test --- ...st_torch_distributed_fsdp_frozen_weight.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 test/test_torch_distributed_fsdp_frozen_weight.py diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py new file mode 100644 index 00000000000..4abb8bc6582 --- /dev/null +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -0,0 +1,35 @@ +import os +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1" + +import sys +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + + +def _mp_fn(index): + dev = xm.xla_device() + if xm.xla_device_hw(dev) not in ('TPU', 'GPU'): + print( + 'Default device {} is not a TPU or GPU device'.format(dev), + file=sys.stderr) + return + + model = nn.Linear(1024, 1024) + model.weight.requires_grad = False # the weight param is frozen + + model = FSDP(model) # wrapping the linear module with FSDP + + input = torch.rand((2, 1024), device=xm.xla_device()) + + output = model(input) + loss = torch.sum(output) + loss.backward() + assert not any(p._has_full_param for p in model.full_params), \ + 'Expecting all the full params to be freed at this moment.' + + +if __name__ == "__main__": + xmp.spawn(_mp_fn, args=()) From 5dc7e9058806d6f86c4ad7364743a2308ae7b623 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Wed, 23 Aug 2023 17:49:04 +0000 Subject: [PATCH 3/4] formatting --- ...st_torch_distributed_fsdp_frozen_weight.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index 4abb8bc6582..4dd82ce094c 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -1,4 +1,5 @@ import os + os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1" import sys @@ -10,26 +11,26 @@ def _mp_fn(index): - dev = xm.xla_device() - if xm.xla_device_hw(dev) not in ('TPU', 'GPU'): - print( + dev = xm.xla_device() + if xm.xla_device_hw(dev) not in ('TPU', 'GPU'): + print( 'Default device {} is not a TPU or GPU device'.format(dev), file=sys.stderr) - return + return - model = nn.Linear(1024, 1024) - model.weight.requires_grad = False # the weight param is frozen + model = nn.Linear(1024, 1024) + model.weight.requires_grad = False # the weight param is frozen - model = FSDP(model) # wrapping the linear module with FSDP + model = FSDP(model) # wrapping the linear module with FSDP - input = torch.rand((2, 1024), device=xm.xla_device()) + input = torch.rand((2, 1024), device=xm.xla_device()) - output = model(input) - loss = torch.sum(output) - loss.backward() - assert not any(p._has_full_param for p in model.full_params), \ - 'Expecting all the full params to be freed at this moment.' + output = model(input) + loss = torch.sum(output) + loss.backward() + assert not any(p._has_full_param for p in model.full_params), \ + 'Expecting all the full params to be freed at this moment.' if __name__ == "__main__": - xmp.spawn(_mp_fn, args=()) + xmp.spawn(_mp_fn, args=()) From d92e36f581bdffa1fc14d30625afc54e32c4008f Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Wed, 30 Aug 2023 18:02:36 +0000 Subject: [PATCH 4/4] remove unnecessary env var in test --- test/test_torch_distributed_fsdp_frozen_weight.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index 4dd82ce094c..79b65a46999 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -1,7 +1,3 @@ -import os - -os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1" - import sys import torch import torch.nn as nn