Skip to content

Commit

Permalink
move_to_cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
lezwon committed Oct 25, 2020
1 parent b3ef662 commit c0f2ad6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
12 changes: 6 additions & 6 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

import io
from distutils.version import LooseVersion
from typing import Union, IO
from pathlib import Path
from urllib.parse import urlparse
import torch
from typing import IO, Union

import fsspec
import torch

from pytorch_lightning.utilities import move_data_to_device


def load(path_or_url: Union[str, IO, Path], map_location=None):
Expand Down Expand Up @@ -53,9 +55,7 @@ def atomic_save(checkpoint, filepath: str):
This points to the file that the checkpoint will be stored in.
"""

for key, value in checkpoint:
if isinstance(value, torch.Tensor) and 'xka' in value.device:
checkpoint[key] = value.cpu()
checkpoint = move_data_to_device(checkpoint, torch.device("cpu"))

bytesbuffer = io.BytesIO()
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
Expand Down
36 changes: 33 additions & 3 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from multiprocessing import Process, Queue

import pytest
import torch
from torch.utils.data import DataLoader

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.accelerators.accelerator import BackendType
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -32,7 +31,6 @@

if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
SERIAL_EXEC = xmp.MpSerialExecutor()

Expand Down Expand Up @@ -297,3 +295,35 @@ def test_broadcast(rank):
assert result == ("ver_0.5", "logger_name", 0)

xmp.spawn(test_broadcast, nprocs=8, start_method='fork')


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_resume_training_on_cpu():
""" Checks if training can be resumed from a saved checkpoint on CPU"""

# Train a model on TPU
model = EvalModelTemplate()
trainer = Trainer(
checkpoint_callback=True,
max_epochs=10,
tpu_cores=8,
)
trainer.fit(model)

model_path = trainer.checkpoint_callback.best_model_path

# Verify saved Tensors are on CPU
ckpt = torch.load(model_path)
weight_tensor = list(ckpt["state_dict"].values())[0]
assert weight_tensor.device == torch.device("cpu")

# Verify that training is resumed on CPU
trainer = Trainer(
resume_from_checkpoint=model_path,
checkpoint_callback=True,
max_epochs=20,
)
result = trainer.fit(model)

assert result == 1

0 comments on commit c0f2ad6

Please sign in to comment.