Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨🚨🚨 An attempt to fix #29554. Include 'LayerNorm.' in gamma/beta rename scope, optimize string search. #35615

Merged
merged 6 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4336,26 +4336,27 @@ def from_pretrained(
return model

@staticmethod
def _fix_state_dict_key_on_load(key):
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""

if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
if key.endswith("LayerNorm.beta"):
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
elif key.endswith("LayerNorm.gamma"):
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True

# to avoid logging parametrized weight norm renaming
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
return key.replace("weight_g", "parametrizations.weight.original0")
return key.replace("weight_g", "parametrizations.weight.original0"), True
if "weight_v" in key:
return key.replace("weight_v", "parametrizations.weight.original1")
return key.replace("weight_v", "parametrizations.weight.original1"), True
else:
if "parametrizations.weight.original0" in key:
return key.replace("parametrizations.weight.original0", "weight_g")
return key.replace("parametrizations.weight.original0", "weight_g"), True
if "parametrizations.weight.original1" in key:
return key.replace("parametrizations.weight.original1", "weight_v")
return key
return key.replace("parametrizations.weight.original1", "weight_v"), True

return key, False

@classmethod
def _fix_state_dict_keys_on_load(cls, state_dict):
Expand All @@ -4366,15 +4367,15 @@ def _fix_state_dict_keys_on_load(cls, state_dict):
renamed_keys = {}
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
new_key = cls._fix_state_dict_key_on_load(key)
if new_key != key:
new_key, has_changed = cls._fix_state_dict_key_on_load(key)
if has_changed:
state_dict[new_key] = state_dict.pop(key)

# add it once for logging
if "gamma" in key and "gamma" not in renamed_keys:
renamed_keys["gamma"] = (key, new_key)
if "beta" in key and "beta" not in renamed_keys:
renamed_keys["beta"] = (key, new_key)
# track gamma/beta rename for logging
if key.endswith("LayerNorm.gamma"):
renamed_keys["LayerNorm.gamma"] = (key, new_key)
elif key.endswith("LayerNorm.beta"):
renamed_keys["LayerNorm.beta"] = (key, new_key)

if renamed_keys:
warning_msg = f"A pretrained model of type `{cls.__name__}` "
Expand All @@ -4387,19 +4388,19 @@ def _fix_state_dict_keys_on_load(cls, state_dict):
return state_dict

@staticmethod
def _fix_state_dict_key_on_save(key):
def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]:
"""
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
Do nothing by default, but can be overriden in particular models.
Do nothing by default, but can be overridden in particular models.
"""
return key
return key, False

def _fix_state_dict_keys_on_save(self, state_dict):
"""
Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
"""
return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()}

@classmethod
def _load_pretrained_model(
Expand Down Expand Up @@ -4457,7 +4458,7 @@ def _load_pretrained_model(
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)

original_loaded_keys = loaded_keys
loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys]
loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_keys]

if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,22 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@staticmethod
def _fix_state_dict_key_on_load(key):
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""
Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
We don't want this behavior for timm wrapped models. Instead, this method adds a
"timm_model." prefix to enable loading official timm Hub checkpoints.
"""
if "timm_model." not in key:
return f"timm_model.{key}"
return key
return f"timm_model.{key}", True
return key, False

def _fix_state_dict_key_on_save(self, key):
"""
Overrides original method to remove "timm_model." prefix from state_dict keys.
Makes the saved checkpoint compatible with the `timm` library.
"""
return key.replace("timm_model.", "")
return key.replace("timm_model.", ""), True

def load_state_dict(self, state_dict, *args, **kwargs):
"""
Expand Down
Loading