From 03e26d39097e110f6e6c48a11bfbd8b7fc6435f4 Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Thu, 18 Apr 2024 14:53:08 +0100 Subject: [PATCH] add mmdeploy maskrcnn opset --- .../detection/heads/delta_xywh_bbox_coder.py | 163 ++++++++++++++-- .../mmconfigs/maskrcnn_efficientnetb2b.yaml | 3 + .../mmdet/models/bbox_heads/bbox_head.py | 132 +++++++++++++ .../mmdet/models/custom_roi_head.py | 123 ++++++++++++ .../mmdet/models/dense_heads/rpn_head.py | 178 ++++++++++++++++++ .../mmdet/models/detectors/two_stage.py | 2 +- .../mmdet/models/mask_heads/fcn_mask_head.py | 128 +++++++++++++ .../single_level_roi_extractor.py | 75 ++++++++ 8 files changed, 788 insertions(+), 16 deletions(-) diff --git a/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py b/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py index 4ae2dc1ffdf..fa6b162d160 100644 --- a/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py +++ b/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py @@ -9,6 +9,8 @@ from mmengine.registry import TASK_UTILS from torch import Tensor +from otx.algo.detection.deployment import is_mmdeploy_enabled + # This class and its supporting functions below lightly adapted from the mmdet DeltaXYWHBBoxCoder available at: # https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py @@ -199,21 +201,6 @@ def delta2bbox( References: .. [1] https://arxiv.org/abs/1311.2524 - - Example: - >>> rois = torch.Tensor([[ 0., 0., 1., 1.], - >>> [ 0., 0., 1., 1.], - >>> [ 0., 0., 1., 1.], - >>> [ 5., 5., 5., 5.]]) - >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], - >>> [ 1., 1., 1., 1.], - >>> [ 0., 0., 2., -1.], - >>> [ 0.7, -1.9, -0.5, 0.3]]) - >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3)) - tensor([[0.0000, 0.0000, 1.0000, 1.0000], - [0.1409, 0.1409, 2.8591, 2.8591], - [0.0000, 0.3161, 4.1945, 0.6839], - [5.0000, 5.0000, 5.0000, 5.0000]]) """ num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4 if num_bboxes == 0: @@ -251,3 +238,149 @@ def delta2bbox( bboxes[..., 0::2].clamp_(min=0, max=max_shape[1]) bboxes[..., 1::2].clamp_(min=0, max=max_shape[0]) return bboxes.reshape(num_bboxes, -1) + + +if is_mmdeploy_enabled(): + from mmdeploy.codebase.mmdet.deploy import clip_bboxes + from mmdeploy.core import FUNCTION_REWRITER + + @FUNCTION_REWRITER.register_rewriter( + func_name="otx.algo.detection.heads.delta_xywh_bbox_coder.DeltaXYWHBBoxCoder.decode", + backend="default", + ) + def deltaxywhbboxcoder__decode( + self: DeltaXYWHBBoxCoder, + bboxes: Tensor, + pred_bboxes: Tensor, + max_shape: Tensor | None = None, + wh_ratio_clip: float = 16 / 1000, + ) -> Tensor: + """Rewrite `decode` of `DeltaXYWHBBoxCoder` for default backend. + + Rewrite this func to call `delta2bbox` directly. + + Args: + bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4) + pred_bboxes (Tensor): Encoded offsets with respect to each roi. + Has shape (B, N, num_classes * 4) or (B, N, 4) or + (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H + when rois is a grid of anchors.Offset encoding follows [1]_. + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]] + and the length of max_shape should also be B. + wh_ratio_clip (float, optional): The allowed ratio between + width and height. + + Returns: + torch.Tensor: Decoded boxes. + """ + if pred_bboxes.size(0) != bboxes.size(0): + msg = "The batch size of pred_bboxes and bboxes should be equal." + raise ValueError(msg) + if pred_bboxes.ndim == 3 and pred_bboxes.size(1) != bboxes.size(1): + msg = "The number of bboxes should be equal." + raise ValueError(msg) + return delta2bbox( + bboxes, + pred_bboxes, + self.means, + self.stds, + max_shape, + wh_ratio_clip, + self.clip_border, + self.add_ctr_clamp, + self.ctr_clamp, + ) + + @FUNCTION_REWRITER.register_rewriter( + func_name="otx.algo.detection.heads.delta_xywh_bbox_coder.delta2bbox", + backend="default", + ) + def delta2bbox_opset( + rois: Tensor, + deltas: Tensor, + means: Tensor = (0.0, 0.0, 0.0, 0.0), + stds: Tensor = (1.0, 1.0, 1.0, 1.0), + max_shape: Tensor | None = None, + wh_ratio_clip: float = 16 / 1000, + clip_border: bool = True, + add_ctr_clamp: bool = False, + ctr_clamp: int = 32, + ) -> Tensor: + """Rewrite `delta2bbox` for default backend. + + Since the need of clip op with dynamic min and max, this function uses + clip_bboxes function to support dynamic shape. + + Args: + ctx (ContextCaller): The context with additional information. + rois (Tensor): Boxes to be transformed. Has shape (N, 4). + deltas (Tensor): Encoded offsets relative to each roi. + Has shape (N, num_classes * 4) or (N, 4). Note + N = num_base_anchors * W * H, when rois is a grid of + anchors. Offset encoding follows [1]_. + means (Sequence[float]): Denormalizing means for delta coordinates. + Default (0., 0., 0., 0.). + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates. Default (1., 1., 1., 1.). + max_shape (tuple[int, int]): Maximum bounds for boxes, specifies + (H, W). Default None. + wh_ratio_clip (float): Maximum aspect ratio for boxes. Default + 16 / 1000. + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Default True. + add_ctr_clamp (bool): Whether to add center clamp. When set to True, + the center of the prediction bounding box will be clamped to + avoid being too far away from the center of the anchor. + Only used by YOLOF. Default False. + ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. + Default 32. + + Return: + bboxes (Tensor): Boxes with shape (N, num_classes * 4) or (N, 4), + where 4 represent tl_x, tl_y, br_x, br_y. + """ + means = deltas.new_tensor(means).view(1, -1) + stds = deltas.new_tensor(stds).view(1, -1) + delta_shape = deltas.shape + reshaped_deltas = deltas.view(delta_shape[:-1] + (-1, 4)) + denorm_deltas = reshaped_deltas * stds + means + + dxy = denorm_deltas[..., :2] + dwh = denorm_deltas[..., 2:] + + # fix openvino on torch1.13 + xy1 = rois[..., :2].unsqueeze(2) + xy2 = rois[..., 2:].unsqueeze(2) + + pxy = (xy1 + xy2) * 0.5 + pwh = xy2 - xy1 + dxy_wh = pwh * dxy + + max_ratio = np.abs(np.log(wh_ratio_clip)) + if add_ctr_clamp: + dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp) + dwh = torch.clamp(dwh, max=max_ratio) + else: + dwh = dwh.clamp(min=-max_ratio, max=max_ratio) + + # Use exp(network energy) to enlarge/shrink each roi + half_gwh = pwh * dwh.exp() * 0.5 + # Use network energy to shift the center of each roi + gxy = pxy + dxy_wh + + # Convert center-xy/width/height to top-left, bottom-right + xy1 = gxy - half_gwh + xy2 = gxy + half_gwh + + x1 = xy1[..., 0] + y1 = xy1[..., 1] + x2 = xy2[..., 0] + y2 = xy2[..., 1] + + if clip_border and max_shape is not None: + x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, max_shape) + + return torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) diff --git a/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_efficientnetb2b.yaml b/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_efficientnetb2b.yaml index 13ce4962ebf..aec40371011 100644 --- a/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_efficientnetb2b.yaml +++ b/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_efficientnetb2b.yaml @@ -14,6 +14,9 @@ data_preprocessor: - 1.0 - 1.0 type: MaskRCNN +# TODO(Eugene): Update scope after removing mmdet as pytorchcv_model is register under mmdet.registry.MODELS. +# https://github.com/openvinotoolkit/training_extensions/pull/3281 +# _scope_: mmengine backbone: type: efficientnet_b2b out_indices: diff --git a/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/bbox_head.py b/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/bbox_head.py index b4e072eef15..74200da476d 100644 --- a/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/bbox_head.py +++ b/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/bbox_head.py @@ -17,6 +17,7 @@ from torch import Tensor, nn from torch.nn.modules.utils import _pair +from otx.algo.detection.deployment import is_mmdeploy_enabled from otx.algo.instance_segmentation.mmdet.models.layers import multiclass_nms from otx.algo.instance_segmentation.mmdet.models.utils import ( InstanceList, @@ -310,3 +311,134 @@ def _predict_by_feat_single( results.scores = det_bboxes[:, -1] results.labels = det_labels return results + + +if is_mmdeploy_enabled(): + from mmdeploy.codebase.mmdet.deploy import get_post_processing_params + from mmdeploy.core import FUNCTION_REWRITER, mark + from mmdeploy.mmcv.ops import multiclass_nms as multiclass_nms_ops + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.bbox_heads.bbox_head.BBoxHead.forward", + ) + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.custom_roi_head.CustomConvFCBBoxHead.forward", + ) + def bbox_head__forward(self, x): + """Rewrite `forward` for default backend. + + This function uses the specific `forward` function for the BBoxHead + or ConvFCBBoxHead after adding marks. + + Args: + ctx (ContextCaller): The context with additional information. + self: The instance of the original class. + x (Tensor): Input image tensor. + + Returns: + tuple(Tensor, Tensor): The (cls_score, bbox_pred). The cls_score + has shape (N, num_det, num_cls) and the bbox_pred has shape + (N, num_det, 4). + """ + ctx = FUNCTION_REWRITER.get_context() + + @mark("bbox_head_forward", inputs=["bbox_feats"], outputs=["cls_score", "bbox_pred"]) + def __forward(self, x): + return ctx.origin_func(self, x) + + return __forward(self, x) + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.bbox_heads.bbox_head.BBoxHead.predict_by_feat", + ) + def bbox_head__predict_by_feat( + self: BBoxHead, + rois: Tensor, + cls_scores: tuple[Tensor], + bbox_preds: tuple[Tensor], + batch_img_metas: list[dict], + rcnn_test_cfg: dict, + rescale: bool = False, + ) -> tuple[Tensor]: + """Rewrite `predict_by_feat` of `BBoxHead` for default backend. + + Transform network output for a batch into bbox predictions. Support + `reg_class_agnostic == False` case. + + Args: + rois (tuple[Tensor]): Tuple of boxes to be transformed. + Each has shape (num_boxes, 5). last dimension 5 arrange as + (batch_index, x1, y1, x2, y2). + cls_scores (tuple[Tensor]): Tuple of box scores, each has shape + (num_boxes, num_classes + 1). + bbox_preds (tuple[Tensor]): Tuple of box energies / deltas, each + has shape (num_boxes, num_classes * 4). + batch_img_metas (list[dict]): List of image information. + rcnn_test_cfg (obj:`ConfigDict`, optional): `test_cfg` of R-CNN. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + - dets (Tensor): Classification bboxes and scores, has a shape + (num_instance, 5) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + """ + ctx = FUNCTION_REWRITER.get_context() + if rois.ndim != 3: + raise ValueError("Only support export two stage model to ONNX with batch dimension.") + + img_shape = batch_img_metas[0]["img_shape"] + if self.custom_cls_channels: + scores = self.loss_cls.get_activation(cls_scores) + else: + scores = torch.nn.functional.softmax(cls_scores, dim=-1) if cls_scores is not None else None + + if bbox_preds is not None: + # num_classes = 1 if self.reg_class_agnostic else self.num_classes + # if num_classes > 1: + # rois = rois.repeat_interleave(num_classes, dim=1) + bboxes = self.bbox_coder.decode(rois[..., 1:], bbox_preds, max_shape=img_shape) + else: + bboxes = rois[..., 1:].clone() + if img_shape is not None: + max_shape = bboxes.new_tensor(img_shape)[..., :2] + min_xy = bboxes.new_tensor(0) + max_xy = torch.cat([max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2) + bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) + bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) + + batch_size = scores.shape[0] + device = scores.device + # ignore background class + scores = scores[..., : self.num_classes] + if not self.reg_class_agnostic: + # only keep boxes with the max scores + max_inds = scores.reshape(-1, self.num_classes).argmax(1, keepdim=True) + encode_size = self.bbox_coder.encode_size + bboxes = bboxes.reshape(-1, self.num_classes, encode_size) + dim0_inds = torch.arange(bboxes.shape[0], device=device).unsqueeze(-1) + bboxes = bboxes[dim0_inds, max_inds].reshape(batch_size, -1, encode_size) + # get nms params + post_params = get_post_processing_params(ctx.cfg) + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = rcnn_test_cfg["nms"].get("iou_threshold", post_params.iou_threshold) + score_threshold = rcnn_test_cfg.get("score_thr", post_params.score_threshold) + if torch.onnx.is_in_onnx_export(): + pre_top_k = post_params.pre_top_k + else: + # For two stage partition post processing + pre_top_k = -1 if post_params.pre_top_k >= bboxes.shape[1] else post_params.pre_top_k + keep_top_k = rcnn_test_cfg.get("max_per_img", post_params.keep_top_k) + nms_type = rcnn_test_cfg["nms"].get("type") + return multiclass_nms_ops( + bboxes, + scores, + max_output_boxes_per_class, + nms_type=nms_type, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + ) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/custom_roi_head.py b/src/otx/algo/instance_segmentation/mmdet/models/custom_roi_head.py index 8c1f92bf58f..702b981daa0 100644 --- a/src/otx/algo/instance_segmentation/mmdet/models/custom_roi_head.py +++ b/src/otx/algo/instance_segmentation/mmdet/models/custom_roi_head.py @@ -13,6 +13,7 @@ from mmengine.registry import MODELS, TASK_UTILS from torch import Tensor +from otx.algo.detection.deployment import is_mmdeploy_enabled from otx.algo.detection.heads.class_incremental_mixin import ( ClassIncrementalMixin, ) @@ -615,3 +616,125 @@ def loss( else: losses["loss_bbox"] = bbox_pred[pos_inds].sum() return losses + + +if is_mmdeploy_enabled(): + from mmdeploy.core import FUNCTION_REWRITER + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.custom_roi_head.StandardRoIHead.predict_bbox", + ) + def standard_roi_head__predict_bbox( + self: StandardRoIHead, + x: tuple[Tensor], + batch_img_metas: list[dict], + rpn_results_list: list[Tensor], + rcnn_test_cfg: ConfigType, + rescale: bool = False, + ) -> list[Tensor]: + """Rewrite `predict_bbox` of `StandardRoIHead` for default backend. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + rpn_results_list (list[Tensor]): List of region + proposals. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[Tensor]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - dets (Tensor): Classification bboxes and scores, has a shape + (num_instance, 5) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + """ + rois = rpn_results_list[0] + rois_dims = int(rois.shape[-1]) + batch_index = ( + torch.arange(rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(rois.size(0), rois.size(1), 1) + ) + rois = torch.cat([batch_index, rois[..., : rois_dims - 1]], dim=-1) + batch_size = rois.shape[0] + num_proposals_per_img = rois.shape[1] + + # Eliminate the batch dimension + rois = rois.view(-1, rois_dims) + bbox_results = self._bbox_forward(x, rois) + cls_scores = bbox_results["cls_score"] + bbox_preds = bbox_results["bbox_pred"] + + # Recover the batch dimension + rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1)) + cls_scores = cls_scores.reshape(batch_size, num_proposals_per_img, cls_scores.size(-1)) + + bbox_preds = bbox_preds.reshape(batch_size, num_proposals_per_img, bbox_preds.size(-1)) + return self.bbox_head.predict_by_feat( + rois=rois, + cls_scores=cls_scores, + bbox_preds=bbox_preds, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=rcnn_test_cfg, + rescale=rescale, + ) + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.custom_roi_head.StandardRoIHead.predict_mask", + ) + def standard_roi_head__predict_mask( + self: StandardRoIHead, + x: tuple[Tensor], + batch_img_metas: list[dict], + results_list: list[Tensor], + rescale: bool = False, + ) -> list[Tensor]: + """Forward the mask head and predict detection results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[Tensor]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + dets, det_labels = results_list + batch_size = dets.size(0) + det_bboxes = dets[..., :4] + # expand might lead to static shape, use broadcast instead + batch_index = torch.arange(det_bboxes.size(0), device=det_bboxes.device).float().view( + -1, + 1, + 1, + ) + det_bboxes.new_zeros((det_bboxes.size(0), det_bboxes.size(1))).unsqueeze(-1) + mask_rois = torch.cat([batch_index, det_bboxes], dim=-1) + mask_rois = mask_rois.view(-1, 5) + mask_results = self._mask_forward(x, mask_rois) + mask_preds = mask_results["mask_preds"] + num_det = det_bboxes.shape[1] + segm_results = self.mask_head.predict_by_feat( + mask_preds, + results_list, + batch_img_metas, + self.test_cfg, + rescale=rescale, + ) + segm_results = segm_results.reshape(batch_size, num_det, segm_results.shape[-2], segm_results.shape[-1]) + return dets, det_labels, segm_results diff --git a/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rpn_head.py b/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rpn_head.py index 3433f878a4a..8d93095f070 100644 --- a/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rpn_head.py +++ b/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rpn_head.py @@ -8,6 +8,7 @@ from __future__ import annotations import copy +import warnings from typing import TYPE_CHECKING import torch @@ -24,6 +25,7 @@ from mmengine.structures import InstanceData from torch import Tensor, nn +from otx.algo.detection.deployment import is_mmdeploy_enabled from otx.algo.detection.heads.anchor_head import AnchorHead from otx.algo.instance_segmentation.mmdet.structures.bbox import ( empty_box_as, @@ -300,3 +302,179 @@ def _bbox_post_process( results_.labels = results.scores.new_zeros(0) results = results_ return results + + +if is_mmdeploy_enabled(): + from mmdeploy.codebase.mmdet.deploy import gather_topk, get_post_processing_params, pad_with_value_if_necessary + from mmdeploy.core import FUNCTION_REWRITER + from mmdeploy.mmcv.ops import multiclass_nms as multiclass_nms_ops + from mmdeploy.utils import is_dynamic_shape + + @FUNCTION_REWRITER.register_rewriter( + func_name="otx.algo.instance_segmentation.mmdet.models.dense_heads.rpn_head.RPNHead.predict_by_feat", + ) + def rpn_head__predict_by_feat( + self: RPNHead, + cls_scores: list[Tensor], + bbox_preds: list[Tensor], + batch_img_metas: list[dict], + score_factors: list[Tensor] | None = None, + cfg: ConfigDict | None = None, + rescale: bool = False, + with_nms: bool = True, + **kwargs, + ) -> tuple: + """Rewrite `predict_by_feat` of `RPNHead` for default backend. + + Rewrite this function to deploy model, transform network output for a + batch into bbox predictions. + + Args: + ctx (ContextCaller): The context with additional information. + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Defaults to None. + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + If with_nms == True: + tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels), + `dets` of shape [N, num_det, 5] and `labels` of shape + [N, num_det]. + Else: + tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, + batch_mlvl_scores, batch_mlvl_centerness + """ + warnings.warn(f"score_factors: {score_factors} is not used in RPNHead", stacklevel=2) + warnings.warn(f"rescale: {rescale} is not used in RPNHead", stacklevel=2) + warnings.warn(f"kwargs: {kwargs} is not used in RPNHead", stacklevel=2) + ctx = FUNCTION_REWRITER.get_context() + img_metas = batch_img_metas + if len(cls_scores) != len(bbox_preds): + msg = "cls_scores and bbox_preds should have the same length" + raise ValueError(msg) + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + num_levels = len(cls_scores) + + device = cls_scores[0].device + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_anchors = self.anchor_generator.grid_anchors(featmap_sizes, device=device) + + mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] + mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] + if len(mlvl_cls_scores) != len(mlvl_bbox_preds) != len(mlvl_anchors): + msg = "mlvl_cls_scores, mlvl_bbox_preds and mlvl_anchors should have the same length" + raise ValueError(msg) + + cfg = self.test_cfg if cfg is None else cfg + if cfg is None: + warnings.warn("cfg is None, use default cfg", stacklevel=2) + cfg = { + "max_per_img": 1000, + "min_bbox_size": 0, + "nms": {"iou_threshold": 0.7, "type": "nms"}, + "nms_pre": 1000, + } + batch_size = mlvl_cls_scores[0].shape[0] + pre_topk = cfg.get("nms_pre", -1) + + # loop over features, decode boxes + mlvl_valid_bboxes = [] + mlvl_scores = [] + mlvl_valid_anchors = [] + for cls_score, bbox_pred, anchors in zip( + mlvl_cls_scores, + mlvl_bbox_preds, + mlvl_anchors, + ): + if cls_score.size()[-2:] != bbox_pred.size()[-2:]: + msg = "cls_score and bbox_pred should have the same size" + raise ValueError(msg) + cls_score = cls_score.permute(0, 2, 3, 1) + if self.use_sigmoid_cls: + cls_score = cls_score.reshape(batch_size, -1) + scores = cls_score.sigmoid() + else: + cls_score = cls_score.reshape(batch_size, -1, 2) + # We set FG labels to [0, num_class-1] and BG label to + # num_class in RPN head since mmdet v2.5, which is unified to + # be consistent with other head since mmdet v2.0. In mmdet v2.0 + # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. + scores = cls_score.softmax(-1)[..., 0] + scores = scores.reshape(batch_size, -1, 1) + dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, dim) + + # use static anchor if input shape is static + if not is_dynamic_flag: + anchors = anchors.data + + anchors = anchors.unsqueeze(0) + + # topk in tensorrt does not support shape 0: + _, topk_inds = scores.squeeze(2).topk(pre_topk) + bbox_pred, scores = gather_topk( + bbox_pred, + scores, + inds=topk_inds, + batch_size=batch_size, + is_batched=True, + ) + anchors = gather_topk(anchors, inds=topk_inds, batch_size=batch_size, is_batched=False) + mlvl_valid_bboxes.append(bbox_pred) + mlvl_scores.append(scores) + mlvl_valid_anchors.append(anchors) + + batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) + batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1) + batch_mlvl_bboxes = self.bbox_coder.decode( + batch_mlvl_anchors, + batch_mlvl_bboxes, + max_shape=img_metas[0]["img_shape"], + ) + # ignore background class + if not self.use_sigmoid_cls: + batch_mlvl_scores = batch_mlvl_scores[..., : self.num_classes] + if not with_nms: + return batch_mlvl_bboxes, batch_mlvl_scores + + post_params = get_post_processing_params(deploy_cfg) + iou_threshold = cfg["nms"].get("iou_threshold", post_params.iou_threshold) + score_threshold = cfg.get("score_thr", post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get("max_per_img", post_params.keep_top_k) + # only one class in rpn + max_output_boxes_per_class = keep_top_k + nms_type = cfg["nms"].get("type") + return multiclass_nms_ops( + batch_mlvl_bboxes, + batch_mlvl_scores, + max_output_boxes_per_class, + nms_type=nms_type, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + ) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/detectors/two_stage.py b/src/otx/algo/instance_segmentation/mmdet/models/detectors/two_stage.py index dfe12c0dbd0..1d826ff070e 100644 --- a/src/otx/algo/instance_segmentation/mmdet/models/detectors/two_stage.py +++ b/src/otx/algo/instance_segmentation/mmdet/models/detectors/two_stage.py @@ -323,7 +323,7 @@ def two_stage_detector__forward( - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). """ - del mode, kwargs + warnings.warn(f"{mode}, {kwargs} not used", stacklevel=2) ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg diff --git a/src/otx/algo/instance_segmentation/mmdet/models/mask_heads/fcn_mask_head.py b/src/otx/algo/instance_segmentation/mmdet/models/mask_heads/fcn_mask_head.py index cdf2454878c..3a48c8e50c4 100644 --- a/src/otx/algo/instance_segmentation/mmdet/models/mask_heads/fcn_mask_head.py +++ b/src/otx/algo/instance_segmentation/mmdet/models/mask_heads/fcn_mask_head.py @@ -21,6 +21,7 @@ from torch import Tensor, nn from torch.nn.modules.utils import _pair +from otx.algo.detection.deployment import is_mmdeploy_enabled from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss from otx.algo.detection.utils.structures import SamplingResult from otx.algo.instance_segmentation.mmdet.models.utils import ( @@ -423,3 +424,130 @@ def _do_paste_mask(masks: Tensor, boxes: Tensor, img_h: int, img_w: int, skip_em if skip_empty: return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) return img_masks[:, 0], () + + +if is_mmdeploy_enabled(): + from mmdeploy.codebase.mmdet.deploy import get_post_processing_params + from mmdeploy.core import FUNCTION_REWRITER + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.mask_heads.fcn_mask_head.FCNMaskHead.predict_by_feat", + ) + def fcn_mask_head__predict_by_feat( + self: FCNMaskHead, + mask_preds: Tensor, + results_list: list[Tensor], + batch_img_metas: list[dict], + rcnn_test_cfg: ConfigDict, + rescale: bool = False, + activate_map: bool = False, + ) -> list[Tensor]: + """Transform a batch of output features extracted from the head into mask results. + + Args: + mask_preds (tuple[Tensor]): Tuple of predicted foreground masks, + each has shape (n, num_classes, h, w). + results_list (list[Tensor]): Detection results of + each image. + batch_img_metas (list[dict]): List of image information. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + activate_map (book): Whether get results with augmentations test. + If True, the `mask_preds` will not process with sigmoid. + Defaults to False. + + Returns: + list[Tensor]: Detection results of each image + after the post process. Each item usually contains following keys. + + - dets (Tensor): Classification scores, has a shape + (num_instance, 5) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + ctx = FUNCTION_REWRITER.get_context() + ori_shape = batch_img_metas[0]["img_shape"] + dets, det_labels = results_list + dets = dets.view(-1, 5) + det_labels = det_labels.view(-1) + mask_preds = mask_preds.sigmoid() + bboxes = dets[:, :4] + labels = det_labels + threshold = rcnn_test_cfg.mask_thr_binary + if not self.class_agnostic: + box_inds = torch.arange(mask_preds.shape[0], device=bboxes.device) + mask_pred = mask_preds[box_inds, labels][:, None] + + # grid sample is not supported by most engine + # so we add a flag to disable it. + mmdet_params = get_post_processing_params(ctx.cfg) + export_postprocess_mask = mmdet_params.get("export_postprocess_mask", False) + if not export_postprocess_mask: + return mask_pred + + masks, _ = _do_paste_mask_ops(mask_pred, bboxes, ori_shape[0], ori_shape[1], skip_empty=False) + if threshold >= 0: + masks = (masks >= threshold).to(dtype=torch.bool) + return masks + + def _do_paste_mask_ops( + masks: Tensor, + boxes: Tensor, + img_h: int, + img_w: int, + skip_empty: bool = True, + ) -> Tensor: + """Paste instance masks according to boxes. + + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + + Args: + masks (Tensor): N, 1, H, W + boxes (Tensor): N, 4 + img_h (int): Height of the image to be pasted. + img_w (int): Width of the image to be pasted. + skip_empty (bool): Only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + + Returns: + tuple: (Tensor, tuple). The first item is mask tensor, the second one + is the slice object. + If skip_empty == False, the whole image will be pasted. It will + return a mask of shape (N, img_h, img_w) and an empty tuple. + If skip_empty == True, only area around the mask will be pasted. + A mask of shape (N, h', w') and its start and end coordinates + in the original image will be returned. + """ + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + if skip_empty: + x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(dtype=torch.int32) + x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + N = masks.shape[0] + + img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5 + img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + img_masks = torch.nn.functional.grid_sample(masks.to(dtype=torch.float32), grid, align_corners=False) + + if skip_empty: + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + return img_masks[:, 0], () diff --git a/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/single_level_roi_extractor.py b/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/single_level_roi_extractor.py index 30cde589a1d..648d6bebef6 100644 --- a/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/single_level_roi_extractor.py +++ b/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/single_level_roi_extractor.py @@ -12,6 +12,7 @@ from mmengine.registry import MODELS from torch import Tensor +from otx.algo.detection.deployment import is_mmdeploy_enabled from otx.algo.instance_segmentation.mmdet.models.utils import ConfigType, OptMultiConfig from .base_roi_extractor import BaseRoIExtractor @@ -116,3 +117,77 @@ def forward(self, feats: tuple[Tensor], rois: Tensor, roi_scale_factor: float | # included in the computation graph to avoid runtime bugs. roi_feats += sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + feats[i].sum() * 0.0 return roi_feats + + +if is_mmdeploy_enabled(): + from mmdeploy.core.rewriters import FUNCTION_REWRITER + from torch.autograd import Function + + class SingleRoIExtractorOpenVINO(Function): + """This class adds support for ExperimentalDetectronROIFeatureExtractor when exporting to OpenVINO. + + The `forward` method returns the original output, which is calculated in + advance and added to the SingleRoIExtractorOpenVINO class. In addition, the + list of arguments is changed here to be more suitable for + ExperimentalDetectronROIFeatureExtractor. + """ + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def forward(g, output_size, featmap_strides, sample_num, rois, *feats): + """Run forward.""" + return SingleRoIExtractorOpenVINO.origin_output + + @staticmethod + def symbolic(g, output_size, featmap_strides, sample_num, rois, *feats): + """Symbolic function for creating onnx op.""" + from torch.onnx.symbolic_opset10 import _slice + + rois = _slice(g, rois, axes=[1], starts=[1], ends=[5]) + domain = "org.openvinotoolkit" + op_name = "ExperimentalDetectronROIFeatureExtractor" + return g.op( + f"{domain}::{op_name}", + rois, + *feats, + output_size_i=output_size, + pyramid_scales_i=featmap_strides, + sampling_ratio_i=sample_num, + image_id_i=0, + distribute_rois_between_levels_i=1, + preserve_rois_order_i=0, + aligned_i=1, + outputs=1, + ) + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.roi_extractors." + "single_level_roi_extractor.SingleRoIExtractor.forward", + backend="openvino", + ) + def single_roi_extractor__forward__openvino( + self: SingleRoIExtractor, + feats: tuple[Tensor], + rois: Tensor, + roi_scale_factor: float | None = None, + ) -> Tensor: + """Replaces SingleRoIExtractor with SingleRoIExtractorOpenVINO when exporting to OpenVINO. + + This function uses ExperimentalDetectronROIFeatureExtractor for OpenVINO. + """ + ctx = FUNCTION_REWRITER.get_context() + + # Adding original output to SingleRoIExtractorOpenVINO. + state = torch._C._get_tracing_state() + origin_output = ctx.origin_func(self, feats, rois, roi_scale_factor) + SingleRoIExtractorOpenVINO.origin_output = origin_output + torch._C._set_tracing_state(state) + + output_size = self.roi_layers[0].output_size[0] + featmap_strides = self.featmap_strides + sample_num = self.roi_layers[0].sampling_ratio + + args = (output_size, featmap_strides, sample_num, rois, *feats) + return SingleRoIExtractorOpenVINO.apply(*args)