Skip to content

Commit

Permalink
MaskRCNN Native Exporter (#3412)
Browse files Browse the repository at this point in the history
* migrate mmdet maskrcnn modules

* style reformat

* style reformat

* stype reformat

* ignore mypy, ruff errors

* skip mypy error

* update

* fix loss

* add maskrcnn

* update import

* update import

* add necks

* update

* update

* add cross-entropy loss

* style changes

* mypy changes and style changes

* update style

* remove box structures

* add resnet

* udpate

* modify resnet

* add annotation

* style changes

* update

* fix all mypy issues

* fix mypy issues

* style changes

* remove unused losses

* remove focal_loss_pb

* fix all rull and mypy issues

* style change

* update

* udpate license

* udpate

* remove duplicates

* remove as F

* remove as F

* remove mmdet mask structures

* remove duplicates

* style changes

* add new test

* test style change

* fix test

* chagne device for unit test

* add deployment files

* remove deployment from inst-seg

* update deployment

* add mmdeploy maskrcnn opset

* fix linter

* update test

* update test

* update test

* replace mmcv.cnn module

* remove upsample building

* remove upsample building

* use batch_nms from otx

* add swintransformer

* add transformers

* add swin transformer

* style changes

* solve conflicts

* update instance_segmentation/maskrcnn.py

* update nms

* fix xai

* change rotate detection recipe

* fix swint recipe

* remove some files

* decopule mmdeploy and replace with native exporter

* remove duplicates import

* todo

* update

* fix rpn_head training issue

* remove maskrcnn r50 mmconfigs

* fix anchor head and related fixes

* remove gather_topk

* remove maskrcnn efficientnet mmconfig

* remove maskrcnn-swint mmconfig

* revert some changes

* update recipes

* replace mmcv.ops.roi_align with torchvision.ops.roi_align

* fix format issue

* update anchor head

* add CrossSigmoidFocalLoss back

* remove mmdet decouple test

* fix test

* skip xai test for inst-seg for now

* remove code comment

* Disable deterministic in test

* reformat
  • Loading branch information
eugene123tw authored May 2, 2024
1 parent 0f0f943 commit 151a94e
Show file tree
Hide file tree
Showing 48 changed files with 1,404 additions and 1,845 deletions.
3 changes: 3 additions & 0 deletions src/otx/algo/detection/atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from otx.algo.detection.heads.anchor_generator import AnchorGenerator
from otx.algo.detection.heads.atss_assigner import ATSSAssigner
from otx.algo.detection.heads.atss_head import ATSSHead
from otx.algo.detection.heads.base_sampler import PseudoSampler
from otx.algo.detection.heads.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss
from otx.algo.detection.losses.cross_focal_loss import CrossSigmoidFocalLoss
Expand Down Expand Up @@ -233,6 +234,7 @@ class MobileNetV2ATSS(ATSS):
def _build_model(self, num_classes: int) -> SingleStageDetector:
train_cfg = {
"assigner": ATSSAssigner(topk=9),
"sampler": PseudoSampler(),
"allowed_border": -1,
"pos_weight": -1,
"debug": False,
Expand Down Expand Up @@ -304,6 +306,7 @@ class ResNeXt101ATSS(ATSS):
def _build_model(self, num_classes: int) -> SingleStageDetector:
train_cfg = {
"assigner": ATSSAssigner(topk=9),
"sampler": PseudoSampler(),
"allowed_border": -1,
"pos_weight": -1,
"debug": False,
Expand Down
19 changes: 0 additions & 19 deletions src/otx/algo/detection/deployment.py

This file was deleted.

3 changes: 0 additions & 3 deletions src/otx/algo/detection/heads/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@

import numpy as np
import torch
from mmengine.registry import TASK_UTILS
from torch.nn.modules.utils import _pair


# This class and its supporting functions below lightly adapted from the mmdet AnchorGenerator available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/prior_generators/anchor_generator.py
@TASK_UTILS.register_module()
class AnchorGenerator:
"""Standard anchor generator for 2D anchor-based detectors.
Expand Down Expand Up @@ -475,7 +473,6 @@ def __repr__(self) -> str:
return repr_str


@TASK_UTILS.register_module()
class SSDAnchorGeneratorClustered(AnchorGenerator):
"""Custom Anchor Generator for SSD."""

Expand Down
7 changes: 2 additions & 5 deletions src/otx/algo/detection/heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
from torch import Tensor, nn

from otx.algo.detection.heads.anchor_generator import AnchorGenerator
from otx.algo.detection.heads.atss_assigner import ATSSAssigner
from otx.algo.detection.heads.base_head import BaseDenseHead
from otx.algo.detection.heads.base_sampler import PseudoSampler
from otx.algo.detection.heads.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from otx.algo.detection.heads.max_iou_assigner import MaxIoUAssigner
from otx.algo.detection.utils.utils import anchor_inside_flags, images_to_levels, multi_apply, unmap
from otx.algo.utils.mmengine_utils import InstanceData

Expand Down Expand Up @@ -83,8 +80,8 @@ def __init__(
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if self.train_cfg:
self.assigner: MaxIoUAssigner | ATSSAssigner = self.train_cfg["assigner"]
self.sampler = PseudoSampler(context=self) # type: ignore[no-untyped-call]
self.assigner = self.train_cfg.get("assigner", None)
self.sampler = self.train_cfg.get("sampler", None)

self.fp16_enabled = False

Expand Down
188 changes: 187 additions & 1 deletion src/otx/algo/detection/heads/base_sampler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Base Sampler implementation from mmdet."""

from __future__ import annotations

from abc import ABCMeta, abstractmethod

import numpy as np
import torch

from otx.algo.detection.utils.structures import AssignResult, SamplingResult
from otx.algo.utils.mmengine_utils import InstanceData


def ensure_rng(rng: int | np.random.RandomState | None = None) -> np.random.RandomState:
"""Coerces input into a random number generator.
If the input is None, then a global random state is returned.
If the input is a numeric value, then that is used as a seed to construct a
random state. Otherwise the input is returned as-is.
Adapted from [1]_.
Args:
rng (int | numpy.random.RandomState | None):
if None, then defaults to the global rng. Otherwise this can be an
integer or a RandomState class
Returns:
(numpy.random.RandomState) : rng -
a numpy random number generator
References:
.. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
"""
if rng is None:
return np.random.mtrand._rand # noqa: SLF001
if isinstance(rng, int):
return np.random.RandomState(rng)
return rng


class BaseSampler(metaclass=ABCMeta):
"""Base class of samplers.
Expand Down Expand Up @@ -124,7 +155,7 @@ def sample(
class PseudoSampler(BaseSampler):
"""A pseudo sampler that does not do sampling actually."""

def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
pass

def _sample_pos(self, assign_result: AssignResult, num_expected: int, **kwargs) -> torch.Tensor:
Expand Down Expand Up @@ -174,3 +205,158 @@ def sample(
gt_flags=gt_flags,
avg_factor_with_neg=False,
)


class RandomSampler(BaseSampler):
"""Random sampler.
Args:
num (int): Number of samples
pos_fraction (float): Fraction of positive samples
neg_pos_up (int): Upper bound number of negative and
positive samples. Defaults to -1.
add_gt_as_proposals (bool): Whether to add ground truth
boxes as proposals. Defaults to True.
"""

def __init__(
self,
num: int,
pos_fraction: float,
neg_pos_ub: int = -1,
add_gt_as_proposals: bool = True,
**kwargs,
):
super().__init__(
num=num,
pos_fraction=pos_fraction,
neg_pos_ub=neg_pos_ub,
add_gt_as_proposals=add_gt_as_proposals,
)
self.rng = ensure_rng(kwargs.get("rng", None))

def random_choice(self, gallery: torch.Tensor | np.ndarray | list, num: int) -> torch.Tensor | np.ndarray:
"""Random select some elements from the gallery.
If `gallery` is a Tensor, the returned indices will be a Tensor;
If `gallery` is a ndarray or list, the returned indices will be a
ndarray.
Args:
gallery (Tensor | ndarray | list): indices pool.
num (int): expected sample num.
Returns:
Tensor or ndarray: sampled indices.
"""
if len(gallery) < num:
msg = f"Cannot sample {num} elements from a set of size {len(gallery)}"
raise ValueError(msg)

is_tensor = isinstance(gallery, torch.Tensor)
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
_gallery: torch.Tensor = torch.tensor(gallery, dtype=torch.long, device=device) if not is_tensor else gallery
perm = torch.randperm(_gallery.numel())[:num].to(device=_gallery.device)
rand_inds = _gallery[perm]
if not is_tensor:
rand_inds = rand_inds.cpu().numpy()
return rand_inds

def _sample_pos(self, assign_result: AssignResult, num_expected: int, **kwargs: dict) -> torch.Tensor | np.ndarray:
"""Randomly sample some positive samples.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
num_expected (int): The number of expected positive samples
Returns:
Tensor or ndarray: sampled indices.
"""
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
return self.random_choice(pos_inds, num_expected)

def _sample_neg(self, assign_result: AssignResult, num_expected: int, **kwargs: dict) -> torch.Tensor | np.ndarray:
"""Randomly sample some negative samples.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
num_expected (int): The number of expected positive samples
Returns:
Tensor or ndarray: sampled indices.
"""
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
return self.random_choice(neg_inds, num_expected)

def sample(
self,
assign_result: AssignResult,
pred_instances: InstanceData,
gt_instances: InstanceData,
**kwargs,
) -> SamplingResult:
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
Args:
assign_result (:obj:`AssignResult`): Assigning results.
pred_instances (:obj:`InstanceData`): Instances of model
predictions. It includes ``priors``, and the priors can
be anchors or points, or the bboxes predicted by the
previous stage, has shape (n, 4). The bboxes predicted by
the current model or stage will be named ``bboxes``,
``labels``, and ``scores``, the same as the ``InstanceData``
in other places.
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It usually includes ``bboxes``, with shape (k, 4),
and ``labels``, with shape (k, ).
Returns:
:obj:`SamplingResult`: Sampling result.
"""
gt_bboxes = gt_instances.bboxes # type: ignore[attr-defined]
priors = pred_instances.priors # type: ignore[attr-defined]
gt_labels = gt_instances.labels # type: ignore[attr-defined]
if len(priors.shape) < 2:
priors = priors[None, :]

gt_flags = priors.new_zeros((priors.shape[0],), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
priors = torch.cat([gt_bboxes, priors], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = priors.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
gt_flags = torch.cat([gt_ones, gt_flags])

num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=priors, **kwargs) # noqa: SLF001
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_expected_neg = self.num - num_sampled_pos
if self.neg_pos_ub >= 0:
_pos = max(1, num_sampled_pos)
neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=priors, **kwargs) # noqa: SLF001
neg_inds = neg_inds.unique()

return SamplingResult(
pos_inds=pos_inds,
neg_inds=neg_inds,
priors=priors,
gt_bboxes=gt_bboxes,
assign_result=assign_result,
gt_flags=gt_flags,
)
2 changes: 1 addition & 1 deletion src/otx/algo/detection/heads/class_incremental_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_valid_label_mask(
all_labels: list[Tensor],
use_bg: bool = False,
) -> list[Tensor]:
"""Calcualte valid label mask with ignored labels."""
"""Calculate valid label mask with ignored labels."""
num_classes = self.num_classes + 1 if use_bg else self.num_classes # type: ignore[attr-defined]
valid_label_mask = []
for i, meta in enumerate(img_metas):
Expand Down
58 changes: 0 additions & 58 deletions src/otx/algo/detection/heads/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@

import numpy as np
import torch
from mmengine.registry import TASK_UTILS
from torch import Tensor

from otx.algo.detection.deployment import is_mmdeploy_enabled


# This class and its supporting functions below lightly adapted from the mmdet DeltaXYWHBBoxCoder available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py
@TASK_UTILS.register_module()
class DeltaXYWHBBoxCoder:
"""Delta XYWH BBox coder.
Expand Down Expand Up @@ -415,57 +411,3 @@ def clip_bboxes(
x2 = torch.clamp(x2, 0, max_shape[1])
y2 = torch.clamp(y2, 0, max_shape[0])
return x1, y1, x2, y2


if is_mmdeploy_enabled():
from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
func_name="otx.algo.detection.heads.delta_xywh_bbox_coder.DeltaXYWHBBoxCoder.decode",
backend="default",
)
def deltaxywhbboxcoder__decode(
self: DeltaXYWHBBoxCoder,
bboxes: Tensor,
pred_bboxes: Tensor,
max_shape: Tensor | None = None,
wh_ratio_clip: float = 16 / 1000,
) -> Tensor:
"""Rewrite `decode` of `DeltaXYWHBBoxCoder` for default backend.
Rewrite this func to call `delta2bbox` directly.
Args:
bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
pred_bboxes (Tensor): Encoded offsets with respect to each roi.
Has shape (B, N, num_classes * 4) or (B, N, 4) or
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
when rois is a grid of anchors.Offset encoding follows [1]_.
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
wh_ratio_clip (float, optional): The allowed ratio between
width and height.
Returns:
torch.Tensor: Decoded boxes.
"""
if pred_bboxes.size(0) != bboxes.size(0):
msg = "The batch size of pred_bboxes and bboxes should be equal."
raise ValueError(msg)
if pred_bboxes.ndim == 3 and pred_bboxes.size(1) != bboxes.size(1):
msg = "The number of bboxes should be equal."
raise ValueError(msg)
return delta2bbox_export(
bboxes,
pred_bboxes,
self.means,
self.stds,
max_shape,
wh_ratio_clip,
self.clip_border,
self.add_ctr_clamp,
self.ctr_clamp,
)
Loading

0 comments on commit 151a94e

Please sign in to comment.