diff --git a/src/otx/cli/cli.py b/src/otx/cli/cli.py index 7d3f7cecf29..4f4ee4cf624 100644 --- a/src/otx/cli/cli.py +++ b/src/otx/cli/cli.py @@ -123,6 +123,7 @@ def engine_subcommands() -> dict[str, set[str]]: "test": {"datamodule"}.union(device_kwargs), "predict": {"datamodule"}.union(device_kwargs), "export": device_kwargs, + "explain": {"datamodule"}.union(device_kwargs), } def add_subcommands(self) -> None: diff --git a/src/otx/core/config/explain.py b/src/otx/core/config/explain.py new file mode 100644 index 00000000000..dedd837a5f3 --- /dev/null +++ b/src/otx/core/config/explain.py @@ -0,0 +1,17 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Config data type objects for export method.""" +from __future__ import annotations + +from dataclasses import dataclass + +from otx.core.types.explain import TargetExplainGroup + + +@dataclass +class ExplainConfig: + """Data Transfer Object (DTO) for explain configuration.""" + + target_explain_group: TargetExplainGroup = TargetExplainGroup.ALL + postprocess: bool = False diff --git a/src/otx/core/data/module.py b/src/otx/core/data/module.py index fc07554f919..b9dead5de3c 100644 --- a/src/otx/core/data/module.py +++ b/src/otx/core/data/module.py @@ -148,6 +148,21 @@ def test_dataloader(self) -> DataLoader: persistent_workers=config.num_workers > 0, ) + def predict_dataloader(self) -> DataLoader: + """Get test dataloader.""" + config = self.config.test_subset + dataset = self._get_dataset(config.subset_name) + + return DataLoader( + dataset=dataset, + batch_size=config.batch_size, + shuffle=False, + num_workers=config.num_workers, + pin_memory=True, + collate_fn=dataset.collate_fn, + persistent_workers=config.num_workers > 0, + ) + def setup(self, stage: str) -> None: """Setup for each stage.""" diff --git a/src/otx/core/model/entity/classification.py b/src/otx/core/model/entity/classification.py index 0ed17864353..6199af508a5 100644 --- a/src/otx/core/model/entity/classification.py +++ b/src/otx/core/model/entity/classification.py @@ -64,7 +64,7 @@ def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor: raise ValueError output = neck(x) - return head(output) + return head([output]) def remove_explain_hook_handle(self) -> None: """Removes explain hook from the model.""" diff --git a/src/otx/core/model/module/base.py b/src/otx/core/model/module/base.py index 21d878dbad0..cfbe95fb997 100644 --- a/src/otx/core/model/module/base.py +++ b/src/otx/core/model/module/base.py @@ -12,7 +12,11 @@ from lightning import LightningModule from torch import Tensor -from otx.core.data.entity.base import OTXBatchDataEntity +from otx.core.data.entity.base import ( + OTXBatchDataEntity, + OTXBatchLossEntity, + OTXBatchPredEntity, +) from otx.core.model.entity.base import OTXModel from otx.core.types.export import OTXExportFormat from otx.core.utils.utils import is_ckpt_for_finetuning, is_ckpt_from_otx_v1 @@ -205,3 +209,7 @@ def export(self, output_dir: Path, export_format: OTXExportFormat) -> None: export_format: Format in which this `OTXModel` is exported. """ self.model.export(output_dir, export_format) + + def forward(self, *args, **kwargs) -> OTXBatchPredEntity | OTXBatchLossEntity: + """Model forward pass.""" + return self.model.forward(*args, **kwargs) diff --git a/src/otx/core/types/explain.py b/src/otx/core/types/explain.py new file mode 100644 index 00000000000..5ed8b94bea8 --- /dev/null +++ b/src/otx/core/types/explain.py @@ -0,0 +1,22 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""OTX explain type definition.""" + +from __future__ import annotations + +from enum import Enum + + +class TargetExplainGroup(str, Enum): + """OTX target explain group definition. + + Enum contains the following values: + IMAGE - This implies that single global saliency map will be generated for input image. + ALL - This implies that saliency maps will be generated for all possible targets. + PREDICTIONS - This implies that saliency map will be generated per each prediction. + """ + + IMAGE = "IMAGE" + ALL = "ALL" + PREDICTIONS = "PREDICTIONS" diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 2a2c7e17f29..4a55dd7cd2d 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -5,12 +5,14 @@ from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable import torch from lightning import Trainer, seed_everything from otx.core.config.device import DeviceConfig +from otx.core.config.explain import ExplainConfig from otx.core.data.module import OTXDataModule from otx.core.model.entity.base import OTXModel from otx.core.model.module.base import OTXLitModule @@ -19,8 +21,6 @@ from otx.core.utils.cache import TrainerArgumentsCache if TYPE_CHECKING: - from pathlib import Path - from lightning import Callback from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from lightning.pytorch.loggers import Logger @@ -327,6 +327,61 @@ def export(self, *args, **kwargs) -> None: """Export the trained model to OpenVINO Intermediate Representation (IR) or ONNX formats.""" raise NotImplementedError + def explain( + self, + checkpoint: str | Path | None = None, + datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, + explain_config: ExplainConfig | None = None, + **kwargs, + ) -> list | None: + """Run XAI using the specified model and data. + + Args: + checkpoint (str | Path | None, optional): The path to the checkpoint file to load the model from. + datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module to use for predictions. + explain_config (ExplainConfig | None, optional): Config used to handle saliency maps. + **kwargs: Additional keyword arguments for pl.Trainer configuration. + + Returns: + list: Saliency maps. + + Example: + >>> engine.explain( + ... datamodule=OTXDataModule(), + ... checkpoint=, + ... explain_config=ExplainConfig(), + ... ) + """ + import cv2 + + ckpt_path = str(checkpoint) if checkpoint is not None else self.checkpoint + if explain_config is None: + explain_config = ExplainConfig() + + lit_module = self._build_lightning_module( + model=self.model, + optimizer=self.optimizer, + scheduler=self.scheduler, + ) + if datamodule is None: + datamodule = self.datamodule + lit_module.meta_info = datamodule.meta_info + + lit_module.model.register_explain_hook() + + self._build_trainer(**kwargs) + + self.trainer.predict( + model=lit_module, + datamodule=datamodule, + ckpt_path=ckpt_path, + ) + # Optimize for memory <- TODO(negvet) + saliency_maps = self.trainer.model.model.explain_hook.records + # Temporary saving saliency map for image 0, class 0 (for tests) + cv2.imwrite(str(Path(self.work_dir) / "saliency_map.tiff"), saliency_maps[0][0]) + return saliency_maps + # ------------------------------------------------------------------------ # # Property and setter functions provided by Engine. # ------------------------------------------------------------------------ # diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index 939a667921e..0bccbadafdc 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -147,6 +147,50 @@ def test_otx_e2e(recipe: str, tmp_path: Path, fxt_accelerator: str) -> None: assert (tmp_path_test / "outputs" / "csv").exists() +@pytest.mark.parametrize("recipe", RECIPE_LIST) +def test_otx_explain_e2e(recipe: str, tmp_path: Path, fxt_accelerator: str) -> None: + """ + Test OTX CLI explain e2e command. + + Args: + recipe (str): The recipe to use for training. (eg. 'classification/otx_mobilenet_v3_large.yaml') + tmp_path (Path): The temporary path for storing the training outputs. + + Returns: + None + """ + task = recipe.split("/")[-2] + model_name = recipe.split("/")[-1].split(".")[0] + + if "_cls" not in task: + pytest.skip("Supported only for classification tast.") + + if "deit" in model_name or "dino" in model_name: + pytest.skip("Supported only for CNN models.") + + # otx explain + tmp_path_explain = tmp_path / f"otx_explain_{model_name}" + command_cfg = [ + "otx", + "explain", + "--config", + recipe, + "--data_root", + DATASET[task]["data_root"], + "--engine.work_dir", + str(tmp_path_explain / "outputs"), + "--engine.device", + fxt_accelerator, + *DATASET[task]["overrides"], + ] + + with patch("sys.argv", command_cfg): + main() + + assert (tmp_path_explain / "outputs").exists() + assert (tmp_path_explain / "outputs" / "saliency_map.tiff").exists() + + @pytest.mark.parametrize("recipe", RECIPE_OV_LIST) def test_otx_ov_test(recipe: str, tmp_path: Path) -> None: """ diff --git a/tests/unit/algo/hooks/__init__.py b/tests/unit/algo/hooks/__init__.py new file mode 100644 index 00000000000..916f3a44b27 --- /dev/null +++ b/tests/unit/algo/hooks/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/algo/hooks/test_xai_hooks.py b/tests/unit/algo/hooks/test_xai_hooks.py new file mode 100644 index 00000000000..450d998364e --- /dev/null +++ b/tests/unit/algo/hooks/test_xai_hooks.py @@ -0,0 +1,32 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import torch +from otx.algo.hooks.recording_forward_hook import ReciproCAMHook + + +def test_reciprocam() -> None: + def cls_head_forward_fn(_) -> None: + return torch.zeros((25, 2)) + + num_classes = 2 + optimize_gap = False + hook = ReciproCAMHook( + cls_head_forward_fn, + num_classes=num_classes, + optimize_gap=optimize_gap, + ) + + assert hook.handle is None + assert hook.records == [] + assert hook._norm_saliency_maps + + feature_map = torch.zeros((1, 10, 5, 5)) + + saliency_maps = hook.func(feature_map) + assert saliency_maps.size() == torch.Size([1, 2, 5, 5]) + + hook.recording_forward(None, None, feature_map) + assert len(hook.records) == 1 + + hook.reset() + assert hook.records == []