Skip to content

Commit

Permalink
Merge pull request #3782 from harimkang/harimkan/enhancce-cli
Browse files Browse the repository at this point in the history
Update 'otx benchmark' print outputs & cli print outputs
  • Loading branch information
harimkang authored Aug 2, 2024
2 parents 6b916e3 + 764c42f commit 0a20117
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 8 deletions.
25 changes: 24 additions & 1 deletion src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,34 @@ def run(self) -> None:
fn_kwargs = self.prepare_subcommand_kwargs(self.subcommand)
fn = getattr(self.engine, self.subcommand)
try:
fn(**fn_kwargs)
outputs = fn(**fn_kwargs)
self._print_results(outputs=outputs)
except Exception:
self.console.print_exception(width=self.console.width)
raise
self.save_config(work_dir=Path(self.engine.work_dir))
else:
msg = f"Unrecognized subcommand: {self.subcommand}"
raise ValueError(msg)

def _print_results(self, outputs: Any) -> None: # noqa: ANN401
if outputs is None:
return
if self.subcommand == "train" and isinstance(outputs, dict):
# Print Metric like 'otx test'
from rich.table import Column, Table
from torch import Tensor

table_headers = ["Train metric", "Value"]
columns = [Column(h, justify="center", style="magenta", width=self.console.width) for h in table_headers]
columns[0].style = "cyan"
table = Table(*columns)
for metric, row in outputs.items():
if isinstance(row, Tensor):
row = row.item() if row.numel() == 1 else row.tolist() # noqa: PLW2901
table.add_row(*[metric, f"{row}"])
self.console.print(table)
elif self.subcommand in ("export", "optimize"):
# Print output model path
self.console.print(f"{self.subcommand} output: {outputs}")
self.console.print(f"Work Directory: {self.engine.work_dir}")
5 changes: 5 additions & 0 deletions src/otx/cli/utils/help_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
*BASE_ARGUMENTS,
*ENGINE_ARGUMENTS,
},
"benchmark": {
"checkpoint",
*BASE_ARGUMENTS,
*ENGINE_ARGUMENTS,
},
}


Expand Down
38 changes: 31 additions & 7 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,8 +789,9 @@ def benchmark(
batch_size: int = 1,
n_iters: int = 10,
extended_stats: bool = False,
print_table: bool = True,
) -> dict[str, str]:
"""Executes model micro benchmarking on random data.
r"""Executes model micro benchmarking on random data.
Benchmark can provide latency, throughput, number of parameters,
and theoretical computational complexity with batch size 1.
Expand All @@ -802,24 +803,37 @@ def benchmark(
batch_size (int, optional): Batch size for benchmarking. Defaults to 1.
n_iters (int, optional): Number of iterations to average on. Defaults to 10.
extended_stats (bool, optional): Flag that enables printing of per module complexity for torch model.
Defaults to False.
Defaults to False.
print_table (bool, optional): Flag that enables printing the benchmark results in a rich table.
Defaults to True.
Returns:
dict[str, str]: a dict with the benchmark results.
Example:
>>> engine.benchmark(
... datamodule=OTXDataModule(),
... checkpoint=<checkpoint-path>,
... batch_size=1,
... n_iters=20,
... extended_stats=True,
... )
CLI Usage:
To run benchmark using the configuration, launch
1. To run benchmark by specifying the work_dir where did the training, run
```shell
>>> otx benchmark --work_dir <WORK_DIR_PATH, str>
```
2. To run benchmark by specifying the checkpoint, run
```shell
>>> otx benchmark \
... --config <CONFIG_PATH> --data_root <DATASET_PATH, str> \
... --work_dir <WORK_DIR_PATH, str> \
... --checkpoint <CKPT_PATH, str>
```
3. To run benchmark using the configuration, launch
```shell
>>> otx benchmark \
... --config <CONFIG_PATH> \
... --data_root <DATASET_PATH, str> \
... --checkpoint <CKPT_PATH, str>
```
"""
Expand Down Expand Up @@ -883,8 +897,18 @@ def dummy_infer(model: OTXModel, batch_size: int = 1) -> float:
params_num_str = convert_num_with_suffix(params_num, get_suffix_str(params_num * 100))
final_stats["parameters_number"] = params_num_str

for name, val in final_stats.items():
print(f"{name:<20} | {val}")
if print_table:
from rich.console import Console
from rich.table import Column, Table

console = Console()
table_headers = ["Benchmark", "Value"]
columns = [Column(h, justify="center", style="magenta", width=console.width) for h in table_headers]
columns[0].style = "cyan"
table = Table(*columns)
for name, val in final_stats.items():
table.add_row(*[f"{name:<20}", f"{val}"])
console.print(table)

with (Path(self.work_dir) / "benchmark_report.csv").open("w") as f:
writer = csv.writer(f)
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import sys

import pytest
import torch
import yaml
from otx.cli import OTXCLI, main
from rich.console import Console


class TestOTXCLI:
Expand Down Expand Up @@ -189,3 +191,20 @@ def test_print_metric_override_command(self, fxt_metric_override_command, capfd)
out, _ = capfd.readouterr()
result_config = yaml.safe_load(out)
assert result_config["metric"] == "otx.core.metrics.fmeasure._f_measure_callable"

def test_print_results(self, mocker, capfd):
mocker.patch("otx.cli.cli.OTXCLI.__init__", return_value=None)
cli = OTXCLI()
cli.console = Console()
cli.engine = mocker.MagicMock()
cli.engine.work_dir.return_value = "work_dir"

cli.subcommand = "train"
output = {"loss": torch.tensor(0.1), "metric": torch.tensor(0.9)}
cli._print_results(output)
out, _ = capfd.readouterr()
assert "Train metric" in out
assert "Value" in out
assert "loss" in out
assert "metric" in out
assert "Work Directory:" in out

0 comments on commit 0a20117

Please sign in to comment.