diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8eb2d7439ef3..9d5266f2bcbe 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4336,26 +4336,27 @@ def from_pretrained( return model @staticmethod - def _fix_state_dict_key_on_load(key): + def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" - if "beta" in key: - return key.replace("beta", "bias") - if "gamma" in key: - return key.replace("gamma", "weight") + if key.endswith("LayerNorm.beta"): + return key.replace("LayerNorm.beta", "LayerNorm.bias"), True + elif key.endswith("LayerNorm.gamma"): + return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True # to avoid logging parametrized weight norm renaming if hasattr(nn.utils.parametrizations, "weight_norm"): if "weight_g" in key: - return key.replace("weight_g", "parametrizations.weight.original0") + return key.replace("weight_g", "parametrizations.weight.original0"), True if "weight_v" in key: - return key.replace("weight_v", "parametrizations.weight.original1") + return key.replace("weight_v", "parametrizations.weight.original1"), True else: if "parametrizations.weight.original0" in key: - return key.replace("parametrizations.weight.original0", "weight_g") + return key.replace("parametrizations.weight.original0", "weight_g"), True if "parametrizations.weight.original1" in key: - return key.replace("parametrizations.weight.original1", "weight_v") - return key + return key.replace("parametrizations.weight.original1", "weight_v"), True + + return key, False @classmethod def _fix_state_dict_keys_on_load(cls, state_dict): @@ -4366,15 +4367,15 @@ def _fix_state_dict_keys_on_load(cls, state_dict): renamed_keys = {} state_dict_keys = list(state_dict.keys()) for key in state_dict_keys: - new_key = cls._fix_state_dict_key_on_load(key) - if new_key != key: + new_key, has_changed = cls._fix_state_dict_key_on_load(key) + if has_changed: state_dict[new_key] = state_dict.pop(key) - # add it once for logging - if "gamma" in key and "gamma" not in renamed_keys: - renamed_keys["gamma"] = (key, new_key) - if "beta" in key and "beta" not in renamed_keys: - renamed_keys["beta"] = (key, new_key) + # track gamma/beta rename for logging + if key.endswith("LayerNorm.gamma"): + renamed_keys["LayerNorm.gamma"] = (key, new_key) + elif key.endswith("LayerNorm.beta"): + renamed_keys["LayerNorm.beta"] = (key, new_key) if renamed_keys: warning_msg = f"A pretrained model of type `{cls.__name__}` " @@ -4387,19 +4388,19 @@ def _fix_state_dict_keys_on_load(cls, state_dict): return state_dict @staticmethod - def _fix_state_dict_key_on_save(key): + def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]: """ Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save. - Do nothing by default, but can be overriden in particular models. + Do nothing by default, but can be overridden in particular models. """ - return key + return key, False def _fix_state_dict_keys_on_save(self, state_dict): """ Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save. Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`. """ - return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()} + return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()} @classmethod def _load_pretrained_model( diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 47e8944583b4..e160a965c4a9 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -90,15 +90,15 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def _fix_state_dict_key_on_load(key): + def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: """ Overrides original method that renames `gamma` and `beta` to `weight` and `bias`. We don't want this behavior for timm wrapped models. Instead, this method adds a "timm_model." prefix to enable loading official timm Hub checkpoints. """ if "timm_model." not in key: - return f"timm_model.{key}" - return key + return f"timm_model.{key}", True + return key, False def _fix_state_dict_key_on_save(self, key): """