Skip to content

Commit

Permalink
Fix amp tests (Lightning-AI#661)
Browse files Browse the repository at this point in the history
* Run AMP tests in their own process

With opt_level="O1" (the default), AMP patches many
torch functions, which breaks any tests that run afterwards.
This patch introduces a pytest extension that lets
tests be marked with @pytest.mark.spawn so that they
are run in their own process using torch.multiprocessing.spawn
so that the main python interpreter stays un-patched.

Note that tests using DDP already run AMP in its own process,
so they don't need this annotation.

* Fix AMP tests

Since AMP defaults to O1 now, DP tests no longer throw exceptions.

Since AMP patches torch functions, CPU inference no longer works.
Skip prediction step for AMP tests.

* typo
  • Loading branch information
neggert authored and williamFalcon committed Jan 5, 2020
1 parent c32f2b9 commit 019f612
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 24 deletions.
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

import torch.multiprocessing as mp


def pytest_configure(config):
config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")


def wrap(i, fn, args):
return fn(*args)


@pytest.mark.tryfirst
def pytest_pyfunc_call(pyfuncitem):
if pyfuncitem.get_closest_marker("spawn"):
testfunction = pyfuncitem.obj
funcargs = pyfuncitem.funcargs
testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames])

mp.spawn(wrap, (testfunction, testargs))
return True
36 changes: 12 additions & 24 deletions tests/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_amp_single_gpu(tmpdir):
tutils.run_model_test(trainer_options, model)


@pytest.mark.spawn
def test_no_amp_single_gpu(tmpdir):
"""Make sure DDP + AMP work."""
tutils.reset_seed()
Expand All @@ -51,8 +52,10 @@ def test_no_amp_single_gpu(tmpdir):
use_amp=True
)

with pytest.raises((MisconfigurationException, ModuleNotFoundError)):
tutils.run_model_test(trainer_options, model)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)

assert result == 1


def test_amp_gpu_ddp(tmpdir):
Expand All @@ -78,6 +81,7 @@ def test_amp_gpu_ddp(tmpdir):
tutils.run_model_test(trainer_options, model)


@pytest.mark.spawn
def test_amp_gpu_ddp_slurm_managed(tmpdir):
"""Make sure DDP + AMP work."""
if not tutils.can_run_gpu_test():
Expand Down Expand Up @@ -124,26 +128,6 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23'
assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23'

# test model loading with a map_location
pretrained_model = tutils.load_model(logger.experiment, trainer.checkpoint_callback.filepath)

# test model preds
for dataloader in trainer.get_test_dataloaders():
tutils.run_prediction(dataloader, pretrained_model)

if trainer.use_ddp:
# on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()

# test HPC loading / saving
trainer.hpc_save(tmpdir, logger)
trainer.hpc_load(tmpdir, on_gpu=True)

# test freeze on gpu
model.freeze()
model.unfreeze()


def test_cpu_model_with_amp(tmpdir):
"""Make sure model trains on CPU."""
Expand All @@ -165,6 +149,7 @@ def test_cpu_model_with_amp(tmpdir):
tutils.run_model_test(trainer_options, model, on_gpu=False)


@pytest.mark.spawn
def test_amp_gpu_dp(tmpdir):
"""Make sure DP + AMP work."""
tutils.reset_seed()
Expand All @@ -180,8 +165,11 @@ def test_amp_gpu_dp(tmpdir):
distributed_backend='dp',
use_amp=True
)
with pytest.raises(MisconfigurationException):
tutils.run_model_test(trainer_options, model, hparams)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

assert result == 1


if __name__ == '__main__':
Expand Down

0 comments on commit 019f612

Please sign in to comment.