Skip to content

Commit

Permalink
An attempt to fix huggingface#29554. Include 'LayerNorm.' in gamma/be…
Browse files Browse the repository at this point in the history
…ta rename scope, reduce number of characters searched on every load considerably.
  • Loading branch information
rwightman committed Jan 10, 2025
1 parent 04eae98 commit faeed0d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
43 changes: 22 additions & 21 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4336,26 +4336,27 @@ def from_pretrained(
return model

@staticmethod
def _fix_state_dict_key_on_load(key):
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""

if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
if key.endswith("LayerNorm.beta"):
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
elif key.endswith("LayerNorm.gamma"):
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True

# to avoid logging parametrized weight norm renaming
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
return key.replace("weight_g", "parametrizations.weight.original0")
return key.replace("weight_g", "parametrizations.weight.original0"), True
if "weight_v" in key:
return key.replace("weight_v", "parametrizations.weight.original1")
return key.replace("weight_v", "parametrizations.weight.original1"), True
else:
if "parametrizations.weight.original0" in key:
return key.replace("parametrizations.weight.original0", "weight_g")
return key.replace("parametrizations.weight.original0", "weight_g"), True
if "parametrizations.weight.original1" in key:
return key.replace("parametrizations.weight.original1", "weight_v")
return key
return key.replace("parametrizations.weight.original1", "weight_v"), True

return key, False

@classmethod
def _fix_state_dict_keys_on_load(cls, state_dict):
Expand All @@ -4366,15 +4367,15 @@ def _fix_state_dict_keys_on_load(cls, state_dict):
renamed_keys = {}
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
new_key = cls._fix_state_dict_key_on_load(key)
if new_key != key:
new_key, has_changed = cls._fix_state_dict_key_on_load(key)
if has_changed:
state_dict[new_key] = state_dict.pop(key)

# add it once for logging
if "gamma" in key and "gamma" not in renamed_keys:
renamed_keys["gamma"] = (key, new_key)
if "beta" in key and "beta" not in renamed_keys:
renamed_keys["beta"] = (key, new_key)
# track gamma/beta rename for logging
if key.endswith("LayerNorm.gamma"):
renamed_keys["LayerNorm.gamma"] = (key, new_key)
elif key.endswith("LayerNorm.beta"):
renamed_keys["LayerNorm.beta"] = (key, new_key)

if renamed_keys:
warning_msg = f"A pretrained model of type `{cls.__name__}` "
Expand All @@ -4387,19 +4388,19 @@ def _fix_state_dict_keys_on_load(cls, state_dict):
return state_dict

@staticmethod
def _fix_state_dict_key_on_save(key):
def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]:
"""
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
Do nothing by default, but can be overriden in particular models.
Do nothing by default, but can be overridden in particular models.
"""
return key
return key, False

def _fix_state_dict_keys_on_save(self, state_dict):
"""
Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
"""
return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()}

@classmethod
def _load_pretrained_model(
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@staticmethod
def _fix_state_dict_key_on_load(key):
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""
Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
We don't want this behavior for timm wrapped models. Instead, this method adds a
"timm_model." prefix to enable loading official timm Hub checkpoints.
"""
if "timm_model." not in key:
return f"timm_model.{key}"
return key
return f"timm_model.{key}", True
return key, False

def _fix_state_dict_key_on_save(self, key):
"""
Expand Down

0 comments on commit faeed0d

Please sign in to comment.