diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 8cfc8314ee296..6374b06702b50 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -34,6 +34,7 @@ ShardedStateDictConfig, StateDictType, ) +from torch.distributed.fsdp._common_utils import FSDP_PREFIX from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer @@ -1179,6 +1180,20 @@ def test_shared_module_and_shared_parameter(self): self.assertEqual(state_dict["net2.0.bias"], state_dict["net3.0.bias"]) self.assertEqual(state_dict["net2.0.weight"], state_dict["net3.0.weight"]) + @skip_if_lt_x_gpu(2) + def test_full_state_dict_missing_unexpected_keys_cleaned(self): + model = self._get_simple_nested_model() + sd = model.state_dict() + # Create a missing key + sd.pop(next(iter(sd.keys()))) + # Create an unexpected key + sd["unexpected"] = torch.ones(1) + missing, unexpected = model.load_state_dict(sd, strict=False) + assert len(missing) == 1 + assert len(unexpected) == 1 + self.assertTrue(FSDP_PREFIX not in missing[0]) + self.assertTrue(FSDP_PREFIX not in unexpected[0]) + @skip_if_lt_x_gpu(2) def test_sharded_load_multi_backend_pg(self): auto_wrap_policy = ModuleWrapPolicy( diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 66840851d17ec..eac5bed6f0ea3 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -2,7 +2,17 @@ import logging import math import warnings -from typing import Any, Callable, cast, Dict, Generator, Iterator, no_type_check, Tuple +from typing import ( + Any, + Callable, + cast, + Dict, + Generator, + Iterator, + List, + no_type_check, + Tuple, +) import torch import torch.distributed as dist @@ -854,6 +864,7 @@ def _pre_load_state_dict_hook( @torch.no_grad() def _post_load_state_dict_hook( module: nn.Module, + incompatible_keys: Tuple[List[str], List[str]], *args: Any, ) -> None: fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) @@ -877,6 +888,15 @@ def _post_load_state_dict_hook( # loading state_dict. _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state) + # When reporting incompatible keys, trim FSDP prefixes. + missing_keys = incompatible_keys[0] + unexpected_keys = incompatible_keys[1] + for i in range(len(missing_keys)): + missing_keys[i] = clean_tensor_name(missing_keys[i]) + + for i in range(len(unexpected_keys)): + unexpected_keys[i] = clean_tensor_name(unexpected_keys[i]) + if fsdp_state._is_root: SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ")