From 13bd79e05af1badb826feccd9e60af991d49893c Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Fri, 19 Jan 2024 21:38:11 +0900 Subject: [PATCH 01/16] for a start --- src/otx/cli/__init__.py | 8 +++++ src/otx/cli/explain.py | 55 +++++++++++++++++++++++++++++++ src/otx/config/model/default.yaml | 4 +++ src/otx/core/config/explain.py | 17 ++++++++++ src/otx/core/config/model.py | 1 + src/otx/core/types/explain.py | 0 src/otx/engine/engine.py | 54 ++++++++++++++++++++++++++++++ 7 files changed, 139 insertions(+) create mode 100644 src/otx/cli/explain.py create mode 100644 src/otx/core/config/explain.py create mode 100644 src/otx/core/types/explain.py diff --git a/src/otx/cli/__init__.py b/src/otx/cli/__init__.py index 93095d6705a..97045bb66cb 100644 --- a/src/otx/cli/__init__.py +++ b/src/otx/cli/__init__.py @@ -57,6 +57,11 @@ def setup_subcommands(self) -> None: test_parser.add_argument("overrides", help="overrides values", default=[], nargs="+") parser_subcommands.add_subcommand("test", test_parser, help="Testing subcommand for OTX") + # otx explain parser + explain_parser = ArgumentParser() + explain_parser.add_argument("overrides", help="overrides values", default=[], nargs="+") + parser_subcommands.add_subcommand("explain", explain_parser, help="Explaining subcommand for OTX") + def run(self) -> None: """Run the OTX CLI.""" subcommand = self.config["subcommand"] @@ -72,7 +77,10 @@ def run(self) -> None: from otx.cli.test import otx_test otx_test(**self.config["test"]) + elif subcommand == "explain": + from otx.cli.explain import otx_explain + otx_explain(**self.config["explain"]) def main() -> None: """Entry point for OTX CLI. diff --git a/src/otx/cli/explain.py b/src/otx/cli/explain.py new file mode 100644 index 00000000000..3310fc76435 --- /dev/null +++ b/src/otx/cli/explain.py @@ -0,0 +1,55 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""CLI entrypoint for model XAI.""" + +from __future__ import annotations + +import hydra +import logging as log +from hydra import compose, initialize +from otx.core.model.entity.base import OTXModel +from otx.cli.utils.hydra import configure_hydra_outputs + + +def otx_explain(overrides: list[str]) -> None: + """Main entry point for XAI. + + :param overrides: Override List values. + :return: Optional[float] with optimized metric value. + """ + from otx.core.config import register_configs + + # This should be in front of hydra.initialize() + register_configs() + + with initialize(config_path="../config", version_base="1.3", job_name="otx_explain"): + cfg = compose(config_name="test", overrides=overrides, return_hydra_config=True) + configure_hydra_outputs(cfg) + + # explain the model + from otx.core.data.module import OTXDataModule + + log.info(f"Instantiating datamodule <{cfg.data}>") + datamodule = OTXDataModule(task=cfg.base.task, config=cfg.data) + + log.info(f"Instantiating model <{cfg.model}>") + model: OTXModel = hydra.utils.instantiate(cfg.model.otx_model) + optimizer = hydra.utils.instantiate(cfg.model.optimizer) + scheduler = hydra.utils.instantiate(cfg.model.scheduler) + + from otx.engine import Engine + + trainer_kwargs = {**cfg.trainer} + engine = Engine( + task=cfg.base.task, + work_dir=cfg.base.output_dir, + model=model, + optimizer=optimizer, + scheduler=scheduler, + datamodule=datamodule, + checkpoint=cfg.checkpoint, + device=trainer_kwargs.pop("accelerator", "auto"), + ) + saliency_maps = engine.explain() + return saliency_maps diff --git a/src/otx/config/model/default.yaml b/src/otx/config/model/default.yaml index c938721f066..2ad1625cd9b 100644 --- a/src/otx/config/model/default.yaml +++ b/src/otx/config/model/default.yaml @@ -13,3 +13,7 @@ scheduler: otx_model: num_classes: ??? + +explain_config: + target_explain_group: PREDICTIONS + postprocess: true diff --git a/src/otx/core/config/explain.py b/src/otx/core/config/explain.py new file mode 100644 index 00000000000..5dc251b4001 --- /dev/null +++ b/src/otx/core/config/explain.py @@ -0,0 +1,17 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Config data type objects for export method.""" +from __future__ import annotations + +from dataclasses import dataclass + +from otx.core.types.export import OTXExportFormatType, OTXExportPrecisionType + + +@dataclass +class ExplainConfig: + """DTO for explain configuration.""" + + target_explain_group: TargetExplainGroup + postprocess: bool diff --git a/src/otx/core/config/model.py b/src/otx/core/config/model.py index f62ead935f7..603ec7eb873 100644 --- a/src/otx/core/config/model.py +++ b/src/otx/core/config/model.py @@ -15,3 +15,4 @@ class ModelConfig: scheduler: dict otx_model: dict torch_compile: bool + explain_config: dict diff --git a/src/otx/core/types/explain.py b/src/otx/core/types/explain.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index bdac374329f..117e031fd82 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -329,6 +329,60 @@ def export(self, *args, **kwargs) -> None: # Property and setter functions provided by Engine. # ------------------------------------------------------------------------ # + def explain( + self, + checkpoint: str | Path | None = None, + datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, + **kwargs, + ) -> dict: + """Run the testing phase of the engine. + + Args: + datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module containing the test data. + checkpoint (str | Path | None, optional): Path to the checkpoint file to load the model from. + Defaults to None. + **kwargs: Additional keyword arguments for pl.Trainer configuration. + + Returns: + dict: Dictionary containing the callback metrics from the trainer. + + Example: + >>> engine.test( + ... datamodule=OTXDataModule(), + ... checkpoint=, + ... ) + + CLI Usage: + 1. you can pick a model. + ```python + otx test + --model --data_root + --checkpoint + ``` + 2. If you have a ready configuration file, run it like this. + ```python + otx test --config --checkpoint + ``` + """ + lit_module = self._build_lightning_module( + model=self.model, + optimizer=self.optimizer, + scheduler=self.scheduler, + ) + if datamodule is None: + datamodule = self.datamodule + lit_module.meta_info = datamodule.meta_info + + self._build_trainer(**kwargs) + + self.trainer.test( + model=lit_module, + dataloaders=datamodule, + ckpt_path=str(checkpoint) if checkpoint is not None else self.checkpoint, + ) + + return self.trainer.callback_metrics + @property def trainer(self) -> Trainer: """Returns the trainer object associated with the engine. From 42a9607fff36f3872abfa4e7ae04ad1d6ac7dee6 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Fri, 19 Jan 2024 22:20:11 +0900 Subject: [PATCH 02/16] combine all together, maps validated --- src/otx/cli/explain.py | 3 +-- src/otx/config/model/default.yaml | 2 +- src/otx/core/config/explain.py | 2 +- src/otx/core/config/model.py | 4 +++- src/otx/core/model/entity/classification.py | 2 +- src/otx/core/types/explain.py | 17 +++++++++++++++++ src/otx/engine/engine.py | 11 +++++++---- 7 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/otx/cli/explain.py b/src/otx/cli/explain.py index 3310fc76435..d3dd694b709 100644 --- a/src/otx/cli/explain.py +++ b/src/otx/cli/explain.py @@ -51,5 +51,4 @@ def otx_explain(overrides: list[str]) -> None: checkpoint=cfg.checkpoint, device=trainer_kwargs.pop("accelerator", "auto"), ) - saliency_maps = engine.explain() - return saliency_maps + engine.explain() # cfg.base.output_dir, cfg.model.explain_config diff --git a/src/otx/config/model/default.yaml b/src/otx/config/model/default.yaml index 2ad1625cd9b..c0c54ed974f 100644 --- a/src/otx/config/model/default.yaml +++ b/src/otx/config/model/default.yaml @@ -16,4 +16,4 @@ otx_model: explain_config: target_explain_group: PREDICTIONS - postprocess: true + postprocess: false diff --git a/src/otx/core/config/explain.py b/src/otx/core/config/explain.py index 5dc251b4001..ca05019c8db 100644 --- a/src/otx/core/config/explain.py +++ b/src/otx/core/config/explain.py @@ -6,7 +6,7 @@ from dataclasses import dataclass -from otx.core.types.export import OTXExportFormatType, OTXExportPrecisionType +from otx.core.types.explain import TargetExplainGroup @dataclass diff --git a/src/otx/core/config/model.py b/src/otx/core/config/model.py index 603ec7eb873..e0157d8ca70 100644 --- a/src/otx/core/config/model.py +++ b/src/otx/core/config/model.py @@ -5,6 +5,8 @@ from dataclasses import dataclass +from src.otx.core.config.explain import ExplainConfig + @dataclass class ModelConfig: @@ -15,4 +17,4 @@ class ModelConfig: scheduler: dict otx_model: dict torch_compile: bool - explain_config: dict + explain_config: ExplainConfig diff --git a/src/otx/core/model/entity/classification.py b/src/otx/core/model/entity/classification.py index 49252bf3936..392e6b5bf1b 100644 --- a/src/otx/core/model/entity/classification.py +++ b/src/otx/core/model/entity/classification.py @@ -63,7 +63,7 @@ def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor: raise ValueError output = neck(x) - return head(output) + return head([output]) def remove_explain_hook_handle(self) -> None: """Removes explain hook from the model.""" diff --git a/src/otx/core/types/explain.py b/src/otx/core/types/explain.py index e69de29bb2d..b6a42e8bde8 100644 --- a/src/otx/core/types/explain.py +++ b/src/otx/core/types/explain.py @@ -0,0 +1,17 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""OTX explain type definition.""" + +from __future__ import annotations + +from enum import Enum + + +class TargetExplainGroup(str, Enum): + """OTX target explain group definition.""" + + IMAGE = "IMAGE" + ALL = "ALL" + PREDICTIONS = "PREDICTIONS" + CUSTOM = "CUSTOM" diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 4f0c8214f9e..c7f08700c2b 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -336,7 +336,7 @@ def explain( datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, **kwargs, ) -> dict: - """Run the testing phase of the engine. + """Run the explain phase of the engine. Args: datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module containing the test data. @@ -348,7 +348,7 @@ def explain( dict: Dictionary containing the callback metrics from the trainer. Example: - >>> engine.test( + >>> engine.explain( ... datamodule=OTXDataModule(), ... checkpoint=, ... ) @@ -356,7 +356,7 @@ def explain( CLI Usage: 1. you can pick a model. ```python - otx test + otx explain --model --data_root --checkpoint ``` @@ -374,6 +374,8 @@ def explain( datamodule = self.datamodule lit_module.meta_info = datamodule.meta_info + lit_module.model.register_explain_hook() + self._build_trainer(**kwargs) self.trainer.test( @@ -382,7 +384,8 @@ def explain( ckpt_path=str(checkpoint) if checkpoint is not None else self.checkpoint, ) - return self.trainer.callback_metrics + saliency_maps = self.trainer.model.model.explain_hook.records + return saliency_maps @property def trainer(self) -> Trainer: From 23dd1a01a488fce30549f269f237069bf867836d Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Sat, 20 Jan 2024 00:43:49 +0900 Subject: [PATCH 03/16] Based on predict entry --- src/otx/cli/explain.py | 2 +- src/otx/engine/engine.py | 49 +++++++++++++--------------------------- 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/src/otx/cli/explain.py b/src/otx/cli/explain.py index d3dd694b709..7668283c47b 100644 --- a/src/otx/cli/explain.py +++ b/src/otx/cli/explain.py @@ -51,4 +51,4 @@ def otx_explain(overrides: list[str]) -> None: checkpoint=cfg.checkpoint, device=trainer_kwargs.pop("accelerator", "auto"), ) - engine.explain() # cfg.base.output_dir, cfg.model.explain_config + engine.explain(output_dir=cfg.base.output_dir, explain_config=cfg.model.explain_config) diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index c7f08700c2b..ca0dd95b7c8 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -11,6 +11,7 @@ from lightning import Trainer, seed_everything from otx.core.config.device import DeviceConfig +from otx.core.config.explain import ExplainConfig from otx.core.data.module import OTXDataModule from otx.core.model.entity.base import OTXModel from otx.core.model.module.base import OTXLitModule @@ -326,44 +327,22 @@ def export(self, *args, **kwargs) -> None: """Export the trained model to OpenVINO Intermediate Representation (IR) or ONNX formats.""" raise NotImplementedError - # ------------------------------------------------------------------------ # - # Property and setter functions provided by Engine. - # ------------------------------------------------------------------------ # - def explain( self, checkpoint: str | Path | None = None, datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, + output_dir: str | Path | None = None, + explain_config: ExplainConfig | None = None, **kwargs, - ) -> dict: - """Run the explain phase of the engine. + ) -> list | None: + """Run XAI using the specified model and data. Args: - datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module containing the test data. - checkpoint (str | Path | None, optional): Path to the checkpoint file to load the model from. - Defaults to None. + checkpoint (str | Path | None, optional): The path to the checkpoint file to load the model from. + datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module to use for predictions. + output_dir (str | None, optional): Path to save saliency maps. + explain_config (ExplainConfig | None, optional): Config used to handle saliency maps. **kwargs: Additional keyword arguments for pl.Trainer configuration. - - Returns: - dict: Dictionary containing the callback metrics from the trainer. - - Example: - >>> engine.explain( - ... datamodule=OTXDataModule(), - ... checkpoint=, - ... ) - - CLI Usage: - 1. you can pick a model. - ```python - otx explain - --model --data_root - --checkpoint - ``` - 2. If you have a ready configuration file, run it like this. - ```python - otx test --config --checkpoint - ``` """ lit_module = self._build_lightning_module( model=self.model, @@ -378,15 +357,19 @@ def explain( self._build_trainer(**kwargs) - self.trainer.test( + prediction = self.trainer.predict( model=lit_module, - dataloaders=datamodule, + datamodule=datamodule, ckpt_path=str(checkpoint) if checkpoint is not None else self.checkpoint, ) - saliency_maps = self.trainer.model.model.explain_hook.records + # TODO: select, process, and save saliency maps. Should be done here? return saliency_maps + # ------------------------------------------------------------------------ # + # Property and setter functions provided by Engine. + # ------------------------------------------------------------------------ # + @property def trainer(self) -> Trainer: """Returns the trainer object associated with the engine. From d56df17c192af7bb61a464355593fe56e2b95f2f Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Sat, 20 Jan 2024 01:04:30 +0900 Subject: [PATCH 04/16] fix engine --- src/otx/engine/engine.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index ca0dd95b7c8..73299c6ce5c 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -357,14 +357,12 @@ def explain( self._build_trainer(**kwargs) - prediction = self.trainer.predict( + self.trainer.predict( model=lit_module, datamodule=datamodule, ckpt_path=str(checkpoint) if checkpoint is not None else self.checkpoint, ) - saliency_maps = self.trainer.model.model.explain_hook.records - # TODO: select, process, and save saliency maps. Should be done here? - return saliency_maps + return self.trainer.model.model.explain_hook.records # ------------------------------------------------------------------------ # # Property and setter functions provided by Engine. From 1820b890109715708aa3ad2b2d24e4c98bfe8b2f Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Mon, 22 Jan 2024 23:29:09 +0900 Subject: [PATCH 05/16] fix import --- src/otx/core/config/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/otx/core/config/model.py b/src/otx/core/config/model.py index e0157d8ca70..6c475e0920e 100644 --- a/src/otx/core/config/model.py +++ b/src/otx/core/config/model.py @@ -5,7 +5,7 @@ from dataclasses import dataclass -from src.otx.core.config.explain import ExplainConfig +from otx.core.config.explain import ExplainConfig @dataclass From 64ac3a42593e616a6cc712a73ce3ae77bb2417a1 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Mon, 22 Jan 2024 23:40:20 +0900 Subject: [PATCH 06/16] Docstring for TargetExplainGroup --- src/otx/core/types/explain.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/otx/core/types/explain.py b/src/otx/core/types/explain.py index b6a42e8bde8..110fed06d1e 100644 --- a/src/otx/core/types/explain.py +++ b/src/otx/core/types/explain.py @@ -9,9 +9,14 @@ class TargetExplainGroup(str, Enum): - """OTX target explain group definition.""" + """OTX target explain group definition. + + Enum contains the following values: + IMAGE - This implies that single global saliency map will be generated for input image. + ALL - This implies that saliency maps will be generated for all possible targets. + PREDICTIONS - This implies that saliency map will be generated per each prediction. + """ IMAGE = "IMAGE" ALL = "ALL" PREDICTIONS = "PREDICTIONS" - CUSTOM = "CUSTOM" From 5f0f7a3400cc55314d18f5586b2c368785ee19b9 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 23 Jan 2024 02:10:44 +0900 Subject: [PATCH 07/16] dummy save and tests --- src/otx/engine/engine.py | 6 +++- tests/integration/cli/test_cli.py | 40 +++++++++++++++++++++++++ tests/unit/algo/hooks/__init__.py | 2 ++ tests/unit/algo/hooks/test_xai_hooks.py | 32 ++++++++++++++++++++ 4 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 tests/unit/algo/hooks/__init__.py create mode 100644 tests/unit/algo/hooks/test_xai_hooks.py diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 73299c6ce5c..2e2617bce46 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -344,6 +344,8 @@ def explain( explain_config (ExplainConfig | None, optional): Config used to handle saliency maps. **kwargs: Additional keyword arguments for pl.Trainer configuration. """ + import cv2 + lit_module = self._build_lightning_module( model=self.model, optimizer=self.optimizer, @@ -362,7 +364,9 @@ def explain( datamodule=datamodule, ckpt_path=str(checkpoint) if checkpoint is not None else self.checkpoint, ) - return self.trainer.model.model.explain_hook.records + saliency_maps = self.trainer.model.model.explain_hook.records + cv2.imwrite(str(output_dir / "saliency_map.tiff"), saliency_maps[0][0]) + return saliency_maps # ------------------------------------------------------------------------ # # Property and setter functions provided by Engine. diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index b3a9fe6fc56..b3d8016ea9f 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -17,6 +17,11 @@ RECIPE_OV_LIST = [str(_.relative_to(RECIPE_PATH)) for _ in RECIPE_PATH.glob("**/openvino_model.yaml")] RECIPE_LIST = set(ALL_RECIPE_LIST) - set(RECIPE_OV_LIST) +RECIPE_PATH_CLS = Path(inspect.getfile(otx_module)).parent / "recipe" / "multiclass_classification" +RECIPE_PATH_CNN = list(RECIPE_PATH_CLS.glob("**/*efficient*.yaml")) + list(RECIPE_PATH_CLS.glob("**/*mobilenet*.yaml")) +RECIPE_LIST_XAI = [str(_.relative_to(RECIPE_PATH)) for _ in RECIPE_PATH_CNN] + + # [TODO]: This is a temporary approach. DATASET = { "multiclass_classification": { @@ -139,6 +144,41 @@ def test_otx_e2e(recipe: str, tmp_path: Path, fxt_accelerator: str) -> None: assert (tmp_path_test / "outputs" / "lightning_logs").exists() +@pytest.mark.parametrize("recipe", RECIPE_LIST_XAI) +def test_otx_explain_e2e(recipe: str, tmp_path: Path, fxt_accelerator: str) -> None: + """ + Test OTX CLI explain e2e command. + + Args: + recipe (str): The recipe to use for training. (eg. 'classification/otx_mobilenet_v3_large.yaml') + tmp_path (Path): The temporary path for storing the training outputs. + + Returns: + None + """ + task = recipe.split("/")[0] + model_name = recipe.split("/")[1].split(".")[0] + + # otx explain + tmp_path_explain = tmp_path / f"otx_explain_{model_name}" + command_cfg = [ + "otx", + "explain", + f"+recipe={recipe}", + f"base.data_dir={DATASET[task]['data_dir']}", + f"base.work_dir={tmp_path_explain}", + f"base.output_dir={tmp_path_explain / 'outputs'}", + f"trainer={fxt_accelerator}", + *DATASET[task]["overrides"], + ] + + with patch("sys.argv", command_cfg): + main() + + assert (tmp_path_explain / "outputs").exists() + assert (tmp_path_explain / "outputs" / "saliency_map.tiff").exists() + + @pytest.mark.parametrize("recipe", RECIPE_OV_LIST) def test_otx_ov_test(recipe: str, tmp_path: Path) -> None: """ diff --git a/tests/unit/algo/hooks/__init__.py b/tests/unit/algo/hooks/__init__.py new file mode 100644 index 00000000000..916f3a44b27 --- /dev/null +++ b/tests/unit/algo/hooks/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/algo/hooks/test_xai_hooks.py b/tests/unit/algo/hooks/test_xai_hooks.py new file mode 100644 index 00000000000..8525d89ae6e --- /dev/null +++ b/tests/unit/algo/hooks/test_xai_hooks.py @@ -0,0 +1,32 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import torch +from otx.algo.hooks.recording_forward_hook import ReciproCAMHook + + +def test_reciprocam(): + def cls_head_forward_fn(_): + return torch.zeros((25, 2)) + + num_classes = 2 + optimize_gap = False + hook = ReciproCAMHook( + cls_head_forward_fn, + num_classes=num_classes, + optimize_gap=optimize_gap, + ) + + assert hook.handle is None + assert hook.records == [] + assert hook._norm_saliency_maps + + feature_map = torch.zeros((1, 10, 5, 5)) + + saliency_maps = hook.func(feature_map) + assert saliency_maps.size() == torch.Size([1, 2, 5, 5]) + + hook.recording_forward(None, None, feature_map) + assert len(hook.records) == 1 + + hook.reset() + assert hook.records == [] From 0902aa17c5bfc5bb88aaeabc438033cdc40134c4 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 23 Jan 2024 02:12:57 +0900 Subject: [PATCH 08/16] minor --- src/otx/core/types/explain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/otx/core/types/explain.py b/src/otx/core/types/explain.py index 110fed06d1e..5ed8b94bea8 100644 --- a/src/otx/core/types/explain.py +++ b/src/otx/core/types/explain.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # """OTX explain type definition.""" From 37cfc4221c3a62987a9f08a2a05bad008b7b9eac Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Wed, 24 Jan 2024 21:30:27 +0900 Subject: [PATCH 09/16] Update cli and engine --- src/otx/cli/cli.py | 1 + src/otx/core/config/explain.py | 4 ++-- src/otx/engine/engine.py | 21 +++++++++++++++++---- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/otx/cli/cli.py b/src/otx/cli/cli.py index 7d3f7cecf29..4f4ee4cf624 100644 --- a/src/otx/cli/cli.py +++ b/src/otx/cli/cli.py @@ -123,6 +123,7 @@ def engine_subcommands() -> dict[str, set[str]]: "test": {"datamodule"}.union(device_kwargs), "predict": {"datamodule"}.union(device_kwargs), "export": device_kwargs, + "explain": {"datamodule"}.union(device_kwargs), } def add_subcommands(self) -> None: diff --git a/src/otx/core/config/explain.py b/src/otx/core/config/explain.py index ca05019c8db..2f9bc3f8873 100644 --- a/src/otx/core/config/explain.py +++ b/src/otx/core/config/explain.py @@ -13,5 +13,5 @@ class ExplainConfig: """DTO for explain configuration.""" - target_explain_group: TargetExplainGroup - postprocess: bool + target_explain_group: TargetExplainGroup = TargetExplainGroup.ALL + postprocess: bool = False diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 556534a6226..0d15ed6281c 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -332,7 +332,6 @@ def explain( self, checkpoint: str | Path | None = None, datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, - output_dir: str | Path | None = None, explain_config: ExplainConfig | None = None, **kwargs, ) -> list | None: @@ -341,12 +340,26 @@ def explain( Args: checkpoint (str | Path | None, optional): The path to the checkpoint file to load the model from. datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module to use for predictions. - output_dir (str | None, optional): Path to save saliency maps. explain_config (ExplainConfig | None, optional): Config used to handle saliency maps. **kwargs: Additional keyword arguments for pl.Trainer configuration. + + Returns: + list: Saliency maps. + + Example: + >>> engine.explain( + ... datamodule=OTXDataModule(), + ... checkpoint=, + ... explain_config=ExplainConfig(), + ... ) """ + from pathlib import Path import cv2 + ckpt_path = str(checkpoint) if checkpoint is not None else self.checkpoint + if explain_config is None: + explain_config = ExplainConfig() + lit_module = self._build_lightning_module( model=self.model, optimizer=self.optimizer, @@ -363,10 +376,10 @@ def explain( self.trainer.predict( model=lit_module, datamodule=datamodule, - ckpt_path=str(checkpoint) if checkpoint is not None else self.checkpoint, + ckpt_path=ckpt_path, ) saliency_maps = self.trainer.model.model.explain_hook.records - cv2.imwrite(str(output_dir / "saliency_map.tiff"), saliency_maps[0][0]) + cv2.imwrite(str(Path(self.work_dir) / "saliency_map.tiff"), saliency_maps[0][0]) return saliency_maps # ------------------------------------------------------------------------ # From 18ed822f59ce65a25f5eb3ba2349e88933822e31 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Wed, 24 Jan 2024 21:31:09 +0900 Subject: [PATCH 10/16] make predict entry point to work --- src/otx/core/data/module.py | 15 +++++++++++++++ src/otx/core/model/module/base.py | 3 +++ 2 files changed, 18 insertions(+) diff --git a/src/otx/core/data/module.py b/src/otx/core/data/module.py index fc07554f919..b9dead5de3c 100644 --- a/src/otx/core/data/module.py +++ b/src/otx/core/data/module.py @@ -148,6 +148,21 @@ def test_dataloader(self) -> DataLoader: persistent_workers=config.num_workers > 0, ) + def predict_dataloader(self) -> DataLoader: + """Get test dataloader.""" + config = self.config.test_subset + dataset = self._get_dataset(config.subset_name) + + return DataLoader( + dataset=dataset, + batch_size=config.batch_size, + shuffle=False, + num_workers=config.num_workers, + pin_memory=True, + collate_fn=dataset.collate_fn, + persistent_workers=config.num_workers > 0, + ) + def setup(self, stage: str) -> None: """Setup for each stage.""" diff --git a/src/otx/core/model/module/base.py b/src/otx/core/model/module/base.py index 21d878dbad0..b2abdcd6ac1 100644 --- a/src/otx/core/model/module/base.py +++ b/src/otx/core/model/module/base.py @@ -205,3 +205,6 @@ def export(self, output_dir: Path, export_format: OTXExportFormat) -> None: export_format: Format in which this `OTXModel` is exported. """ self.model.export(output_dir, export_format) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.model.forward(*args, **kwargs) From 3ef65d1d3d560fb63396dca457cfe7fd0393ba48 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Wed, 24 Jan 2024 21:31:54 +0900 Subject: [PATCH 11/16] update test explain cli --- tests/integration/cli/test_cli.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index ea606683b47..0bccbadafdc 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -18,11 +18,6 @@ RECIPE_LIST = set(RECIPE_LIST) - set(RECIPE_OV_LIST) -RECIPE_PATH_CLS = Path(inspect.getfile(otx_module)).parent / "recipe" / "multiclass_classification" -RECIPE_PATH_CNN = list(RECIPE_PATH_CLS.glob("**/*efficient*.yaml")) + list(RECIPE_PATH_CLS.glob("**/*mobilenet*.yaml")) -RECIPE_LIST_XAI = [str(_.relative_to(RECIPE_PATH)) for _ in RECIPE_PATH_CNN] - - # [TODO]: This is a temporary approach. DATASET = { "multi_class_cls": { @@ -152,7 +147,7 @@ def test_otx_e2e(recipe: str, tmp_path: Path, fxt_accelerator: str) -> None: assert (tmp_path_test / "outputs" / "csv").exists() -@pytest.mark.parametrize("recipe", RECIPE_LIST_XAI) +@pytest.mark.parametrize("recipe", RECIPE_LIST) def test_otx_explain_e2e(recipe: str, tmp_path: Path, fxt_accelerator: str) -> None: """ Test OTX CLI explain e2e command. @@ -164,19 +159,28 @@ def test_otx_explain_e2e(recipe: str, tmp_path: Path, fxt_accelerator: str) -> N Returns: None """ - task = recipe.split("/")[0] - model_name = recipe.split("/")[1].split(".")[0] + task = recipe.split("/")[-2] + model_name = recipe.split("/")[-1].split(".")[0] + + if "_cls" not in task: + pytest.skip("Supported only for classification tast.") + + if "deit" in model_name or "dino" in model_name: + pytest.skip("Supported only for CNN models.") # otx explain tmp_path_explain = tmp_path / f"otx_explain_{model_name}" command_cfg = [ "otx", "explain", - f"+recipe={recipe}", - f"base.data_dir={DATASET[task]['data_dir']}", - f"base.work_dir={tmp_path_explain}", - f"base.output_dir={tmp_path_explain / 'outputs'}", - f"trainer={fxt_accelerator}", + "--config", + recipe, + "--data_root", + DATASET[task]["data_root"], + "--engine.work_dir", + str(tmp_path_explain / "outputs"), + "--engine.device", + fxt_accelerator, *DATASET[task]["overrides"], ] From a43348947733df57b0baf41b1ab7e3e0728f14fe Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 25 Jan 2024 01:02:04 +0900 Subject: [PATCH 12/16] fix linters --- src/otx/cli/explain.py | 54 ---------------------------------------- src/otx/engine/engine.py | 1 + 2 files changed, 1 insertion(+), 54 deletions(-) delete mode 100644 src/otx/cli/explain.py diff --git a/src/otx/cli/explain.py b/src/otx/cli/explain.py deleted file mode 100644 index 7668283c47b..00000000000 --- a/src/otx/cli/explain.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""CLI entrypoint for model XAI.""" - -from __future__ import annotations - -import hydra -import logging as log -from hydra import compose, initialize -from otx.core.model.entity.base import OTXModel -from otx.cli.utils.hydra import configure_hydra_outputs - - -def otx_explain(overrides: list[str]) -> None: - """Main entry point for XAI. - - :param overrides: Override List values. - :return: Optional[float] with optimized metric value. - """ - from otx.core.config import register_configs - - # This should be in front of hydra.initialize() - register_configs() - - with initialize(config_path="../config", version_base="1.3", job_name="otx_explain"): - cfg = compose(config_name="test", overrides=overrides, return_hydra_config=True) - configure_hydra_outputs(cfg) - - # explain the model - from otx.core.data.module import OTXDataModule - - log.info(f"Instantiating datamodule <{cfg.data}>") - datamodule = OTXDataModule(task=cfg.base.task, config=cfg.data) - - log.info(f"Instantiating model <{cfg.model}>") - model: OTXModel = hydra.utils.instantiate(cfg.model.otx_model) - optimizer = hydra.utils.instantiate(cfg.model.optimizer) - scheduler = hydra.utils.instantiate(cfg.model.scheduler) - - from otx.engine import Engine - - trainer_kwargs = {**cfg.trainer} - engine = Engine( - task=cfg.base.task, - work_dir=cfg.base.output_dir, - model=model, - optimizer=optimizer, - scheduler=scheduler, - datamodule=datamodule, - checkpoint=cfg.checkpoint, - device=trainer_kwargs.pop("accelerator", "auto"), - ) - engine.explain(output_dir=cfg.base.output_dir, explain_config=cfg.model.explain_config) diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 0d15ed6281c..7271a94a0c5 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -354,6 +354,7 @@ def explain( ... ) """ from pathlib import Path + import cv2 ckpt_path = str(checkpoint) if checkpoint is not None else self.checkpoint From 15f8dc6256a4f7591c531a38e619998d76553a56 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 25 Jan 2024 01:22:02 +0900 Subject: [PATCH 13/16] minor --- src/otx/core/model/module/base.py | 3 ++- tests/unit/algo/hooks/test_xai_hooks.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/otx/core/model/module/base.py b/src/otx/core/model/module/base.py index b2abdcd6ac1..e5ddb9832d6 100644 --- a/src/otx/core/model/module/base.py +++ b/src/otx/core/model/module/base.py @@ -206,5 +206,6 @@ def export(self, output_dir: Path, export_format: OTXExportFormat) -> None: """ self.model.export(output_dir, export_format) - def forward(self, *args: Any, **kwargs: Any) -> Any: + def forward(self, *args, **kwargs): + """Model forward pass.""" return self.model.forward(*args, **kwargs) diff --git a/tests/unit/algo/hooks/test_xai_hooks.py b/tests/unit/algo/hooks/test_xai_hooks.py index 8525d89ae6e..450d998364e 100644 --- a/tests/unit/algo/hooks/test_xai_hooks.py +++ b/tests/unit/algo/hooks/test_xai_hooks.py @@ -4,8 +4,8 @@ from otx.algo.hooks.recording_forward_hook import ReciproCAMHook -def test_reciprocam(): - def cls_head_forward_fn(_): +def test_reciprocam() -> None: + def cls_head_forward_fn(_) -> None: return torch.zeros((25, 2)) num_classes = 2 From c0a8804a07a1edae72876f42855c185f9fc1a442 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 25 Jan 2024 01:28:22 +0900 Subject: [PATCH 14/16] minor --- src/otx/core/model/module/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/otx/core/model/module/base.py b/src/otx/core/model/module/base.py index e5ddb9832d6..cdf61a03174 100644 --- a/src/otx/core/model/module/base.py +++ b/src/otx/core/model/module/base.py @@ -12,7 +12,10 @@ from lightning import LightningModule from torch import Tensor -from otx.core.data.entity.base import OTXBatchDataEntity +from otx.core.data.entity.base import ( + OTXBatchDataEntity, + OTXBatchPredEntity, +) from otx.core.model.entity.base import OTXModel from otx.core.types.export import OTXExportFormat from otx.core.utils.utils import is_ckpt_for_finetuning, is_ckpt_from_otx_v1 @@ -206,6 +209,6 @@ def export(self, output_dir: Path, export_format: OTXExportFormat) -> None: """ self.model.export(output_dir, export_format) - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> OTXBatchPredEntity: """Model forward pass.""" return self.model.forward(*args, **kwargs) From f3d994b408514e5f4275c6a684613cd48f9da599 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 25 Jan 2024 01:42:11 +0900 Subject: [PATCH 15/16] minor --- src/otx/core/model/module/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/otx/core/model/module/base.py b/src/otx/core/model/module/base.py index cdf61a03174..cfbe95fb997 100644 --- a/src/otx/core/model/module/base.py +++ b/src/otx/core/model/module/base.py @@ -14,6 +14,7 @@ from otx.core.data.entity.base import ( OTXBatchDataEntity, + OTXBatchLossEntity, OTXBatchPredEntity, ) from otx.core.model.entity.base import OTXModel @@ -209,6 +210,6 @@ def export(self, output_dir: Path, export_format: OTXExportFormat) -> None: """ self.model.export(output_dir, export_format) - def forward(self, *args, **kwargs) -> OTXBatchPredEntity: + def forward(self, *args, **kwargs) -> OTXBatchPredEntity | OTXBatchLossEntity: """Model forward pass.""" return self.model.forward(*args, **kwargs) From a63889a6ad1bf54d2bcc9002d23ac85845fb5257 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 25 Jan 2024 22:23:01 +0900 Subject: [PATCH 16/16] resolve comments --- src/otx/core/config/explain.py | 2 +- src/otx/engine/engine.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/otx/core/config/explain.py b/src/otx/core/config/explain.py index 2f9bc3f8873..dedd837a5f3 100644 --- a/src/otx/core/config/explain.py +++ b/src/otx/core/config/explain.py @@ -11,7 +11,7 @@ @dataclass class ExplainConfig: - """DTO for explain configuration.""" + """Data Transfer Object (DTO) for explain configuration.""" target_explain_group: TargetExplainGroup = TargetExplainGroup.ALL postprocess: bool = False diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 7271a94a0c5..4a55dd7cd2d 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -5,6 +5,7 @@ from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable import torch @@ -20,8 +21,6 @@ from otx.core.utils.cache import TrainerArgumentsCache if TYPE_CHECKING: - from pathlib import Path - from lightning import Callback from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from lightning.pytorch.loggers import Logger @@ -353,8 +352,6 @@ def explain( ... explain_config=ExplainConfig(), ... ) """ - from pathlib import Path - import cv2 ckpt_path = str(checkpoint) if checkpoint is not None else self.checkpoint @@ -379,7 +376,9 @@ def explain( datamodule=datamodule, ckpt_path=ckpt_path, ) + # Optimize for memory <- TODO(negvet) saliency_maps = self.trainer.model.model.explain_hook.records + # Temporary saving saliency map for image 0, class 0 (for tests) cv2.imwrite(str(Path(self.work_dir) / "saliency_map.tiff"), saliency_maps[0][0]) return saliency_maps