Skip to content

Commit

Permalink
Add --benchmark_dispatches option to pytest. (huggingface#800)
Browse files Browse the repository at this point in the history
* Add --benchmark_dispatches option to pytest.

* Update README.md and fix filepath for dispatch benchmarks
  • Loading branch information
monorimet authored Feb 19, 2023
1 parent 4f045db commit 6d2a485
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 9 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,14 @@ python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use g
pytest tank/test_models.py -k "MiniLM"
```


### How to use your locally built IREE / Torch-MLIR with SHARK
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
for Torch-MLIR.

### How to use your locally built Torch-MLIR with SHARK
How to use your locally built Torch-MLIR with SHARK:
```shell
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
2.) Run `pip uninstall torch-mlir`.
Expand All @@ -240,9 +240,15 @@ Now the SHARK will use your locally build Torch-MLIR repo.

## Benchmarking Dispatches

To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your command line argument.
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`

For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
```
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
```
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.

if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:

```
Expand All @@ -266,7 +272,7 @@ Output will include:
- A .txt file containing benchmark output


See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
See tank/README.md for further instructions on how to run model tests and benchmarks from the SHARK tank.

</details>

Expand Down
10 changes: 10 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,13 @@ def pytest_addoption(parser):
default="gs://shark_tank/latest",
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
)
parser.addoption(
"--benchmark_dispatches",
default=None,
help="Benchmark individual dispatch kernels produced by IREE compiler. Use 'All' for all, or specific dispatches e.g. '0 1 2 10'",
)
parser.addoption(
"--dispatch_benchmarks_dir",
default="./temp_dispatch_benchmarks",
help="Directory in which dispatch benchmarks are saved.",
)
15 changes: 10 additions & 5 deletions shark/iree_utils/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,14 @@ def run_benchmark_module(benchmark_cl):
benchmark_path
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
bench_result = run_cmd(" ".join(benchmark_cl))
print(bench_result)
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
unit = match.group(3)
try:
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
unit = match.group(3)
except AttributeError:
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
match = regex_split.search(bench_result)
time = float(match.group(1))
unit = match.group(2)
return 1.0 / (time * 0.001)
19 changes: 19 additions & 0 deletions tank/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,19 @@ def __init__(self, config):
def create_and_check_module(self, dynamic, device):
shark_args.local_tank_cache = self.local_tank_cache
shark_args.force_update_tank = self.update_tank
shark_args.dispatch_benchmarks = self.benchmark_dispatches
if self.benchmark_dispatches is not None:
_m = self.config["model_name"].split("/")
_m.extend([self.config["framework"], str(dynamic), device])
_m = "_".join(_m)
shark_args.dispatch_benchmarks_dir = os.path.join(
self.dispatch_benchmarks_dir,
_m,
)
if not os.path.exists(self.dispatch_benchmarks_dir):
os.mkdir(self.dispatch_benchmarks_dir)
if not os.path.exists(shark_args.dispatch_benchmarks_dir):
os.mkdir(shark_args.dispatch_benchmarks_dir)
if "nhcw-nhwc" in self.config["flags"] and not os.path.isfile(
".use-iree"
):
Expand Down Expand Up @@ -278,6 +291,12 @@ def test_module(self, dynamic, device, config):
"update_tank"
)
self.module_tester.tank_url = self.pytestconfig.getoption("tank_url")
self.module_tester.benchmark_dispatches = self.pytestconfig.getoption(
"benchmark_dispatches"
)
self.module_tester.dispatch_benchmarks_dir = (
self.pytestconfig.getoption("dispatch_benchmarks_dir")
)

if config["xfail_cpu"] == "True" and device == "cpu":
pytest.xfail(reason=config["xfail_reason"])
Expand Down

0 comments on commit 6d2a485

Please sign in to comment.