diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 33845384faaa8..cc245e99e7929 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -52,6 +52,11 @@ def atomic_save(checkpoint, filepath: str): filepath: The path to which the checkpoint will be saved. This points to the file that the checkpoint will be stored in. """ + + for key, value in checkpoint: + if isinstance(value, torch.Tensor) and 'xka' in value.device: + checkpoint[key] = value.cpu() + bytesbuffer = io.BytesIO() # Can't use the new zipfile serialization for 1.6.0 because there's a bug in # torch.hub.load_state_dict_from_url() that prevents it from loading the new files.