Skip to content

Commit

Permalink
[TKW] Distribute gpu tests (#353)
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Jan 7, 2025
1 parent 4d2eaef commit 3dbb4e5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
run: |
pip install --no-compile -r pytorch-rocm-requirements.txt
export WAVE_RUN_E2E_TESTS=1
WAVE_CACHE_ON=0 pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/
WAVE_CACHE_ON=0 pytest -n 8 --capture=tee-sys -vv --gpu-distribute 8 ./tests/kernel/wave/
- name: Run e2e tests on AMD GPU MI250
if: "contains(matrix.os, 'mi250') && !cancelled()"
Expand Down
12 changes: 11 additions & 1 deletion iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,8 +1053,18 @@ def all_equal(input_list: list[Any]) -> bool:
return all(elem == input_list[0] for elem in input_list)


DEFAULT_GPU_DEVICE = None


def get_default_gpu_device_name() -> str:
if DEFAULT_GPU_DEVICE is None:
return "cuda"

return f"cuda:{DEFAULT_GPU_DEVICE}"


def get_default_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
return get_default_gpu_device_name() if torch.cuda.is_available() else "cpu"


def to_default_device(tensor: torch.Tensor) -> torch.Tensor:
Expand Down
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def pytest_addoption(parser):
default=None,
help="save performance info into provided directory, filename based on current test name",
)
parser.addoption(
"--gpu-distribute",
type=int,
default=0,
help="Distribute over N gpu devices when running with pytest-xdist",
)


def pytest_configure(config):
Expand All @@ -28,11 +34,31 @@ def pytest_configure(config):
)


def _set_default_device(config):
distribute = int(config.getoption("--gpu-distribute"))
if distribute < 1:
return

if not hasattr(config, "workerinput"):
return

worker_id = config.workerinput["workerid"]
if not worker_id.startswith("gw"):
return

device_id = int(worker_id[2:]) % int(distribute)

import iree.turbine.kernel.wave.utils as utils

utils.DEFAULT_GPU_DEVICE = device_id


def _has_marker(item, marker):
return next(item.iter_markers(marker), None) is not None


def pytest_collection_modifyitems(config, items):
_set_default_device(config)
run_perf = config.getoption("--runperf")
for item in items:
is_validate_only = _has_marker(item, "validate_only")
Expand Down

0 comments on commit 3dbb4e5

Please sign in to comment.