Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XAI support for rmdet_inst_tiny #3356

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ XAI algorithms for instance segmentation

For instance segmentation networks the following algorithm is used to generate saliency maps:

- **MaskRCNNExplainAlgo​** - in this approach the predicted object masks are combined and aggregated per class to generate the saliency maps for each class.
- **InstSegExplainAlgo​** - in this approach the predicted object masks are combined and aggregated per class to generate the saliency maps for each class.


.. tab-set::
Expand Down
2 changes: 1 addition & 1 deletion docs/source/guide/tutorials/base/explain.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ based on the used model:
- ``ViT Recipro-CAM`` - for transformer-based classification models

- ``DetClassProbabilityMap`` - for single-stage detector models
- ``MaskRCNNExplainAlgo`` - for MaskRCNN instance segmentation models
- ``InstSegExplainAlgo`` - for MaskRCNN and RTMDetInst instance segmentation models

.. note::

Expand Down
7 changes: 5 additions & 2 deletions src/otx/algo/explain/explain_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,11 @@ def func(
return saliency_map.reshape((batch_size, self._num_classes, height, width))


class MaskRCNNExplainAlgo(BaseExplainAlgo):
"""Dummy saliency map algo for Mask R-CNN model."""
class InstSegExplainAlgo(BaseExplainAlgo):
"""Dummy saliency map algo for Mask R-CNN and RTMDetInst model.

Predicted masks are combined and aggregated per-class to generate the saliency maps.
"""

def __init__(self, num_classes: int) -> None:
super().__init__()
Expand Down
4 changes: 4 additions & 0 deletions src/otx/algo/utils/xai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def convert_maps_to_dict_all(saliency_map: list[np.ndarray]) -> list[dict[Any, n
"""Convert salincy maps to dict for TargetExplainGroup.ALL."""
processed_saliency_maps = []
for maps_per_image in saliency_map:
if maps_per_image.size == 0:
processed_saliency_maps.append({0: np.zeros((1, 1, 1))})
continue

if maps_per_image.ndim != 3:
msg = "Shape mismatch."
raise ValueError(msg)
Expand Down
48 changes: 31 additions & 17 deletions src/otx/core/model/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from openvino.model_api.tilers import InstanceSegmentationTiler
from torchvision import tv_tensors

from otx.algo.explain.explain_algo import feature_vector_fn
from otx.algo.explain.explain_algo import InstSegExplainAlgo, feature_vector_fn
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity
Expand All @@ -35,7 +35,7 @@
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from mmdet.models.data_preprocessors import DetDataPreprocessor
from mmdet.models.detectors.two_stage import TwoStageDetector
from mmdet.models.detectors import TwoStageDetector
from mmdet.structures import OptSampleList
from omegaconf import DictConfig
from openvino.model_api.models.utils import InstanceSegmentationResult
Expand Down Expand Up @@ -226,10 +226,9 @@ def __init__(
torch_compile=torch_compile,
)

from otx.algo.explain.explain_algo import feature_vector_fn

self.model.feature_vector_fn = feature_vector_fn
self.model.explain_fn = self.get_explain_fn()
self.model.get_results_from_head = self.get_results_from_head

def forward_explain(self, inputs: InstanceSegBatchDataEntity) -> InstanceSegBatchPredEntity:
"""Model forward function."""
Expand Down Expand Up @@ -268,34 +267,49 @@ def _forward_explain_inst_seg(
x = self.extract_feat(inputs)

feature_vector = self.feature_vector_fn(x)
predictions = self.get_results_from_head(x, data_samples)

rpn_results_list = self.rpn_head.predict(x, data_samples, rescale=False)
results_list = self.roi_head.predict(x, rpn_results_list, data_samples, rescale=True)

if isinstance(results_list, tuple) and isinstance(results_list[0], torch.Tensor): # rewrite
if isinstance(predictions, tuple) and isinstance(predictions[0], torch.Tensor):
# Export case, consists of tensors
predictions = results_list
# For OV task saliency map are generated on MAPI side
saliency_map = torch.empty(1, dtype=torch.uint8)

elif isinstance(results_list, list) and isinstance(results_list[0], InstanceData): # rewrite
elif isinstance(predictions, list) and isinstance(predictions[0], InstanceData):
# Predict case, consists of InstanceData
predictions = self.add_pred_to_datasample(data_samples, results_list)

features_for_sal_map = [data_sample.pred_instances for data_sample in data_samples]
saliency_map = self.explain_fn(features_for_sal_map)
saliency_map = self.explain_fn(predictions)
predictions = self.add_pred_to_datasample(data_samples, predictions)

return {
"predictions": predictions,
"feature_vector": feature_vector,
"saliency_map": saliency_map,
}

def get_results_from_head(
self,
x: tuple[torch.Tensor],
data_samples: OptSampleList | None,
) -> tuple[torch.Tensor] | list[InstanceData]:
"""Get the results from the head of the instance segmentation model.

Args:
x (tuple[torch.Tensor]): The features from backbone and neck.
data_samples (OptSampleList | None): A list of data samples.

Returns:
tuple[torch.Tensor] | list[InstanceData]: The predicted results from the head of the model.
Tuple for the Export case, list for the Predict case.
"""
from otx.algo.instance_segmentation.rtmdet_inst import RTMDetInst

if isinstance(self, RTMDetInst):
return self.model.bbox_head.predict(x, data_samples, rescale=False)
rpn_results_list = self.model.rpn_head.predict(x, data_samples, rescale=False)
return self.model.roi_head.predict(x, rpn_results_list, data_samples, rescale=True)

def get_explain_fn(self) -> Callable:
"""Returns explain function."""
from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo

explainer = MaskRCNNExplainAlgo(num_classes=self.num_classes)
explainer = InstSegExplainAlgo(num_classes=self.num_classes)
return explainer.func

@contextmanager
Expand Down
5 changes: 2 additions & 3 deletions src/otx/core/utils/tile_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchvision import tv_tensors
from torchvision.ops import batched_nms

from otx.algo.explain.explain_algo import InstSegExplainAlgo
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import ImageInfo, T_OTXBatchPredEntity, T_OTXDataEntity
from otx.core.data.entity.detection import DetBatchPredEntity, DetPredEntity
Expand Down Expand Up @@ -431,10 +432,8 @@ def get_saliency_maps_from_masks(
Returns:
np.array: Class-wise Saliency Maps. One saliency map per each class - [class_id, H, W]
"""
from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo

if masks is None:
return np.ndarray([])

pred = {"labels": labels, "scores": scores, "masks": masks}
return MaskRCNNExplainAlgo.average_and_normalize(pred, num_classes)
return InstSegExplainAlgo.average_and_normalize(pred, num_classes)
8 changes: 4 additions & 4 deletions tests/e2e/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ def test_otx_e2e_cli(
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]):
return # Supported only for classification, detection and instance segmentation task.

if "dino" in model_name or "rtmdet_inst_tiny" in model_name:
return # DINO and Rtmdet_tiny are not supported.
if "dino" in model_name:
return # DINO is not supported.

format_to_file = {
"ONNX": "exported_model.onnx",
Expand Down Expand Up @@ -284,8 +284,8 @@ def test_otx_explain_e2e_cli(
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]):
pytest.skip("Supported only for classification, detection and instance segmentation task.")

if "dino" in model_name or "rtmdet_inst_tiny" in model_name:
pytest.skip("DINO and Rtmdet_tiny are not supported.")
if "dino" in model_name:
pytest.skip("DINO is not supported.")

# otx explain
tmp_path_explain = tmp_path / f"otx_explain_{model_name}"
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/api/test_xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_forward_explain(
task = recipe.split("/")[-2]
model_name = recipe.split("/")[-1].split(".")[0]

if "dino" in model_name or "rtmdet_inst_tiny" in model_name:
pytest.skip("DINO and Rtmdet_tiny are not supported.")
if "dino" in model_name:
pytest.skip("DINO is not supported.")

engine = Engine.from_config(
config_path=recipe,
Expand Down Expand Up @@ -92,8 +92,8 @@ def test_predict_with_explain(
task = recipe.split("/")[-2]
model_name = recipe.split("/")[-1].split(".")[0]

if "dino" in model_name or "rtmdet_inst_tiny" in model_name:
pytest.skip("DINO and Rtmdet_tiny are not supported.")
if "dino" in model_name:
pytest.skip("DINO is not supported.")

if "mobilenet_v3_large" in model_name:
pytest.skip("There's issue with mobilenet_v3_large model. Skip for now.")
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ def test_otx_e2e(
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]):
return # Supported only for classification, detection and instance segmentation task.

if "dino" in model_name or "rtmdet_inst_tiny" in model_name:
return # DINO and Rtmdet_tiny are not supported.
if "dino" in model_name:
return # DINO is not supported.

format_to_file = {
"ONNX": "exported_model.onnx",
Expand Down Expand Up @@ -294,8 +294,8 @@ def test_otx_explain_e2e(
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]):
pytest.skip("Supported only for classification, detection and instance segmentation task.")

if "dino" in model_name or "rtmdet_inst_tiny" in model_name:
pytest.skip("DINO and Rtmdet_tiny are not supported.")
if "dino" in model_name:
pytest.skip("DINO is not supported.")

# otx explain
tmp_path_explain = tmp_path / f"otx_explain_{model_name}"
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/algo/explain/test_xai_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from otx.algo.explain.explain_algo import (
ActivationMap,
DetClassProbabilityMap,
MaskRCNNExplainAlgo,
InstSegExplainAlgo,
ReciproCAM,
ViTReciproCAM,
)
Expand Down Expand Up @@ -80,9 +80,9 @@ def test_detclassprob() -> None:
assert saliency_maps.size() == torch.Size([5, 2, 2, 2])


def test_maskrcnn() -> None:
def test_instseg() -> None:
num_classes = 2
explain_algo = MaskRCNNExplainAlgo(
explain_algo = InstSegExplainAlgo(
num_classes=num_classes,
)

Expand Down
Loading