Skip to content

Fix amp tests #661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

import torch.multiprocessing as mp


def pytest_configure(config):
config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")


def wrap(i, fn, args):
return fn(*args)


@pytest.mark.tryfirst
def pytest_pyfunc_call(pyfuncitem):
if pyfuncitem.get_closest_marker("spawn"):
testfunction = pyfuncitem.obj
funcargs = pyfuncitem.funcargs
testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames])

mp.spawn(wrap, (testfunction, testargs))
return True
36 changes: 12 additions & 24 deletions tests/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_amp_single_gpu(tmpdir):
tutils.run_model_test(trainer_options, model)


@pytest.mark.spawn
def test_no_amp_single_gpu(tmpdir):
"""Make sure DDP + AMP work."""
tutils.reset_seed()
Expand All @@ -51,8 +52,10 @@ def test_no_amp_single_gpu(tmpdir):
use_amp=True
)

with pytest.raises((MisconfigurationException, ModuleNotFoundError)):
tutils.run_model_test(trainer_options, model)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)

assert result == 1


def test_amp_gpu_ddp(tmpdir):
Expand All @@ -78,6 +81,7 @@ def test_amp_gpu_ddp(tmpdir):
tutils.run_model_test(trainer_options, model)


@pytest.mark.spawn
def test_amp_gpu_ddp_slurm_managed(tmpdir):
"""Make sure DDP + AMP work."""
if not tutils.can_run_gpu_test():
Expand Down Expand Up @@ -124,26 +128,6 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23'
assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23'

# test model loading with a map_location
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just comment it?

pretrained_model = tutils.load_model(logger.experiment, trainer.checkpoint_callback.filepath)

# test model preds
for dataloader in trainer.get_test_dataloaders():
tutils.run_prediction(dataloader, pretrained_model)

if trainer.use_ddp:
# on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()

# test HPC loading / saving
trainer.hpc_save(tmpdir, logger)
trainer.hpc_load(tmpdir, on_gpu=True)

# test freeze on gpu
model.freeze()
model.unfreeze()


def test_cpu_model_with_amp(tmpdir):
"""Make sure model trains on CPU."""
Expand All @@ -165,6 +149,7 @@ def test_cpu_model_with_amp(tmpdir):
tutils.run_model_test(trainer_options, model, on_gpu=False)


@pytest.mark.spawn
def test_amp_gpu_dp(tmpdir):
"""Make sure DP + AMP work."""
tutils.reset_seed()
Expand All @@ -180,8 +165,11 @@ def test_amp_gpu_dp(tmpdir):
distributed_backend='dp',
use_amp=True
)
with pytest.raises(MisconfigurationException):
tutils.run_model_test(trainer_options, model, hparams)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

assert result == 1


if __name__ == '__main__':
Expand Down