From 6133b4c38e4f10306b4913db360fbd8289ad5e79 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 8 Jun 2023 00:50:06 +0900 Subject: [PATCH] per class sal maps for maskrcnn --- .../mmcv/hooks/recording_forward_hook.py | 22 ++- .../hooks/det_class_probability_map_hook.py | 125 +++++++++++++++++- .../detectors/custom_maskrcnn_detector.py | 6 +- .../custom_maskrcnn_tile_optimized.py | 6 +- .../detection/adapters/mmdet/task.py | 36 +++-- .../model_wrappers/openvino_models.py | 55 ++++++++ .../detection/adapters/openvino/task.py | 17 ++- otx/api/utils/dataset_utils.py | 49 ++++--- 8 files changed, 276 insertions(+), 40 deletions(-) diff --git a/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py b/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py index cb9a83a25a9..03eb9bcde1f 100644 --- a/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py +++ b/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py @@ -16,8 +16,9 @@ from __future__ import annotations from abc import ABC -from typing import List, Sequence, Union +from typing import List, Optional, Sequence, Union +import numpy as np import torch from otx.algorithms.classification import MMCLS_AVAILABLE @@ -69,10 +70,25 @@ def _recording_forward( self, _: torch.nn.Module, x: torch.Tensor, output: torch.Tensor ): # pylint: disable=unused-argument tensors = self.func(output) - tensors = tensors.detach().cpu().numpy() - for tensor in tensors: + if isinstance(tensors, torch.Tensor): + tensors_np = tensors.detach().cpu().numpy() + elif isinstance(tensors, np.ndarray): + tensors_np = tensors + else: + self._torch_to_numpy_from_list(tensors) + tensors_np = tensors + + for tensor in tensors_np: self._records.append(tensor) + def _torch_to_numpy_from_list(self, tensor_list: List[Optional[torch.Tensor]]): + if isinstance(tensor_list[0], list): + self._torch_to_numpy_from_list(tensor_list[0]) + else: + for i in range(len(tensor_list)): + if isinstance(tensor_list[i], torch.Tensor): + tensor_list[i] = tensor_list[i].detach().cpu().numpy() + def __enter__(self) -> BaseRecordingForwardHook: """Enter.""" self._handle = self._module.backbone.register_forward_hook(self._recording_forward) diff --git a/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py b/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py index a82e8c50260..23ce3234bd7 100644 --- a/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py +++ b/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py @@ -2,10 +2,13 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # -from typing import List, Tuple, Union +import copy +from typing import List, Optional, Tuple, Union +import numpy as np import torch import torch.nn.functional as F +from mmdet.core import bbox2roi from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( BaseRecordingForwardHook, @@ -121,3 +124,123 @@ def forward_single(x, cls_convs, conv_cls): "YOLOXHead, ATSSHead, SSDHead, VFNetHead." ) return cls_scores + + +class MaskRCNNHook(BaseRecordingForwardHook): + """Saliency map hook for Mask R-CNN model. Only for torch model, does not support OV IR model. + + Args: + module (torch.nn.Module): Mask R-CNN model. + input_img_shape (Tuple[int]): Resolution of the model input image. + saliency_map_shape (Tuple[int]): Resolution of the output saliency map. + normalize (bool): Flag that defines if the output saliency map will be normalized. + Although, partial normalization is anyway done by segmentation mask head. + """ + + def __init__( + self, + module: torch.nn.Module, + input_img_shape: Tuple[int, int], + saliency_map_shape: Tuple[int, int] = (224, 224), + normalize: bool = True, + ) -> None: + super().__init__(module) + self._neck = module.neck if module.with_neck else None + self._input_img_shape = input_img_shape + self._saliency_map_shape = saliency_map_shape + self._norm_saliency_maps = normalize + + def func( + self, + feature_map: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], + _: int = -1, + ) -> List[List[Optional[np.ndarray]]]: + """Generate saliency maps by aggregating per-class soft predictions of mask head for all detected boxes. + + :param feature_map: Feature maps from backbone. + :return: Class-wise Saliency Maps. One saliency map per each predicted class. + """ + with torch.no_grad(): + if self._neck is not None: + feature_map = self._module.neck(feature_map) + + det_bboxes, det_labels = self._get_detections(feature_map) + saliency_maps = self._get_saliency_maps_from_mask_predictions(feature_map, det_bboxes, det_labels) + if self._norm_saliency_maps: + saliency_maps = self._normalize(saliency_maps) + return saliency_maps + + def _get_detections(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + batch_size = x[0].shape[0] + img_metas = [ + { + "scale_factor": [1, 1, 1, 1], + "img_shape": self._input_img_shape, + } + ] + img_metas *= batch_size + proposals = self._module.rpn_head.simple_test_rpn(x, img_metas) + test_cfg = copy.deepcopy(self._module.roi_head.test_cfg) + test_cfg["max_per_img"] = 300 + test_cfg["nms"]["iou_threshold"] = 1 + test_cfg["nms"]["max_num"] = 300 + det_bboxes, det_labels = self._module.roi_head.simple_test_bboxes( + x, img_metas, proposals, test_cfg, rescale=False + ) + return det_bboxes, det_labels + + def _get_saliency_maps_from_mask_predictions( + self, x: torch.Tensor, det_bboxes: List[torch.Tensor], det_labels: List[torch.Tensor] + ) -> List[List[Optional[np.ndarray]]]: + _bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))] + mask_rois = bbox2roi(_bboxes) + mask_results = self._module.roi_head._mask_forward(x, mask_rois) + mask_pred = mask_results["mask_pred"] + num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes] + mask_preds = mask_pred.split(num_mask_roi_per_img, 0) + + batch_size = x[0].shape[0] + + scale_x = self._input_img_shape[1] / self._saliency_map_shape[1] + scale_y = self._input_img_shape[0] / self._saliency_map_shape[0] + scale_factor = torch.FloatTensor((scale_x, scale_y, scale_x, scale_y)) + test_cfg = self._module.roi_head.test_cfg.copy() + test_cfg["mask_thr_binary"] = -1 + + saliency_maps = [] # type: List[List[Optional[np.ndarray]]] + for i in range(batch_size): + saliency_maps.append([]) + for j in range(self._module.roi_head.mask_head.num_classes): + saliency_maps[i].append(None) + + for i in range(batch_size): + if det_bboxes[i].shape[0] == 0: + continue + else: + segm_result = self._module.roi_head.mask_head.get_seg_masks( + mask_preds[i], + _bboxes[i], + det_labels[i], + test_cfg, + self._saliency_map_shape, + scale_factor=scale_factor, + rescale=True, + ) + for class_id, segm_res in enumerate(segm_result): + if segm_res: + saliency_maps[i][class_id] = np.mean(np.array(segm_res), axis=0) + return saliency_maps + + @staticmethod + def _normalize(saliency_maps: List[List[Optional[np.ndarray]]]) -> List[List[Optional[np.ndarray]]]: + batch_size = len(saliency_maps) + num_classes = len(saliency_maps[0]) + for i in range(batch_size): + for class_id in range(num_classes): + per_class_map = saliency_maps[i][class_id] + if per_class_map is not None: + max_values = np.max(per_class_map) + per_class_map = 255 * (per_class_map) / (max_values + 1e-12) + per_class_map = per_class_map.astype(np.uint8) + saliency_maps[i][class_id] = per_class_map + return saliency_maps diff --git a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_detector.py b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_detector.py index 7216dbba158..afd8b0bcf0e 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_detector.py +++ b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_detector.py @@ -10,7 +10,6 @@ from mmdet.models.detectors.mask_rcnn import MaskRCNN from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( - ActivationMapHook, FeatureVectorHook, ) from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled @@ -99,7 +98,7 @@ def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, pr def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **kwargs): """Function for custom_mask_rcnn__simple_test.""" assert self.with_bbox, "Bbox head must be implemented." - x = backbone_out = self.backbone(img) + x = self.backbone(img) if self.with_neck: x = self.neck(x) if proposals is None: @@ -108,7 +107,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **k if ctx.cfg["dump_features"]: feature_vector = FeatureVectorHook.func(x) - saliency_map = ActivationMapHook.func(backbone_out) + # Saliency map will be generated from predictions. Generate dummy saliency_map. + saliency_map = torch.empty(1, dtype=torch.uint8) return (*out, feature_vector, saliency_map) return out diff --git a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_tile_optimized.py b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_tile_optimized.py index 5fb2338d7bd..0622230dba7 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_tile_optimized.py +++ b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_tile_optimized.py @@ -191,7 +191,6 @@ def simple_test(self, img, img_metas, proposals=None, rescale=False): # pylint: disable=ungrouped-imports from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( - ActivationMapHook, FeatureVectorHook, ) @@ -309,7 +308,7 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None): assert self.with_bbox, "Bbox head must be implemented." tile_prob = self.tile_classifier.simple_test(img) - x = backbone_out = self.backbone(img) + x = self.backbone(img) if self.with_neck: x = self.neck(x) if proposals is None: @@ -318,7 +317,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None): if ctx.cfg["dump_features"]: feature_vector = FeatureVectorHook.func(x) - saliency_map = ActivationMapHook.func(backbone_out) + # Saliency map will be generated from predictions. Generate dummy saliency_map. + saliency_map = torch.empty(1, dtype=torch.uint8) return (*out, tile_prob, feature_vector, saliency_map) return (*out, tile_prob) diff --git a/otx/algorithms/detection/adapters/mmdet/task.py b/otx/algorithms/detection/adapters/mmdet/task.py index 97b98f26041..aa332fd2f69 100644 --- a/otx/algorithms/detection/adapters/mmdet/task.py +++ b/otx/algorithms/detection/adapters/mmdet/task.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions # and limitations under the License. +import functools import glob import io import os @@ -61,7 +62,9 @@ from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import ( DetClassProbabilityMapHook, + MaskRCNNHook, ) +from otx.algorithms.detection.adapters.mmdet.models.detectors import CustomMaskRCNN from otx.algorithms.detection.adapters.mmdet.utils import ( patch_input_preprocessing, patch_input_shape, @@ -398,7 +401,13 @@ def hook(module, inp, outp): if raw_model.__class__.__name__ == "NNCFNetwork": raw_model = raw_model.get_nncf_wrapped_model() if isinstance(raw_model, TwoStageDetector): - saliency_hook = ActivationMapHook(feature_model) + test_pipeline = cfg.data.test.pipeline + width, height = None, None + for pipeline in test_pipeline: + width, height = pipeline.get("img_scale", (None, None)) + if height is None: + raise ValueError("img_scale has to be defined in the test pipeline.") + saliency_hook = MaskRCNNHook(feature_model, input_img_shape=(height, width)) else: saliency_hook = DetClassProbabilityMapHook(feature_model) @@ -516,12 +525,6 @@ def _explain_model( explain_parameters: Optional[ExplainParameters] = None, ) -> Dict[str, Any]: """Main explain function of MMDetectionTask.""" - - explainer_hook_selector = { - "classwisesaliencymap": DetClassProbabilityMapHook, - "eigencam": EigenCamHook, - "activationmap": ActivationMapHook, - } self._data_cfg = ConfigDict( data=ConfigDict( train=ConfigDict( @@ -590,6 +593,23 @@ def hook(module, inp, outp): model.register_forward_pre_hook(pre_hook) model.register_forward_hook(hook) + if isinstance(feature_model, CustomMaskRCNN): + test_pipeline = cfg.data.test.pipeline + width, height = None, None + for pipeline in test_pipeline: + width, height = pipeline.get("img_scale", (None, None)) + if height is None: + raise ValueError("img_scale has to be defined in the test pipeline.") + per_class_xai_algorithm = functools.partial(MaskRCNNHook, input_img_shape=(height, width)) + else: + per_class_xai_algorithm = DetClassProbabilityMapHook # type: ignore + + explainer_hook_selector = { + "classwisesaliencymap": per_class_xai_algorithm, + "eigencam": EigenCamHook, + "activationmap": ActivationMapHook, + } + explainer = explain_parameters.explainer if explain_parameters else None if explainer is not None: explainer_hook = explainer_hook_selector.get(explainer.lower(), None) @@ -601,7 +621,7 @@ def hook(module, inp, outp): # Class-wise Saliency map for Single-Stage Detector, otherwise use class-ignore saliency map. eval_predictions = [] - with explainer_hook(feature_model) as saliency_hook: + with explainer_hook(feature_model) as saliency_hook: # type: ignore for data in dataloader: with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) diff --git a/otx/algorithms/detection/adapters/openvino/model_wrappers/openvino_models.py b/otx/algorithms/detection/adapters/openvino/model_wrappers/openvino_models.py index 5fd9a0248b5..1e9e731dbb7 100644 --- a/otx/algorithms/detection/adapters/openvino/model_wrappers/openvino_models.py +++ b/otx/algorithms/detection/adapters/openvino/model_wrappers/openvino_models.py @@ -16,6 +16,7 @@ from typing import Dict +import cv2 import numpy as np try: @@ -107,6 +108,60 @@ def postprocess(self, outputs, meta): return scores, classes, boxes, resized_masks + def postprocess_saliency_map(self, outputs, meta, num_classes): + """Post process function for saliency map of OTX MaskRCNN model.""" + boxes = outputs[self.output_blob_name["boxes"]] + if boxes.shape[0] == 1: + boxes = boxes.squeeze(0) + scores = boxes[:, 4] + boxes = boxes[:, :4] + masks = outputs[self.output_blob_name["masks"]] + if masks.shape[0] == 1: + masks = masks.squeeze(0) + classes = outputs[self.output_blob_name["labels"]].astype(np.uint32) + if classes.shape[0] == 1: + classes = classes.squeeze(0) + + scale_x = meta["resized_shape"][1] / meta["original_shape"][1] + scale_y = meta["resized_shape"][0] / meta["original_shape"][0] + boxes[:, 0::2] /= scale_x + boxes[:, 1::2] /= scale_y + + saliency_maps = [None for _ in range(num_classes)] + for box, score, cls, raw_mask in zip(boxes, scores, classes, masks): + resized_mask = self._resize_mask(box, raw_mask * score, *meta["original_shape"][:-1]) + if saliency_maps[cls] is None: + saliency_maps[cls] = [resized_mask] + else: + saliency_maps[cls].append(resized_mask) + + # Normalize + for i in range(num_classes): + per_class_map = saliency_maps[i] + if per_class_map is not None: + per_class_map = np.array(per_class_map).mean(0) + max_values = np.max(per_class_map) + per_class_map = 255 * (per_class_map) / (max_values + 1e-12) + per_class_map = per_class_map.astype(np.uint8) + saliency_maps[i] = per_class_map + return saliency_maps + + def _resize_mask(self, box, raw_cls_mask, im_h, im_w): + # Add zero border to prevent upsampling artifacts on segment borders. + raw_cls_mask = np.pad(raw_cls_mask, ((1, 1), (1, 1)), "constant", constant_values=0) + extended_box = self._expand_box(box, raw_cls_mask.shape[0] / (raw_cls_mask.shape[0] - 2.0)).astype(int) + w, h = np.maximum(extended_box[2:] - extended_box[:2] + 1, 1) + x0, y0 = np.clip(extended_box[:2], a_min=0, a_max=[im_w, im_h]) + x1, y1 = np.clip(extended_box[2:] + 1, a_min=0, a_max=[im_w, im_h]) + + raw_cls_mask = cv2.resize(raw_cls_mask.astype(np.float32), (w, h)) + # Put an object mask in an image mask. + im_mask = np.zeros((im_h, im_w), dtype=np.float32) + im_mask[y0:y1, x0:x1] = raw_cls_mask[ + (y0 - extended_box[1]) : (y1 - extended_box[1]), (x0 - extended_box[0]) : (x1 - extended_box[0]) + ] + return im_mask + def segm_postprocess(self, *args, **kwargs): """Post-process for segmentation masks.""" return self._segm_postprocess(*args, **kwargs) diff --git a/otx/algorithms/detection/adapters/openvino/task.py b/otx/algorithms/detection/adapters/openvino/task.py index eca6de758f6..b91bdc5dce1 100644 --- a/otx/algorithms/detection/adapters/openvino/task.py +++ b/otx/algorithms/detection/adapters/openvino/task.py @@ -39,6 +39,7 @@ from otx.algorithms.common.utils.logger import get_logger from otx.algorithms.common.utils.utils import get_default_async_reqs_num from otx.algorithms.detection.adapters.openvino import model_wrappers +from otx.algorithms.detection.adapters.openvino.model_wrappers import OTXMaskRCNNModel from otx.algorithms.detection.configs.base import DetectionConfig from otx.api.configuration.helper.utils import ( config_to_bytes, @@ -118,6 +119,16 @@ def post_process(self, prediction: Dict[str, np.ndarray], metadata: Dict[str, An return self.converter.convert_to_annotation(detections, metadata) + def post_process_saliency_map(self, prediction: Dict[str, np.ndarray], metadata: Dict[str, Any]): + """Saliency map post-process function of OpenVINO Detection Inferencer.""" + if isinstance(self.model, OTXMaskRCNNModel): + # MaskRCNN IR model does not include saliency map postprocessing -> it is done externally. + num_classes = len(self.converter.labels) # type: ignore + return self.model.postprocess_saliency_map(prediction, metadata, num_classes) + else: + # All other IR models include saliency map postprocessing. + return prediction["saliency_map"][0] + def predict(self, image: np.ndarray): """Predict function of OpenVINO Detection Inferencer.""" image, metadata = self.pre_process(image) @@ -132,7 +143,7 @@ def predict(self, image: np.ndarray): else: features = ( raw_predictions["feature_vector"].reshape(-1), - raw_predictions["saliency_map"][0], + self.post_process_saliency_map(raw_predictions, metadata), ) return predictions, features @@ -526,7 +537,7 @@ def add_prediction(id: int, predicted_scene: AnnotationSceneEntity, aux_data: tu if add_saliency_map and saliency_map is not None: labels = self.task_environment.get_labels().copy() - if saliency_map.shape[0] == len(labels) + 1: + if len(saliency_map) == len(labels) + 1: # Include the background as the last category labels.append(LabelEntity("background", Domain.DETECTION)) @@ -596,7 +607,7 @@ def explain( ) labels = self.task_environment.get_labels().copy() - if saliency_map.shape[0] == len(labels) + 1: + if len(saliency_map) == len(labels) + 1: # Include the background as the last category labels.append(LabelEntity("background", Domain.DETECTION)) diff --git a/otx/api/utils/dataset_utils.py b/otx/api/utils/dataset_utils.py index c3058de81ee..b233a5fe85c 100644 --- a/otx/api/utils/dataset_utils.py +++ b/otx/api/utils/dataset_utils.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions # and limitations under the License. -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np @@ -195,28 +195,29 @@ def contains_anomalous_images(dataset: DatasetEntity) -> bool: # pylint: disable-msg=too-many-locals def add_saliency_maps_to_dataset_item( dataset_item: DatasetItemEntity, - saliency_map: np.ndarray, + saliency_map: Union[List[Optional[np.ndarray]], np.ndarray], model: Optional[ModelEntity], labels: List[LabelEntity], predicted_scored_labels: Optional[List[ScoredLabel]] = None, explain_predicted_classes: bool = True, process_saliency_maps: bool = False, ): - """Add saliency maps(2d for class-ignore saliency map, 3d for class-wise saliency maps) to a single dataset item.""" - if saliency_map.ndim == 2: - # Single saliency map per image, support e.g. EigenCAM use case - if process_saliency_maps: - saliency_map = get_actmap(saliency_map, (dataset_item.width, dataset_item.height)) - saliency_media = ResultMediaEntity( - name="Saliency Map", - type="saliency_map", - annotation_scene=dataset_item.annotation_scene, - numpy=saliency_map, - roi=dataset_item.roi, - ) - dataset_item.append_metadata_item(saliency_media, model=model) - elif saliency_map.ndim == 3: - # Multiple saliency maps per image (class-wise saliency map) + """Add saliency maps (2D array for class-agnostic saliency map, + 3D array or list or 2D arrays for class-wise saliency maps) to a single dataset item.""" + if isinstance(saliency_map, list): + class_wise_saliency_map = True + elif isinstance(saliency_map, np.ndarray): + if saliency_map.ndim == 2: + class_wise_saliency_map = False + elif saliency_map.ndim == 3: + class_wise_saliency_map = True + else: + raise ValueError(f"Saliency map has to be 2 or 3-dimensional array, " f"but got {saliency_map.ndim} dims.") + else: + raise TypeError("Check saliency_map, it has to be list or np.ndarray.") + + if class_wise_saliency_map: + # Multiple saliency maps per image (class-wise saliency map), support e.g. ReciproCAM if explain_predicted_classes: # Explain only predicted classes if predicted_scored_labels is None: @@ -232,7 +233,7 @@ def add_saliency_maps_to_dataset_item( for class_id, class_wise_saliency_map in enumerate(saliency_map): label = labels[class_id] - if label in explain_targets: + if class_wise_saliency_map is not None and label in explain_targets: if process_saliency_maps: class_wise_saliency_map = get_actmap( class_wise_saliency_map, (dataset_item.width, dataset_item.height) @@ -247,4 +248,14 @@ def add_saliency_maps_to_dataset_item( ) dataset_item.append_metadata_item(saliency_media, model=model) else: - raise RuntimeError(f"Single saliency map has to be 2 or 3-dimensional, but got {saliency_map.ndim} dims") + # Single saliency map per image, support e.g. ActivationMap + if process_saliency_maps: + saliency_map = get_actmap(saliency_map, (dataset_item.width, dataset_item.height)) + saliency_media = ResultMediaEntity( + name="Saliency Map", + type="saliency_map", + annotation_scene=dataset_item.annotation_scene, + numpy=saliency_map, + roi=dataset_item.roi, + ) + dataset_item.append_metadata_item(saliency_media, model=model)