Skip to content

Commit

Permalink
added on_save to accelerators
Browse files Browse the repository at this point in the history
  • Loading branch information
lezwon committed Nov 5, 2020
1 parent b2ce064 commit 6013ed8
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 36 deletions.
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def __setstate__(self, d):
self.dist = d['dist']
self.ddp_plugin = d['ddp_plugin']

def on_save(self, checkpoint):
return checkpoint


# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
class BackendType(Enum):
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, move_data_to_device
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
Expand Down Expand Up @@ -320,7 +320,8 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
state_dict = move_data_to_device(model.state_dict(), torch.device("cpu"))
atomic_save(state_dict, last_path)
mp_queue.put(last_path)

def broadcast(self, obj, src=0):
Expand All @@ -332,3 +333,11 @@ def broadcast(self, obj, src=0):
buffer = io.BytesIO(data.cpu().byte().numpy())
obj = torch.load(buffer)
return obj

def on_save(self, checkpoint):
"""
Move XLA tensors to CPU before saving
Recommended on XLA Guide:
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
"""
return move_data_to_device(checkpoint, torch.device("cpu"))
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def hpc_save(self, folderpath: str, logger):
checkpoint = self.dump_checkpoint()

model.on_hpc_save(checkpoint)
checkpoint = self.trainer.accelerator_backend.on_save(checkpoint)

# do the actual save
# TODO: fix for anything with multiprocess DP, DDP, DDP2
Expand Down Expand Up @@ -388,6 +389,7 @@ def save_checkpoint(self, filepath, weights_only: bool = False):

if self.trainer.is_global_zero:
# write the checkpoint dictionary on the file
checkpoint = self.trainer.accelerator_backend.on_save(checkpoint)
try:
atomic_save(checkpoint, filepath)
except AttributeError as err:
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def atomic_save(checkpoint, filepath: str):
This points to the file that the checkpoint will be stored in.
"""

checkpoint = move_data_to_device(checkpoint, torch.device("cpu"))

bytesbuffer = io.BytesIO()
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
Expand Down
70 changes: 70 additions & 0 deletions tests/backends/test_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from tests.base.boring_model import BoringModel
from tests.base.develop_utils import pl_multi_process_test


@pytest.mark.skipif(not XLADeviceUtils.TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_resume_training_on_cpu():
""" Checks if training can be resumed from a saved checkpoint on CPU"""

# Train a model on TPU
model = BoringModel()
trainer = Trainer(
checkpoint_callback=True,
max_epochs=10,
tpu_cores=8,
)
trainer.fit(model)

model_path = trainer.checkpoint_callback.best_model_path

# Verify saved Tensors are on CPU
ckpt = torch.load(model_path)
weight_tensor = list(ckpt["state_dict"].values())[0]
assert weight_tensor.device == torch.device("cpu")

# Verify that training is resumed on CPU
trainer = Trainer(
resume_from_checkpoint=model_path,
checkpoint_callback=True,
max_epochs=20,
)
result = trainer.fit(model)

assert result == 1


@pytest.mark.skipif(not XLADeviceUtils.TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_if_test_works_after_train():
""" Ensure that .test() works after .fit() """

# Train a model on TPU
model = BoringModel()
trainer = Trainer(
checkpoint_callback=True,
max_epochs=10,
tpu_cores=8,
)
trainer.fit(model)

assert trainer.test() is not None
32 changes: 0 additions & 32 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,35 +295,3 @@ def test_broadcast(rank):
assert result == ("ver_0.5", "logger_name", 0)

xmp.spawn(test_broadcast, nprocs=8, start_method='fork')


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_resume_training_on_cpu():
""" Checks if training can be resumed from a saved checkpoint on CPU"""

# Train a model on TPU
model = EvalModelTemplate()
trainer = Trainer(
checkpoint_callback=True,
max_epochs=10,
tpu_cores=8,
)
trainer.fit(model)

model_path = trainer.checkpoint_callback.best_model_path

# Verify saved Tensors are on CPU
ckpt = torch.load(model_path)
weight_tensor = list(ckpt["state_dict"].values())[0]
assert weight_tensor.device == torch.device("cpu")

# Verify that training is resumed on CPU
trainer = Trainer(
resume_from_checkpoint=model_path,
checkpoint_callback=True,
max_epochs=20,
)
result = trainer.fit(model)

assert result == 1

0 comments on commit 6013ed8

Please sign in to comment.