Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for safetensors #970

Merged
merged 5 commits into from
Jul 26, 2023
Merged

Add support for safetensors #970

merged 5 commits into from
Jul 26, 2023

Conversation

BenjaminBossan
Copy link
Collaborator

Description

net.save_params and net.load_params now support passing use_safetensors=True, which will result in the underlying state_dict being serialized/deserialized using safetensors instead of torch.save and torch.load (which both rely on pickle).

Similarly, Checkpoint, TrainEndCheckpoint, and LoadInitState now support use_safetensors=True.

By default, nothing changes for the user, safetensors is opt in.

Motivation

safetensors has a couple of advantages over pickle, which are described here. Most notably, it doesn't have the security issues that pickle has. Recently, it has even been audited and found to be safe. Furthermore, it is growing in popularity, being natively supported by Hugging Face Hub for instance, and it is compatible with most major frameworks like tensorflow.

Implementation

Using safetensors requires users to install the library. I haven't added it as a default dependency, to keep those lean. I also haven't added any special error messages when safetensors is not installed, I think the standard Python error message is sufficient.

The code changes have been mostly very straightforward. A small caveat is that, contrary to using torch.load/torch.save, this does not work for serializing the optimizer. That is because safetensors is only capable of serializing tensors, but the optimizer state_dict contains other stuff. If users absolutely need to save the optimizer, they have to use pickle for that.

I was entertaining the idea of inferring the serialization format by the file name, enabling something like:

net = ...
net.save_params(f_params='module.safetensors', f_optimizer='optimizer.pkl')

i.e. allowing the module to be stored with safetensors and the optimizer with pickle in a single method call. But supporting this would add quite some complexity and ambiguity for very little benefit. The user can just make a separate call to save_params to save the optimizer without safetensors instead. Similarly, they can add a second Checkpoint callback for the optimizer.

I added a check in the callbacks to see if f_optimizer is being set when use_safetensors=True, in which case I raise an error. This seemed necessary to me because it is set there by default and because these callbacks intercept errors very broadly when saving params, so it can be easy for users to miss when the optimizer is not actually being saved (there is a message printed if net.verbose but this is easy to miss IMO).

For net.save_params, there already is an error message when using safetensors fails. The only thing I did was to intercept it and add more context to the error being raised.

Most of the changes were required for testing. There, I had to use quite a few if...else because I needed to exclude the optimizer each time that safetensors is used.

I also made two unrelated changes:

  • For Checkpoint, I moved a docstring argument lower to reflect its order in the argument list.
  • I deleted the pickle_dump_mock fixture in TestTrainEndCheckpoint, which isn't being used.

Description

net.save_params and net.load_params now support passing
use_safetensors=True, which will result in the underlying state_dict
being serialized/deserialized using safetensors instead of torch.save
and torch.load (which both rely on pickle).

Similarly, Checkpoint, TrainEndCheckpoint, and LoadInitState now support
use_safetensors=True.

By default, nothing changes for the user, safetensors is opt in.

Motivation

safetensors has a couple of advantages over pickle, which are described
[here](https://github.com/huggingface/safetensors/#yet-another-format-).
Most notably, it doesn't have the security issues that pickle has.
Recently, it has even been
[audited](https://huggingface.co/blog/safetensors-security-audit) and
found to be safe. Furthermore, it is growing in popularity, being
natively supported by Hugging Face Hub for instance, and it is
compatible with most major frameworks like tensorflow.

Implementation

Using safetensors requires users to install the library. I haven't added
it as a default dependency, to keep those lean. I also haven't added any
special error messages when safetensors is not installed, I think the
standard Python message is sufficient.

The code changes have been mostly very straightforward. A small caveat
is that, contrary to using torch.load/torch.save, this does not work
for serializing the optimizer. That is because safetensors is only
capable of serializing tensors, but the optimizer state_dict contains
other stuff. If users absolutely need to save the optimizer, they have
to use pickle for that.

I was entertaining the idea of inferring the serialization format by the
file name, enabling something like:

net = ...
net.save_params(f_params='module.safetensors', f_optimizer='optimizer.pkl')

i.e. the module to be stored with safetensors and the optimizer with
pickle. But this would add quite some complexity and ambiguity for very
little benefit. The user can just make a separate call to `save_params`
to save the optimizer without safetensors instead. Similarly, they can
add a second Checkpoint callback for the optimizer.

I added a check in the callbacks to see if f_optimizer is being set when
use_safetensors=True, in which case I raise an error. This seemed
necessary to me because it is set there by default and because these
callbacks intercept errors very broadly when saving params, so it can be
easy for users to miss when the optimizer is not actually being
saved (there is a message printed if net.verbose but this is easy to
miss IMO).

Most of the changes were required for testing. There, I had to use quite
a few if...else because I needed to exclude the optimizer each time that
safetensors is used.

I also made two unrelated changes:

- For Checkpoint, I moved a docstring argument lower to reflect its
  order in the argument list.
- I deleted the pickle_dump_mock fixture in TestTrainEndCheckpoint,
  which wasn't being used.
@BenjaminBossan
Copy link
Collaborator Author

Tests failing for unrelated reasons, waiting for #990

Copy link
Member

@ottonemo ottonemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about the usability of this and think that it is OK.
I think I would prefer inference of file type by extension but agree that it is something that can bring complexity quickly if there are co-dependencies between formats, for example.

If safetensors gains popularity I can see platforms that take weights and store them for archival/experiment tracking using them as well. In that case checkpointing is in the impossible situation to need to store the optimizer (state dict) and the weights. In that case the user could, of course, add two checkpoint callbacks, one internal (using pickle) and one to export checkpoints to the platform. My argument also only holds for platforms where you store intermediate results (and want the user to be able to continue training afterwards), so very niche. Therefore I think the current solution is fine.

@BenjaminBossan BenjaminBossan merged commit e83e6a4 into master Jul 26, 2023
@BenjaminBossan BenjaminBossan deleted the support-safetensors branch July 26, 2023 09:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants