Skip to content

Commit

Permalink
convert xla tensor to cpu before save
Browse files Browse the repository at this point in the history
  • Loading branch information
lezwon committed Oct 25, 2020
1 parent 6ad2995 commit b3ef662
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def atomic_save(checkpoint, filepath: str):
filepath: The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in.
"""

for key, value in checkpoint:
if isinstance(value, torch.Tensor) and 'xka' in value.device:
checkpoint[key] = value.cpu()

bytesbuffer = io.BytesIO()
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
Expand Down

0 comments on commit b3ef662

Please sign in to comment.