Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per-class saliency maps for M-RCNN #2227

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from __future__ import annotations

from abc import ABC
from typing import List, Sequence, Union
from typing import List, Optional, Sequence, Union

import numpy as np
import torch

from otx.algorithms.classification import MMCLS_AVAILABLE
Expand Down Expand Up @@ -69,10 +70,25 @@ def _recording_forward(
self, _: torch.nn.Module, x: torch.Tensor, output: torch.Tensor
): # pylint: disable=unused-argument
tensors = self.func(output)
tensors = tensors.detach().cpu().numpy()
for tensor in tensors:
if isinstance(tensors, torch.Tensor):
tensors_np = tensors.detach().cpu().numpy()
elif isinstance(tensors, np.ndarray):
tensors_np = tensors
else:
self._torch_to_numpy_from_list(tensors)
tensors_np = tensors

for tensor in tensors_np:
self._records.append(tensor)

def _torch_to_numpy_from_list(self, tensor_list: List[Optional[torch.Tensor]]):
if isinstance(tensor_list[0], list):
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
self._torch_to_numpy_from_list(tensor_list[0])
else:
for i in range(len(tensor_list)):
if isinstance(tensor_list[i], torch.Tensor):
tensor_list[i] = tensor_list[i].detach().cpu().numpy()

def __enter__(self) -> BaseRecordingForwardHook:
"""Enter."""
self._handle = self._module.backbone.register_forward_hook(self._recording_forward)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from typing import List, Tuple, Union
import copy
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import bbox2roi

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
BaseRecordingForwardHook,
Expand Down Expand Up @@ -121,3 +124,127 @@ def forward_single(x, cls_convs, conv_cls):
"YOLOXHead, ATSSHead, SSDHead, VFNetHead."
)
return cls_scores


class MaskRCNNRecordingForwardHook(BaseRecordingForwardHook):
"""Saliency map hook for Mask R-CNN model. Only for torch model, does not support OV IR model.
negvet marked this conversation as resolved.
Show resolved Hide resolved

Args:
module (torch.nn.Module): Mask R-CNN model.
input_img_shape (Tuple[int]): Resolution of the model input image.
saliency_map_shape (Tuple[int]): Resolution of the output saliency map.
max_detections_per_img (int): Upper limit of the number of detections
from which soft mask predictions are getting aggregated.
normalize (bool): Flag that defines if the output saliency map will be normalized.
Although, partial normalization is anyway done by segmentation mask head.
"""

def __init__(
self,
module: torch.nn.Module,
input_img_shape: Tuple[int, int],
saliency_map_shape: Tuple[int, int] = (224, 224),
max_detections_per_img: int = 300,
normalize: bool = True,
) -> None:
super().__init__(module)
self._neck = module.neck if module.with_neck else None
self._input_img_shape = input_img_shape
self._saliency_map_shape = saliency_map_shape
self._max_detections_per_img = max_detections_per_img
self._norm_saliency_maps = normalize

def func(
self,
feature_map: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
_: int = -1,
) -> List[List[Optional[np.ndarray]]]:
"""Generate saliency maps by aggregating per-class soft predictions of mask head for all detected boxes.

:param feature_map: Feature maps from backbone.
:return: Class-wise Saliency Maps. One saliency map per each predicted class.
"""
with torch.no_grad():
if self._neck is not None:
feature_map = self._module.neck(feature_map)

det_bboxes, det_labels = self._get_detections(feature_map)
saliency_maps = self._get_saliency_maps_from_mask_predictions(feature_map, det_bboxes, det_labels)
if self._norm_saliency_maps:
saliency_maps = self._normalize(saliency_maps)
return saliency_maps

def _get_detections(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
batch_size = x[0].shape[0]
img_metas = [
{
"scale_factor": [1, 1, 1, 1], # dummy scale_factor, not used
"img_shape": self._input_img_shape,
}
]
img_metas *= batch_size
proposals = self._module.rpn_head.simple_test_rpn(x, img_metas)
test_cfg = copy.deepcopy(self._module.roi_head.test_cfg)
test_cfg["max_per_img"] = self._max_detections_per_img
test_cfg["nms"]["iou_threshold"] = 1
test_cfg["nms"]["max_num"] = self._max_detections_per_img
det_bboxes, det_labels = self._module.roi_head.simple_test_bboxes(
x, img_metas, proposals, test_cfg, rescale=False
)
return det_bboxes, det_labels

def _get_saliency_maps_from_mask_predictions(
self, x: torch.Tensor, det_bboxes: List[torch.Tensor], det_labels: List[torch.Tensor]
) -> List[List[Optional[np.ndarray]]]:
_bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))]
mask_rois = bbox2roi(_bboxes)
mask_results = self._module.roi_head._mask_forward(x, mask_rois)
mask_pred = mask_results["mask_pred"]
num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)

batch_size = x[0].shape[0]

scale_x = self._input_img_shape[1] / self._saliency_map_shape[1]
scale_y = self._input_img_shape[0] / self._saliency_map_shape[0]
scale_factor = torch.FloatTensor((scale_x, scale_y, scale_x, scale_y))
test_cfg = self._module.roi_head.test_cfg.copy()
test_cfg["mask_thr_binary"] = -1

saliency_maps = [] # type: List[List[Optional[np.ndarray]]]
for i in range(batch_size):
saliency_maps.append([])
for j in range(self._module.roi_head.mask_head.num_classes):
saliency_maps[i].append(None)

for i in range(batch_size):
if det_bboxes[i].shape[0] == 0:
continue
else:
segm_result = self._module.roi_head.mask_head.get_seg_masks(
mask_preds[i],
_bboxes[i],
det_labels[i],
test_cfg,
self._saliency_map_shape,
scale_factor=scale_factor,
rescale=True,
)
for class_id, segm_res in enumerate(segm_result):
if segm_res:
saliency_maps[i][class_id] = np.mean(np.array(segm_res), axis=0)
return saliency_maps

@staticmethod
def _normalize(saliency_maps: List[List[Optional[np.ndarray]]]) -> List[List[Optional[np.ndarray]]]:
batch_size = len(saliency_maps)
num_classes = len(saliency_maps[0])
for i in range(batch_size):
for class_id in range(num_classes):
per_class_map = saliency_maps[i][class_id]
if per_class_map is not None:
max_values = np.max(per_class_map)
per_class_map = 255 * (per_class_map) / (max_values + 1e-12)
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved
per_class_map = per_class_map.astype(np.uint8)
saliency_maps[i][class_id] = per_class_map
return saliency_maps
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from mmdet.models.detectors.mask_rcnn import MaskRCNN

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
Expand Down Expand Up @@ -99,7 +98,7 @@ def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, pr
def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **kwargs):
"""Function for custom_mask_rcnn__simple_test."""
assert self.with_bbox, "Bbox head must be implemented."
x = backbone_out = self.backbone(img)
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
if proposals is None:
Expand All @@ -108,7 +107,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **k

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(x)
saliency_map = ActivationMapHook.func(backbone_out)
# Saliency map will be generated from predictions. Generate dummy saliency_map.
saliency_map = torch.empty(1, dtype=torch.uint8)
return (*out, feature_vector, saliency_map)

return out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def simple_test(self, img, img_metas, proposals=None, rescale=False, full_res_im

# pylint: disable=ungrouped-imports
from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)

Expand Down Expand Up @@ -327,7 +326,7 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None):
assert self.with_bbox, "Bbox head must be implemented."
tile_prob = self.tile_classifier.simple_test(img)

x = backbone_out = self.backbone(img)
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
if proposals is None:
Expand All @@ -336,7 +335,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None):

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(x)
saliency_map = ActivationMapHook.func(backbone_out)
# Saliency map will be generated from predictions. Generate dummy saliency_map.
saliency_map = torch.empty(1, dtype=torch.uint8)
return (*out, tile_prob, feature_vector, saliency_map)

return (*out, tile_prob)
35 changes: 26 additions & 9 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
MaskRCNNRecordingForwardHook,
)
from otx.algorithms.detection.adapters.mmdet.utils import (
patch_input_preprocessing,
Expand Down Expand Up @@ -399,7 +400,13 @@ def hook(module, inp, outp):
if raw_model.__class__.__name__ == "NNCFNetwork":
raw_model = raw_model.get_nncf_wrapped_model()
if isinstance(raw_model, TwoStageDetector):
saliency_hook = ActivationMapHook(feature_model)
test_pipeline = cfg.data.test.pipeline
width, height = None, None
for pipeline in test_pipeline:
width, height = pipeline.get("img_scale", (None, None))
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
if height is None:
raise ValueError("img_scale has to be defined in the test pipeline.")
saliency_hook = MaskRCNNRecordingForwardHook(feature_model, input_img_shape=(height, width))
else:
saliency_hook = DetClassProbabilityMapHook(feature_model)

Expand Down Expand Up @@ -522,15 +529,9 @@ def _explain_model(
explain_parameters: Optional[ExplainParameters] = None,
) -> Dict[str, Any]:
"""Main explain function of MMDetectionTask."""

for item in dataset:
item.subset = Subset.TESTING

explainer_hook_selector = {
"classwisesaliencymap": DetClassProbabilityMapHook,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}
self._data_cfg = ConfigDict(
data=ConfigDict(
train=ConfigDict(
Expand Down Expand Up @@ -599,6 +600,23 @@ def hook(module, inp, outp):
model.register_forward_pre_hook(pre_hook)
model.register_forward_hook(hook)

if isinstance(feature_model, TwoStageDetector):
test_pipeline = cfg.data.test.pipeline
width, height = None, None
for pipeline in test_pipeline:
width, height = pipeline.get("img_scale", (None, None))
if height is None:
raise ValueError("img_scale has to be defined in the test pipeline.")
per_class_xai_algorithm = partial(MaskRCNNRecordingForwardHook, input_img_shape=(height, width))
else:
per_class_xai_algorithm = DetClassProbabilityMapHook # type: ignore

explainer_hook_selector = {
"classwisesaliencymap": per_class_xai_algorithm,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}

explainer = explain_parameters.explainer if explain_parameters else None
if explainer is not None:
explainer_hook = explainer_hook_selector.get(explainer.lower(), None)
Expand All @@ -608,9 +626,8 @@ def hook(module, inp, outp):
raise NotImplementedError(f"Explainer algorithm {explainer} not supported!")
logger.info(f"Explainer algorithm: {explainer}")

# Class-wise Saliency map for Single-Stage Detector, otherwise use class-ignore saliency map.
eval_predictions = []
with explainer_hook(feature_model) as saliency_hook:
with explainer_hook(feature_model) as saliency_hook: # type: ignore
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved
for data in dataloader:
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
Expand Down
Loading