Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parametrization-based weight norm #33275

Merged

Conversation

ylacombe
Copy link
Contributor

@ylacombe ylacombe commented Sep 3, 2024

What does this PR do?

Supersedes #32194 and fixes #31970 and #26796!

While #32194 was already a great work, it wasn't compatible with versions of Torch that only had nn.utils.weight_norm.

I'll left a review to explain some choices and to highlight where I'm not quite sure of my solution!

cc @LysandreJik and @ArthurZucker !

@@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
def _load_state_dict_into_meta_model(
model,
state_dict,
loaded_state_dict_keys, # left for now but could be removed, see below
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I removed loaded_state_dict_keys from _load_state_dict_into_meta_model, because according to the following snippet, it was not actually used before:

        # First part of the test is always true as load_state_dict_keys always contains state_dict keys.
        if param_name not in loaded_state_dict_keys or param_name not in expected_keys:

I might have overlooked some downside effects, especially with quantization and/or training frameworks. WDYT @ArthurZucker and @LysandreJik ? Who should I tag for more info?

Also happy to change back to the original behaviour

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it doesn't break any tests, let's remove it and keep an eye out for eventual breakage

@@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
def _load_state_dict_into_meta_model(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As explained here, the issue doesn't appear when doing regular loading of the state dict, but only when doing metaloading!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me in practice for the affected models; @ArthurZucker if you can give it a second look just to confirm or infirm

@@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
def _load_state_dict_into_meta_model(
model,
state_dict,
loaded_state_dict_keys, # left for now but could be removed, see below
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it doesn't break any tests, let's remove it and keep an eye out for eventual breakage

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing @ylacombe!

Adding remapping in the loading functions I'm a bit squeamish about, as it causes issues for "gamma" and "beta" but this seems pretty well controlled and an only likely to hit some weights very rarely.

@ylacombe ylacombe merged commit 18e1a9c into huggingface:main Sep 17, 2024
23 checks passed
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* refactor weight_norm + propose uniformed solution to reconcile meta load_state_dict with classic loading

* make style

* fix sew

* fix sew and sew_d tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

WavLM returns empty hidden states when loaded directly to GPU
4 participants