Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weight loading of weight_g_idx compressed-tensors parameters #35741

Closed
wants to merge 3 commits into from

Conversation

dsikka
Copy link
Contributor

@dsikka dsikka commented Jan 16, 2025

What does this PR do?

  • There is currently a check in modeling_utils.py which replaces parameters with the weight_g substring in the parameter key with parametrizations.weight.original0. This incorrectly replaces the substring of compressed-tensors model parameters with the key weight_g_idx, resulting in a loading error.

if "weight_g" in key:

  • Updates the src/transformers/quantizers/quantizer_compressed_tensors.py logic to temporarily turn off the weight_norm attribute such that the substring replacement does not occur in modeling_utils.py during weight loading. It then turns it back on in _process_model_after_weight_loading. This allows the models to be loaded correctly

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc
@ArthurZucker

@dsikka dsikka changed the title Fix weight loading of weight_g_idx parameters of compressed-tensors models Fix weight loading of weight_g_idx compressed-tensors parameters Jan 16, 2025
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @dsikka ! Instead of a hotfix inside this quantization method, we can maybe solve the issue at its core. I think that you are having this issue because of a recent PR merged to add support for Timm models cc @qubvel. Before, we were only modifying the loaded keys but now, we are also fixing the state_dict which led to this situation. Also the real issue is that the condition to perform the modification of the keys is not strict enough.

    @staticmethod
    def _fix_state_dict_key_on_load(key):
        """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""

        if "beta" in key:
            return key.replace("beta", "bias")
        if "gamma" in key:
            return key.replace("gamma", "weight")

        # to avoid logging parametrized weight norm renaming
        if hasattr(nn.utils.parametrizations, "weight_norm"):
            if "weight_g" in key:
                return key.replace("weight_g", "parametrizations.weight.original0")
            if "weight_v" in key:
                return key.replace("weight_v", "parametrizations.weight.original1")
        else:
            if "parametrizations.weight.original0" in key:
                return key.replace("parametrizations.weight.original0", "weight_g")
            if "parametrizations.weight.original1" in key:
                return key.replace("parametrizations.weight.original1", "weight_v")
        return key

A potential workaround would be to check if key.endswith("weight_g") or key.endswith("weight_v") instead

@qubvel
Copy link
Member

qubvel commented Jan 17, 2025

This is also discussed in the latest fix for gamma/ beta renaming #35615

And already should be fixed, because the above PR introduced .endswith("weight_v") change.

@dsikka please update to the latest source code from main and let us know if you still face any issues, thanks 🤗

@dsikka
Copy link
Contributor Author

dsikka commented Jan 17, 2025

perfect timing, that PR fixes our issue.

Thanks for linking!

@dsikka dsikka closed this Jan 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants