Skip to content

Commit

Permalink
Per-class saliency maps for M-RCNN (#2301)
Browse files Browse the repository at this point in the history
* per class sal maps for maskrcnn

* tiling support + test enablement/fix

* enable xai detection e2e tests
  • Loading branch information
negvet authored Jul 6, 2023
1 parent ae81031 commit f1baed1
Show file tree
Hide file tree
Showing 14 changed files with 602 additions and 245 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from __future__ import annotations

from abc import ABC
from typing import List, Sequence, Union
from typing import List, Optional, Sequence, Union

import numpy as np
import torch

from otx.algorithms.classification import MMCLS_AVAILABLE
Expand Down Expand Up @@ -69,10 +70,24 @@ def _recording_forward(
self, _: torch.nn.Module, x: torch.Tensor, output: torch.Tensor
): # pylint: disable=unused-argument
tensors = self.func(output)
tensors = tensors.detach().cpu().numpy()
for tensor in tensors:
if isinstance(tensors, torch.Tensor):
tensors_np = tensors.detach().cpu().numpy()
elif isinstance(tensors, np.ndarray):
tensors_np = tensors
else:
self._torch_to_numpy_from_list(tensors)
tensors_np = tensors

for tensor in tensors_np:
self._records.append(tensor)

def _torch_to_numpy_from_list(self, tensor_list: List[Optional[torch.Tensor]]):
for i in range(len(tensor_list)):
if isinstance(tensor_list[i], list):
self._torch_to_numpy_from_list(tensor_list[i])
elif isinstance(tensor_list[i], torch.Tensor):
tensor_list[i] = tensor_list[i].detach().cpu().numpy()

def __enter__(self) -> BaseRecordingForwardHook:
"""Enter."""
self._handle = self._module.backbone.register_forward_hook(self._recording_forward)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from typing import List, Tuple, Union
import copy
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import bbox2roi

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
BaseRecordingForwardHook,
Expand Down Expand Up @@ -120,3 +123,127 @@ def forward_single(x, cls_convs, conv_cls):
"YOLOXHead, ATSSHead, SSDHead, VFNetHead."
)
return cls_scores


class MaskRCNNRecordingForwardHook(BaseRecordingForwardHook):
"""Saliency map hook for Mask R-CNN model. Only for torch model, does not support OpenVINO IR model.
Args:
module (torch.nn.Module): Mask R-CNN model.
input_img_shape (Tuple[int]): Resolution of the model input image.
saliency_map_shape (Tuple[int]): Resolution of the output saliency map.
max_detections_per_img (int): Upper limit of the number of detections
from which soft mask predictions are getting aggregated.
normalize (bool): Flag that defines if the output saliency map will be normalized.
Although, partial normalization is anyway done by segmentation mask head.
"""

def __init__(
self,
module: torch.nn.Module,
input_img_shape: Tuple[int, int],
saliency_map_shape: Tuple[int, int] = (224, 224),
max_detections_per_img: int = 300,
normalize: bool = True,
) -> None:
super().__init__(module)
self._neck = module.neck if module.with_neck else None
self._input_img_shape = input_img_shape
self._saliency_map_shape = saliency_map_shape
self._max_detections_per_img = max_detections_per_img
self._norm_saliency_maps = normalize

def func(
self,
feature_map: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
_: int = -1,
) -> List[List[Optional[np.ndarray]]]:
"""Generate saliency maps by aggregating per-class soft predictions of mask head for all detected boxes.
:param feature_map: Feature maps from backbone.
:return: Class-wise Saliency Maps. One saliency map per each predicted class.
"""
with torch.no_grad():
if self._neck is not None:
feature_map = self._module.neck(feature_map)

det_bboxes, det_labels = self._get_detections(feature_map)
saliency_maps = self._get_saliency_maps_from_mask_predictions(feature_map, det_bboxes, det_labels)
if self._norm_saliency_maps:
saliency_maps = self._normalize(saliency_maps)
return saliency_maps

def _get_detections(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
batch_size = x[0].shape[0]
img_metas = [
{
"scale_factor": [1, 1, 1, 1], # dummy scale_factor, not used
"img_shape": self._input_img_shape,
}
]
img_metas *= batch_size
proposals = self._module.rpn_head.simple_test_rpn(x, img_metas)
test_cfg = copy.deepcopy(self._module.roi_head.test_cfg)
test_cfg["max_per_img"] = self._max_detections_per_img
test_cfg["nms"]["iou_threshold"] = 1
test_cfg["nms"]["max_num"] = self._max_detections_per_img
det_bboxes, det_labels = self._module.roi_head.simple_test_bboxes(
x, img_metas, proposals, test_cfg, rescale=False
)
return det_bboxes, det_labels

def _get_saliency_maps_from_mask_predictions(
self, x: torch.Tensor, det_bboxes: List[torch.Tensor], det_labels: List[torch.Tensor]
) -> List[List[Optional[np.ndarray]]]:
_bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))]
mask_rois = bbox2roi(_bboxes)
mask_results = self._module.roi_head._mask_forward(x, mask_rois)
mask_pred = mask_results["mask_pred"]
num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)

batch_size = x[0].shape[0]

scale_x = self._input_img_shape[1] / self._saliency_map_shape[1]
scale_y = self._input_img_shape[0] / self._saliency_map_shape[0]
scale_factor = torch.FloatTensor((scale_x, scale_y, scale_x, scale_y))
test_cfg = self._module.roi_head.test_cfg.copy()
test_cfg["mask_thr_binary"] = -1

saliency_maps = [] # type: List[List[Optional[np.ndarray]]]
for i in range(batch_size):
saliency_maps.append([])
for j in range(self._module.roi_head.mask_head.num_classes):
saliency_maps[i].append(None)

for i in range(batch_size):
if det_bboxes[i].shape[0] == 0:
continue
else:
segm_result = self._module.roi_head.mask_head.get_seg_masks(
mask_preds[i],
_bboxes[i],
det_labels[i],
test_cfg,
self._saliency_map_shape,
scale_factor=scale_factor,
rescale=True,
)
for class_id, segm_res in enumerate(segm_result):
if segm_res:
saliency_maps[i][class_id] = np.mean(np.array(segm_res), axis=0)
return saliency_maps

@staticmethod
def _normalize(saliency_maps: List[List[Optional[np.ndarray]]]) -> List[List[Optional[np.ndarray]]]:
batch_size = len(saliency_maps)
num_classes = len(saliency_maps[0])
for i in range(batch_size):
for class_id in range(num_classes):
per_class_map = saliency_maps[i][class_id]
if per_class_map is not None:
max_values = np.max(per_class_map)
per_class_map = 255 * (per_class_map) / (max_values + 1e-12)
per_class_map = per_class_map.astype(np.uint8)
saliency_maps[i][class_id] = per_class_map
return saliency_maps
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from mmdet.models.detectors.mask_rcnn import MaskRCNN

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
Expand Down Expand Up @@ -99,7 +98,7 @@ def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, pr
def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **kwargs):
"""Function for custom_mask_rcnn__simple_test."""
assert self.with_bbox, "Bbox head must be implemented."
x = backbone_out = self.backbone(img)
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
if proposals is None:
Expand All @@ -108,7 +107,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **k

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(x)
saliency_map = ActivationMapHook.func(backbone_out)
# Saliency map will be generated from predictions. Generate dummy saliency_map.
saliency_map = torch.empty(1, dtype=torch.uint8)
return (*out, feature_vector, saliency_map)

return out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def simple_test(self, img, img_metas, proposals=None, rescale=False, full_res_im

# pylint: disable=ungrouped-imports
from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)

Expand Down Expand Up @@ -327,7 +326,7 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None):
assert self.with_bbox, "Bbox head must be implemented."
tile_prob = self.tile_classifier.simple_test(img)

x = backbone_out = self.backbone(img)
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
if proposals is None:
Expand All @@ -336,7 +335,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None):

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(x)
saliency_map = ActivationMapHook.func(backbone_out)
# Saliency map will be generated from predictions. Generate dummy saliency_map.
saliency_map = torch.empty(1, dtype=torch.uint8)
return (*out, tile_prob, feature_vector, saliency_map)

return (*out, tile_prob)
25 changes: 16 additions & 9 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
MaskRCNNRecordingForwardHook,
)
from otx.algorithms.detection.adapters.mmdet.utils import (
patch_input_preprocessing,
Expand Down Expand Up @@ -397,7 +398,8 @@ def hook(module, inp, outp):
if raw_model.__class__.__name__ == "NNCFNetwork":
raw_model = raw_model.get_nncf_wrapped_model()
if isinstance(raw_model, TwoStageDetector):
saliency_hook = ActivationMapHook(feature_model)
height, width, _ = mm_dataset[0]["img_metas"][0].data["img_shape"]
saliency_hook = MaskRCNNRecordingForwardHook(feature_model, input_img_shape=(height, width))
else:
saliency_hook = DetClassProbabilityMapHook(feature_model)

Expand Down Expand Up @@ -515,15 +517,9 @@ def _explain_model(
explain_parameters: Optional[ExplainParameters] = None,
) -> Dict[str, Any]:
"""Main explain function of MMDetectionTask."""

for item in dataset:
item.subset = Subset.TESTING

explainer_hook_selector = {
"classwisesaliencymap": DetClassProbabilityMapHook,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}
self._data_cfg = ConfigDict(
data=ConfigDict(
train=ConfigDict(
Expand Down Expand Up @@ -593,6 +589,18 @@ def hook(module, inp, outp):
model.register_forward_pre_hook(pre_hook)
model.register_forward_hook(hook)

if isinstance(feature_model, TwoStageDetector):
height, width, _ = mm_dataset[0]["img_metas"][0].data["img_shape"]
per_class_xai_algorithm = partial(MaskRCNNRecordingForwardHook, input_img_shape=(width, height))
else:
per_class_xai_algorithm = DetClassProbabilityMapHook # type: ignore

explainer_hook_selector = {
"classwisesaliencymap": per_class_xai_algorithm,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}

explainer = explain_parameters.explainer if explain_parameters else None
if explainer is not None:
explainer_hook = explainer_hook_selector.get(explainer.lower(), None)
Expand All @@ -602,9 +610,8 @@ def hook(module, inp, outp):
raise NotImplementedError(f"Explainer algorithm {explainer} not supported!")
logger.info(f"Explainer algorithm: {explainer}")

# Class-wise Saliency map for Single-Stage Detector, otherwise use class-ignore saliency map.
eval_predictions = []
with explainer_hook(feature_model) as saliency_hook:
with explainer_hook(feature_model) as saliency_hook: # type: ignore
for data in dataloader:
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Dict

import cv2
import numpy as np

try:
Expand Down Expand Up @@ -107,6 +108,85 @@ def postprocess(self, outputs, meta):

return scores, classes, boxes, resized_masks

def get_saliency_map_from_prediction(self, outputs, meta, num_classes):
"""Post process function for saliency map of OTX MaskRCNN model."""
boxes = outputs[self.output_blob_name["boxes"]]
if boxes.shape[0] == 1:
boxes = boxes.squeeze(0)
scores = boxes[:, 4]
boxes = boxes[:, :4]
masks = outputs[self.output_blob_name["masks"]]
if masks.shape[0] == 1:
masks = masks.squeeze(0)
classes = outputs[self.output_blob_name["labels"]].astype(np.uint32)
if classes.shape[0] == 1:
classes = classes.squeeze(0)

scale_x = meta["resized_shape"][1] / meta["original_shape"][1]
scale_y = meta["resized_shape"][0] / meta["original_shape"][0]
boxes[:, 0::2] /= scale_x
boxes[:, 1::2] /= scale_y

saliency_maps = [None for _ in range(num_classes)]
for box, score, cls, raw_mask in zip(boxes, scores, classes, masks):
resized_mask = self._resize_mask(box, raw_mask * score, *meta["original_shape"][:-1])
if saliency_maps[cls] is None:
saliency_maps[cls] = [resized_mask]
else:
saliency_maps[cls].append(resized_mask)

saliency_maps = self._average_and_normalize(saliency_maps, num_classes)
return saliency_maps

def _resize_mask(self, box, raw_cls_mask, im_h, im_w):
# Add zero border to prevent upsampling artifacts on segment borders.
raw_cls_mask = np.pad(raw_cls_mask, ((1, 1), (1, 1)), "constant", constant_values=0)
extended_box = self._expand_box(box, raw_cls_mask.shape[0] / (raw_cls_mask.shape[0] - 2.0)).astype(int)
w, h = np.maximum(extended_box[2:] - extended_box[:2] + 1, 1)
x0, y0 = np.clip(extended_box[:2], a_min=0, a_max=[im_w, im_h])
x1, y1 = np.clip(extended_box[2:] + 1, a_min=0, a_max=[im_w, im_h])

raw_cls_mask = cv2.resize(raw_cls_mask.astype(np.float32), (w, h))
# Put an object mask in an image mask.
im_mask = np.zeros((im_h, im_w), dtype=np.float32)
im_mask[y0:y1, x0:x1] = raw_cls_mask[
(y0 - extended_box[1]) : (y1 - extended_box[1]), (x0 - extended_box[0]) : (x1 - extended_box[0])
]
return im_mask

@staticmethod
def _average_and_normalize(saliency_maps, num_classes):
for i in range(num_classes):
if saliency_maps[i] is not None:
saliency_maps[i] = np.array(saliency_maps[i]).mean(0)

for i in range(num_classes):
per_class_map = saliency_maps[i]
if per_class_map is not None:
max_values = np.max(per_class_map)
per_class_map = 255 * (per_class_map) / (max_values + 1e-12)
per_class_map = per_class_map.astype(np.uint8)
saliency_maps[i] = per_class_map
return saliency_maps

def get_tiling_saliency_map_from_prediction(self, detections, num_classes):
"""Post process function for saliency map of OTX MaskRCNN model for tiling."""
saliency_maps = [None for _ in range(num_classes)]

# No detection case
if isinstance(detections, np.ndarray) and detections.size == 0:
return saliency_maps

classes = [int(cls) - 1 for cls in detections[1]]
masks = detections[3]
for mask, cls in zip(masks, classes):
if saliency_maps[cls] is None:
saliency_maps[cls] = [mask]
else:
saliency_maps[cls].append(mask)
saliency_maps = self._average_and_normalize(saliency_maps, num_classes)
return saliency_maps

def segm_postprocess(self, *args, **kwargs):
"""Post-process for segmentation masks."""
return self._segm_postprocess(*args, **kwargs)
Expand Down
Loading

0 comments on commit f1baed1

Please sign in to comment.