Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't load models with a gamma or beta parameter #29554

Closed
malik-ali opened this issue Mar 9, 2024 · 13 comments · Fixed by #31654 · May be fixed by #33192
Closed

Can't load models with a gamma or beta parameter #29554

malik-ali opened this issue Mar 9, 2024 · 13 comments · Fixed by #31654 · May be fixed by #33192
Labels
Feature request Request for a new feature Good Difficult Issue Should Fix This has been identified as a bug and should be fixed.

Comments

@malik-ali
Copy link

It seems that you cannot create parameters with the string gamma or beta in any modules you write if you intend to save/load them with the transformers library. There is a small function called _fix_keys implemented in the model loading (link). It renames all instances of beta or gamma in any substring of the sate_dict keys to be bias and weight. This means if your modules actually have a parameter with these names, they won't be loaded when using a pretrained model.

As far as I can tell, it's completely undocumented that people shouldn't create any parameters with the string gamma or beta in them.

Here is a minimal reproducible example:

import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig

class Model(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.gamma = nn.Parameter(torch.zeros(4))

    def forward(self):
        return self.gamma.sum()


if __name__ == '__main__':
    config = PretrainedConfig()

    # 1) First run this
    #model = Model(config)
    #print(model())

    #model.save_pretrained('test_out')

    # 2) Then try this
    model = Model.from_pretrained('test_out', config=config)
    print(model())

When you run this code, you get the following error:

Some weights of Model were not initialized from the model checkpoint at test_out and are newly initialized: ['gamma']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
@malik-ali malik-ali changed the title Can't have models with a gamma or beta parameter Can't load models with a gamma or beta parameter Mar 10, 2024
@NielsRogge
Copy link
Contributor

Yes that's correct, it's a bug I pointed out in my video series on contributing to Transformers.

This is due to these lines:

if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
.

I assume they are there for backwards compatibility reasons. If we would know which models require this exception, we could fix this.

@malik-ali
Copy link
Author

malik-ali commented Mar 10, 2024

I assumed the same, but it's a pretty annoying bug to have to find on your own. Would it be worth adding a warning to the init method of the PreTrainedModel class to let users know if their parameters have the string "gamma" or "beta" in them and encourage them to change it? At least while this block of code still exists in the codebase.

It's further complicated by the fact that accelerate uses torch.load_state, which doesn't do this renaming. So there is an incompatibility between two highly coupled libraries.

@amyeroberts
Copy link
Collaborator

Hi @malik-ali, thanks for raising this issue! Indeed, this isn't a desired behaviour.

If we would know which models require this exception, we could fix this.

I think this would be very hard to do. There are many saved checkpoints both on and off the hub, as well as all sorts of custom models which might rely on this behaviour.

Would it be worth adding a warning to the init method of the PreTrainedModel class to let users know if their parameters have the string "gamma" or "beta" in them and encourage them to change it? At least while this block of code still exists in the codebase.

Yes, I think a warning for a few cycle releases is the best way to go. I would put this in the _load_state_dict_into_model function and trigger if "gamma" or "beta" are in the key.

It won't be possible to tell if the parameter is from an "old" state or a new model, but we can warn that the renaming is happening, that the behaviour will be removed in a future release and they should update the weights in their state dict to use "weight" or "bias" to be loaded properly.

@malik-ali Would you like to open a PR to add this? This way you get the github contribution for your suggested solution

@malik-ali
Copy link
Author

@amyeroberts I'd be happy to! Just one question: if we add this to the _load_state_dict_into_model, is it correct that users would only see this warning when loading their pretrained model?

I ask because I ran into this issue after training a model for several days and later loading it. It would have been nice to see the warning before doing all the training, so that I could rename the parameters on the spot. Do you think a warning like that would be feasible?

(My fix was to manually rename the keys of the saved state_dict and then rename the parameters in my model)

@amyeroberts
Copy link
Collaborator

Good point! In this case, we'll need to add a warning in two places to make sure we catch both new model creations and old state dicts being loaded in.

@fzyzcjy
Copy link
Contributor

fzyzcjy commented Mar 21, 2024

+1 Find this problem today...

@amyeroberts amyeroberts added Should Fix This has been identified as a bug and should be fixed. Feature request Request for a new feature labels Mar 21, 2024
@malik-ali
Copy link
Author

@amyeroberts I might not have a chance to push a fix for this for at least a few weeks so please feel free to make any changes as you (or anyone) wishes!

@amyeroberts
Copy link
Collaborator

@malik-ali OK - thanks for letting us know. I've added a 'Good difficult Issue' to flag for anyone in the community that might want to tackle this in the meantime

@whwangovo
Copy link

Good question. It took me two full days to process him. At first when I troubleshooted I always thought it was a problem with my training process. :(

Why have you added warnings only for the initialization process and not for renaming during loading as well?

The model I'm using is timm's convnext (which is even the companion framework to transformers), which would have the parameter gamma.

When loading he just tells me that I didn't successfully load the gamma function without telling me why, and I think the user should be informed when renaming the state_dict, otherwise it will cause unnecessary confusion.

@rwightman
Copy link
Contributor

rwightman commented Jan 9, 2025

maybe we should revive this and fix the issue once and for all? this is the sort of legacy baggage that really should get cleaned up instead of ignoring for 'backwards compat'. For reference to others in here in the past I'm bringing it up because it required special work around for the TimmWrapper and it's still breaking timm models that hahave 'gamma' keys in the TimmBackbone module

It dates back to old bert models ported from Transformers. I believe @thomwolf was overseeing that?

There probably aren't actually that many weight instance in the wild which rely on this mechanism. And there's likely an identifiable signature (key names in the state_dict) of models that actually need the rename.

EDIT: By signature based on key names, I mean an absolute key name like encoder.blocks[0].norm1.gamma, you verify that is there (or similar patterns for any known model) before you rename *.gamma and *.beta.

CC @ArthurZucker @qubvel

Also discussing in huggingface/pytorch-image-models#2324

@rwightman rwightman reopened this Jan 9, 2025
@rwightman
Copy link
Contributor

rwightman commented Jan 9, 2025

The original bert weights, including the safetensor include the old keys that need renaming. If you look at the snippet below, covering bert would be easy to do without impacting other models.

Slightly more specific, but still would make a bit be uneasy, *.LayerNorm.gamma (or beta) will cover bert and it's quite unlikely to see PyTorch users using LayerNorm as a layer name. You can also regex all keys or check for specific key to enable renaming (bert.embeddings.LayerNorm.gamma, bert.encoder.layer.0.attention.output.LayerNorm.gamma)

I guess the big question is, are there any models besides bert that needed this? T5 looks fine, couldn't find any other suspects but I'm not intimately familiar with old TF LM ports.

bert.embeddings		
bert.embeddings.position_embeddings.weight	[512, 768]	
bert.embeddings.token_type_embeddings.weight	[2, 768]	
bert.embeddings.word_embeddings.weight	[30 522, 768]	
bert.embeddings.LayerNorm.beta	[768]	
bert.embeddings.LayerNorm.gamma	[768]	
bert.encoder		
bert.encoder.layer.0.attention.self.key.bias	[768]	
bert.encoder.layer.0.attention.self.key.weight	[768, 768]	
bert.encoder.layer.0.attention.self.query.bias	[768]	
bert.encoder.layer.0.attention.self.query.weight	[768, 768]	
bert.encoder.layer.0.attention.self.value.bias	[768]	
bert.encoder.layer.0.attention.self.value.weight	[768, 768]	
bert.encoder.layer.0.attention.output.dense.bias	[768]	
bert.encoder.layer.0.attention.output.dense.weight	[768, 768]	
bert.encoder.layer.0.attention.output.LayerNorm.beta	[768]	
bert.encoder.layer.0.attention.output.LayerNorm.gamma	[768]	
bert.encoder.layer.0.intermediate.dense.bias	[3 072]	
bert.encoder.layer.0.intermediate.dense.weight	[3 072, 768]	
bert.encoder.layer.0.output.dense.bias	[768]	
bert.encoder.layer.0.output.dense.weight	[768, 3 072]	
bert.encoder.layer.0.output.LayerNorm.beta
...

Something like this would narrow the scope considerably and probably catch the intended models?

for key in list(state_dict.keys()):
    if key.endswith("LayerNorm.gamma"):
        new_key = key.replace("LayerNorm.gamma", "LayerNorm.weight")
        state_dict[new_key] = state_dict.pop(key)
    elif key.endswith("LayerNorm.beta"):
        new_key = key.replace("LayerNorm.beta", "LayerNorm.bias")
        state_dict[new_key] = state_dict.pop(key)

@NielsRogge
Copy link
Contributor

Thanks so much for fixing this @rwightman!

@rwightman
Copy link
Contributor

@NielsRogge thanks, if you happen to be aware of any other high risk/models weights that might have actually been relying on this rename (vs being constrained by it, heh), let me know. I could really only find Bert (which was what it was added for originally as far as I could tell). Hopefully we're safe wrt to any lurking regressions...

bursteratom pushed a commit to bursteratom/transformers that referenced this issue Jan 31, 2025
…a/beta rename scope, optimize string search. (huggingface#35615)

* An attempt to fix huggingface#29554. Include 'LayerNorm.' in gamma/beta rename scope, reduce number of characters searched on every load considerably.

* Fix fix on load issue

* Fix gamma/beta warning test

* A style complaint

* Improve efficiency of weight norm key rename. Add better comments about weight norm and layer norm renaming.

* Habitual elif redunant with the return
elvircrn pushed a commit to elvircrn/transformers that referenced this issue Feb 13, 2025
…a/beta rename scope, optimize string search. (huggingface#35615)

* An attempt to fix huggingface#29554. Include 'LayerNorm.' in gamma/beta rename scope, reduce number of characters searched on every load considerably.

* Fix fix on load issue

* Fix gamma/beta warning test

* A style complaint

* Improve efficiency of weight norm key rename. Add better comments about weight norm and layer norm renaming.

* Habitual elif redunant with the return
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature Good Difficult Issue Should Fix This has been identified as a bug and should be fixed.
Projects
None yet
6 participants