Skip to content

Commit

Permalink
multi gpu test bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Dec 26, 2022
1 parent 428c524 commit 5587d32
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions tests/integration/cli/classification/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def tmp_dir_path():
yield Path(tmp_dir)


MULTI_GPU_AVAILABLE = torch.cuda.device_count() > 1
MULTI_GPU_UNAVAILABLE = torch.cuda.device_count() <= 1
TT_STABILITY_TESTS = os.environ.get("TT_STABILITY_TESTS", False)
if TT_STABILITY_TESTS:
default_template = parse_model_template(
Expand Down Expand Up @@ -112,12 +112,12 @@ def test_otx_train(self, template, tmp_dir_path):

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.skipif(MULTI_GPU_AVAILABLE, reason="The number of gpu is insufficient")
@pytest.mark.skipif(MULTI_GPU_UNAVAILABLE, reason="The number of gpu is insufficient")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_multi_gpu_train(self, template, tmp_dir_path):
args = args.copy()
args["--gpus"] = "0,1"
otx_train_testing(template, tmp_dir_path, otx_dir, args)
args1 = args.copy()
args1["--gpus"] = "0,1"
otx_train_testing(template, tmp_dir_path, otx_dir, args1)

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/cli/detection/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

otx_dir = os.getcwd()

MULTI_GPU_AVAILABLE = torch.cuda.device_count() > 1
MULTI_GPU_UNAVAILABLE = torch.cuda.device_count() <= 1
TT_STABILITY_TESTS = os.environ.get("TT_STABILITY_TESTS", False)
if TT_STABILITY_TESTS:
default_template = parse_model_template(
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_otx_train(self, template, tmp_dir_path):

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.skipif(MULTI_GPU_AVAILABLE, reason="The number of gpu is insufficient")
@pytest.mark.skipif(MULTI_GPU_UNAVAILABLE, reason="The number of gpu is insufficient")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_multi_gpu_train(self, template, tmp_dir_path):
args1 = args.copy()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/cli/detection/test_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

otx_dir = os.getcwd()

MULTI_GPU_AVAILABLE = torch.cuda.device_count() > 1
MULTI_GPU_UNAVAILABLE = torch.cuda.device_count() <= 1
TT_STABILITY_TESTS = os.environ.get("TT_STABILITY_TESTS", False)
if TT_STABILITY_TESTS:
default_template = parse_model_template(
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_otx_train(self, template, tmp_dir_path):

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.skipif(MULTI_GPU_AVAILABLE, reason="The number of gpu is insufficient")
@pytest.mark.skipif(MULTI_GPU_UNAVAILABLE, reason="The number of gpu is insufficient")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_multi_gpu_train(self, template, tmp_dir_path):
args1 = args.copy()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/cli/segmentation/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

otx_dir = os.getcwd()

MULTI_GPU_AVAILABLE = torch.cuda.device_count() > 1
MULTI_GPU_UNAVAILABLE = torch.cuda.device_count() <= 1
TT_STABILITY_TESTS = os.environ.get("TT_STABILITY_TESTS", False)
if TT_STABILITY_TESTS:
default_template = parse_model_template(
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_otx_train(self, template, tmp_dir_path):

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.skipif(MULTI_GPU_AVAILABLE, reason="The number of gpu is insufficient")
@pytest.mark.skipif(MULTI_GPU_UNAVAILABLE, reason="The number of gpu is insufficient")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_multi_gpu_train(self, template, tmp_dir_path):
args1 = args.copy()
Expand Down

0 comments on commit 5587d32

Please sign in to comment.