Skip to content

Commit

Permalink
add mmdeploy maskrcnn opset
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw committed Apr 18, 2024
1 parent e5027d9 commit 03e26d3
Show file tree
Hide file tree
Showing 8 changed files with 788 additions and 16 deletions.
163 changes: 148 additions & 15 deletions src/otx/algo/detection/heads/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from mmengine.registry import TASK_UTILS
from torch import Tensor

from otx.algo.detection.deployment import is_mmdeploy_enabled


# This class and its supporting functions below lightly adapted from the mmdet DeltaXYWHBBoxCoder available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py
Expand Down Expand Up @@ -199,21 +201,6 @@ def delta2bbox(
References:
.. [1] https://arxiv.org/abs/1311.2524
Example:
>>> rois = torch.Tensor([[ 0., 0., 1., 1.],
>>> [ 0., 0., 1., 1.],
>>> [ 0., 0., 1., 1.],
>>> [ 5., 5., 5., 5.]])
>>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
>>> [ 1., 1., 1., 1.],
>>> [ 0., 0., 2., -1.],
>>> [ 0.7, -1.9, -0.5, 0.3]])
>>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
tensor([[0.0000, 0.0000, 1.0000, 1.0000],
[0.1409, 0.1409, 2.8591, 2.8591],
[0.0000, 0.3161, 4.1945, 0.6839],
[5.0000, 5.0000, 5.0000, 5.0000]])
"""
num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4
if num_bboxes == 0:
Expand Down Expand Up @@ -251,3 +238,149 @@ def delta2bbox(
bboxes[..., 0::2].clamp_(min=0, max=max_shape[1])
bboxes[..., 1::2].clamp_(min=0, max=max_shape[0])
return bboxes.reshape(num_bboxes, -1)


if is_mmdeploy_enabled():
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
func_name="otx.algo.detection.heads.delta_xywh_bbox_coder.DeltaXYWHBBoxCoder.decode",
backend="default",
)
def deltaxywhbboxcoder__decode(
self: DeltaXYWHBBoxCoder,
bboxes: Tensor,
pred_bboxes: Tensor,
max_shape: Tensor | None = None,
wh_ratio_clip: float = 16 / 1000,
) -> Tensor:
"""Rewrite `decode` of `DeltaXYWHBBoxCoder` for default backend.
Rewrite this func to call `delta2bbox` directly.
Args:
bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
pred_bboxes (Tensor): Encoded offsets with respect to each roi.
Has shape (B, N, num_classes * 4) or (B, N, 4) or
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
when rois is a grid of anchors.Offset encoding follows [1]_.
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
wh_ratio_clip (float, optional): The allowed ratio between
width and height.
Returns:
torch.Tensor: Decoded boxes.
"""
if pred_bboxes.size(0) != bboxes.size(0):
msg = "The batch size of pred_bboxes and bboxes should be equal."
raise ValueError(msg)
if pred_bboxes.ndim == 3 and pred_bboxes.size(1) != bboxes.size(1):
msg = "The number of bboxes should be equal."
raise ValueError(msg)
return delta2bbox(
bboxes,
pred_bboxes,
self.means,
self.stds,
max_shape,
wh_ratio_clip,
self.clip_border,
self.add_ctr_clamp,
self.ctr_clamp,
)

@FUNCTION_REWRITER.register_rewriter(
func_name="otx.algo.detection.heads.delta_xywh_bbox_coder.delta2bbox",
backend="default",
)
def delta2bbox_opset(
rois: Tensor,
deltas: Tensor,
means: Tensor = (0.0, 0.0, 0.0, 0.0),
stds: Tensor = (1.0, 1.0, 1.0, 1.0),
max_shape: Tensor | None = None,
wh_ratio_clip: float = 16 / 1000,
clip_border: bool = True,
add_ctr_clamp: bool = False,
ctr_clamp: int = 32,
) -> Tensor:
"""Rewrite `delta2bbox` for default backend.
Since the need of clip op with dynamic min and max, this function uses
clip_bboxes function to support dynamic shape.
Args:
ctx (ContextCaller): The context with additional information.
rois (Tensor): Boxes to be transformed. Has shape (N, 4).
deltas (Tensor): Encoded offsets relative to each roi.
Has shape (N, num_classes * 4) or (N, 4). Note
N = num_base_anchors * W * H, when rois is a grid of
anchors. Offset encoding follows [1]_.
means (Sequence[float]): Denormalizing means for delta coordinates.
Default (0., 0., 0., 0.).
stds (Sequence[float]): Denormalizing standard deviation for delta
coordinates. Default (1., 1., 1., 1.).
max_shape (tuple[int, int]): Maximum bounds for boxes, specifies
(H, W). Default None.
wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
16 / 1000.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Default True.
add_ctr_clamp (bool): Whether to add center clamp. When set to True,
the center of the prediction bounding box will be clamped to
avoid being too far away from the center of the anchor.
Only used by YOLOF. Default False.
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
Default 32.
Return:
bboxes (Tensor): Boxes with shape (N, num_classes * 4) or (N, 4),
where 4 represent tl_x, tl_y, br_x, br_y.
"""
means = deltas.new_tensor(means).view(1, -1)
stds = deltas.new_tensor(stds).view(1, -1)
delta_shape = deltas.shape
reshaped_deltas = deltas.view(delta_shape[:-1] + (-1, 4))
denorm_deltas = reshaped_deltas * stds + means

dxy = denorm_deltas[..., :2]
dwh = denorm_deltas[..., 2:]

# fix openvino on torch1.13
xy1 = rois[..., :2].unsqueeze(2)
xy2 = rois[..., 2:].unsqueeze(2)

pxy = (xy1 + xy2) * 0.5
pwh = xy2 - xy1
dxy_wh = pwh * dxy

max_ratio = np.abs(np.log(wh_ratio_clip))
if add_ctr_clamp:
dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
dwh = torch.clamp(dwh, max=max_ratio)
else:
dwh = dwh.clamp(min=-max_ratio, max=max_ratio)

# Use exp(network energy) to enlarge/shrink each roi
half_gwh = pwh * dwh.exp() * 0.5
# Use network energy to shift the center of each roi
gxy = pxy + dxy_wh

# Convert center-xy/width/height to top-left, bottom-right
xy1 = gxy - half_gwh
xy2 = gxy + half_gwh

x1 = xy1[..., 0]
y1 = xy1[..., 1]
x2 = xy2[..., 0]
y2 = xy2[..., 1]

if clip_border and max_shape is not None:
x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, max_shape)

return torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ data_preprocessor:
- 1.0
- 1.0
type: MaskRCNN
# TODO(Eugene): Update scope after removing mmdet as pytorchcv_model is register under mmdet.registry.MODELS.
# https://github.com/openvinotoolkit/training_extensions/pull/3281
# _scope_: mmengine
backbone:
type: efficientnet_b2b
out_indices:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch import Tensor, nn
from torch.nn.modules.utils import _pair

from otx.algo.detection.deployment import is_mmdeploy_enabled
from otx.algo.instance_segmentation.mmdet.models.layers import multiclass_nms
from otx.algo.instance_segmentation.mmdet.models.utils import (
InstanceList,
Expand Down Expand Up @@ -310,3 +311,134 @@ def _predict_by_feat_single(
results.scores = det_bboxes[:, -1]
results.labels = det_labels
return results


if is_mmdeploy_enabled():
from mmdeploy.codebase.mmdet.deploy import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.mmcv.ops import multiclass_nms as multiclass_nms_ops

@FUNCTION_REWRITER.register_rewriter(
"otx.algo.instance_segmentation.mmdet.models.bbox_heads.bbox_head.BBoxHead.forward",
)
@FUNCTION_REWRITER.register_rewriter(
"otx.algo.instance_segmentation.mmdet.models.custom_roi_head.CustomConvFCBBoxHead.forward",
)
def bbox_head__forward(self, x):
"""Rewrite `forward` for default backend.
This function uses the specific `forward` function for the BBoxHead
or ConvFCBBoxHead after adding marks.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
x (Tensor): Input image tensor.
Returns:
tuple(Tensor, Tensor): The (cls_score, bbox_pred). The cls_score
has shape (N, num_det, num_cls) and the bbox_pred has shape
(N, num_det, 4).
"""
ctx = FUNCTION_REWRITER.get_context()

@mark("bbox_head_forward", inputs=["bbox_feats"], outputs=["cls_score", "bbox_pred"])
def __forward(self, x):
return ctx.origin_func(self, x)

return __forward(self, x)

@FUNCTION_REWRITER.register_rewriter(
"otx.algo.instance_segmentation.mmdet.models.bbox_heads.bbox_head.BBoxHead.predict_by_feat",
)
def bbox_head__predict_by_feat(
self: BBoxHead,
rois: Tensor,
cls_scores: tuple[Tensor],
bbox_preds: tuple[Tensor],
batch_img_metas: list[dict],
rcnn_test_cfg: dict,
rescale: bool = False,
) -> tuple[Tensor]:
"""Rewrite `predict_by_feat` of `BBoxHead` for default backend.
Transform network output for a batch into bbox predictions. Support
`reg_class_agnostic == False` case.
Args:
rois (tuple[Tensor]): Tuple of boxes to be transformed.
Each has shape (num_boxes, 5). last dimension 5 arrange as
(batch_index, x1, y1, x2, y2).
cls_scores (tuple[Tensor]): Tuple of box scores, each has shape
(num_boxes, num_classes + 1).
bbox_preds (tuple[Tensor]): Tuple of box energies / deltas, each
has shape (num_boxes, num_classes * 4).
batch_img_metas (list[dict]): List of image information.
rcnn_test_cfg (obj:`ConfigDict`, optional): `test_cfg` of R-CNN.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
Returns:
- dets (Tensor): Classification bboxes and scores, has a shape
(num_instance, 5)
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
"""
ctx = FUNCTION_REWRITER.get_context()
if rois.ndim != 3:
raise ValueError("Only support export two stage model to ONNX with batch dimension.")

img_shape = batch_img_metas[0]["img_shape"]
if self.custom_cls_channels:
scores = self.loss_cls.get_activation(cls_scores)
else:
scores = torch.nn.functional.softmax(cls_scores, dim=-1) if cls_scores is not None else None

if bbox_preds is not None:
# num_classes = 1 if self.reg_class_agnostic else self.num_classes
# if num_classes > 1:
# rois = rois.repeat_interleave(num_classes, dim=1)
bboxes = self.bbox_coder.decode(rois[..., 1:], bbox_preds, max_shape=img_shape)
else:
bboxes = rois[..., 1:].clone()
if img_shape is not None:
max_shape = bboxes.new_tensor(img_shape)[..., :2]
min_xy = bboxes.new_tensor(0)
max_xy = torch.cat([max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2)
bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)

batch_size = scores.shape[0]
device = scores.device
# ignore background class
scores = scores[..., : self.num_classes]
if not self.reg_class_agnostic:
# only keep boxes with the max scores
max_inds = scores.reshape(-1, self.num_classes).argmax(1, keepdim=True)
encode_size = self.bbox_coder.encode_size
bboxes = bboxes.reshape(-1, self.num_classes, encode_size)
dim0_inds = torch.arange(bboxes.shape[0], device=device).unsqueeze(-1)
bboxes = bboxes[dim0_inds, max_inds].reshape(batch_size, -1, encode_size)
# get nms params
post_params = get_post_processing_params(ctx.cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = rcnn_test_cfg["nms"].get("iou_threshold", post_params.iou_threshold)
score_threshold = rcnn_test_cfg.get("score_thr", post_params.score_threshold)
if torch.onnx.is_in_onnx_export():
pre_top_k = post_params.pre_top_k
else:
# For two stage partition post processing
pre_top_k = -1 if post_params.pre_top_k >= bboxes.shape[1] else post_params.pre_top_k
keep_top_k = rcnn_test_cfg.get("max_per_img", post_params.keep_top_k)
nms_type = rcnn_test_cfg["nms"].get("type")
return multiclass_nms_ops(
bboxes,
scores,
max_output_boxes_per_class,
nms_type=nms_type,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k,
)
Loading

0 comments on commit 03e26d3

Please sign in to comment.