-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix parametrization-based weight norm #33275
Fix parametrization-based weight norm #33275
Conversation
…oad_state_dict with classic loading
@@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): | |||
def _load_state_dict_into_meta_model( | |||
model, | |||
state_dict, | |||
loaded_state_dict_keys, # left for now but could be removed, see below |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, I removed loaded_state_dict_keys
from _load_state_dict_into_meta_model
, because according to the following snippet, it was not actually used before:
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
I might have overlooked some downside effects, especially with quantization and/or training frameworks. WDYT @ArthurZucker and @LysandreJik ? Who should I tag for more info?
Also happy to change back to the original behaviour
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it doesn't break any tests, let's remove it and keep an eye out for eventual breakage
@@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): | |||
def _load_state_dict_into_meta_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As explained here, the issue doesn't appear when doing regular loading of the state dict, but only when doing metaloading!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me in practice for the affected models; @ArthurZucker if you can give it a second look just to confirm or infirm
@@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): | |||
def _load_state_dict_into_meta_model( | |||
model, | |||
state_dict, | |||
loaded_state_dict_keys, # left for now but could be removed, see below |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it doesn't break any tests, let's remove it and keep an eye out for eventual breakage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing @ylacombe!
Adding remapping in the loading functions I'm a bit squeamish about, as it causes issues for "gamma" and "beta" but this seems pretty well controlled and an only likely to hit some weights very rarely.
* refactor weight_norm + propose uniformed solution to reconcile meta load_state_dict with classic loading * make style * fix sew * fix sew and sew_d tests
What does this PR do?
Supersedes #32194 and fixes #31970 and #26796!
While #32194 was already a great work, it wasn't compatible with versions of Torch that only had
nn.utils.weight_norm
.I'll left a review to explain some choices and to highlight where I'm not quite sure of my solution!
cc @LysandreJik and @ArthurZucker !