Skip to content

Commit

Permalink
per class sal maps for maskrcnn
Browse files Browse the repository at this point in the history
  • Loading branch information
negvet committed Jun 7, 2023
1 parent ca29281 commit 6133b4c
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from __future__ import annotations

from abc import ABC
from typing import List, Sequence, Union
from typing import List, Optional, Sequence, Union

import numpy as np
import torch

from otx.algorithms.classification import MMCLS_AVAILABLE
Expand Down Expand Up @@ -69,10 +70,25 @@ def _recording_forward(
self, _: torch.nn.Module, x: torch.Tensor, output: torch.Tensor
): # pylint: disable=unused-argument
tensors = self.func(output)
tensors = tensors.detach().cpu().numpy()
for tensor in tensors:
if isinstance(tensors, torch.Tensor):
tensors_np = tensors.detach().cpu().numpy()
elif isinstance(tensors, np.ndarray):
tensors_np = tensors
else:
self._torch_to_numpy_from_list(tensors)
tensors_np = tensors

for tensor in tensors_np:
self._records.append(tensor)

def _torch_to_numpy_from_list(self, tensor_list: List[Optional[torch.Tensor]]):
if isinstance(tensor_list[0], list):
self._torch_to_numpy_from_list(tensor_list[0])
else:
for i in range(len(tensor_list)):
if isinstance(tensor_list[i], torch.Tensor):
tensor_list[i] = tensor_list[i].detach().cpu().numpy()

def __enter__(self) -> BaseRecordingForwardHook:
"""Enter."""
self._handle = self._module.backbone.register_forward_hook(self._recording_forward)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from typing import List, Tuple, Union
import copy
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import bbox2roi

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
BaseRecordingForwardHook,
Expand Down Expand Up @@ -121,3 +124,123 @@ def forward_single(x, cls_convs, conv_cls):
"YOLOXHead, ATSSHead, SSDHead, VFNetHead."
)
return cls_scores


class MaskRCNNHook(BaseRecordingForwardHook):
"""Saliency map hook for Mask R-CNN model. Only for torch model, does not support OV IR model.
Args:
module (torch.nn.Module): Mask R-CNN model.
input_img_shape (Tuple[int]): Resolution of the model input image.
saliency_map_shape (Tuple[int]): Resolution of the output saliency map.
normalize (bool): Flag that defines if the output saliency map will be normalized.
Although, partial normalization is anyway done by segmentation mask head.
"""

def __init__(
self,
module: torch.nn.Module,
input_img_shape: Tuple[int, int],
saliency_map_shape: Tuple[int, int] = (224, 224),
normalize: bool = True,
) -> None:
super().__init__(module)
self._neck = module.neck if module.with_neck else None
self._input_img_shape = input_img_shape
self._saliency_map_shape = saliency_map_shape
self._norm_saliency_maps = normalize

def func(
self,
feature_map: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
_: int = -1,
) -> List[List[Optional[np.ndarray]]]:
"""Generate saliency maps by aggregating per-class soft predictions of mask head for all detected boxes.
:param feature_map: Feature maps from backbone.
:return: Class-wise Saliency Maps. One saliency map per each predicted class.
"""
with torch.no_grad():
if self._neck is not None:
feature_map = self._module.neck(feature_map)

det_bboxes, det_labels = self._get_detections(feature_map)
saliency_maps = self._get_saliency_maps_from_mask_predictions(feature_map, det_bboxes, det_labels)
if self._norm_saliency_maps:
saliency_maps = self._normalize(saliency_maps)
return saliency_maps

def _get_detections(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
batch_size = x[0].shape[0]
img_metas = [
{
"scale_factor": [1, 1, 1, 1],
"img_shape": self._input_img_shape,
}
]
img_metas *= batch_size
proposals = self._module.rpn_head.simple_test_rpn(x, img_metas)
test_cfg = copy.deepcopy(self._module.roi_head.test_cfg)
test_cfg["max_per_img"] = 300
test_cfg["nms"]["iou_threshold"] = 1
test_cfg["nms"]["max_num"] = 300
det_bboxes, det_labels = self._module.roi_head.simple_test_bboxes(
x, img_metas, proposals, test_cfg, rescale=False
)
return det_bboxes, det_labels

def _get_saliency_maps_from_mask_predictions(
self, x: torch.Tensor, det_bboxes: List[torch.Tensor], det_labels: List[torch.Tensor]
) -> List[List[Optional[np.ndarray]]]:
_bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))]
mask_rois = bbox2roi(_bboxes)
mask_results = self._module.roi_head._mask_forward(x, mask_rois)
mask_pred = mask_results["mask_pred"]
num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)

batch_size = x[0].shape[0]

scale_x = self._input_img_shape[1] / self._saliency_map_shape[1]
scale_y = self._input_img_shape[0] / self._saliency_map_shape[0]
scale_factor = torch.FloatTensor((scale_x, scale_y, scale_x, scale_y))
test_cfg = self._module.roi_head.test_cfg.copy()
test_cfg["mask_thr_binary"] = -1

saliency_maps = [] # type: List[List[Optional[np.ndarray]]]
for i in range(batch_size):
saliency_maps.append([])
for j in range(self._module.roi_head.mask_head.num_classes):
saliency_maps[i].append(None)

for i in range(batch_size):
if det_bboxes[i].shape[0] == 0:
continue
else:
segm_result = self._module.roi_head.mask_head.get_seg_masks(
mask_preds[i],
_bboxes[i],
det_labels[i],
test_cfg,
self._saliency_map_shape,
scale_factor=scale_factor,
rescale=True,
)
for class_id, segm_res in enumerate(segm_result):
if segm_res:
saliency_maps[i][class_id] = np.mean(np.array(segm_res), axis=0)
return saliency_maps

@staticmethod
def _normalize(saliency_maps: List[List[Optional[np.ndarray]]]) -> List[List[Optional[np.ndarray]]]:
batch_size = len(saliency_maps)
num_classes = len(saliency_maps[0])
for i in range(batch_size):
for class_id in range(num_classes):
per_class_map = saliency_maps[i][class_id]
if per_class_map is not None:
max_values = np.max(per_class_map)
per_class_map = 255 * (per_class_map) / (max_values + 1e-12)
per_class_map = per_class_map.astype(np.uint8)
saliency_maps[i][class_id] = per_class_map
return saliency_maps
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from mmdet.models.detectors.mask_rcnn import MaskRCNN

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
Expand Down Expand Up @@ -99,7 +98,7 @@ def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, pr
def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **kwargs):
"""Function for custom_mask_rcnn__simple_test."""
assert self.with_bbox, "Bbox head must be implemented."
x = backbone_out = self.backbone(img)
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
if proposals is None:
Expand All @@ -108,7 +107,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **k

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(x)
saliency_map = ActivationMapHook.func(backbone_out)
# Saliency map will be generated from predictions. Generate dummy saliency_map.
saliency_map = torch.empty(1, dtype=torch.uint8)
return (*out, feature_vector, saliency_map)

return out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def simple_test(self, img, img_metas, proposals=None, rescale=False):

# pylint: disable=ungrouped-imports
from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)

Expand Down Expand Up @@ -309,7 +308,7 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None):
assert self.with_bbox, "Bbox head must be implemented."
tile_prob = self.tile_classifier.simple_test(img)

x = backbone_out = self.backbone(img)
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
if proposals is None:
Expand All @@ -318,7 +317,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None):

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(x)
saliency_map = ActivationMapHook.func(backbone_out)
# Saliency map will be generated from predictions. Generate dummy saliency_map.
saliency_map = torch.empty(1, dtype=torch.uint8)
return (*out, tile_prob, feature_vector, saliency_map)

return (*out, tile_prob)
36 changes: 28 additions & 8 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import functools
import glob
import io
import os
Expand Down Expand Up @@ -61,7 +62,9 @@
from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
MaskRCNNHook,
)
from otx.algorithms.detection.adapters.mmdet.models.detectors import CustomMaskRCNN
from otx.algorithms.detection.adapters.mmdet.utils import (
patch_input_preprocessing,
patch_input_shape,
Expand Down Expand Up @@ -398,7 +401,13 @@ def hook(module, inp, outp):
if raw_model.__class__.__name__ == "NNCFNetwork":
raw_model = raw_model.get_nncf_wrapped_model()
if isinstance(raw_model, TwoStageDetector):
saliency_hook = ActivationMapHook(feature_model)
test_pipeline = cfg.data.test.pipeline
width, height = None, None
for pipeline in test_pipeline:
width, height = pipeline.get("img_scale", (None, None))
if height is None:
raise ValueError("img_scale has to be defined in the test pipeline.")
saliency_hook = MaskRCNNHook(feature_model, input_img_shape=(height, width))
else:
saliency_hook = DetClassProbabilityMapHook(feature_model)

Expand Down Expand Up @@ -516,12 +525,6 @@ def _explain_model(
explain_parameters: Optional[ExplainParameters] = None,
) -> Dict[str, Any]:
"""Main explain function of MMDetectionTask."""

explainer_hook_selector = {
"classwisesaliencymap": DetClassProbabilityMapHook,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}
self._data_cfg = ConfigDict(
data=ConfigDict(
train=ConfigDict(
Expand Down Expand Up @@ -590,6 +593,23 @@ def hook(module, inp, outp):
model.register_forward_pre_hook(pre_hook)
model.register_forward_hook(hook)

if isinstance(feature_model, CustomMaskRCNN):
test_pipeline = cfg.data.test.pipeline
width, height = None, None
for pipeline in test_pipeline:
width, height = pipeline.get("img_scale", (None, None))
if height is None:
raise ValueError("img_scale has to be defined in the test pipeline.")
per_class_xai_algorithm = functools.partial(MaskRCNNHook, input_img_shape=(height, width))
else:
per_class_xai_algorithm = DetClassProbabilityMapHook # type: ignore

explainer_hook_selector = {
"classwisesaliencymap": per_class_xai_algorithm,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}

explainer = explain_parameters.explainer if explain_parameters else None
if explainer is not None:
explainer_hook = explainer_hook_selector.get(explainer.lower(), None)
Expand All @@ -601,7 +621,7 @@ def hook(module, inp, outp):

# Class-wise Saliency map for Single-Stage Detector, otherwise use class-ignore saliency map.
eval_predictions = []
with explainer_hook(feature_model) as saliency_hook:
with explainer_hook(feature_model) as saliency_hook: # type: ignore
for data in dataloader:
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Dict

import cv2
import numpy as np

try:
Expand Down Expand Up @@ -107,6 +108,60 @@ def postprocess(self, outputs, meta):

return scores, classes, boxes, resized_masks

def postprocess_saliency_map(self, outputs, meta, num_classes):
"""Post process function for saliency map of OTX MaskRCNN model."""
boxes = outputs[self.output_blob_name["boxes"]]
if boxes.shape[0] == 1:
boxes = boxes.squeeze(0)
scores = boxes[:, 4]
boxes = boxes[:, :4]
masks = outputs[self.output_blob_name["masks"]]
if masks.shape[0] == 1:
masks = masks.squeeze(0)
classes = outputs[self.output_blob_name["labels"]].astype(np.uint32)
if classes.shape[0] == 1:
classes = classes.squeeze(0)

scale_x = meta["resized_shape"][1] / meta["original_shape"][1]
scale_y = meta["resized_shape"][0] / meta["original_shape"][0]
boxes[:, 0::2] /= scale_x
boxes[:, 1::2] /= scale_y

saliency_maps = [None for _ in range(num_classes)]
for box, score, cls, raw_mask in zip(boxes, scores, classes, masks):
resized_mask = self._resize_mask(box, raw_mask * score, *meta["original_shape"][:-1])
if saliency_maps[cls] is None:
saliency_maps[cls] = [resized_mask]
else:
saliency_maps[cls].append(resized_mask)

# Normalize
for i in range(num_classes):
per_class_map = saliency_maps[i]
if per_class_map is not None:
per_class_map = np.array(per_class_map).mean(0)
max_values = np.max(per_class_map)
per_class_map = 255 * (per_class_map) / (max_values + 1e-12)
per_class_map = per_class_map.astype(np.uint8)
saliency_maps[i] = per_class_map
return saliency_maps

def _resize_mask(self, box, raw_cls_mask, im_h, im_w):
# Add zero border to prevent upsampling artifacts on segment borders.
raw_cls_mask = np.pad(raw_cls_mask, ((1, 1), (1, 1)), "constant", constant_values=0)
extended_box = self._expand_box(box, raw_cls_mask.shape[0] / (raw_cls_mask.shape[0] - 2.0)).astype(int)
w, h = np.maximum(extended_box[2:] - extended_box[:2] + 1, 1)
x0, y0 = np.clip(extended_box[:2], a_min=0, a_max=[im_w, im_h])
x1, y1 = np.clip(extended_box[2:] + 1, a_min=0, a_max=[im_w, im_h])

raw_cls_mask = cv2.resize(raw_cls_mask.astype(np.float32), (w, h))
# Put an object mask in an image mask.
im_mask = np.zeros((im_h, im_w), dtype=np.float32)
im_mask[y0:y1, x0:x1] = raw_cls_mask[
(y0 - extended_box[1]) : (y1 - extended_box[1]), (x0 - extended_box[0]) : (x1 - extended_box[0])
]
return im_mask

def segm_postprocess(self, *args, **kwargs):
"""Post-process for segmentation masks."""
return self._segm_postprocess(*args, **kwargs)
Expand Down
Loading

0 comments on commit 6133b4c

Please sign in to comment.