-
Notifications
You must be signed in to change notification settings - Fork 394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for safetensors #970
Conversation
Description net.save_params and net.load_params now support passing use_safetensors=True, which will result in the underlying state_dict being serialized/deserialized using safetensors instead of torch.save and torch.load (which both rely on pickle). Similarly, Checkpoint, TrainEndCheckpoint, and LoadInitState now support use_safetensors=True. By default, nothing changes for the user, safetensors is opt in. Motivation safetensors has a couple of advantages over pickle, which are described [here](https://github.com/huggingface/safetensors/#yet-another-format-). Most notably, it doesn't have the security issues that pickle has. Recently, it has even been [audited](https://huggingface.co/blog/safetensors-security-audit) and found to be safe. Furthermore, it is growing in popularity, being natively supported by Hugging Face Hub for instance, and it is compatible with most major frameworks like tensorflow. Implementation Using safetensors requires users to install the library. I haven't added it as a default dependency, to keep those lean. I also haven't added any special error messages when safetensors is not installed, I think the standard Python message is sufficient. The code changes have been mostly very straightforward. A small caveat is that, contrary to using torch.load/torch.save, this does not work for serializing the optimizer. That is because safetensors is only capable of serializing tensors, but the optimizer state_dict contains other stuff. If users absolutely need to save the optimizer, they have to use pickle for that. I was entertaining the idea of inferring the serialization format by the file name, enabling something like: net = ... net.save_params(f_params='module.safetensors', f_optimizer='optimizer.pkl') i.e. the module to be stored with safetensors and the optimizer with pickle. But this would add quite some complexity and ambiguity for very little benefit. The user can just make a separate call to `save_params` to save the optimizer without safetensors instead. Similarly, they can add a second Checkpoint callback for the optimizer. I added a check in the callbacks to see if f_optimizer is being set when use_safetensors=True, in which case I raise an error. This seemed necessary to me because it is set there by default and because these callbacks intercept errors very broadly when saving params, so it can be easy for users to miss when the optimizer is not actually being saved (there is a message printed if net.verbose but this is easy to miss IMO). Most of the changes were required for testing. There, I had to use quite a few if...else because I needed to exclude the optimizer each time that safetensors is used. I also made two unrelated changes: - For Checkpoint, I moved a docstring argument lower to reflect its order in the argument list. - I deleted the pickle_dump_mock fixture in TestTrainEndCheckpoint, which wasn't being used.
Tests failing for unrelated reasons, waiting for #990 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about the usability of this and think that it is OK.
I think I would prefer inference of file type by extension but agree that it is something that can bring complexity quickly if there are co-dependencies between formats, for example.
If safetensors gains popularity I can see platforms that take weights and store them for archival/experiment tracking using them as well. In that case checkpointing is in the impossible situation to need to store the optimizer (state dict) and the weights. In that case the user could, of course, add two checkpoint callbacks, one internal (using pickle) and one to export checkpoints to the platform. My argument also only holds for platforms where you store intermediate results (and want the user to be able to continue training afterwards), so very niche. Therefore I think the current solution is fine.
Description
net.save_params
andnet.load_params
now support passinguse_safetensors=True
, which will result in the underlyingstate_dict
being serialized/deserialized using safetensors instead oftorch.save
andtorch.load
(which both rely on pickle).Similarly,
Checkpoint
,TrainEndCheckpoint
, andLoadInitState
now supportuse_safetensors=True
.By default, nothing changes for the user, safetensors is opt in.
Motivation
safetensors has a couple of advantages over pickle, which are described here. Most notably, it doesn't have the security issues that pickle has. Recently, it has even been audited and found to be safe. Furthermore, it is growing in popularity, being natively supported by Hugging Face Hub for instance, and it is compatible with most major frameworks like tensorflow.
Implementation
Using safetensors requires users to install the library. I haven't added it as a default dependency, to keep those lean. I also haven't added any special error messages when safetensors is not installed, I think the standard Python error message is sufficient.
The code changes have been mostly very straightforward. A small caveat is that, contrary to using
torch.load/torch.save
, this does not work for serializing the optimizer. That is because safetensors is only capable of serializing tensors, but the optimizerstate_dict
contains other stuff. If users absolutely need to save the optimizer, they have to use pickle for that.I was entertaining the idea of inferring the serialization format by the file name, enabling something like:
i.e. allowing the module to be stored with safetensors and the optimizer with pickle in a single method call. But supporting this would add quite some complexity and ambiguity for very little benefit. The user can just make a separate call to
save_params
to save the optimizer without safetensors instead. Similarly, they can add a secondCheckpoint
callback for the optimizer.I added a check in the callbacks to see if
f_optimizer
is being set whenuse_safetensors=True
, in which case I raise an error. This seemed necessary to me because it is set there by default and because these callbacks intercept errors very broadly when saving params, so it can be easy for users to miss when the optimizer is not actually being saved (there is a message printed ifnet.verbose
but this is easy to miss IMO).For
net.save_params
, there already is an error message when using safetensors fails. The only thing I did was to intercept it and add more context to the error being raised.Most of the changes were required for testing. There, I had to use quite a few
if...else
because I needed to exclude the optimizer each time that safetensors is used.I also made two unrelated changes:
Checkpoint
, I moved a docstring argument lower to reflect its order in the argument list.pickle_dump_mock
fixture inTestTrainEndCheckpoint
, which isn't being used.