Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TF loading PT safetensors when weights are tied #27490

Merged
merged 21 commits into from
Dec 7, 2023

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Nov 14, 2023

This PR resolves issues when loading PT safetensors files in TF.

The cause of the problem was that safetensors saving discards "aliased" tensors, weight tensors that share the same underlying weight array. This commonly occurs when models use tied weights. Many of our TF models don't support weight tying, however, and as a result the decoder output weights fail to load correctly.

The solution is to use a trick we already use for encoder-decoder models, a model-specific tf_to_pt_weight_rename method. This PR refactors the way that method is called to make it more accessible (no more need to override from_pretrained), and adds tf_to_pt_weight_rename methods to the affected models.

However, I suspect there are more affected models which aren't showing up in this test, because the only models failing in this test are the models that always use weight tying without needing a config flag. If a model optionally ties weights based on a flag, that flag will not be set in this test. I suspect the same fix will be needed for several more models as a result, even though this test doesn't flag them.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@Rocketknight1 Rocketknight1 marked this pull request as ready for review November 14, 2023 15:44
@Rocketknight1
Copy link
Member Author

Update: I tried to get the test to trigger other models by adding config.tie_word_embeddings=True, but everything still seems to be passing, so I guess this PR is ready for review! cc @ArthurZucker or @amyeroberts

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Main question / comment is about the pooling layer being added.

For a follow-up PR it would be good to add a set of tests which check saving/loading of models when optional weight tying is activated

src/transformers/modeling_tf_pytorch_utils.py Outdated Show resolved Hide resolved
@@ -1088,7 +1094,7 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels

self.mpnet = TFMPNetMainLayer(config, name="mpnet")
self.mpnet = TFMPNetMainLayer(config, add_pooling_layer=False, name="mpnet")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this backwards compatible? It looks like the pooling layer would have been added previously

Copy link
Member Author

@Rocketknight1 Rocketknight1 Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't, but as-is the model is incompatible with Torch (and was causing tests to fail) because the TF models were gaining an extra layer that the Torch ones didn't have! I think this is the lesser of two evils, since I don't think there are too many TF MPNet saves floating around.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but as-is the model is incompatible with Torch

My understanding it that the model would still load from PT weights, but would randomly initialize these layers. Are there any other incompatibilities?

I don't think there are too many TF MPNet saves floating around.

Unfortunately, we can't know TF vs. PT downloads from the hub, but there's a handful of repos which have a decent number os downloaded e.g. here and here.

It's arguably making the model "correct" but I believe this will drop layers for pretained checkpoints that previously had pooling layers - effecting finetuned models and downstream tasks.

We should try and find a way to make sure the layers are kept for older models but not added to newer ones

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, actually - I'll revert these changes and skip the TF-PT compatibility check for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's possible to have different behaviour for old and new models, though - weights are initialized before being loaded from the checkpoint, so can't really change up the config based on whether the checkpoint has a certain tensor or not!

Still, you're definitely right that preserving backward compatibility matters more here, so I'll just revert the changes and patch the tests.

@Rocketknight1
Copy link
Member Author

Added your suggestion @amyeroberts! The method always returns a tuple now.

@Rocketknight1
Copy link
Member Author

Quick ping again for final approval @amyeroberts !

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

Only concern is for the changes for MPNet models - all others LGTM

@@ -3224,7 +3224,6 @@ def test_flash_attn_2_fp32_ln(self):
def test_tf_from_pt_safetensors(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = True # Tied weights often cause safetensors loading to fail
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it in this PR to check if it caused any more failures - it didn't, so I took it back out again. It doesn't really fit in this test anyway, I just wanted to see if there were more models that might have issues with it!

@@ -1088,7 +1094,7 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels

self.mpnet = TFMPNetMainLayer(config, name="mpnet")
self.mpnet = TFMPNetMainLayer(config, add_pooling_layer=False, name="mpnet")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but as-is the model is incompatible with Torch

My understanding it that the model would still load from PT weights, but would randomly initialize these layers. Are there any other incompatibilities?

I don't think there are too many TF MPNet saves floating around.

Unfortunately, we can't know TF vs. PT downloads from the hub, but there's a handful of repos which have a decent number os downloaded e.g. here and here.

It's arguably making the model "correct" but I believe this will drop layers for pretained checkpoints that previously had pooling layers - effecting finetuned models and downstream tasks.

We should try and find a way to make sure the layers are kept for older models but not added to newer ones

@ydshieh
Copy link
Collaborator

ydshieh commented Dec 1, 2023

Before merge, let's check this PR agains the Hub Repo: "facebook/bart-large-cnn", say by running

tests/models/bart/test_modeling_tf_bart.py::TFBartModelIntegrationTest::test_cnn_summarization_same_as_fairseq_hard

@Rocketknight1 Rocketknight1 force-pushed the fix-pt-tf-tied-weights-safetensors branch from c1829b7 to 2981063 Compare December 1, 2023 12:52
@Rocketknight1
Copy link
Member Author

@ydshieh - good spot! I missed adding this method for BART. I tested and the slow tests that are failing in the CI are passing now that I've added it.

@ydshieh
Copy link
Collaborator

ydshieh commented Dec 1, 2023

Thanks a lot, @Rocketknight1 . The TFBartModelIntegrationTest all pass now with this PR.

However, for TFRag still have 5 failing tests (as on the slack report)

FAILED tests/models/rag/test_modeling_tf_rag.py::TFRagModelIntegrationTests::test_rag_sequence_inference - ValueError: Weight name final_logits_bias:0 does not start with name_scope tf_rag_sequence_for_generation_1/rag. This is an internal error in Transformers, so (unless you were doing something really evil) please open an...

FAILED tests/models/rag/test_modeling_tf_rag.py::TFRagModelIntegrationTests::test_rag_token_inference - ValueError: Weight name final_logits_bias:0 does not start with name_scope tf_rag_token_for_generation_1/rag. This is an internal error in Transformers, so (unless you were doing something really evil) please open an issue...

FAILED tests/models/rag/test_modeling_tf_rag.py::TFRagModelIntegrationTests::test_rag_token_inference_save_pretrained - ValueError: Weight name final_logits_bias:0 does not start with name_scope tf_rag_token_for_generation_1/rag. This is an internal error in Transformers, so (unless you were doing something really evil) plea...

FAILED tests/models/rag/test_modeling_tf_rag.py::TFRagModelSaveLoadTests::test_rag_sequence_from_pretrained - ValueError: Weight name final_logits_bias:0 does not start with name_scope tf_rag_sequence_for_generation_1/rag. This is an internal error in Transformers, so (unless you were doing something really evil) please open...

FAILED tests/models/rag/test_modeling_tf_rag.py::TFRagModelSaveLoadTests::test_rag_token_from_pretrained - ValueError: Weight name final_logits_bias:0 does not start with name_scope tf_rag_token_for_generation_1/rag. This is an internal error in Transformers, so (unless you were doing something really evil) please open an is...


@Rocketknight1
Copy link
Member Author

@ydshieh noted! I think I'll need a separate PR for those, though, since it's a composite model. This PR should fix most models, though - can we merge it urgently before the release while I work on something to fix models like RAG?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Rocketknight1 for the quick fix! Let's find a more elegant fix after the release, and also fix composite models

@Rocketknight1
Copy link
Member Author

Rocketknight1 commented Dec 7, 2023

@LysandreJik yes, will do! I think this will synergize with the new weight building #27794 as well, and we should be able to get TF-safetensors in good shape soon.

@Rocketknight1 Rocketknight1 force-pushed the fix-pt-tf-tied-weights-safetensors branch from 0eb1373 to 0c5c883 Compare December 7, 2023 14:04
@Rocketknight1 Rocketknight1 merged commit 47500b1 into main Dec 7, 2023
3 checks passed
@Rocketknight1 Rocketknight1 deleted the fix-pt-tf-tied-weights-safetensors branch December 7, 2023 14:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants