-
Notifications
You must be signed in to change notification settings - Fork 714
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add initial design Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Refactor + add to CLI Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Support grid search on class path Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * redirect outputs Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * design v2 Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * remove commented code Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * add dummy experiment Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * add config Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Refactor Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add tests Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Apply suggestions from code review Co-authored-by: Samet Akcay <samet.akcay@intel.com> * address pr comments Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Apply suggestions from code review Co-authored-by: Samet Akcay <samet.akcay@intel.com> * refactor Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Simplify argparse Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * modify logger redirect Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * update docstrings Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> Co-authored-by: Samet Akcay <samet.akcay@intel.com>
- Loading branch information
1 parent
c36f87e
commit 5ff7f10
Showing
32 changed files
with
992 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Subcommand for pipelines.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
import logging | ||
|
||
from jsonargparse import Namespace | ||
|
||
from anomalib.cli.utils.help_formatter import get_short_docstring | ||
from anomalib.utils.exceptions import try_import | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
if try_import("anomalib.pipelines"): | ||
from anomalib.pipelines import Benchmark | ||
from anomalib.pipelines.components.base import Pipeline | ||
|
||
PIPELINE_REGISTRY: dict[str, type[Pipeline]] | None = {"benchmark": Benchmark} | ||
else: | ||
PIPELINE_REGISTRY = None | ||
|
||
|
||
def pipeline_subcommands() -> dict[str, dict[str, str]]: | ||
"""Return subcommands for pipelines.""" | ||
if PIPELINE_REGISTRY is not None: | ||
return {name: {"description": get_short_docstring(pipeline)} for name, pipeline in PIPELINE_REGISTRY.items()} | ||
return {} | ||
|
||
|
||
def run_pipeline(args: Namespace) -> None: | ||
"""Run pipeline.""" | ||
logger.warning("This feature is experimental. It may change or be removed in the future.") | ||
if PIPELINE_REGISTRY is not None: | ||
subcommand = args.subcommand | ||
config = args[subcommand] | ||
PIPELINE_REGISTRY[subcommand]().run(config) | ||
else: | ||
msg = "Pipeline is not available" | ||
raise ValueError(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Pipelines for end-to-end usecases.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .benchmark import Benchmark | ||
|
||
__all__ = ["Benchmark"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Benchmarking.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .pipeline import Benchmark | ||
|
||
__all__ = ["Benchmark"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Benchmark job generator.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from collections.abc import Generator | ||
|
||
from anomalib.data import get_datamodule | ||
from anomalib.models import get_model | ||
from anomalib.pipelines.components import JobGenerator | ||
from anomalib.pipelines.components.utils import get_iterator_from_grid_dict | ||
from anomalib.utils.logging import hide_output | ||
|
||
from .job import BenchmarkJob | ||
|
||
|
||
class BenchmarkJobGenerator(JobGenerator): | ||
"""Generate BenchmarkJob. | ||
Args: | ||
accelerator (str): The accelerator to use. | ||
""" | ||
|
||
def __init__(self, accelerator: str) -> None: | ||
self.accelerator = accelerator | ||
|
||
@property | ||
def job_class(self) -> type: | ||
"""Return the job class.""" | ||
return BenchmarkJob | ||
|
||
@hide_output | ||
def generate_jobs(self, args: dict) -> Generator[BenchmarkJob, None, None]: | ||
"""Return iterator based on the arguments.""" | ||
for _container in get_iterator_from_grid_dict(args): | ||
yield BenchmarkJob( | ||
accelerator=self.accelerator, | ||
seed=_container["seed"], | ||
model=get_model(_container["model"]), | ||
datamodule=get_datamodule(_container["data"]), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
"""Benchmarking job.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import logging | ||
from datetime import datetime | ||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
from typing import Any | ||
|
||
import pandas as pd | ||
from lightning import seed_everything | ||
from rich.console import Console | ||
from rich.table import Table | ||
|
||
from anomalib.data import AnomalibDataModule | ||
from anomalib.engine import Engine | ||
from anomalib.models import AnomalyModule | ||
from anomalib.pipelines.components import Job | ||
from anomalib.utils.logging import hide_output | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class BenchmarkJob(Job): | ||
"""Benchmarking job. | ||
Args: | ||
accelerator (str): The accelerator to use. | ||
model (AnomalyModule): The model to use. | ||
datamodule (AnomalibDataModule): The data module to use. | ||
seed (int): The seed to use. | ||
""" | ||
|
||
name = "benchmark" | ||
|
||
def __init__(self, accelerator: str, model: AnomalyModule, datamodule: AnomalibDataModule, seed: int) -> None: | ||
super().__init__() | ||
self.accelerator = accelerator | ||
self.model = model | ||
self.datamodule = datamodule | ||
self.seed = seed | ||
|
||
@hide_output | ||
def run( | ||
self, | ||
task_id: int | None = None, | ||
) -> dict[str, Any]: | ||
"""Run the benchmark.""" | ||
devices: str | list[int] = "auto" | ||
if task_id is not None: | ||
devices = [task_id] | ||
logger.info(f"Running job {self.model.__class__.__name__} with device {task_id}") | ||
with TemporaryDirectory() as temp_dir: | ||
seed_everything(self.seed) | ||
engine = Engine( | ||
accelerator=self.accelerator, | ||
devices=devices, | ||
default_root_dir=temp_dir, | ||
) | ||
engine.fit(self.model, self.datamodule) | ||
test_results = engine.test(self.model, self.datamodule) | ||
# TODO(ashwinvaidya17): Restore throughput | ||
# https://github.com/openvinotoolkit/anomalib/issues/2054 | ||
output = { | ||
"seed": self.seed, | ||
"accelerator": self.accelerator, | ||
"model": self.model.__class__.__name__, | ||
"data": self.datamodule.__class__.__name__, | ||
"category": self.datamodule.category, | ||
**test_results[0], | ||
} | ||
logger.info(f"Completed with result {output}") | ||
return output | ||
|
||
@staticmethod | ||
def collect(results: list[dict[str, Any]]) -> pd.DataFrame: | ||
"""Gather the results returned from run.""" | ||
output: dict[str, Any] = {} | ||
for key in results[0]: | ||
output[key] = [] | ||
for result in results: | ||
for key, value in result.items(): | ||
output[key].append(value) | ||
return pd.DataFrame(output) | ||
|
||
@staticmethod | ||
def save(result: pd.DataFrame) -> None: | ||
"""Save the result to a csv file.""" | ||
BenchmarkJob._print_tabular_results(result) | ||
file_path = Path("runs") / BenchmarkJob.name / datetime.now().strftime("%Y-%m-%d-%H:%M:%S") / "results.csv" | ||
file_path.parent.mkdir(parents=True, exist_ok=True) | ||
result.to_csv(file_path, index=False) | ||
logger.info(f"Saved results to {file_path}") | ||
|
||
@staticmethod | ||
def _print_tabular_results(gathered_result: pd.DataFrame) -> None: | ||
"""Print the tabular results.""" | ||
if gathered_result is not None: | ||
console = Console() | ||
table = Table(title=f"{BenchmarkJob.name} Results", show_header=True, header_style="bold magenta") | ||
_results = gathered_result.to_dict("list") | ||
for column in _results: | ||
table.add_column(column) | ||
for row in zip(*_results.values(), strict=False): | ||
table.add_row(*[str(value) for value in row]) | ||
console.print(table) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""Benchmarking.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
|
||
from anomalib.pipelines.components.base import Pipeline, Runner | ||
from anomalib.pipelines.components.runners import ParallelRunner, SerialRunner | ||
|
||
from .generator import BenchmarkJobGenerator | ||
|
||
|
||
class Benchmark(Pipeline): | ||
"""Benchmarking pipeline.""" | ||
|
||
def _setup_runners(self, args: dict) -> list[Runner]: | ||
"""Setup the runners for the pipeline.""" | ||
accelerators = args["accelerator"] if isinstance(args["accelerator"], list) else [args["accelerator"]] | ||
runners: list[Runner] = [] | ||
for accelerator in accelerators: | ||
if accelerator == "cpu": | ||
runners.append(SerialRunner(BenchmarkJobGenerator("cpu"))) | ||
elif accelerator == "cuda": | ||
runners.append(ParallelRunner(BenchmarkJobGenerator("cuda"), n_jobs=torch.cuda.device_count())) | ||
else: | ||
msg = f"Unsupported accelerator: {accelerator}" | ||
raise ValueError(msg) | ||
return runners |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""Utilities for the pipeline modules.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .base import Job, JobGenerator, Pipeline, Runner | ||
|
||
__all__ = [ | ||
"Job", | ||
"JobGenerator", | ||
"Pipeline", | ||
"Runner", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Base classes for pipelines.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .job import Job, JobGenerator | ||
from .pipeline import Pipeline | ||
from .runner import Runner | ||
|
||
__all__ = ["Job", "JobGenerator", "Runner", "Pipeline"] |
Oops, something went wrong.