diff --git a/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py b/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py index cb9a83a25a9..d1b816c3a6e 100644 --- a/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py +++ b/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py @@ -16,8 +16,9 @@ from __future__ import annotations from abc import ABC -from typing import List, Sequence, Union +from typing import List, Optional, Sequence, Union +import numpy as np import torch from otx.algorithms.classification import MMCLS_AVAILABLE @@ -69,10 +70,24 @@ def _recording_forward( self, _: torch.nn.Module, x: torch.Tensor, output: torch.Tensor ): # pylint: disable=unused-argument tensors = self.func(output) - tensors = tensors.detach().cpu().numpy() - for tensor in tensors: + if isinstance(tensors, torch.Tensor): + tensors_np = tensors.detach().cpu().numpy() + elif isinstance(tensors, np.ndarray): + tensors_np = tensors + else: + self._torch_to_numpy_from_list(tensors) + tensors_np = tensors + + for tensor in tensors_np: self._records.append(tensor) + def _torch_to_numpy_from_list(self, tensor_list: List[Optional[torch.Tensor]]): + for i in range(len(tensor_list)): + if isinstance(tensor_list[i], list): + self._torch_to_numpy_from_list(tensor_list[i]) + elif isinstance(tensor_list[i], torch.Tensor): + tensor_list[i] = tensor_list[i].detach().cpu().numpy() + def __enter__(self) -> BaseRecordingForwardHook: """Enter.""" self._handle = self._module.backbone.register_forward_hook(self._recording_forward) diff --git a/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py b/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py index a82e8c50260..acf3c07ccd6 100644 --- a/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py +++ b/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py @@ -2,10 +2,13 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # -from typing import List, Tuple, Union +import copy +from typing import List, Optional, Tuple, Union +import numpy as np import torch import torch.nn.functional as F +from mmdet.core import bbox2roi from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( BaseRecordingForwardHook, @@ -121,3 +124,127 @@ def forward_single(x, cls_convs, conv_cls): "YOLOXHead, ATSSHead, SSDHead, VFNetHead." ) return cls_scores + + +class MaskRCNNRecordingForwardHook(BaseRecordingForwardHook): + """Saliency map hook for Mask R-CNN model. Only for torch model, does not support OpenVINO IR model. + + Args: + module (torch.nn.Module): Mask R-CNN model. + input_img_shape (Tuple[int]): Resolution of the model input image. + saliency_map_shape (Tuple[int]): Resolution of the output saliency map. + max_detections_per_img (int): Upper limit of the number of detections + from which soft mask predictions are getting aggregated. + normalize (bool): Flag that defines if the output saliency map will be normalized. + Although, partial normalization is anyway done by segmentation mask head. + """ + + def __init__( + self, + module: torch.nn.Module, + input_img_shape: Tuple[int, int], + saliency_map_shape: Tuple[int, int] = (224, 224), + max_detections_per_img: int = 300, + normalize: bool = True, + ) -> None: + super().__init__(module) + self._neck = module.neck if module.with_neck else None + self._input_img_shape = input_img_shape + self._saliency_map_shape = saliency_map_shape + self._max_detections_per_img = max_detections_per_img + self._norm_saliency_maps = normalize + + def func( + self, + feature_map: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], + _: int = -1, + ) -> List[List[Optional[np.ndarray]]]: + """Generate saliency maps by aggregating per-class soft predictions of mask head for all detected boxes. + + :param feature_map: Feature maps from backbone. + :return: Class-wise Saliency Maps. One saliency map per each predicted class. + """ + with torch.no_grad(): + if self._neck is not None: + feature_map = self._module.neck(feature_map) + + det_bboxes, det_labels = self._get_detections(feature_map) + saliency_maps = self._get_saliency_maps_from_mask_predictions(feature_map, det_bboxes, det_labels) + if self._norm_saliency_maps: + saliency_maps = self._normalize(saliency_maps) + return saliency_maps + + def _get_detections(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + batch_size = x[0].shape[0] + img_metas = [ + { + "scale_factor": [1, 1, 1, 1], # dummy scale_factor, not used + "img_shape": self._input_img_shape, + } + ] + img_metas *= batch_size + proposals = self._module.rpn_head.simple_test_rpn(x, img_metas) + test_cfg = copy.deepcopy(self._module.roi_head.test_cfg) + test_cfg["max_per_img"] = self._max_detections_per_img + test_cfg["nms"]["iou_threshold"] = 1 + test_cfg["nms"]["max_num"] = self._max_detections_per_img + det_bboxes, det_labels = self._module.roi_head.simple_test_bboxes( + x, img_metas, proposals, test_cfg, rescale=False + ) + return det_bboxes, det_labels + + def _get_saliency_maps_from_mask_predictions( + self, x: torch.Tensor, det_bboxes: List[torch.Tensor], det_labels: List[torch.Tensor] + ) -> List[List[Optional[np.ndarray]]]: + _bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))] + mask_rois = bbox2roi(_bboxes) + mask_results = self._module.roi_head._mask_forward(x, mask_rois) + mask_pred = mask_results["mask_pred"] + num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes] + mask_preds = mask_pred.split(num_mask_roi_per_img, 0) + + batch_size = x[0].shape[0] + + scale_x = self._input_img_shape[1] / self._saliency_map_shape[1] + scale_y = self._input_img_shape[0] / self._saliency_map_shape[0] + scale_factor = torch.FloatTensor((scale_x, scale_y, scale_x, scale_y)) + test_cfg = self._module.roi_head.test_cfg.copy() + test_cfg["mask_thr_binary"] = -1 + + saliency_maps = [] # type: List[List[Optional[np.ndarray]]] + for i in range(batch_size): + saliency_maps.append([]) + for j in range(self._module.roi_head.mask_head.num_classes): + saliency_maps[i].append(None) + + for i in range(batch_size): + if det_bboxes[i].shape[0] == 0: + continue + else: + segm_result = self._module.roi_head.mask_head.get_seg_masks( + mask_preds[i], + _bboxes[i], + det_labels[i], + test_cfg, + self._saliency_map_shape, + scale_factor=scale_factor, + rescale=True, + ) + for class_id, segm_res in enumerate(segm_result): + if segm_res: + saliency_maps[i][class_id] = np.mean(np.array(segm_res), axis=0) + return saliency_maps + + @staticmethod + def _normalize(saliency_maps: List[List[Optional[np.ndarray]]]) -> List[List[Optional[np.ndarray]]]: + batch_size = len(saliency_maps) + num_classes = len(saliency_maps[0]) + for i in range(batch_size): + for class_id in range(num_classes): + per_class_map = saliency_maps[i][class_id] + if per_class_map is not None: + max_values = np.max(per_class_map) + per_class_map = 255 * (per_class_map) / (max_values + 1e-12) + per_class_map = per_class_map.astype(np.uint8) + saliency_maps[i][class_id] = per_class_map + return saliency_maps diff --git a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_detector.py b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_detector.py index 7216dbba158..afd8b0bcf0e 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_detector.py +++ b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_detector.py @@ -10,7 +10,6 @@ from mmdet.models.detectors.mask_rcnn import MaskRCNN from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( - ActivationMapHook, FeatureVectorHook, ) from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled @@ -99,7 +98,7 @@ def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, pr def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **kwargs): """Function for custom_mask_rcnn__simple_test.""" assert self.with_bbox, "Bbox head must be implemented." - x = backbone_out = self.backbone(img) + x = self.backbone(img) if self.with_neck: x = self.neck(x) if proposals is None: @@ -108,7 +107,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **k if ctx.cfg["dump_features"]: feature_vector = FeatureVectorHook.func(x) - saliency_map = ActivationMapHook.func(backbone_out) + # Saliency map will be generated from predictions. Generate dummy saliency_map. + saliency_map = torch.empty(1, dtype=torch.uint8) return (*out, feature_vector, saliency_map) return out diff --git a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_tile_optimized.py b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_tile_optimized.py index 44e02c6c523..7743a6c53a1 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_tile_optimized.py +++ b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_tile_optimized.py @@ -209,7 +209,6 @@ def simple_test(self, img, img_metas, proposals=None, rescale=False, full_res_im # pylint: disable=ungrouped-imports from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( - ActivationMapHook, FeatureVectorHook, ) @@ -327,7 +326,7 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None): assert self.with_bbox, "Bbox head must be implemented." tile_prob = self.tile_classifier.simple_test(img) - x = backbone_out = self.backbone(img) + x = self.backbone(img) if self.with_neck: x = self.neck(x) if proposals is None: @@ -336,7 +335,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None): if ctx.cfg["dump_features"]: feature_vector = FeatureVectorHook.func(x) - saliency_map = ActivationMapHook.func(backbone_out) + # Saliency map will be generated from predictions. Generate dummy saliency_map. + saliency_map = torch.empty(1, dtype=torch.uint8) return (*out, tile_prob, feature_vector, saliency_map) return (*out, tile_prob) diff --git a/otx/algorithms/detection/adapters/mmdet/task.py b/otx/algorithms/detection/adapters/mmdet/task.py index 61daea12e7e..ebe04040ad0 100644 --- a/otx/algorithms/detection/adapters/mmdet/task.py +++ b/otx/algorithms/detection/adapters/mmdet/task.py @@ -61,6 +61,7 @@ from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import ( DetClassProbabilityMapHook, + MaskRCNNRecordingForwardHook, ) from otx.algorithms.detection.adapters.mmdet.utils import ( patch_input_preprocessing, @@ -399,7 +400,8 @@ def hook(module, inp, outp): if raw_model.__class__.__name__ == "NNCFNetwork": raw_model = raw_model.get_nncf_wrapped_model() if isinstance(raw_model, TwoStageDetector): - saliency_hook = ActivationMapHook(feature_model) + height, width, _ = mm_dataset[0]["img_metas"][0].data["img_shape"] + saliency_hook = MaskRCNNRecordingForwardHook(feature_model, input_img_shape=(height, width)) else: saliency_hook = DetClassProbabilityMapHook(feature_model) @@ -522,15 +524,9 @@ def _explain_model( explain_parameters: Optional[ExplainParameters] = None, ) -> Dict[str, Any]: """Main explain function of MMDetectionTask.""" - for item in dataset: item.subset = Subset.TESTING - explainer_hook_selector = { - "classwisesaliencymap": DetClassProbabilityMapHook, - "eigencam": EigenCamHook, - "activationmap": ActivationMapHook, - } self._data_cfg = ConfigDict( data=ConfigDict( train=ConfigDict( @@ -599,6 +595,18 @@ def hook(module, inp, outp): model.register_forward_pre_hook(pre_hook) model.register_forward_hook(hook) + if isinstance(feature_model, TwoStageDetector): + height, width, _ = mm_dataset[0]["img_metas"][0].data["img_shape"] + per_class_xai_algorithm = partial(MaskRCNNRecordingForwardHook, input_img_shape=(width, height)) + else: + per_class_xai_algorithm = DetClassProbabilityMapHook # type: ignore + + explainer_hook_selector = { + "classwisesaliencymap": per_class_xai_algorithm, + "eigencam": EigenCamHook, + "activationmap": ActivationMapHook, + } + explainer = explain_parameters.explainer if explain_parameters else None if explainer is not None: explainer_hook = explainer_hook_selector.get(explainer.lower(), None) @@ -608,9 +616,8 @@ def hook(module, inp, outp): raise NotImplementedError(f"Explainer algorithm {explainer} not supported!") logger.info(f"Explainer algorithm: {explainer}") - # Class-wise Saliency map for Single-Stage Detector, otherwise use class-ignore saliency map. eval_predictions = [] - with explainer_hook(feature_model) as saliency_hook: + with explainer_hook(feature_model) as saliency_hook: # type: ignore for data in dataloader: with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) diff --git a/otx/algorithms/detection/adapters/openvino/model_wrappers/openvino_models.py b/otx/algorithms/detection/adapters/openvino/model_wrappers/openvino_models.py index 5fd9a0248b5..8867cfe1494 100644 --- a/otx/algorithms/detection/adapters/openvino/model_wrappers/openvino_models.py +++ b/otx/algorithms/detection/adapters/openvino/model_wrappers/openvino_models.py @@ -16,6 +16,7 @@ from typing import Dict +import cv2 import numpy as np try: @@ -107,6 +108,85 @@ def postprocess(self, outputs, meta): return scores, classes, boxes, resized_masks + def get_saliency_map_from_prediction(self, outputs, meta, num_classes): + """Post process function for saliency map of OTX MaskRCNN model.""" + boxes = outputs[self.output_blob_name["boxes"]] + if boxes.shape[0] == 1: + boxes = boxes.squeeze(0) + scores = boxes[:, 4] + boxes = boxes[:, :4] + masks = outputs[self.output_blob_name["masks"]] + if masks.shape[0] == 1: + masks = masks.squeeze(0) + classes = outputs[self.output_blob_name["labels"]].astype(np.uint32) + if classes.shape[0] == 1: + classes = classes.squeeze(0) + + scale_x = meta["resized_shape"][1] / meta["original_shape"][1] + scale_y = meta["resized_shape"][0] / meta["original_shape"][0] + boxes[:, 0::2] /= scale_x + boxes[:, 1::2] /= scale_y + + saliency_maps = [None for _ in range(num_classes)] + for box, score, cls, raw_mask in zip(boxes, scores, classes, masks): + resized_mask = self._resize_mask(box, raw_mask * score, *meta["original_shape"][:-1]) + if saliency_maps[cls] is None: + saliency_maps[cls] = [resized_mask] + else: + saliency_maps[cls].append(resized_mask) + + saliency_maps = self._average_and_normalize(saliency_maps, num_classes) + return saliency_maps + + def _resize_mask(self, box, raw_cls_mask, im_h, im_w): + # Add zero border to prevent upsampling artifacts on segment borders. + raw_cls_mask = np.pad(raw_cls_mask, ((1, 1), (1, 1)), "constant", constant_values=0) + extended_box = self._expand_box(box, raw_cls_mask.shape[0] / (raw_cls_mask.shape[0] - 2.0)).astype(int) + w, h = np.maximum(extended_box[2:] - extended_box[:2] + 1, 1) + x0, y0 = np.clip(extended_box[:2], a_min=0, a_max=[im_w, im_h]) + x1, y1 = np.clip(extended_box[2:] + 1, a_min=0, a_max=[im_w, im_h]) + + raw_cls_mask = cv2.resize(raw_cls_mask.astype(np.float32), (w, h)) + # Put an object mask in an image mask. + im_mask = np.zeros((im_h, im_w), dtype=np.float32) + im_mask[y0:y1, x0:x1] = raw_cls_mask[ + (y0 - extended_box[1]) : (y1 - extended_box[1]), (x0 - extended_box[0]) : (x1 - extended_box[0]) + ] + return im_mask + + @staticmethod + def _average_and_normalize(saliency_maps, num_classes): + for i in range(num_classes): + if saliency_maps[i] is not None: + saliency_maps[i] = np.array(saliency_maps[i]).mean(0) + + for i in range(num_classes): + per_class_map = saliency_maps[i] + if per_class_map is not None: + max_values = np.max(per_class_map) + per_class_map = 255 * (per_class_map) / (max_values + 1e-12) + per_class_map = per_class_map.astype(np.uint8) + saliency_maps[i] = per_class_map + return saliency_maps + + def get_tiling_saliency_map_from_prediction(self, detections, num_classes): + """Post process function for saliency map of OTX MaskRCNN model for tiling.""" + saliency_maps = [None for _ in range(num_classes)] + + # No detection case + if isinstance(detections, np.ndarray) and detections.size == 0: + return saliency_maps + + classes = [int(cls) - 1 for cls in detections[1]] + masks = detections[3] + for mask, cls in zip(masks, classes): + if saliency_maps[cls] is None: + saliency_maps[cls] = [mask] + else: + saliency_maps[cls].append(mask) + saliency_maps = self._average_and_normalize(saliency_maps, num_classes) + return saliency_maps + def segm_postprocess(self, *args, **kwargs): """Post-process for segmentation masks.""" return self._segm_postprocess(*args, **kwargs) diff --git a/otx/algorithms/detection/adapters/openvino/task.py b/otx/algorithms/detection/adapters/openvino/task.py index cddd4c9b91f..4c2793d8285 100644 --- a/otx/algorithms/detection/adapters/openvino/task.py +++ b/otx/algorithms/detection/adapters/openvino/task.py @@ -39,6 +39,7 @@ from otx.algorithms.common.utils.logger import get_logger from otx.algorithms.common.utils.utils import get_default_async_reqs_num from otx.algorithms.detection.adapters.openvino import model_wrappers +from otx.algorithms.detection.adapters.openvino.model_wrappers import OTXMaskRCNNModel from otx.algorithms.detection.configs.base import DetectionConfig from otx.api.configuration.helper.utils import ( config_to_bytes, @@ -118,6 +119,14 @@ def post_process(self, prediction: Dict[str, np.ndarray], metadata: Dict[str, An return self.converter.convert_to_annotation(detections, metadata) + def get_saliency_map(self, prediction: Dict[str, np.ndarray], metadata: Dict[str, Any]): + """Saliency map function of OpenVINO Detection Inferencer.""" + if isinstance(self.model, OTXMaskRCNNModel): + num_classes = len(self.converter.labels) # type: ignore + return self.model.get_saliency_map_from_prediction(prediction, metadata, num_classes) + else: + return prediction["saliency_map"][0] + def predict(self, image: np.ndarray): """Predict function of OpenVINO Detection Inferencer.""" image, metadata = self.pre_process(image) @@ -132,7 +141,7 @@ def predict(self, image: np.ndarray): else: features = ( raw_predictions["feature_vector"].reshape(-1), - raw_predictions["saliency_map"][0], + self.get_saliency_map(raw_predictions, metadata), ) return predictions, features @@ -364,9 +373,16 @@ def predict( Returns: detections: AnnotationSceneEntity - features: list including saliency map and feature vector + features: list including feature vector and saliency map """ detections, features = self.tiler.predict(image, mode) + + _, saliency_map = features + if saliency_map is not None and isinstance(self.model, OTXMaskRCNNModel): + num_classes = len(self.converter.labels) # type: ignore + saliency_map = self.model.get_tiling_saliency_map_from_prediction(detections, num_classes) + features = features[0], saliency_map + detections = self.converter.convert_to_annotation(detections, metadata={"original_shape": image.shape}) return detections, features @@ -527,7 +543,7 @@ def add_prediction(id: int, predicted_scene: AnnotationSceneEntity, aux_data: tu if add_saliency_map and saliency_map is not None: labels = self.task_environment.get_labels().copy() - if saliency_map.shape[0] == len(labels) + 1: + if len(saliency_map) == len(labels) + 1: # Include the background as the last category labels.append(LabelEntity("background", Domain.DETECTION)) @@ -597,7 +613,7 @@ def explain( ) labels = self.task_environment.get_labels().copy() - if saliency_map.shape[0] == len(labels) + 1: + if len(saliency_map) == len(labels) + 1: # Include the background as the last category labels.append(LabelEntity("background", Domain.DETECTION)) diff --git a/otx/algorithms/detection/task.py b/otx/algorithms/detection/task.py index 38d81809e7e..6b5f092d2f5 100644 --- a/otx/algorithms/detection/task.py +++ b/otx/algorithms/detection/task.py @@ -462,7 +462,7 @@ def _add_predictions_to_dataset( if saliency_map is not None: labels = self._labels.copy() - if saliency_map.shape[0] == len(labels) + 1: + if len(saliency_map) == len(labels) + 1: # Include the background as the last category labels.append(LabelEntity("background", Domain.DETECTION)) @@ -573,7 +573,7 @@ def _add_explanations_to_dataset( """Add saliency map to the dataset.""" for dataset_item, detection, saliency_map in zip(dataset, detections, explain_results): labels = self._labels.copy() - if saliency_map.shape[0] == len(labels) + 1: + if len(saliency_map) == len(labels) + 1: # Include the background as the last category labels.append(LabelEntity("background", Domain.DETECTION)) diff --git a/otx/api/utils/dataset_utils.py b/otx/api/utils/dataset_utils.py index c3058de81ee..b233a5fe85c 100644 --- a/otx/api/utils/dataset_utils.py +++ b/otx/api/utils/dataset_utils.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions # and limitations under the License. -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np @@ -195,28 +195,29 @@ def contains_anomalous_images(dataset: DatasetEntity) -> bool: # pylint: disable-msg=too-many-locals def add_saliency_maps_to_dataset_item( dataset_item: DatasetItemEntity, - saliency_map: np.ndarray, + saliency_map: Union[List[Optional[np.ndarray]], np.ndarray], model: Optional[ModelEntity], labels: List[LabelEntity], predicted_scored_labels: Optional[List[ScoredLabel]] = None, explain_predicted_classes: bool = True, process_saliency_maps: bool = False, ): - """Add saliency maps(2d for class-ignore saliency map, 3d for class-wise saliency maps) to a single dataset item.""" - if saliency_map.ndim == 2: - # Single saliency map per image, support e.g. EigenCAM use case - if process_saliency_maps: - saliency_map = get_actmap(saliency_map, (dataset_item.width, dataset_item.height)) - saliency_media = ResultMediaEntity( - name="Saliency Map", - type="saliency_map", - annotation_scene=dataset_item.annotation_scene, - numpy=saliency_map, - roi=dataset_item.roi, - ) - dataset_item.append_metadata_item(saliency_media, model=model) - elif saliency_map.ndim == 3: - # Multiple saliency maps per image (class-wise saliency map) + """Add saliency maps (2D array for class-agnostic saliency map, + 3D array or list or 2D arrays for class-wise saliency maps) to a single dataset item.""" + if isinstance(saliency_map, list): + class_wise_saliency_map = True + elif isinstance(saliency_map, np.ndarray): + if saliency_map.ndim == 2: + class_wise_saliency_map = False + elif saliency_map.ndim == 3: + class_wise_saliency_map = True + else: + raise ValueError(f"Saliency map has to be 2 or 3-dimensional array, " f"but got {saliency_map.ndim} dims.") + else: + raise TypeError("Check saliency_map, it has to be list or np.ndarray.") + + if class_wise_saliency_map: + # Multiple saliency maps per image (class-wise saliency map), support e.g. ReciproCAM if explain_predicted_classes: # Explain only predicted classes if predicted_scored_labels is None: @@ -232,7 +233,7 @@ def add_saliency_maps_to_dataset_item( for class_id, class_wise_saliency_map in enumerate(saliency_map): label = labels[class_id] - if label in explain_targets: + if class_wise_saliency_map is not None and label in explain_targets: if process_saliency_maps: class_wise_saliency_map = get_actmap( class_wise_saliency_map, (dataset_item.width, dataset_item.height) @@ -247,4 +248,14 @@ def add_saliency_maps_to_dataset_item( ) dataset_item.append_metadata_item(saliency_media, model=model) else: - raise RuntimeError(f"Single saliency map has to be 2 or 3-dimensional, but got {saliency_map.ndim} dims") + # Single saliency map per image, support e.g. ActivationMap + if process_saliency_maps: + saliency_map = get_actmap(saliency_map, (dataset_item.width, dataset_item.height)) + saliency_media = ResultMediaEntity( + name="Saliency Map", + type="saliency_map", + annotation_scene=dataset_item.annotation_scene, + numpy=saliency_map, + roi=dataset_item.roi, + ) + dataset_item.append_metadata_item(saliency_media, model=model) diff --git a/tests/e2e/test_api_xai_sanity.py b/tests/e2e/test_api_xai_sanity.py index 76cdd3250a5..fbef08d253e 100644 --- a/tests/e2e/test_api_xai_sanity.py +++ b/tests/e2e/test_api_xai_sanity.py @@ -5,6 +5,7 @@ import os import os.path as osp import tempfile +from copy import deepcopy import pytest import torch @@ -12,27 +13,33 @@ from otx.algorithms.classification.adapters.mmcls.task import MMClassificationTask from otx.algorithms.classification.adapters.openvino.task import ClassificationOpenVINOTask -# from otx.algorithms.detection.tasks import ( -# DetectionInferenceTask, -# DetectionTrainTask, -# OpenVINODetectionTask, -# ) +from otx.algorithms.detection.adapters.mmdet.task import MMDetectionTask +from otx.algorithms.detection.adapters.openvino.task import OpenVINODetectionTask +from otx.algorithms.detection.configs.base import DetectionConfig +from otx.api.configuration.helper import create from otx.api.entities.inference_parameters import InferenceParameters -from otx.api.entities.model import ModelEntity +from otx.api.entities.model import ( + ModelConfiguration, + ModelEntity, +) from otx.api.entities.result_media import ResultMediaEntity +from otx.api.entities.subset import Subset from otx.api.entities.train_parameters import TrainParameters +from otx.api.entities.model_template import parse_model_template, TaskType +from otx.api.entities.label_schema import LabelGroup, LabelGroupType, LabelSchemaEntity from otx.api.usecases.tasks.interfaces.export_interface import ExportType from otx.cli.utils.io import read_model, save_model_data from tests.integration.api.classification.test_api_classification import ( DEFAULT_CLS_TEMPLATE_DIR, ClassificationTaskAPIBase, ) - -# from tests.integration.api.detection.test_api_detection import ( -# DEFAULT_DET_TEMPLATE_DIR, -# DetectionTaskAPIBase, -# ) +from tests.integration.api.detection.api_detection import DetectionTaskAPIBase, DEFAULT_DET_TEMPLATE_DIR from tests.test_suite.e2e_test_system import e2e_pytest_api +from tests.unit.algorithms.detection.test_helpers import ( + DEFAULT_ISEG_TEMPLATE_DIR, + init_environment, + generate_det_dataset, +) torch.manual_seed(0) @@ -138,77 +145,148 @@ def test_inference_xai(self, multilabel, hierarchical): ) -# class TestOVDetXAIAPI(DetectionTaskAPIBase): -# ref_raw_saliency_shapes = { -# "ATSS": (6, 8), -# "SSD": (13, 13), -# "YOLOX": (13, 13), -# } -# -# @e2e_pytest_api -# @pytest.mark.skip(reason="Detection task refactored.") -# def test_inference_xai(self): -# with tempfile.TemporaryDirectory() as temp_dir: -# hyper_parameters, model_template = self.setup_configurable_parameters( -# DEFAULT_DET_TEMPLATE_DIR, num_iters=15 -# ) -# detection_environment, dataset = self.init_environment(hyper_parameters, model_template, 10) -# -# train_task = DetectionTrainTask(task_environment=detection_environment) -# trained_model = ModelEntity( -# dataset, -# detection_environment.get_model_configuration(), -# ) -# train_task.train(dataset, trained_model, TrainParameters()) -# save_model_data(trained_model, temp_dir) -# -# from otx.api.entities.subset import Subset -# -# for processed_saliency_maps, only_predicted in [[True, False], [False, True]]: -# detection_environment, dataset = self.init_environment(hyper_parameters, model_template, 10) -# inference_parameters = InferenceParameters( -# is_evaluation=False, -# process_saliency_maps=processed_saliency_maps, -# explain_predicted_classes=only_predicted, -# ) -# -# # Infer torch model -# detection_environment.model = trained_model -# inference_task = DetectionInferenceTask(task_environment=detection_environment) -# val_dataset = dataset.get_subset(Subset.VALIDATION) -# predicted_dataset = inference_task.infer(val_dataset.with_empty_annotations(), inference_parameters) -# -# # Check saliency maps torch task -# task_labels = trained_model.configuration.get_label_schema().get_labels(include_empty=False) -# saliency_maps_check( -# predicted_dataset, -# task_labels, -# self.ref_raw_saliency_shapes[model_template.name], -# processed_saliency_maps=processed_saliency_maps, -# only_predicted=only_predicted, -# ) -# -# # Save OV IR model -# inference_task._model_ckpt = osp.join(temp_dir, "weights.pth") -# exported_model = ModelEntity(None, detection_environment.get_model_configuration()) -# inference_task.export(ExportType.OPENVINO, exported_model, dump_features=True) -# os.makedirs(temp_dir, exist_ok=True) -# save_model_data(exported_model, temp_dir) -# -# # Infer OV IR model -# load_weights_ov = osp.join(temp_dir, "openvino.xml") -# detection_environment.model = read_model( -# detection_environment.get_model_configuration(), load_weights_ov, None -# ) -# task = OpenVINODetectionTask(task_environment=detection_environment) -# _, dataset = self.init_environment(hyper_parameters, model_template, 10) -# predicted_dataset_ov = task.infer(dataset.with_empty_annotations(), inference_parameters) -# -# # Check saliency maps OV task -# saliency_maps_check( -# predicted_dataset_ov, -# task_labels, -# self.ref_raw_saliency_shapes[model_template.name], -# processed_saliency_maps=processed_saliency_maps, -# only_predicted=only_predicted, -# ) +class TestOVDetXAIAPI(DetectionTaskAPIBase): + ref_raw_saliency_shapes = { + "ATSS": (6, 8), + } + + @e2e_pytest_api + def test_inference_xai(self): + with tempfile.TemporaryDirectory() as temp_dir: + hyper_parameters, model_template = self.setup_configurable_parameters( + DEFAULT_DET_TEMPLATE_DIR, num_iters=15 + ) + task_env, dataset = self.init_environment(hyper_parameters, model_template, 10) + + train_task = MMDetectionTask(task_environment=task_env) + trained_model = ModelEntity( + dataset, + task_env.get_model_configuration(), + ) + train_task.train(dataset, trained_model, TrainParameters()) + save_model_data(trained_model, temp_dir) + + for processed_saliency_maps, only_predicted in [[True, False], [False, True]]: + task_env, dataset = self.init_environment(hyper_parameters, model_template, 10) + inference_parameters = InferenceParameters( + is_evaluation=False, + process_saliency_maps=processed_saliency_maps, + explain_predicted_classes=only_predicted, + ) + + # Infer torch model + task_env.model = trained_model + inference_task = MMDetectionTask(task_environment=task_env) + val_dataset = dataset.get_subset(Subset.VALIDATION) + predicted_dataset = inference_task.infer(val_dataset.with_empty_annotations(), inference_parameters) + + # Check saliency maps torch task + task_labels = trained_model.configuration.get_label_schema().get_labels(include_empty=False) + saliency_maps_check( + predicted_dataset, + task_labels, + self.ref_raw_saliency_shapes[model_template.name], + processed_saliency_maps=processed_saliency_maps, + only_predicted=only_predicted, + ) + + # Save OV IR model + inference_task._model_ckpt = osp.join(temp_dir, "weights.pth") + exported_model = ModelEntity(None, task_env.get_model_configuration()) + inference_task.export(ExportType.OPENVINO, exported_model, dump_features=True) + os.makedirs(temp_dir, exist_ok=True) + save_model_data(exported_model, temp_dir) + + # Infer OV IR model + load_weights_ov = osp.join(temp_dir, "openvino.xml") + task_env.model = read_model(task_env.get_model_configuration(), load_weights_ov, None) + task = OpenVINODetectionTask(task_environment=task_env) + _, dataset = self.init_environment(hyper_parameters, model_template, 10) + predicted_dataset_ov = task.infer(dataset.with_empty_annotations(), inference_parameters) + + # Check saliency maps OV task + saliency_maps_check( + predicted_dataset_ov, + task_labels, + self.ref_raw_saliency_shapes[model_template.name], + processed_saliency_maps=processed_saliency_maps, + only_predicted=only_predicted, + ) + + +class TestOVISegmXAIAPI: + @e2e_pytest_api + def test_inference_xai(self): + with tempfile.TemporaryDirectory() as temp_dir: + model_template = parse_model_template(os.path.join(DEFAULT_ISEG_TEMPLATE_DIR, "template.yaml")) + hyper_parameters = create(model_template.hyper_parameters.data) + hyper_parameters.learning_parameters.num_iters = 5 + task_env = init_environment(hyper_parameters, model_template, task_type=TaskType.INSTANCE_SEGMENTATION) + + train_task = MMDetectionTask(task_env) + + iseg_dataset, iseg_labels = generate_det_dataset(TaskType.INSTANCE_SEGMENTATION, 100) + iseg_label_schema = LabelSchemaEntity() + iseg_label_group = LabelGroup( + name="labels", + labels=iseg_labels, + group_type=LabelGroupType.EXCLUSIVE, + ) + iseg_label_schema.add_group(iseg_label_group) + + _config = ModelConfiguration(DetectionConfig(), iseg_label_schema) + trained_model = ModelEntity( + iseg_dataset, + _config, + ) + + train_task.train(iseg_dataset, trained_model, TrainParameters()) + + save_model_data(trained_model, temp_dir) + + processed_saliency_maps, only_predicted = False, True + task_env = init_environment(hyper_parameters, model_template, task_type=TaskType.INSTANCE_SEGMENTATION) + inference_parameters = InferenceParameters( + is_evaluation=False, + process_saliency_maps=processed_saliency_maps, + explain_predicted_classes=only_predicted, + ) + + # Infer torch model + task_env.model = trained_model + inference_task = MMDetectionTask(task_environment=task_env) + val_dataset = iseg_dataset.get_subset(Subset.VALIDATION) + val_dataset_copy = deepcopy(val_dataset) + predicted_dataset = inference_task.infer(val_dataset.with_empty_annotations(), inference_parameters) + + # Check saliency maps torch task + task_labels = trained_model.configuration.get_label_schema().get_labels(include_empty=False) + saliency_maps_check( + predicted_dataset, + task_labels, + (224, 224), + processed_saliency_maps=processed_saliency_maps, + only_predicted=only_predicted, + ) + + # Save OV IR model + inference_task._model_ckpt = osp.join(temp_dir, "weights.pth") + exported_model = ModelEntity(None, task_env.get_model_configuration()) + inference_task.export(ExportType.OPENVINO, exported_model, dump_features=True) + os.makedirs(temp_dir, exist_ok=True) + save_model_data(exported_model, temp_dir) + + # Infer OV IR model + load_weights_ov = osp.join(temp_dir, "openvino.xml") + task_env.model = read_model(task_env.get_model_configuration(), load_weights_ov, None) + task = OpenVINODetectionTask(task_environment=task_env) + predicted_dataset_ov = task.infer(val_dataset_copy.with_empty_annotations(), inference_parameters) + + # Check saliency maps OV task + saliency_maps_check( + predicted_dataset_ov, + task_labels, + (480, 640), + processed_saliency_maps=processed_saliency_maps, + only_predicted=only_predicted, + ) diff --git a/tests/integration/api/detection/api_detection.py b/tests/integration/api/detection/api_detection.py new file mode 100644 index 00000000000..4aad43e2777 --- /dev/null +++ b/tests/integration/api/detection/api_detection.py @@ -0,0 +1,98 @@ +"""API Tests for detection training""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import glob +import warnings +import random +import os.path as osp + +from otx.algorithms.detection.utils import generate_label_schema +from otx.api.configuration.helper import create +from otx.api.entities.annotation import AnnotationSceneEntity, AnnotationSceneKind +from otx.api.entities.dataset_item import DatasetItemEntity +from otx.api.entities.datasets import DatasetEntity +from otx.api.entities.image import Image +from otx.api.entities.model_template import ( + TaskType, + parse_model_template, + task_type_to_label_domain, +) +from otx.api.entities.subset import Subset +from otx.api.entities.task_environment import TaskEnvironment +from otx.api.utils.shape_factory import ShapeFactory +from tests.test_helpers import generate_random_annotated_image + +DEFAULT_DET_TEMPLATE_DIR = osp.join("otx/algorithms/detection/configs", "detection", "mobilenetv2_atss") + + +class DetectionTaskAPIBase: + """ + Collection of tests for OTX API and OTX Model Templates + """ + + def init_environment(self, params, model_template, number_of_images=500, task_type=TaskType.DETECTION): + + labels_names = ("rectangle", "ellipse", "triangle") + labels_schema = generate_label_schema(labels_names, task_type_to_label_domain(task_type)) + labels_list = labels_schema.get_labels(False) + environment = TaskEnvironment( + model=None, + hyper_parameters=params, + label_schema=labels_schema, + model_template=model_template, + ) + + warnings.filterwarnings("ignore", message=".* coordinates .* are out of bounds.*") + items = [] + for i in range(0, number_of_images): + image_numpy, annos = generate_random_annotated_image( + image_width=640, + image_height=480, + labels=labels_list, + max_shapes=20, + min_size=50, + max_size=100, + random_seed=None, + ) + # Convert shapes according to task + for anno in annos: + if task_type == TaskType.INSTANCE_SEGMENTATION: + anno.shape = ShapeFactory.shape_as_polygon(anno.shape) + else: + anno.shape = ShapeFactory.shape_as_rectangle(anno.shape) + + image = Image(data=image_numpy) + annotation_scene = AnnotationSceneEntity(kind=AnnotationSceneKind.ANNOTATION, annotations=annos) + items.append(DatasetItemEntity(media=image, annotation_scene=annotation_scene)) + warnings.resetwarnings() + + rng = random.Random() + rng.shuffle(items) + for i, _ in enumerate(items): + subset_region = i / number_of_images + if subset_region >= 0.8: + subset = Subset.TESTING + elif subset_region >= 0.6: + subset = Subset.VALIDATION + else: + subset = Subset.TRAINING + items[i].subset = subset + + dataset = DatasetEntity(items) + return environment, dataset + + @staticmethod + def setup_configurable_parameters(template_dir, num_iters=10): + glb = glob.glob(f"{template_dir}/template*.yaml") + template_path = glb[0] if glb else None + if not template_path: + raise RuntimeError(f"Template YAML not found: {template_dir}") + + model_template = parse_model_template(template_path) + hyper_parameters = create(model_template.hyper_parameters.data) + hyper_parameters.learning_parameters.num_iters = num_iters + hyper_parameters.postprocessing.result_based_confidence_threshold = False + hyper_parameters.postprocessing.confidence_threshold = 0.1 + return hyper_parameters, model_template diff --git a/tests/integration/api/detection/test_api_detection.py b/tests/integration/api/detection/test_api_detection.py index b262c38b52a..963d41b5cb9 100644 --- a/tests/integration/api/detection/test_api_detection.py +++ b/tests/integration/api/detection/test_api_detection.py @@ -1,13 +1,10 @@ """API Tests for detection training""" -# Copyright (C) 2022 Intel Corporation +# Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # -import glob import os.path as osp -import random import time -import warnings from concurrent.futures import ThreadPoolExecutor from typing import Optional @@ -15,27 +12,19 @@ from otx.algorithms.common.tasks.training_base import BaseTask from otx.algorithms.detection.tasks import DetectionInferenceTask, DetectionTrainTask -from otx.algorithms.detection.utils import generate_label_schema -from otx.api.configuration.helper import create -from otx.api.entities.annotation import AnnotationSceneEntity, AnnotationSceneKind -from otx.api.entities.dataset_item import DatasetItemEntity from otx.api.entities.datasets import DatasetEntity -from otx.api.entities.image import Image from otx.api.entities.inference_parameters import InferenceParameters from otx.api.entities.metrics import Performance from otx.api.entities.model import ModelEntity from otx.api.entities.model_template import ( TaskType, parse_model_template, - task_type_to_label_domain, ) from otx.api.entities.resultset import ResultSetEntity from otx.api.entities.subset import Subset -from otx.api.entities.task_environment import TaskEnvironment from otx.api.entities.train_parameters import TrainParameters from otx.api.usecases.tasks.interfaces.export_interface import ExportType -from otx.api.utils.shape_factory import ShapeFactory -from tests.test_helpers import generate_random_annotated_image +from tests.integration.api.detection.api_detection import DetectionTaskAPIBase from tests.test_suite.e2e_test_system import e2e_pytest_api DEFAULT_DET_TEMPLATE_DIR = osp.join("otx/algorithms/detection/configs", "detection", "mobilenetv2_atss") @@ -52,77 +41,6 @@ def task_eval(task: BaseTask, model: ModelEntity, dataset: DatasetEntity) -> Per return result_set.performance -class DetectionTaskAPIBase: - """ - Collection of tests for OTX API and OTX Model Templates - """ - - def init_environment(self, params, model_template, number_of_images=500, task_type=TaskType.DETECTION): - - labels_names = ("rectangle", "ellipse", "triangle") - labels_schema = generate_label_schema(labels_names, task_type_to_label_domain(task_type)) - labels_list = labels_schema.get_labels(False) - environment = TaskEnvironment( - model=None, - hyper_parameters=params, - label_schema=labels_schema, - model_template=model_template, - ) - - warnings.filterwarnings("ignore", message=".* coordinates .* are out of bounds.*") - items = [] - for i in range(0, number_of_images): - image_numpy, annos = generate_random_annotated_image( - image_width=640, - image_height=480, - labels=labels_list, - max_shapes=20, - min_size=50, - max_size=100, - random_seed=None, - ) - # Convert shapes according to task - for anno in annos: - if task_type == TaskType.INSTANCE_SEGMENTATION: - anno.shape = ShapeFactory.shape_as_polygon(anno.shape) - else: - anno.shape = ShapeFactory.shape_as_rectangle(anno.shape) - - image = Image(data=image_numpy) - annotation_scene = AnnotationSceneEntity(kind=AnnotationSceneKind.ANNOTATION, annotations=annos) - items.append(DatasetItemEntity(media=image, annotation_scene=annotation_scene)) - warnings.resetwarnings() - - rng = random.Random() - rng.shuffle(items) - for i, _ in enumerate(items): - subset_region = i / number_of_images - if subset_region >= 0.8: - subset = Subset.TESTING - elif subset_region >= 0.6: - subset = Subset.VALIDATION - else: - subset = Subset.TRAINING - items[i].subset = subset - - dataset = DatasetEntity(items) - return environment, dataset - - @staticmethod - def setup_configurable_parameters(template_dir, num_iters=10): - glb = glob.glob(f"{template_dir}/template*.yaml") - template_path = glb[0] if glb else None - if not template_path: - raise RuntimeError(f"Template YAML not found: {template_dir}") - - model_template = parse_model_template(template_path) - hyper_parameters = create(model_template.hyper_parameters.data) - hyper_parameters.learning_parameters.num_iters = num_iters - hyper_parameters.postprocessing.result_based_confidence_threshold = False - hyper_parameters.postprocessing.confidence_threshold = 0.1 - return hyper_parameters, model_template - - class TestDetectionTaskAPI(DetectionTaskAPIBase): """ Collection of tests for OTX API and OTX Model Templates diff --git a/tests/test_suite/run_test_command.py b/tests/test_suite/run_test_command.py index 79f38f3995b..cc6a6c510bd 100644 --- a/tests/test_suite/run_test_command.py +++ b/tests/test_suite/run_test_command.py @@ -679,10 +679,7 @@ def xfail_templates(templates, xfail_template_ids_reasons): def otx_explain_testing(template, root, otx_dir, args, trained=False): template_work_dir = get_template_dir(template, root) - if "RCNN" in template.model_template_id: - test_algorithm = "ActivationMap" - else: - test_algorithm = "ClassWiseSaliencyMap" + test_algorithm = "ClassWiseSaliencyMap" train_ann_file = args.get("--train-ann-file", "") if "hierarchical" in train_ann_file: @@ -717,10 +714,7 @@ def otx_explain_testing(template, root, otx_dir, args, trained=False): def otx_explain_testing_all_classes(template, root, otx_dir, args): template_work_dir = get_template_dir(template, root) - if "RCNN" in template.model_template_id: - test_algorithm = "ActivationMap" - else: - test_algorithm = "ClassWiseSaliencyMap" + test_algorithm = "ClassWiseSaliencyMap" train_ann_file = args.get("--train-ann-file", "") if "hierarchical" in train_ann_file: @@ -761,10 +755,7 @@ def otx_explain_testing_all_classes(template, root, otx_dir, args): def otx_explain_testing_process_saliency_maps(template, root, otx_dir, args, trained=False): template_work_dir = get_template_dir(template, root) - if "RCNN" in template.model_template_id: - test_algorithm = "ActivationMap" - else: - test_algorithm = "ClassWiseSaliencyMap" + test_algorithm = "ClassWiseSaliencyMap" train_ann_file = args.get("--train-ann-file", "") if "hierarchical" in train_ann_file: @@ -800,10 +791,7 @@ def otx_explain_testing_process_saliency_maps(template, root, otx_dir, args, tra def otx_explain_openvino_testing(template, root, otx_dir, args, trained=False): template_work_dir = get_template_dir(template, root) - if "RCNN" in template.model_template_id: - test_algorithm = "ActivationMap" - else: - test_algorithm = "ClassWiseSaliencyMap" + test_algorithm = "ClassWiseSaliencyMap" train_ann_file = args.get("--train-ann-file", "") if "hierarchical" in train_ann_file: @@ -839,10 +827,7 @@ def otx_explain_openvino_testing(template, root, otx_dir, args, trained=False): def otx_explain_all_classes_openvino_testing(template, root, otx_dir, args): template_work_dir = get_template_dir(template, root) - if "RCNN" in template.model_template_id: - test_algorithm = "ActivationMap" - else: - test_algorithm = "ClassWiseSaliencyMap" + test_algorithm = "ClassWiseSaliencyMap" train_ann_file = args.get("--train-ann-file", "") if "hierarchical" in train_ann_file: @@ -884,10 +869,7 @@ def otx_explain_all_classes_openvino_testing(template, root, otx_dir, args): def otx_explain_process_saliency_maps_openvino_testing(template, root, otx_dir, args, trained=False): template_work_dir = get_template_dir(template, root) - if "RCNN" in template.model_template_id: - test_algorithm = "ActivationMap" - else: - test_algorithm = "ClassWiseSaliencyMap" + test_algorithm = "ClassWiseSaliencyMap" train_ann_file = args.get("--train-ann-file", "") if "hierarchical" in train_ann_file: diff --git a/tests/unit/algorithms/detection/test_xai_detection_validity.py b/tests/unit/algorithms/detection/test_xai_detection_validity.py index 2a4b5c5f1fa..08e0442a5fc 100644 --- a/tests/unit/algorithms/detection/test_xai_detection_validity.py +++ b/tests/unit/algorithms/detection/test_xai_detection_validity.py @@ -11,12 +11,16 @@ from otx.algorithms.common.adapters.mmcv.utils.config_utils import MPAConfig from otx.algorithms.detection.adapters.mmdet.hooks import DetClassProbabilityMapHook +from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import MaskRCNNRecordingForwardHook from otx.cli.registry import Registry from tests.test_suite.e2e_test_system import e2e_pytest_unit templates_det = Registry("otx/algorithms").filter(task_type="DETECTION").templates templates_det_ids = [template.model_template_id for template in templates_det] +templates_two_stage_det = Registry("otx/algorithms/detection").filter(task_type="INSTANCE_SEGMENTATION").templates +templates_two_stage_det_ids = [template.model_template_id for template in templates_two_stage_det] + class TestExplainMethods: ref_saliency_shapes = { @@ -31,9 +35,8 @@ class TestExplainMethods: "SSD": np.array([119, 72, 118, 35, 39, 30, 31, 31, 36, 28, 44, 23, 61], dtype=np.uint8), } - @e2e_pytest_unit - @pytest.mark.parametrize("template", templates_det, ids=templates_det_ids) - def test_saliency_map_det(self, template): + @staticmethod + def _get_model(template): torch.manual_seed(0) base_dir = os.path.abspath(os.path.dirname(template.model_template_path)) @@ -42,19 +45,26 @@ def test_saliency_map_det(self, template): model = build_detector(cfg.model) model = model.eval() + return model + @staticmethod + def _get_data(): img = torch.ones(2, 3, 416, 416) - 0.5 img_metas = [ { "img_shape": (416, 416, 3), + "ori_shape": (416, 416, 3), "scale_factor": np.array([1.1784703, 0.832, 1.1784703, 0.832], dtype=np.float32), }, - { - "img_shape": (416, 416, 3), - "scale_factor": np.array([1.1784703, 0.832, 1.1784703, 0.832], dtype=np.float32), - }, - ] + ] * 2 data = {"img_metas": [img_metas], "img": [img]} + return data + + @e2e_pytest_unit + @pytest.mark.parametrize("template", templates_det, ids=templates_det_ids) + def test_saliency_map_det(self, template): + model = self._get_model(template) + data = self._get_data() with DetClassProbabilityMapHook(model) as det_hook: with torch.no_grad(): @@ -65,3 +75,18 @@ def test_saliency_map_det(self, template): assert saliency_maps[0].ndim == 3 assert saliency_maps[0].shape == self.ref_saliency_shapes[template.name] assert (saliency_maps[0][0][0] == self.ref_saliency_vals_det[template.name]).all() + + @e2e_pytest_unit + @pytest.mark.parametrize("template", templates_two_stage_det, ids=templates_two_stage_det_ids) + def test_saliency_map_two_stage_det(self, template): + model = self._get_model(template) + data = self._get_data() + + with MaskRCNNRecordingForwardHook(model, input_img_shape=(800, 1344)) as det_hook: + with torch.no_grad(): + _ = model(return_loss=False, rescale=True, **data) + saliency_maps = det_hook.records + + # MaskRCNNRecordingForwardHook generates saliency maps based on predictions. + # Current test does not intend to test a trained model - so no prediction and no saliency maps are available. + assert saliency_maps == [[None] * model.roi_head.mask_head.num_classes] * 2