-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix safetensors failing tests #27231
Conversation
@@ -2313,5 +2326,8 @@ def __init__(self, config: ProphetNetConfig): | |||
super().__init__(config) | |||
self.decoder = ProphetNetDecoder(config) | |||
|
|||
# This is a link so that tied weights work across classes | |||
self.word_embeddings = self.decoder.word_embeddings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, this is necessary as safetensors
doesn't want to save tied weights and therefore saves a single identifier for all the tied weights.
It becomes an issue when we have a few: in this case we have 4 tied weights, with some being loaded in some models (like encoders), and others being loaded in other models (like decoders).
If the encoder-decoder parent class saves a checkpoint, then it can select to save a single copy of a tensor which is only visible in the encoder; so when loading the decoder, it would discard that weight even though it is the only reference to the tied weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes in this file are incorrect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be overridden by #27240
The documentation is not available anymore as the PR was closed or merged. |
@@ -3715,6 +3722,7 @@ def __init__(self, config): | |||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |||
|
|||
# Initialize weights and apply final processing | |||
self.shared = self.lm_head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't see this is used within SeamlessM4TForSpeechToSpeech
(unlike the above change in SeamlessM4TForSpeechToText
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove this in favor of #27240
@@ -304,6 +304,25 @@ def test_forward_signature(self): | |||
expected_arg_names = ["pixel_values"] | |||
self.assertListEqual(arg_names[:1], expected_arg_names) | |||
|
|||
def test_load_save_without_tied_weights(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kosmos-2
is the only model that requires this test overridden from the common tests: maybe I am doing something wrong when adding it.
Is this to be temporary here. I can take a look for this one later after PR being merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because Kosmos has a config.text_config
rather than just a config
The tests all passed, and overall looks good despite I am not familiar with this part. I am a bit worried tests like |
Thanks for the review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for digging into this and fixing!
From the explanation I think I understand the issue with loading weights for prophetnet - happy if tests pass for these models.
+1 on @ydshieh comment on still keeping regression tests that check saving & loading with torch bin format.
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
* Fix Kosmos2 * Fix ProphetNet * Fix MarianMT * Fix M4T * XLM ProphetNet * ProphetNet fix * XLM ProphetNet * Final M4T fixes * Tied weights keys * Revert M4T changes * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
cc @ydshieh