From 866cc0e782f6af5673679e1887d3357b8b92b244 Mon Sep 17 00:00:00 2001 From: Galina Date: Fri, 19 Apr 2024 00:52:14 +0900 Subject: [PATCH 1/3] XAI for rmdet_inst_tiny --- src/otx/algo/utils/xai_utils.py | 4 ++ src/otx/core/model/instance_segmentation.py | 42 ++++++++++++++------- tests/e2e/cli/test_cli.py | 8 ++-- tests/integration/api/test_xai.py | 8 ++-- tests/integration/cli/test_cli.py | 8 ++-- 5 files changed, 45 insertions(+), 25 deletions(-) diff --git a/src/otx/algo/utils/xai_utils.py b/src/otx/algo/utils/xai_utils.py index 9718e498e57..2466a2a367f 100644 --- a/src/otx/algo/utils/xai_utils.py +++ b/src/otx/algo/utils/xai_utils.py @@ -96,6 +96,10 @@ def convert_maps_to_dict_all(saliency_map: list[np.ndarray]) -> list[dict[Any, n """Convert salincy maps to dict for TargetExplainGroup.ALL.""" processed_saliency_maps = [] for maps_per_image in saliency_map: + if maps_per_image.size == 0: + processed_saliency_maps.append({0: np.zeros((1, 1, 1))}) + continue + if maps_per_image.ndim != 3: msg = "Shape mismatch." raise ValueError(msg) diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index 121c6ca585e..c997f0f976c 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from mmdet.models.data_preprocessors import DetDataPreprocessor - from mmdet.models.detectors.two_stage import TwoStageDetector + from mmdet.models.detectors import TwoStageDetector from mmdet.structures import OptSampleList from omegaconf import DictConfig from openvino.model_api.models.utils import InstanceSegmentationResult @@ -224,10 +224,9 @@ def __init__( torch_compile=torch_compile, ) - from otx.algo.explain.explain_algo import get_feature_vector - self.model.feature_vector_fn = get_feature_vector self.model.explain_fn = self.get_explain_fn() + self.model.get_results_from_head = self.get_results_from_head def forward_explain(self, inputs: InstanceSegBatchDataEntity) -> InstanceSegBatchPredEntity: """Model forward function.""" @@ -266,22 +265,17 @@ def _forward_explain_inst_seg( x = self.extract_feat(inputs) feature_vector = self.feature_vector_fn(x) + predictions = self.get_results_from_head(x, data_samples) - rpn_results_list = self.rpn_head.predict(x, data_samples, rescale=False) - results_list = self.roi_head.predict(x, rpn_results_list, data_samples, rescale=True) - - if isinstance(results_list, tuple) and isinstance(results_list[0], torch.Tensor): # rewrite + if isinstance(predictions, tuple) and isinstance(predictions[0], torch.Tensor): # Export case, consists of tensors - predictions = results_list # For OV task saliency map are generated on MAPI side saliency_map = torch.empty(1, dtype=torch.uint8) - elif isinstance(results_list, list) and isinstance(results_list[0], InstanceData): # rewrite + elif isinstance(predictions, list) and isinstance(predictions[0], InstanceData): # Predict case, consists of InstanceData - predictions = self.add_pred_to_datasample(data_samples, results_list) - - features_for_sal_map = [data_sample.pred_instances for data_sample in data_samples] - saliency_map = self.explain_fn(features_for_sal_map) + saliency_map = self.explain_fn(predictions) + predictions = self.add_pred_to_datasample(data_samples, predictions) return { "predictions": predictions, @@ -289,6 +283,28 @@ def _forward_explain_inst_seg( "saliency_map": saliency_map, } + def get_results_from_head( + self, + x: tuple[torch.Tensor], + data_samples: OptSampleList | None, + ) -> tuple[torch.Tensor] | list[InstanceData]: + """Get the results from the head of the instance segmentation model. + + Args: + x (tuple[torch.Tensor]): The features from backbone and neck. + data_samples (OptSampleList | None): A list of data samples. + + Returns: + tuple[torch.Tensor] | list[InstanceData]: The predicted results from the head of the model. + Tuple for the Export case, list for the Predict case. + """ + from otx.algo.instance_segmentation.rtmdet_inst import RTMDetInst + + if isinstance(self, RTMDetInst): + return self.model.bbox_head.predict(x, data_samples, rescale=False) + rpn_results_list = self.model.rpn_head.predict(x, data_samples, rescale=False) + return self.model.roi_head.predict(x, rpn_results_list, data_samples, rescale=True) + def get_explain_fn(self) -> Callable: """Returns explain function.""" from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo diff --git a/tests/e2e/cli/test_cli.py b/tests/e2e/cli/test_cli.py index d65908fab33..278bd6b6aae 100644 --- a/tests/e2e/cli/test_cli.py +++ b/tests/e2e/cli/test_cli.py @@ -210,8 +210,8 @@ def test_otx_e2e_cli( if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]): return # Supported only for classification, detection and instance segmentation task. - if "dino" in model_name or "rtmdet_inst_tiny" in model_name: - return # DINO and Rtmdet_tiny are not supported. + if "dino" in model_name: + return # DINO and is not supported. format_to_file = { "ONNX": "exported_model.onnx", @@ -284,8 +284,8 @@ def test_otx_explain_e2e_cli( if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]): pytest.skip("Supported only for classification, detection and instance segmentation task.") - if "dino" in model_name or "rtmdet_inst_tiny" in model_name: - pytest.skip("DINO and Rtmdet_tiny are not supported.") + if "dino" in model_name: + pytest.skip("DINO is not supported.") # otx explain tmp_path_explain = tmp_path / f"otx_explain_{model_name}" diff --git a/tests/integration/api/test_xai.py b/tests/integration/api/test_xai.py index f8d1fa724be..8554c21ff9b 100644 --- a/tests/integration/api/test_xai.py +++ b/tests/integration/api/test_xai.py @@ -44,8 +44,8 @@ def test_forward_explain( task = recipe.split("/")[-2] model_name = recipe.split("/")[-1].split(".")[0] - if "dino" in model_name or "rtmdet_inst_tiny" in model_name: - pytest.skip("DINO and Rtmdet_tiny are not supported.") + if "dino" in model_name: + pytest.skip("DINO is not supported.") engine = Engine.from_config( config_path=recipe, @@ -92,8 +92,8 @@ def test_predict_with_explain( task = recipe.split("/")[-2] model_name = recipe.split("/")[-1].split(".")[0] - if "dino" in model_name or "rtmdet_inst_tiny" in model_name: - pytest.skip("DINO and Rtmdet_tiny are not supported.") + if "dino" in model_name: + pytest.skip("DINO is not supported.") if "mobilenet_v3_large" in model_name: pytest.skip("There's issue with mobilenet_v3_large model. Skip for now.") diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index 07acab85b16..d5c3b0a0333 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -222,8 +222,8 @@ def test_otx_e2e( if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]): return # Supported only for classification, detection and instance segmentation task. - if "dino" in model_name or "rtmdet_inst_tiny" in model_name: - return # DINO and Rtmdet_tiny are not supported. + if "dino" in model_name: + return # DINO is not supported. format_to_file = { "ONNX": "exported_model.onnx", @@ -294,8 +294,8 @@ def test_otx_explain_e2e( if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]): pytest.skip("Supported only for classification, detection and instance segmentation task.") - if "dino" in model_name or "rtmdet_inst_tiny" in model_name: - pytest.skip("DINO and Rtmdet_tiny are not supported.") + if "dino" in model_name: + pytest.skip("DINO is not supported.") # otx explain tmp_path_explain = tmp_path / f"otx_explain_{model_name}" From 24ae02470690b9767d2ed335edabaf37ebd845f3 Mon Sep 17 00:00:00 2001 From: Galina Date: Sat, 20 Apr 2024 04:01:10 +0900 Subject: [PATCH 2/3] Rename explain hook --- docs/source/guide/explanation/additional_features/xai.rst | 2 +- docs/source/guide/tutorials/base/explain.rst | 2 +- src/otx/algo/explain/explain_algo.py | 7 +++++-- src/otx/core/model/instance_segmentation.py | 6 ++---- src/otx/core/utils/tile_merge.py | 5 ++--- tests/unit/algo/explain/test_xai_algorithms.py | 6 +++--- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/source/guide/explanation/additional_features/xai.rst b/docs/source/guide/explanation/additional_features/xai.rst index 49c30c31194..4356e1ef5e1 100644 --- a/docs/source/guide/explanation/additional_features/xai.rst +++ b/docs/source/guide/explanation/additional_features/xai.rst @@ -100,7 +100,7 @@ XAI algorithms for instance segmentation For instance segmentation networks the following algorithm is used to generate saliency maps: -- **MaskRCNNExplainAlgo​** - in this approach the predicted object masks are combined and aggregated per class to generate the saliency maps for each class. +- **InstSegExplainAlgo​** - in this approach the predicted object masks are combined and aggregated per class to generate the saliency maps for each class. .. tab-set:: diff --git a/docs/source/guide/tutorials/base/explain.rst b/docs/source/guide/tutorials/base/explain.rst index c94875f3fda..bf2af135783 100644 --- a/docs/source/guide/tutorials/base/explain.rst +++ b/docs/source/guide/tutorials/base/explain.rst @@ -124,7 +124,7 @@ based on the used model: - ``ViT Recipro-CAM`` - for transformer-based classification models - ``DetClassProbabilityMap`` - for single-stage detector models -- ``MaskRCNNExplainAlgo`` - for MaskRCNN instance segmentation models +- ``InstSegExplainAlgo`` - for MaskRCNN and RTMDetInst instance segmentation models .. note:: diff --git a/src/otx/algo/explain/explain_algo.py b/src/otx/algo/explain/explain_algo.py index 07e521db462..1aafa59d648 100644 --- a/src/otx/algo/explain/explain_algo.py +++ b/src/otx/algo/explain/explain_algo.py @@ -311,8 +311,11 @@ def func( return saliency_map.reshape((batch_size, self._num_classes, height, width)) -class MaskRCNNExplainAlgo(BaseExplainAlgo): - """Dummy saliency map algo for Mask R-CNN model.""" +class InstSegExplainAlgo(BaseExplainAlgo): + """Dummy saliency map algo for Mask R-CNN and RTMDetInst model. + + Predicted masks are combined and aggregated per-class to generate the saliency maps. + """ def __init__(self, num_classes: int) -> None: super().__init__() diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index 339405f61f8..064c6bf2e59 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -17,7 +17,7 @@ from openvino.model_api.tilers import InstanceSegmentationTiler from torchvision import tv_tensors -from otx.algo.explain.explain_algo import feature_vector_fn +from otx.algo.explain.explain_algo import InstSegExplainAlgo, feature_vector_fn from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity @@ -309,9 +309,7 @@ def get_results_from_head( def get_explain_fn(self) -> Callable: """Returns explain function.""" - from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo - - explainer = MaskRCNNExplainAlgo(num_classes=self.num_classes) + explainer = InstSegExplainAlgo(num_classes=self.num_classes) return explainer.func @contextmanager diff --git a/src/otx/core/utils/tile_merge.py b/src/otx/core/utils/tile_merge.py index a99cb5d24aa..609d1ced365 100644 --- a/src/otx/core/utils/tile_merge.py +++ b/src/otx/core/utils/tile_merge.py @@ -15,6 +15,7 @@ from torchvision import tv_tensors from torchvision.ops import batched_nms +from otx.algo.explain.explain_algo import InstSegExplainAlgo from otx.core.config.data import TileConfig from otx.core.data.entity.base import ImageInfo, T_OTXBatchPredEntity, T_OTXDataEntity from otx.core.data.entity.detection import DetBatchPredEntity, DetPredEntity @@ -431,10 +432,8 @@ def get_saliency_maps_from_masks( Returns: np.array: Class-wise Saliency Maps. One saliency map per each class - [class_id, H, W] """ - from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo - if masks is None: return np.ndarray([]) pred = {"labels": labels, "scores": scores, "masks": masks} - return MaskRCNNExplainAlgo.average_and_normalize(pred, num_classes) + return InstSegExplainAlgo.average_and_normalize(pred, num_classes) diff --git a/tests/unit/algo/explain/test_xai_algorithms.py b/tests/unit/algo/explain/test_xai_algorithms.py index 141b52a00de..fc4e5b2f562 100644 --- a/tests/unit/algo/explain/test_xai_algorithms.py +++ b/tests/unit/algo/explain/test_xai_algorithms.py @@ -5,7 +5,7 @@ from otx.algo.explain.explain_algo import ( ActivationMap, DetClassProbabilityMap, - MaskRCNNExplainAlgo, + InstSegExplainAlgo, ReciproCAM, ViTReciproCAM, ) @@ -80,9 +80,9 @@ def test_detclassprob() -> None: assert saliency_maps.size() == torch.Size([5, 2, 2, 2]) -def test_maskrcnn() -> None: +def test_instseg() -> None: num_classes = 2 - explain_algo = MaskRCNNExplainAlgo( + explain_algo = InstSegExplainAlgo( num_classes=num_classes, ) From d029ca2761c39092abe7460b0eb883f33156ba3d Mon Sep 17 00:00:00 2001 From: Galina Date: Sat, 20 Apr 2024 04:16:32 +0900 Subject: [PATCH 3/3] Minor --- tests/e2e/cli/test_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/cli/test_cli.py b/tests/e2e/cli/test_cli.py index 278bd6b6aae..f0909df548d 100644 --- a/tests/e2e/cli/test_cli.py +++ b/tests/e2e/cli/test_cli.py @@ -211,7 +211,7 @@ def test_otx_e2e_cli( return # Supported only for classification, detection and instance segmentation task. if "dino" in model_name: - return # DINO and is not supported. + return # DINO is not supported. format_to_file = { "ONNX": "exported_model.onnx",