Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per-class saliency maps for M-RCNN #2227

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from __future__ import annotations

from abc import ABC
from typing import List, Sequence, Union
from typing import List, Optional, Sequence, Union

import numpy as np
import torch

from otx.algorithms.classification import MMCLS_AVAILABLE
Expand Down Expand Up @@ -69,10 +70,24 @@ def _recording_forward(
self, _: torch.nn.Module, x: torch.Tensor, output: torch.Tensor
): # pylint: disable=unused-argument
tensors = self.func(output)
tensors = tensors.detach().cpu().numpy()
for tensor in tensors:
if isinstance(tensors, torch.Tensor):
tensors_np = tensors.detach().cpu().numpy()
elif isinstance(tensors, np.ndarray):
tensors_np = tensors
else:
self._torch_to_numpy_from_list(tensors)
tensors_np = tensors

for tensor in tensors_np:
self._records.append(tensor)

def _torch_to_numpy_from_list(self, tensor_list: List[Optional[torch.Tensor]]):
for i in range(len(tensor_list)):
if isinstance(tensor_list[i], list):
self._torch_to_numpy_from_list(tensor_list[i])
elif isinstance(tensor_list[i], torch.Tensor):
tensor_list[i] = tensor_list[i].detach().cpu().numpy()

def __enter__(self) -> BaseRecordingForwardHook:
"""Enter."""
self._handle = self._module.backbone.register_forward_hook(self._recording_forward)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from typing import List, Tuple, Union
import copy
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import bbox2roi

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
BaseRecordingForwardHook,
Expand Down Expand Up @@ -121,3 +124,127 @@ def forward_single(x, cls_convs, conv_cls):
"YOLOXHead, ATSSHead, SSDHead, VFNetHead."
)
return cls_scores


class MaskRCNNRecordingForwardHook(BaseRecordingForwardHook):
"""Saliency map hook for Mask R-CNN model. Only for torch model, does not support OpenVINO IR model.

Args:
module (torch.nn.Module): Mask R-CNN model.
input_img_shape (Tuple[int]): Resolution of the model input image.
saliency_map_shape (Tuple[int]): Resolution of the output saliency map.
max_detections_per_img (int): Upper limit of the number of detections
from which soft mask predictions are getting aggregated.
normalize (bool): Flag that defines if the output saliency map will be normalized.
Although, partial normalization is anyway done by segmentation mask head.
"""

def __init__(
self,
module: torch.nn.Module,
input_img_shape: Tuple[int, int],
saliency_map_shape: Tuple[int, int] = (224, 224),
max_detections_per_img: int = 300,
normalize: bool = True,
) -> None:
super().__init__(module)
self._neck = module.neck if module.with_neck else None
self._input_img_shape = input_img_shape
self._saliency_map_shape = saliency_map_shape
self._max_detections_per_img = max_detections_per_img
self._norm_saliency_maps = normalize

def func(
self,
feature_map: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
_: int = -1,
) -> List[List[Optional[np.ndarray]]]:
"""Generate saliency maps by aggregating per-class soft predictions of mask head for all detected boxes.

:param feature_map: Feature maps from backbone.
:return: Class-wise Saliency Maps. One saliency map per each predicted class.
"""
with torch.no_grad():
if self._neck is not None:
feature_map = self._module.neck(feature_map)

det_bboxes, det_labels = self._get_detections(feature_map)
saliency_maps = self._get_saliency_maps_from_mask_predictions(feature_map, det_bboxes, det_labels)
if self._norm_saliency_maps:
saliency_maps = self._normalize(saliency_maps)
return saliency_maps

def _get_detections(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
batch_size = x[0].shape[0]
img_metas = [
{
"scale_factor": [1, 1, 1, 1], # dummy scale_factor, not used
"img_shape": self._input_img_shape,
}
]
img_metas *= batch_size
proposals = self._module.rpn_head.simple_test_rpn(x, img_metas)
test_cfg = copy.deepcopy(self._module.roi_head.test_cfg)
test_cfg["max_per_img"] = self._max_detections_per_img
test_cfg["nms"]["iou_threshold"] = 1
test_cfg["nms"]["max_num"] = self._max_detections_per_img
det_bboxes, det_labels = self._module.roi_head.simple_test_bboxes(
x, img_metas, proposals, test_cfg, rescale=False
)
return det_bboxes, det_labels

def _get_saliency_maps_from_mask_predictions(
self, x: torch.Tensor, det_bboxes: List[torch.Tensor], det_labels: List[torch.Tensor]
) -> List[List[Optional[np.ndarray]]]:
_bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))]
mask_rois = bbox2roi(_bboxes)
mask_results = self._module.roi_head._mask_forward(x, mask_rois)
mask_pred = mask_results["mask_pred"]
num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)

batch_size = x[0].shape[0]

scale_x = self._input_img_shape[1] / self._saliency_map_shape[1]
scale_y = self._input_img_shape[0] / self._saliency_map_shape[0]
scale_factor = torch.FloatTensor((scale_x, scale_y, scale_x, scale_y))
test_cfg = self._module.roi_head.test_cfg.copy()
test_cfg["mask_thr_binary"] = -1

saliency_maps = [] # type: List[List[Optional[np.ndarray]]]
for i in range(batch_size):
saliency_maps.append([])
for j in range(self._module.roi_head.mask_head.num_classes):
saliency_maps[i].append(None)

for i in range(batch_size):
if det_bboxes[i].shape[0] == 0:
continue
else:
segm_result = self._module.roi_head.mask_head.get_seg_masks(
mask_preds[i],
_bboxes[i],
det_labels[i],
test_cfg,
self._saliency_map_shape,
scale_factor=scale_factor,
rescale=True,
)
for class_id, segm_res in enumerate(segm_result):
if segm_res:
saliency_maps[i][class_id] = np.mean(np.array(segm_res), axis=0)
return saliency_maps

@staticmethod
def _normalize(saliency_maps: List[List[Optional[np.ndarray]]]) -> List[List[Optional[np.ndarray]]]:
batch_size = len(saliency_maps)
num_classes = len(saliency_maps[0])
for i in range(batch_size):
for class_id in range(num_classes):
per_class_map = saliency_maps[i][class_id]
if per_class_map is not None:
max_values = np.max(per_class_map)
per_class_map = 255 * (per_class_map) / (max_values + 1e-12)
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved
per_class_map = per_class_map.astype(np.uint8)
saliency_maps[i][class_id] = per_class_map
return saliency_maps
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from mmdet.models.detectors.mask_rcnn import MaskRCNN

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
Expand Down Expand Up @@ -99,7 +98,7 @@ def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, pr
def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **kwargs):
"""Function for custom_mask_rcnn__simple_test."""
assert self.with_bbox, "Bbox head must be implemented."
x = backbone_out = self.backbone(img)
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
if proposals is None:
Expand All @@ -108,7 +107,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **k

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(x)
saliency_map = ActivationMapHook.func(backbone_out)
# Saliency map will be generated from predictions. Generate dummy saliency_map.
saliency_map = torch.empty(1, dtype=torch.uint8)
return (*out, feature_vector, saliency_map)

return out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def simple_test(self, img, img_metas, proposals=None, rescale=False, full_res_im

# pylint: disable=ungrouped-imports
from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)

Expand Down Expand Up @@ -327,7 +326,7 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None):
assert self.with_bbox, "Bbox head must be implemented."
tile_prob = self.tile_classifier.simple_test(img)

x = backbone_out = self.backbone(img)
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
if proposals is None:
Expand All @@ -336,7 +335,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None):

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(x)
saliency_map = ActivationMapHook.func(backbone_out)
# Saliency map will be generated from predictions. Generate dummy saliency_map.
saliency_map = torch.empty(1, dtype=torch.uint8)
return (*out, tile_prob, feature_vector, saliency_map)

return (*out, tile_prob)
25 changes: 16 additions & 9 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
MaskRCNNRecordingForwardHook,
)
from otx.algorithms.detection.adapters.mmdet.utils import (
patch_input_preprocessing,
Expand Down Expand Up @@ -399,7 +400,8 @@ def hook(module, inp, outp):
if raw_model.__class__.__name__ == "NNCFNetwork":
raw_model = raw_model.get_nncf_wrapped_model()
if isinstance(raw_model, TwoStageDetector):
saliency_hook = ActivationMapHook(feature_model)
height, width, _ = mm_dataset[0]["img_metas"][0].data["img_shape"]
saliency_hook = MaskRCNNRecordingForwardHook(feature_model, input_img_shape=(height, width))
else:
saliency_hook = DetClassProbabilityMapHook(feature_model)

Expand Down Expand Up @@ -522,15 +524,9 @@ def _explain_model(
explain_parameters: Optional[ExplainParameters] = None,
) -> Dict[str, Any]:
"""Main explain function of MMDetectionTask."""

for item in dataset:
item.subset = Subset.TESTING

explainer_hook_selector = {
"classwisesaliencymap": DetClassProbabilityMapHook,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}
self._data_cfg = ConfigDict(
data=ConfigDict(
train=ConfigDict(
Expand Down Expand Up @@ -599,6 +595,18 @@ def hook(module, inp, outp):
model.register_forward_pre_hook(pre_hook)
model.register_forward_hook(hook)

if isinstance(feature_model, TwoStageDetector):
height, width, _ = mm_dataset[0]["img_metas"][0].data["img_shape"]
per_class_xai_algorithm = partial(MaskRCNNRecordingForwardHook, input_img_shape=(width, height))
else:
per_class_xai_algorithm = DetClassProbabilityMapHook # type: ignore

explainer_hook_selector = {
"classwisesaliencymap": per_class_xai_algorithm,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}

explainer = explain_parameters.explainer if explain_parameters else None
if explainer is not None:
explainer_hook = explainer_hook_selector.get(explainer.lower(), None)
Expand All @@ -608,9 +616,8 @@ def hook(module, inp, outp):
raise NotImplementedError(f"Explainer algorithm {explainer} not supported!")
logger.info(f"Explainer algorithm: {explainer}")

# Class-wise Saliency map for Single-Stage Detector, otherwise use class-ignore saliency map.
eval_predictions = []
with explainer_hook(feature_model) as saliency_hook:
with explainer_hook(feature_model) as saliency_hook: # type: ignore
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved
for data in dataloader:
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Dict

import cv2
import numpy as np

try:
Expand Down Expand Up @@ -107,6 +108,85 @@ def postprocess(self, outputs, meta):

return scores, classes, boxes, resized_masks

def get_saliency_map_from_prediction(self, outputs, meta, num_classes):
"""Post process function for saliency map of OTX MaskRCNN model."""
boxes = outputs[self.output_blob_name["boxes"]]
if boxes.shape[0] == 1:
boxes = boxes.squeeze(0)
scores = boxes[:, 4]
boxes = boxes[:, :4]
masks = outputs[self.output_blob_name["masks"]]
if masks.shape[0] == 1:
masks = masks.squeeze(0)
classes = outputs[self.output_blob_name["labels"]].astype(np.uint32)
if classes.shape[0] == 1:
classes = classes.squeeze(0)

scale_x = meta["resized_shape"][1] / meta["original_shape"][1]
scale_y = meta["resized_shape"][0] / meta["original_shape"][0]
boxes[:, 0::2] /= scale_x
boxes[:, 1::2] /= scale_y

saliency_maps = [None for _ in range(num_classes)]
for box, score, cls, raw_mask in zip(boxes, scores, classes, masks):
resized_mask = self._resize_mask(box, raw_mask * score, *meta["original_shape"][:-1])
if saliency_maps[cls] is None:
saliency_maps[cls] = [resized_mask]
else:
saliency_maps[cls].append(resized_mask)
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved

saliency_maps = self._average_and_normalize(saliency_maps, num_classes)
return saliency_maps

def _resize_mask(self, box, raw_cls_mask, im_h, im_w):
# Add zero border to prevent upsampling artifacts on segment borders.
raw_cls_mask = np.pad(raw_cls_mask, ((1, 1), (1, 1)), "constant", constant_values=0)
extended_box = self._expand_box(box, raw_cls_mask.shape[0] / (raw_cls_mask.shape[0] - 2.0)).astype(int)
w, h = np.maximum(extended_box[2:] - extended_box[:2] + 1, 1)
x0, y0 = np.clip(extended_box[:2], a_min=0, a_max=[im_w, im_h])
x1, y1 = np.clip(extended_box[2:] + 1, a_min=0, a_max=[im_w, im_h])

raw_cls_mask = cv2.resize(raw_cls_mask.astype(np.float32), (w, h))
# Put an object mask in an image mask.
im_mask = np.zeros((im_h, im_w), dtype=np.float32)
im_mask[y0:y1, x0:x1] = raw_cls_mask[
(y0 - extended_box[1]) : (y1 - extended_box[1]), (x0 - extended_box[0]) : (x1 - extended_box[0])
]
return im_mask

@staticmethod
def _average_and_normalize(saliency_maps, num_classes):
for i in range(num_classes):
if saliency_maps[i] is not None:
saliency_maps[i] = np.array(saliency_maps[i]).mean(0)

for i in range(num_classes):
per_class_map = saliency_maps[i]
if per_class_map is not None:
max_values = np.max(per_class_map)
per_class_map = 255 * (per_class_map) / (max_values + 1e-12)
per_class_map = per_class_map.astype(np.uint8)
saliency_maps[i] = per_class_map
return saliency_maps

def get_tiling_saliency_map_from_prediction(self, detections, num_classes):
"""Post process function for saliency map of OTX MaskRCNN model for tiling."""
saliency_maps = [None for _ in range(num_classes)]

# No detection case
if isinstance(detections, np.ndarray) and detections.size == 0:
return saliency_maps

classes = [int(cls) - 1 for cls in detections[1]]
masks = detections[3]
for mask, cls in zip(masks, classes):
if saliency_maps[cls] is None:
saliency_maps[cls] = [mask]
else:
saliency_maps[cls].append(mask)
saliency_maps = self._average_and_normalize(saliency_maps, num_classes)
return saliency_maps

def segm_postprocess(self, *args, **kwargs):
"""Post-process for segmentation masks."""
return self._segm_postprocess(*args, **kwargs)
Expand Down
Loading