-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TF loading PT safetensors when weights are tied #27490
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Update: I tried to get the test to trigger other models by adding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
Main question / comment is about the pooling layer being added.
For a follow-up PR it would be good to add a set of tests which check saving/loading of models when optional weight tying is activated
@@ -1088,7 +1094,7 @@ def __init__(self, config, *inputs, **kwargs): | |||
super().__init__(config, *inputs, **kwargs) | |||
self.num_labels = config.num_labels | |||
|
|||
self.mpnet = TFMPNetMainLayer(config, name="mpnet") | |||
self.mpnet = TFMPNetMainLayer(config, add_pooling_layer=False, name="mpnet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this backwards compatible? It looks like the pooling layer would have been added previously
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It isn't, but as-is the model is incompatible with Torch (and was causing tests to fail) because the TF models were gaining an extra layer that the Torch ones didn't have! I think this is the lesser of two evils, since I don't think there are too many TF MPNet saves floating around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but as-is the model is incompatible with Torch
My understanding it that the model would still load from PT weights, but would randomly initialize these layers. Are there any other incompatibilities?
I don't think there are too many TF MPNet saves floating around.
Unfortunately, we can't know TF vs. PT downloads from the hub, but there's a handful of repos which have a decent number os downloaded e.g. here and here.
It's arguably making the model "correct" but I believe this will drop layers for pretained checkpoints that previously had pooling layers - effecting finetuned models and downstream tasks.
We should try and find a way to make sure the layers are kept for older models but not added to newer ones
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, actually - I'll revert these changes and skip the TF-PT compatibility check for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it's possible to have different behaviour for old and new models, though - weights are initialized before being loaded from the checkpoint, so can't really change up the config based on whether the checkpoint has a certain tensor or not!
Still, you're definitely right that preserving backward compatibility matters more here, so I'll just revert the changes and patch the tests.
Added your suggestion @amyeroberts! The method always returns a tuple now. |
Quick ping again for final approval @amyeroberts ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
Only concern is for the changes for MPNet models - all others LGTM
tests/test_modeling_common.py
Outdated
@@ -3224,7 +3224,6 @@ def test_flash_attn_2_fp32_ln(self): | |||
def test_tf_from_pt_safetensors(self): | |||
for model_class in self.all_model_classes: | |||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | |||
config.tie_word_embeddings = True # Tied weights often cause safetensors loading to fail |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it in this PR to check if it caused any more failures - it didn't, so I took it back out again. It doesn't really fit in this test anyway, I just wanted to see if there were more models that might have issues with it!
@@ -1088,7 +1094,7 @@ def __init__(self, config, *inputs, **kwargs): | |||
super().__init__(config, *inputs, **kwargs) | |||
self.num_labels = config.num_labels | |||
|
|||
self.mpnet = TFMPNetMainLayer(config, name="mpnet") | |||
self.mpnet = TFMPNetMainLayer(config, add_pooling_layer=False, name="mpnet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but as-is the model is incompatible with Torch
My understanding it that the model would still load from PT weights, but would randomly initialize these layers. Are there any other incompatibilities?
I don't think there are too many TF MPNet saves floating around.
Unfortunately, we can't know TF vs. PT downloads from the hub, but there's a handful of repos which have a decent number os downloaded e.g. here and here.
It's arguably making the model "correct" but I believe this will drop layers for pretained checkpoints that previously had pooling layers - effecting finetuned models and downstream tasks.
We should try and find a way to make sure the layers are kept for older models but not added to newer ones
Before merge, let's check this PR agains the Hub Repo:
|
c1829b7
to
2981063
Compare
@ydshieh - good spot! I missed adding this method for BART. I tested and the slow tests that are failing in the CI are passing now that I've added it. |
Thanks a lot, @Rocketknight1 . The However, for
|
@ydshieh noted! I think I'll need a separate PR for those, though, since it's a composite model. This PR should fix most models, though - can we merge it urgently before the release while I work on something to fix models like RAG? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @Rocketknight1 for the quick fix! Let's find a more elegant fix after the release, and also fix composite models
@LysandreJik yes, will do! I think this will synergize with the new weight building #27794 as well, and we should be able to get TF-safetensors in good shape soon. |
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
…ny unexpected failures anyway
0eb1373
to
0c5c883
Compare
This PR resolves issues when loading PT safetensors files in TF.
The cause of the problem was that safetensors saving discards "aliased" tensors, weight tensors that share the same underlying weight array. This commonly occurs when models use tied weights. Many of our TF models don't support weight tying, however, and as a result the decoder output weights fail to load correctly.
The solution is to use a trick we already use for encoder-decoder models, a model-specific
tf_to_pt_weight_rename
method. This PR refactors the way that method is called to make it more accessible (no more need to overridefrom_pretrained
), and addstf_to_pt_weight_rename
methods to the affected models.However, I suspect there are more affected models which aren't showing up in this test, because the only models failing in this test are the models that always use weight tying without needing a config flag. If a model optionally ties weights based on a flag, that flag will not be set in this test. I suspect the same fix will be needed for several more models as a result, even though this test doesn't flag them.