-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't default to other weights file when use_safetensors=True #31874
Don't default to other weights file when use_safetensors=True #31874
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Looks good! Thank you so much for your help! @amyeroberts |
elif os.path.isfile( | ||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") | ||
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)): | ||
elif not use_safetensors and ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use if from_tf and ...
here, to correspond to (several lines above)
if is_local:
if from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
elif not use_safetensors and os.path.isfile( | ||
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
@amyeroberts Thank you. Although I think the changes make some sense, I am not 100% sure if the current design is explicitly aiming to fail only when it tries to load from a remote repository. And when it is local, it is designed to not fail (i.e. always trust a local file). Therefore I would prefer @LysandreJik to confirm it is a design or it is just a miss. |
@ydshieh Thanks for the review!
That's a good point. If that's the case, then I think this is a bit more involved, as I think we'd need to update documentation to make it clearer and possibly revise some of the flags in |
Yes, I agree. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense! Thanks for your PR @amyeroberts !
And regarding the two messages above, I don't think we want to
We want local execution to behave exactly the same as it would with a remote repository. Otherwise, it can get very messy IMO. |
tests/utils/test_modeling_utils.py
Outdated
@@ -815,6 +815,62 @@ def test_checkpoint_variant_local_sharded_safe(self): | |||
for p1, p2 in zip(model.parameters(), new_model.parameters()): | |||
self.assertTrue(torch.allclose(p1, p2)) | |||
|
|||
def test_checkpoint_safetensors(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good (to all of us) that we write a short comment about why we add a test.
As I am sure we already have some tests that are about saving/loading. And when we add these 2 new tests, the main test case is about failing the loading if a format is specified (as we don't switch to other format)
.
Reading the test names here doesn't tell me what is the main objective, and I will be confused why we still add new saving/loading tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only have 2 minor comments otherwise it's good!
(Comments for the tests' goal would be very nice.)
5b6ee7b
to
fd0a5cc
Compare
fd0a5cc
to
70b2e26
Compare
What does this PR do?
Fixes #31649
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.