Skip to content

Commit

Permalink
🚨🚨🚨 An attempt to fix huggingface#29554. Include 'LayerNorm.' in gamm…
Browse files Browse the repository at this point in the history
…a/beta rename scope, optimize string search. (huggingface#35615)

* An attempt to fix huggingface#29554. Include 'LayerNorm.' in gamma/beta rename scope, reduce number of characters searched on every load considerably.

* Fix fix on load issue

* Fix gamma/beta warning test

* A style complaint

* Improve efficiency of weight norm key rename. Add better comments about weight norm and layer norm renaming.

* Habitual elif redunant with the return
  • Loading branch information
rwightman authored and bursteratom committed Jan 28, 2025
1 parent 1aeaead commit b0b44c6
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 65 deletions.
59 changes: 32 additions & 27 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4367,26 +4367,31 @@ def from_pretrained(
return model

@staticmethod
def _fix_state_dict_key_on_load(key):
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""

if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
# This rename is logged.
if key.endswith("LayerNorm.beta"):
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
if key.endswith("LayerNorm.gamma"):
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True

# to avoid logging parametrized weight norm renaming
# Rename weight norm parametrizations to match changes across torch versions.
# Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
# This rename is not logged.
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
return key.replace("weight_g", "parametrizations.weight.original0")
if "weight_v" in key:
return key.replace("weight_v", "parametrizations.weight.original1")
if key.endswith("weight_g"):
return key.replace("weight_g", "parametrizations.weight.original0"), True
if key.endswith("weight_v"):
return key.replace("weight_v", "parametrizations.weight.original1"), True
else:
if "parametrizations.weight.original0" in key:
return key.replace("parametrizations.weight.original0", "weight_g")
if "parametrizations.weight.original1" in key:
return key.replace("parametrizations.weight.original1", "weight_v")
return key
if key.endswith("parametrizations.weight.original0"):
return key.replace("parametrizations.weight.original0", "weight_g"), True
if key.endswith("parametrizations.weight.original1"):
return key.replace("parametrizations.weight.original1", "weight_v"), True

return key, False

@classmethod
def _fix_state_dict_keys_on_load(cls, state_dict):
Expand All @@ -4397,15 +4402,15 @@ def _fix_state_dict_keys_on_load(cls, state_dict):
renamed_keys = {}
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
new_key = cls._fix_state_dict_key_on_load(key)
if new_key != key:
new_key, has_changed = cls._fix_state_dict_key_on_load(key)
if has_changed:
state_dict[new_key] = state_dict.pop(key)

# add it once for logging
if "gamma" in key and "gamma" not in renamed_keys:
renamed_keys["gamma"] = (key, new_key)
if "beta" in key and "beta" not in renamed_keys:
renamed_keys["beta"] = (key, new_key)
# track gamma/beta rename for logging
if key.endswith("LayerNorm.gamma"):
renamed_keys["LayerNorm.gamma"] = (key, new_key)
elif key.endswith("LayerNorm.beta"):
renamed_keys["LayerNorm.beta"] = (key, new_key)

if renamed_keys:
warning_msg = f"A pretrained model of type `{cls.__name__}` "
Expand All @@ -4418,19 +4423,19 @@ def _fix_state_dict_keys_on_load(cls, state_dict):
return state_dict

@staticmethod
def _fix_state_dict_key_on_save(key):
def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]:
"""
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
Do nothing by default, but can be overriden in particular models.
Do nothing by default, but can be overridden in particular models.
"""
return key
return key, False

def _fix_state_dict_keys_on_save(self, state_dict):
"""
Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
"""
return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()}

@classmethod
def _load_pretrained_model(
Expand Down Expand Up @@ -4488,7 +4493,7 @@ def _load_pretrained_model(
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)

original_loaded_keys = loaded_keys
loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys]
loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_keys]

if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,22 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@staticmethod
def _fix_state_dict_key_on_load(key):
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""
Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
We don't want this behavior for timm wrapped models. Instead, this method adds a
"timm_model." prefix to enable loading official timm Hub checkpoints.
"""
if "timm_model." not in key:
return f"timm_model.{key}"
return key
return f"timm_model.{key}", True
return key, False

def _fix_state_dict_key_on_save(self, key):
"""
Overrides original method to remove "timm_model." prefix from state_dict keys.
Makes the saved checkpoint compatible with the `timm` library.
"""
return key.replace("timm_model.", "")
return key.replace("timm_model.", ""), True

def load_state_dict(self, state_dict, *args, **kwargs):
"""
Expand Down
58 changes: 24 additions & 34 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,57 +1618,47 @@ def test_model_from_pretrained_from_mlx(self):
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))

def test_warning_for_beta_gamma_parameters(self):
class TestModelGamma(PreTrainedModel):
class TestGammaBetaNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.gamma = torch.nn.Parameter(torch.ones(1))
self.beta = torch.nn.Parameter(torch.zeros(1))

def forward(self):
return self.gamma.sum() + self.beta.sum()

class TestModelGammaBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.gamma_param = nn.Parameter(torch.ones(10))
self.LayerNorm = TestGammaBetaNorm()
self.post_init()

def forward(self):
return self.gamma_param.sum()
return self.LayerNorm()

logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "`gamma_param` -> `weight_param`"
model = TestModelGamma(config)
warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"
warning_msg_beta = "`LayerNorm.beta` -> `LayerNorm.bias`"
model = TestModelGammaBeta(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl1:
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)
_, loading_info = TestModelGammaBeta.from_pretrained(
tmp_dir, config=config, output_loading_info=True
)

missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelGamma`", cl1.out)
self.assertIn("`TestModelGammaBeta`", cl1.out)
self.assertIn(warning_msg_gamma, cl1.out)
self.assertIn("gamma_param", missing_keys)
self.assertIn("weight_param", unexpected_keys)

class TestModelBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.beta_param = nn.Parameter(torch.ones(10))
self.post_init()

def forward(self):
return self.beta_param.sum()

warning_msg_beta = "`beta_param` -> `bias_param`"
model = TestModelBeta(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl2:
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)

missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelBeta`", cl2.out)
self.assertIn(warning_msg_beta, cl2.out)
self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys)
self.assertIn(warning_msg_beta, cl1.out)
self.assertIn("LayerNorm.gamma", missing_keys)
self.assertIn("LayerNorm.weight", unexpected_keys)
self.assertIn("LayerNorm.beta", missing_keys)
self.assertIn("LayerNorm.bias", unexpected_keys)

def test_isin_mps_friendly(self):
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""
Expand Down

0 comments on commit b0b44c6

Please sign in to comment.