-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic safetensors conversion when lacking these files #29390
Conversation
src/transformers/modeling_utils.py
Outdated
cls._auto_conversion = Thread( | ||
target=auto_conversion, | ||
args=(pretrained_model_name_or_path,), | ||
kwargs=cached_file_kwargs, | ||
) | ||
cls._auto_conversion.start() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Wauplin curious if you have a better idea in mind to have access to the thread started here; I don't need to join it during runtime, I'm only attributing it to the class here so that I can access it within the test files (but not super keen on modifying internals just for the tests to be simpler ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LysandreJik I'm not shocked by having a cls._auto_conversion
attribute TBH. Though a solution to get rid of it is to give a name to the thread. Something like that:
Thread(
target=auto_conversion,
args=(pretrained_model_name_or_path,),
kwargs=cached_file_kwargs,
name="Thread-autoconversion-{<unique id here>}",
).start()
and then in the tests:
for thread in threading.enumerate():
print(thread.name)
# ...
# Thread-autoconversion-0
Thread names don't have to be unique BTW (they have a thread id anyway). But I think it's best to at least assign a unique number to the name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it's quite hacky IMO. In a simple case it should work fine but if you start to have several threads / parallel tests, it might get harder to be 100% sure the thread you've started is indeed the one you retrieve in the test logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah here it's really only for testing and I don't want to depend on a flaky time.sleep
or something so ensuring that the thread joins first is optimal. The thread name is actually much better IMO, I'll implement that! Thanks a lot!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neat stuff.
UI-wise, Let's also think about whether we add some kind of "official HF Staff" tag to the bot's PRs, or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean and very nice!
# message. | ||
|
||
if resolved_archive_file is not None: | ||
if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing I would be wary is just that if we convert a big checkpoint from torch to safetensors and we want to load it in Flax
, sharded safetensors are not supported yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flax defaults to loading flax checkpoints, not safetensors, so it won't be affected by a repo where there is sharded safetensors
@@ -1428,7 +1429,7 @@ def test_safetensors_on_the_fly_wrong_user_opened_pr(self): | |||
bot_opened_pr_title = None | |||
|
|||
for discussion in discussions: | |||
if discussion.author == "SFconvertBot": | |||
if discussion.author == "SFconvertbot": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
➕ on @julien-c's comment, have had feedback that this is not explicit enough.
if discussion.author == "SFconvertbot": | |
if discussion.author == "HuggingFaceOfficialSafetensorConverter": |
bot
is scary for some 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can't change the account name now
but we will think of a way to make it clearer in the UI that it's a "official bot"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good 👍🏻
Thanks both for the review! |
I'll merge it now and will keep monitoring issues to ensure it doesn't break things in the wild. |
* Automatic safetensors conversion when lacking these files * Remove debug * Thread name * Typo * Ensure that raises do not affect the main thread
When a user calls the PyTorch
from_pretrained
on a repository that only contains PyTorch/Flax/TF files, start an auto conversion in the background so that it has a PR opened withsafetensors
files.