Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tpu save #4309

Merged
merged 27 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Added feature to move tensors to CPU before saving ([#4309](https://github.com/PyTorchLightning/pytorch-lightning/pull/4309))



## [1.0.8] - 2020-11-24
Expand Down
2 changes: 1 addition & 1 deletion dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ local tputests = base.BaseTest {
command: utils.scriptCommand(
|||
cd pytorch-lightning
coverage run --source=pytorch_lightning -m pytest tests/models/test_tpu.py -v
coverage run --source=pytorch_lightning -m pytest tests/models/test_tpu.py tests/backends/test_tpu_backend.py pytorch_lightning/utilities/xla_device_utils.py -v
test_exit_code=$?
echo "\n||| END PYTEST LOGS |||\n"
coverage xml
Expand Down
2 changes: 1 addition & 1 deletion docs/source/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ TPU 16-bit
16-bit on TPUs is much simpler. To use 16-bit with TPUs set precision to 16 when using the TPU flag

.. testcode::
:skipif: not XLA_AVAILABLE
:skipif: not TPU_AVAILABLE

# DEFAULT
trainer = Trainer(tpu_cores=8, precision=32)
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def package_list_from_file(file):
NATIVE_AMP_AVAILABLE,
APEX_AVAILABLE,
XLA_AVAILABLE,
TPU_AVAILABLE,
)
TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None

Expand Down
2 changes: 1 addition & 1 deletion docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ Your effective batch size is batch_size * total tpu cores.

This parameter can be either 1 or 8.

.. testcode::
Example::

# your_trainer_file.py

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def __setstate__(self, d):
self.dist = d['dist']
self.ddp_plugin = d['ddp_plugin']

def on_save(self, checkpoint):
return checkpoint


# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
class BackendType(Enum):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.utilities import XLA_AVAILABLE, device_parser, rank_zero_only
from pytorch_lightning.utilities import XLA_AVAILABLE, device_parser, rank_zero_only, TPU_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -350,7 +350,7 @@ def set_distributed_mode(self):

rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer.on_gpu}')
num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0
rank_zero_info(f'TPU available: {XLA_AVAILABLE}, using: {num_cores} TPU cores')
rank_zero_info(f'TPU available: {TPU_AVAILABLE}, using: {num_cores} TPU cores')

if torch.cuda.is_available() and not self.trainer.on_gpu:
rank_zero_warn('GPU available but not used. Set the --gpus flag when calling the script.')
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_only, rank_zero_warn, move_data_to_device
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -344,7 +344,8 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
state_dict = move_data_to_device(model.state_dict(), torch.device("cpu"))
atomic_save(state_dict, last_path)
mp_queue.put(last_path)

def broadcast(self, obj, src=0):
Expand All @@ -366,3 +367,11 @@ def sync_tensor(self,
@property
def norm_clipping_epsilon(self):
return 1e-6

def on_save(self, checkpoint):
"""
Move XLA tensors to CPU before saving
Recommended on XLA Guide:
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
"""
return move_data_to_device(checkpoint, torch.device("cpu"))
7 changes: 2 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import os
import re
import tempfile
import types
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
Expand All @@ -31,14 +30,12 @@
from torch.optim.optimizer import Optimizer

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import TPU_AVAILABLE, AMPType, rank_zero_warn
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
Expand Down Expand Up @@ -1239,7 +1236,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself.

"""
if on_tpu:
if on_tpu and TPU_AVAILABLE:
xm.optimizer_step(optimizer, optimizer_args={'closure': optimizer_closure, **kwargs})

elif self.trainer.amp_backend is not None:
Expand Down
28 changes: 2 additions & 26 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import inspect
import os
import re
import tempfile
import types
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional
from weakref import proxy

import torch
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim import SGD
from torch.optim.optimizer import Optimizer

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where did this go? cc @tchaton

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess useless imports.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall not be there at all in the first place, @SeanNaren

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities import TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def hpc_save(self, folderpath: str, logger):

model.on_hpc_save(checkpoint)

if self.trainer.accelerator_backend:
checkpoint = self.trainer.accelerator_backend.on_save(checkpoint)
Comment on lines +230 to +231
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't accelerator backend always set? even on cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it isnt set when auto_lr_find is used.


# do the actual save
# TODO: fix for anything with multiprocess DP, DDP, DDP2
try:
Expand Down Expand Up @@ -380,6 +383,8 @@ def save_checkpoint(self, filepath, weights_only: bool = False):

if self.trainer.is_global_zero:
# write the checkpoint dictionary on the file
if self.trainer.accelerator_backend:
checkpoint = self.trainer.accelerator_backend.on_save(checkpoint)
try:
atomic_save(checkpoint, filepath)
except AttributeError as err:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
f'specify a path for a checkpoint .test(ckpt_path=PATH)'
)
return {}
if self.accelerator_backend is not None:
if self.accelerator_backend is not None and not self.use_tpu:
self.accelerator_backend.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

import io
from distutils.version import LooseVersion
from typing import Union, IO
from pathlib import Path
from urllib.parse import urlparse
import torch
from typing import IO, Union

import fsspec
import torch


def load(path_or_url: Union[str, IO, Path], map_location=None):
Expand Down Expand Up @@ -52,6 +52,7 @@ def atomic_save(checkpoint, filepath: str):
filepath: The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in.
"""

bytesbuffer = io.BytesIO()
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
import torch
from typing import Union, Any, List, Optional, MutableSequence

from pytorch_lightning.utilities import TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -104,6 +106,9 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int
if not _tpu_cores_valid(tpu_cores):
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")

if tpu_cores is not None and not TPU_AVAILABLE:
raise MisconfigurationException('No TPU devices were found.')

return tpu_cores


Expand Down
18 changes: 14 additions & 4 deletions pytorch_lightning/utilities/xla_device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import functools
import importlib
import queue as q
import traceback
from multiprocessing import Process, Queue

import torch
Expand All @@ -28,8 +29,6 @@ def inner_f(queue, func, *args, **kwargs): # pragma: no cover
try:
queue.put(func(*args, **kwargs))
except Exception:
import traceback

traceback.print_exc()
queue.put(None)

Expand All @@ -40,10 +39,11 @@ def wrapper(*args, **kwargs):
queue = Queue()
proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs)
proc.start()
proc.join(10)
proc.join(20)
try:
return queue.get_nowait()
except q.Empty:
traceback.print_exc()
return False

return wrapper
Expand Down Expand Up @@ -81,10 +81,20 @@ def _is_device_tpu() -> bool:
device_type = XLADeviceUtils._fetch_xla_device_type(device)
return device_type == "TPU"

@staticmethod
def xla_available() -> bool:
"""
Check if XLA library is installed

Return:
A boolean value indicating if a XLA is installed
"""
return XLA_AVAILABLE

@staticmethod
def tpu_device_exists() -> bool:
"""
Public method to check if TPU is available
Runs XLA device check within a separate process

Return:
A boolean value indicating if a TPU device exists on the system
Expand Down
58 changes: 58 additions & 0 deletions tests/backends/test_tpu_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from tests.base.boring_model import BoringModel
from tests.base.develop_utils import pl_multi_process_test


@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
@pl_multi_process_test
def test_resume_training_on_cpu(tmpdir):
""" Checks if training can be resumed from a saved checkpoint on CPU"""

# Train a model on TPU
model = BoringModel()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=8,)
trainer.fit(model)

model_path = trainer.checkpoint_callback.best_model_path

# Verify saved Tensors are on CPU
ckpt = torch.load(model_path)
weight_tensor = list(ckpt["state_dict"].values())[0]
assert weight_tensor.device == torch.device("cpu")

# Verify that training is resumed on CPU
trainer = Trainer(resume_from_checkpoint=model_path, checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir)
result = trainer.fit(model)

assert result == 1


@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
@pl_multi_process_test
def test_if_test_works_after_train(tmpdir):
""" Ensure that .test() works after .fit() """

# Train a model on TPU
model = BoringModel()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=8, default_root_dir=tmpdir)
trainer.fit(model)

assert trainer.test() == 1
Loading