From 9db6a849ed53b5468b4b1f2e4bfba15a60ba3a88 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Mon, 26 Feb 2024 20:49:35 -0800 Subject: [PATCH] [FSDP] Clean missing and unexpected keys (#120600) Currently, when loading w/strict=False or w/strict=True and looking at error message, FQNs are garbled w/FSDP details such as "_fsdp_wrapped_module". This makes it tricky for upstream applications to validate the expected set of keys are missing / unexpected (for example with PEFT where state_dict is loaded non-strict), and makes error message more complicated w/FSDP details. This PR cleans those prefixes by using `clean_tensor_name` in FSDP's existing post load_state_dict hooks. Currently, only full_state_dict impl is tested, can test the rest of the impls as follow up work. Differential Revision: [D54182472](https://our.internmc.facebook.com/intern/diff/D54182472/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/120600 Approved by: https://github.com/XilunWu, https://github.com/fegin --- test/distributed/fsdp/test_fsdp_state_dict.py | 15 +++++++++++++ torch/distributed/fsdp/_state_dict_utils.py | 22 ++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 8cfc8314ee2968..6374b06702b50f 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -34,6 +34,7 @@ ShardedStateDictConfig, StateDictType, ) +from torch.distributed.fsdp._common_utils import FSDP_PREFIX from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer @@ -1179,6 +1180,20 @@ def test_shared_module_and_shared_parameter(self): self.assertEqual(state_dict["net2.0.bias"], state_dict["net3.0.bias"]) self.assertEqual(state_dict["net2.0.weight"], state_dict["net3.0.weight"]) + @skip_if_lt_x_gpu(2) + def test_full_state_dict_missing_unexpected_keys_cleaned(self): + model = self._get_simple_nested_model() + sd = model.state_dict() + # Create a missing key + sd.pop(next(iter(sd.keys()))) + # Create an unexpected key + sd["unexpected"] = torch.ones(1) + missing, unexpected = model.load_state_dict(sd, strict=False) + assert len(missing) == 1 + assert len(unexpected) == 1 + self.assertTrue(FSDP_PREFIX not in missing[0]) + self.assertTrue(FSDP_PREFIX not in unexpected[0]) + @skip_if_lt_x_gpu(2) def test_sharded_load_multi_backend_pg(self): auto_wrap_policy = ModuleWrapPolicy( diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 66840851d17ec8..eac5bed6f0ea33 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -2,7 +2,17 @@ import logging import math import warnings -from typing import Any, Callable, cast, Dict, Generator, Iterator, no_type_check, Tuple +from typing import ( + Any, + Callable, + cast, + Dict, + Generator, + Iterator, + List, + no_type_check, + Tuple, +) import torch import torch.distributed as dist @@ -854,6 +864,7 @@ def _pre_load_state_dict_hook( @torch.no_grad() def _post_load_state_dict_hook( module: nn.Module, + incompatible_keys: Tuple[List[str], List[str]], *args: Any, ) -> None: fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) @@ -877,6 +888,15 @@ def _post_load_state_dict_hook( # loading state_dict. _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state) + # When reporting incompatible keys, trim FSDP prefixes. + missing_keys = incompatible_keys[0] + unexpected_keys = incompatible_keys[1] + for i in range(len(missing_keys)): + missing_keys[i] = clean_tensor_name(missing_keys[i]) + + for i in range(len(unexpected_keys)): + unexpected_keys[i] = clean_tensor_name(unexpected_keys[i]) + if fsdp_state._is_root: SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ")