Skip to content

Commit

Permalink
[FSDP] Clean missing and unexpected keys (pytorch#120600)
Browse files Browse the repository at this point in the history
Currently, when loading w/strict=False or w/strict=True and looking at
error message, FQNs are garbled w/FSDP details such as "_fsdp_wrapped_module".
This makes it tricky for upstream applications to validate the expected set of
keys are missing / unexpected (for example with PEFT where state_dict is loaded
non-strict), and makes error message more complicated w/FSDP details.

This PR cleans those prefixes by using `clean_tensor_name` in FSDP's existing
post load_state_dict hooks. Currently, only full_state_dict impl is tested, can test the rest of the impls as follow up work.

Differential Revision: [D54182472](https://our.internmc.facebook.com/intern/diff/D54182472/)

Pull Request resolved: pytorch#120600
Approved by: https://github.com/XilunWu, https://github.com/fegin
  • Loading branch information
rohan-varma authored and pytorchmergebot committed Feb 27, 2024
1 parent b2a318d commit 9db6a84
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
15 changes: 15 additions & 0 deletions test/distributed/fsdp/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ShardedStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp._common_utils import FSDP_PREFIX
from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM
from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap
from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer
Expand Down Expand Up @@ -1179,6 +1180,20 @@ def test_shared_module_and_shared_parameter(self):
self.assertEqual(state_dict["net2.0.bias"], state_dict["net3.0.bias"])
self.assertEqual(state_dict["net2.0.weight"], state_dict["net3.0.weight"])

@skip_if_lt_x_gpu(2)
def test_full_state_dict_missing_unexpected_keys_cleaned(self):
model = self._get_simple_nested_model()
sd = model.state_dict()
# Create a missing key
sd.pop(next(iter(sd.keys())))
# Create an unexpected key
sd["unexpected"] = torch.ones(1)
missing, unexpected = model.load_state_dict(sd, strict=False)
assert len(missing) == 1
assert len(unexpected) == 1
self.assertTrue(FSDP_PREFIX not in missing[0])
self.assertTrue(FSDP_PREFIX not in unexpected[0])

@skip_if_lt_x_gpu(2)
def test_sharded_load_multi_backend_pg(self):
auto_wrap_policy = ModuleWrapPolicy(
Expand Down
22 changes: 21 additions & 1 deletion torch/distributed/fsdp/_state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import logging
import math
import warnings
from typing import Any, Callable, cast, Dict, Generator, Iterator, no_type_check, Tuple
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterator,
List,
no_type_check,
Tuple,
)

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -854,6 +864,7 @@ def _pre_load_state_dict_hook(
@torch.no_grad()
def _post_load_state_dict_hook(
module: nn.Module,
incompatible_keys: Tuple[List[str], List[str]],
*args: Any,
) -> None:
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
Expand All @@ -877,6 +888,15 @@ def _post_load_state_dict_hook(
# loading state_dict.
_post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)

# When reporting incompatible keys, trim FSDP prefixes.
missing_keys = incompatible_keys[0]
unexpected_keys = incompatible_keys[1]
for i in range(len(missing_keys)):
missing_keys[i] = clean_tensor_name(missing_keys[i])

for i in range(len(unexpected_keys)):
unexpected_keys[i] = clean_tensor_name(unexpected_keys[i])

if fsdp_state._is_root:
SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ")

Expand Down

0 comments on commit 9db6a84

Please sign in to comment.