diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c17cdc06cc19..241424deae720 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,6 +95,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Added feature to move tensors to CPU before saving ([#4309](https://github.com/PyTorchLightning/pytorch-lightning/pull/4309)) + ## [1.0.8] - 2020-11-24 diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 47a3966167ea2..7e4d841387800 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -21,7 +21,7 @@ local tputests = base.BaseTest { command: utils.scriptCommand( ||| cd pytorch-lightning - coverage run --source=pytorch_lightning -m pytest tests/models/test_tpu.py -v + coverage run --source=pytorch_lightning -m pytest tests/models/test_tpu.py tests/backends/test_tpu_backend.py pytorch_lightning/utilities/xla_device_utils.py -v test_exit_code=$? echo "\n||| END PYTEST LOGS |||\n" coverage xml diff --git a/docs/source/amp.rst b/docs/source/amp.rst index 38a255c81aa4b..ac36c5893e6b6 100644 --- a/docs/source/amp.rst +++ b/docs/source/amp.rst @@ -88,7 +88,7 @@ TPU 16-bit 16-bit on TPUs is much simpler. To use 16-bit with TPUs set precision to 16 when using the TPU flag .. testcode:: - :skipif: not XLA_AVAILABLE + :skipif: not TPU_AVAILABLE # DEFAULT trainer = Trainer(tpu_cores=8, precision=32) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7c0f8d63f1b2a..655e8dba30a36 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -362,6 +362,7 @@ def package_list_from_file(file): NATIVE_AMP_AVAILABLE, APEX_AVAILABLE, XLA_AVAILABLE, + TPU_AVAILABLE, ) TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index c390db8d7537e..79d6284a4e27c 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -1100,7 +1100,7 @@ Your effective batch size is batch_size * total tpu cores. This parameter can be either 1 or 8. -.. testcode:: +Example:: # your_trainer_file.py diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 931a39e07af89..0f61c53ffa1a0 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -241,6 +241,9 @@ def __setstate__(self, d): self.dist = d['dist'] self.ddp_plugin = d['ddp_plugin'] + def on_save(self, checkpoint): + return checkpoint + # TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos... class BackendType(Enum): diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index ed0379e46dc5b..a22a8fb3702ee 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -20,7 +20,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment -from pytorch_lightning.utilities import XLA_AVAILABLE, device_parser, rank_zero_only +from pytorch_lightning.utilities import XLA_AVAILABLE, device_parser, rank_zero_only, TPU_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -350,7 +350,7 @@ def set_distributed_mode(self): rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer.on_gpu}') num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0 - rank_zero_info(f'TPU available: {XLA_AVAILABLE}, using: {num_cores} TPU cores') + rank_zero_info(f'TPU available: {TPU_AVAILABLE}, using: {num_cores} TPU cores') if torch.cuda.is_available() and not self.trainer.on_gpu: rank_zero_warn('GPU available but not used. Set the --gpus flag when calling the script.') diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 30cf6c9dbf169..6da5150d1fa8a 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -23,7 +23,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_only, rank_zero_warn, move_data_to_device from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -344,7 +344,8 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): last_path = None if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - atomic_save(model.state_dict(), last_path) + state_dict = move_data_to_device(model.state_dict(), torch.device("cpu")) + atomic_save(state_dict, last_path) mp_queue.put(last_path) def broadcast(self, obj, src=0): @@ -366,3 +367,11 @@ def sync_tensor(self, @property def norm_clipping_epsilon(self): return 1e-6 + + def on_save(self, checkpoint): + """ + Move XLA tensors to CPU before saving + Recommended on XLA Guide: + https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors + """ + return move_data_to_device(checkpoint, torch.device("cpu")) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d02124dd533a7..78fc740e389aa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -20,7 +20,6 @@ import os import re import tempfile -import types from abc import ABC from argparse import Namespace from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union @@ -31,14 +30,12 @@ from torch.optim.optimizer import Optimizer from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import TPU_AVAILABLE, AMPType, rank_zero_warn +from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args @@ -1239,7 +1236,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself. """ - if on_tpu: + if on_tpu and TPU_AVAILABLE: xm.optimizer_step(optimizer, optimizer_args={'closure': optimizer_closure, **kwargs}) elif self.trainer.amp_backend is not None: diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 66ce64b0c6887..f8f6a7b6c0f12 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -11,38 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections -import copy -import inspect -import os -import re -import tempfile import types -from abc import ABC -from argparse import Namespace -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional from weakref import proxy -import torch -from torch import ScriptModule, Tensor -from torch.nn import Module -from torch.optim import SGD from torch.optim.optimizer import Optimizer -from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.grads import GradInformation -from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks -from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO -from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import AMPType, rank_zero_warn -from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.utilities import TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils - -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() if TPU_AVAILABLE: import torch_xla.core.xla_model as xm diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2177ce530aa5c..2311cc767de2d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -227,6 +227,9 @@ def hpc_save(self, folderpath: str, logger): model.on_hpc_save(checkpoint) + if self.trainer.accelerator_backend: + checkpoint = self.trainer.accelerator_backend.on_save(checkpoint) + # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: @@ -380,6 +383,8 @@ def save_checkpoint(self, filepath, weights_only: bool = False): if self.trainer.is_global_zero: # write the checkpoint dictionary on the file + if self.trainer.accelerator_backend: + checkpoint = self.trainer.accelerator_backend.on_save(checkpoint) try: atomic_save(checkpoint, filepath) except AttributeError as err: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fd715988ef370..012d9b3a6fd5e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -785,7 +785,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): f'specify a path for a checkpoint .test(ckpt_path=PATH)' ) return {} - if self.accelerator_backend is not None: + if self.accelerator_backend is not None and not self.use_tpu: self.accelerator_backend.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 33845384faaa8..7dc9c90e16dbd 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -14,11 +14,11 @@ import io from distutils.version import LooseVersion -from typing import Union, IO from pathlib import Path -from urllib.parse import urlparse -import torch +from typing import IO, Union + import fsspec +import torch def load(path_or_url: Union[str, IO, Path], map_location=None): @@ -52,6 +52,7 @@ def atomic_save(checkpoint, filepath: str): filepath: The path to which the checkpoint will be saved. This points to the file that the checkpoint will be stored in. """ + bytesbuffer = io.BytesIO() # Can't use the new zipfile serialization for 1.6.0 because there's a bug in # torch.hub.load_state_dict_from_url() that prevents it from loading the new files. diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index eb0a6fe5c95a4..05a342a2e7180 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -13,6 +13,8 @@ # limitations under the License. import torch from typing import Union, Any, List, Optional, MutableSequence + +from pytorch_lightning.utilities import TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -104,6 +106,9 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int if not _tpu_cores_valid(tpu_cores): raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]") + if tpu_cores is not None and not TPU_AVAILABLE: + raise MisconfigurationException('No TPU devices were found.') + return tpu_cores diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index c6dd63237e121..f08b8e114e939 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -14,6 +14,7 @@ import functools import importlib import queue as q +import traceback from multiprocessing import Process, Queue import torch @@ -28,8 +29,6 @@ def inner_f(queue, func, *args, **kwargs): # pragma: no cover try: queue.put(func(*args, **kwargs)) except Exception: - import traceback - traceback.print_exc() queue.put(None) @@ -40,10 +39,11 @@ def wrapper(*args, **kwargs): queue = Queue() proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) proc.start() - proc.join(10) + proc.join(20) try: return queue.get_nowait() except q.Empty: + traceback.print_exc() return False return wrapper @@ -81,10 +81,20 @@ def _is_device_tpu() -> bool: device_type = XLADeviceUtils._fetch_xla_device_type(device) return device_type == "TPU" + @staticmethod + def xla_available() -> bool: + """ + Check if XLA library is installed + + Return: + A boolean value indicating if a XLA is installed + """ + return XLA_AVAILABLE + @staticmethod def tpu_device_exists() -> bool: """ - Public method to check if TPU is available + Runs XLA device check within a separate process Return: A boolean value indicating if a TPU device exists on the system diff --git a/tests/backends/test_tpu_backend.py b/tests/backends/test_tpu_backend.py new file mode 100644 index 0000000000000..cb8ffce38913a --- /dev/null +++ b/tests/backends/test_tpu_backend.py @@ -0,0 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils +from tests.base.boring_model import BoringModel +from tests.base.develop_utils import pl_multi_process_test + + +@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") +@pl_multi_process_test +def test_resume_training_on_cpu(tmpdir): + """ Checks if training can be resumed from a saved checkpoint on CPU""" + + # Train a model on TPU + model = BoringModel() + trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=8,) + trainer.fit(model) + + model_path = trainer.checkpoint_callback.best_model_path + + # Verify saved Tensors are on CPU + ckpt = torch.load(model_path) + weight_tensor = list(ckpt["state_dict"].values())[0] + assert weight_tensor.device == torch.device("cpu") + + # Verify that training is resumed on CPU + trainer = Trainer(resume_from_checkpoint=model_path, checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir) + result = trainer.fit(model) + + assert result == 1 + + +@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") +@pl_multi_process_test +def test_if_test_works_after_train(tmpdir): + """ Ensure that .test() works after .fit() """ + + # Train a model on TPU + model = BoringModel() + trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=8, default_root_dir=tmpdir) + trainer.fit(model) + + assert trainer.test() == 1 diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b69f1b60fcbf7..e838dc60d81b3 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from argparse import ArgumentParser from unittest import mock import pytest @@ -216,6 +217,7 @@ def test_dataloaders_passed_to_fit(tmpdir): ['tpu_cores', 'expected_tpu_id'], [pytest.param(1, None), pytest.param(8, None), pytest.param([1], 1), pytest.param([8], 8)], ) +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires missing TPU") def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id): """Test if trainer.tpu_id is set as expected""" assert Trainer(tpu_cores=tpu_cores).tpu_id == expected_tpu_id @@ -230,20 +232,13 @@ def test_tpu_misconfiguration(): @pytest.mark.skipif(TPU_AVAILABLE, reason="test requires missing TPU") def test_exception_when_no_tpu_found(tmpdir): """Test if exception is thrown when xla devices are not available""" - model = EvalModelTemplate() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=0.4, - limit_val_batches=0.2, - tpu_cores=8, - ) - with pytest.raises(MisconfigurationException, match='PyTorch XLA not installed.'): - trainer.fit(model) + with pytest.raises(MisconfigurationException, match='No TPU devices were found.'): + Trainer(tpu_cores=8) @pytest.mark.parametrize('tpu_cores', [1, 8, [1]]) +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores): """Test if distributed_backend is set to `tpu` when tpu_cores is not None""" assert Trainer(tpu_cores=tpu_cores).distributed_backend == "tpu" @@ -294,3 +289,51 @@ def test_broadcast(rank): assert result == ("ver_0.5", "logger_name", 0) xmp.spawn(test_broadcast, nprocs=8, start_method='fork') + + +@pytest.mark.parametrize( + ["tpu_cores", "expected_tpu_id", "error_expected"], + [ + pytest.param(1, None, False), + pytest.param(8, None, False), + pytest.param([1], 1, False), + pytest.param([8], 8, False), + pytest.param("1,", 1, False), + pytest.param("1", None, False), + pytest.param("9, ", 9, True), + pytest.param([9], 9, True), + pytest.param([0], 0, True), + pytest.param(2, None, True), + pytest.param(10, None, True), + ], +) +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") +@pl_multi_process_test +def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): + if error_expected: + with pytest.raises(MisconfigurationException, match=r".*tpu_cores` can only be 1, 8 or [<1-8>]*"): + Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores) + else: + trainer = Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores) + assert trainer.tpu_id == expected_tpu_id + + +@pytest.mark.parametrize(['cli_args', 'expected'], [ + pytest.param('--tpu_cores=8', + {'tpu_cores': 8}), + pytest.param("--tpu_cores=1,", + {'tpu_cores': '1,'}) +]) +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") +@pl_multi_process_test +def test_tpu_cores_with_argparse(cli_args, expected): + """Test passing tpu_cores in command line""" + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parent_parser=parser) + args = Trainer.parse_argparser(parser) + + for k, v in expected.items(): + assert getattr(args, k) == v + assert Trainer.from_argparse_args(args) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 085d361952844..328b2c0a0f859 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1042,31 +1042,6 @@ def test_gpu_choice(tmpdir): Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True) -@pytest.mark.parametrize( - ["tpu_cores", "expected_tpu_id", "error_expected"], - [ - pytest.param(1, None, False), - pytest.param(8, None, False), - pytest.param([1], 1, False), - pytest.param([8], 8, False), - pytest.param("1,", 1, False), - pytest.param("1", None, False), - pytest.param("9, ", 9, True), - pytest.param([9], 9, True), - pytest.param([0], 0, True), - pytest.param(2, None, True), - pytest.param(10, None, True), - ], -) -def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): - if error_expected: - with pytest.raises(MisconfigurationException, match=r".*tpu_cores` can only be 1, 8 or [<1-8>]*"): - Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores, auto_select_gpus=True) - else: - trainer = Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores, auto_select_gpus=True) - assert trainer.tpu_id == expected_tpu_id - - @pytest.mark.parametrize( ["limit_val_batches"], [ diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index b5dbdb0e803ba..c39d643aed407 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -117,10 +117,6 @@ def _raise(): {'auto_lr_find': True, 'auto_scale_batch_size': True}), pytest.param('--auto_lr_find 0 --auto_scale_batch_size n', {'auto_lr_find': False, 'auto_scale_batch_size': False}), - pytest.param('--tpu_cores=8', - {'tpu_cores': 8}), - pytest.param("--tpu_cores=1,", - {'tpu_cores': '1,'}), pytest.param( "", { diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 1b3911d4152c0..19174c41527e0 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -16,40 +16,33 @@ import pytest import pytorch_lightning.utilities.xla_device_utils as xla_utils -from pytorch_lightning.utilities import XLA_AVAILABLE +from pytorch_lightning.utilities import XLA_AVAILABLE, TPU_AVAILABLE from tests.base.develop_utils import pl_multi_process_test if XLA_AVAILABLE: import torch_xla.core.xla_model as xm +# lets hope that in or env we have installed XLA only for TPU devices, otherwise, it is testing in the cycle "if I am true test that I am true :D" @pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): """Check tpu_device_exists returns None when torch_xla is not available""" assert xla_utils.XLADeviceUtils.tpu_device_exists() is None -@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires torch_xla to be installed") +@pl_multi_process_test def test_tpu_device_presence(): """Check tpu_device_exists returns True when TPU is available""" assert xla_utils.XLADeviceUtils.tpu_device_exists() is True -@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") -@pl_multi_process_test -def test_xla_device_is_a_tpu(): - """Check that the XLA device is a TPU""" - device = xm.xla_device() - device_type = xm.xla_device_hw(device) - return device_type == "TPU" - - -def test_result_returns_within_10_seconds(): +def test_result_returns_within_20_seconds(): """Check that pl_multi_process returns within 10 seconds""" start = time.time() result = xla_utils.pl_multi_process(time.sleep)(25) end = time.time() elapsed_time = int(end - start) - assert elapsed_time <= 10 + assert elapsed_time <= 20 assert result is False