Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMDet MaskRCNN ResNet50/SwinTransformer Decouple #3281

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
fa0d462
migrate mmdet maskrcnn modules
eugene123tw Apr 5, 2024
763ce7d
style reformat
eugene123tw Apr 5, 2024
5ace6b0
style reformat
eugene123tw Apr 5, 2024
b2b6e0c
stype reformat
eugene123tw Apr 5, 2024
a305c7c
ignore mypy, ruff errors
eugene123tw Apr 8, 2024
6b682b7
skip mypy error
eugene123tw Apr 8, 2024
727cf2b
update
eugene123tw Apr 8, 2024
22f8f81
fix loss
eugene123tw Apr 8, 2024
08d766f
add maskrcnn
eugene123tw Apr 8, 2024
873a101
update import
eugene123tw Apr 8, 2024
bfca5f0
update import
eugene123tw Apr 8, 2024
47f3744
add necks
eugene123tw Apr 8, 2024
7bb7f39
update
eugene123tw Apr 8, 2024
d37f8a1
update
eugene123tw Apr 8, 2024
abd1f21
add cross-entropy loss
eugene123tw Apr 9, 2024
4978171
style changes
eugene123tw Apr 9, 2024
330f721
mypy changes and style changes
eugene123tw Apr 10, 2024
92dc0bf
update style
eugene123tw Apr 11, 2024
87a2415
Merge branch 'develop' into eugene/CVS-137823-mmdet-maskrcnn-decouple
eugene123tw Apr 11, 2024
c4dec94
remove box structures
eugene123tw Apr 11, 2024
5723ec2
add resnet
eugene123tw Apr 11, 2024
47de4f1
udpate
eugene123tw Apr 11, 2024
f5a51da
modify resnet
eugene123tw Apr 15, 2024
63f46c4
add annotation
eugene123tw Apr 15, 2024
46b37d5
style changes
eugene123tw Apr 15, 2024
cef1e24
update
eugene123tw Apr 15, 2024
655baea
fix all mypy issues
eugene123tw Apr 15, 2024
b1ed150
fix mypy issues
eugene123tw Apr 15, 2024
27b6a4a
style changes
eugene123tw Apr 16, 2024
e87cfa9
remove unused losses
eugene123tw Apr 16, 2024
c2c2394
remove focal_loss_pb
eugene123tw Apr 16, 2024
edd85e0
fix all rull and mypy issues
eugene123tw Apr 16, 2024
742b6fd
fix conflicts
eugene123tw Apr 16, 2024
194a6c2
style change
eugene123tw Apr 16, 2024
c1938de
update
eugene123tw Apr 16, 2024
734a459
udpate license
eugene123tw Apr 16, 2024
93fcd79
udpate
eugene123tw Apr 16, 2024
1d0926d
remove duplicates
eugene123tw Apr 17, 2024
1238d28
remove as F
eugene123tw Apr 17, 2024
55d77f5
remove as F
eugene123tw Apr 17, 2024
1598929
remove mmdet mask structures
eugene123tw Apr 17, 2024
498b750
remove duplicates
eugene123tw Apr 17, 2024
a0f52de
style changes
eugene123tw Apr 17, 2024
549d0ef
add new test
eugene123tw Apr 17, 2024
0a074ae
test style change
eugene123tw Apr 17, 2024
9106bc1
fix test
eugene123tw Apr 17, 2024
8197825
Merge branch 'develop' into eugene/CVS-137823-mmdet-maskrcnn-decouple
eugene123tw Apr 17, 2024
07cd25e
chagne device for unit test
eugene123tw Apr 17, 2024
ab07675
add deployment files
eugene123tw Apr 17, 2024
70717ce
remove deployment from inst-seg
eugene123tw Apr 17, 2024
e5027d9
update deployment
eugene123tw Apr 18, 2024
03e26d3
add mmdeploy maskrcnn opset
eugene123tw Apr 18, 2024
bbeaa32
fix linter
eugene123tw Apr 18, 2024
478158a
update test
eugene123tw Apr 18, 2024
052a582
update test
eugene123tw Apr 18, 2024
28cea6f
update test
eugene123tw Apr 18, 2024
ed7275e
Merge branch 'develop' into eugene/CVS-137823-mmdet-maskrcnn-decouple
eugene123tw Apr 22, 2024
04cd223
replace mmcv.cnn module
eugene123tw Apr 22, 2024
11ae260
remove upsample building
eugene123tw Apr 22, 2024
d8a78ca
remove upsample building
eugene123tw Apr 22, 2024
34ea7d7
use batch_nms from otx
eugene123tw Apr 22, 2024
617e1ac
add swintransformer
eugene123tw Apr 22, 2024
5a6737f
add transformers
eugene123tw Apr 22, 2024
fe34145
add swin transformer
eugene123tw Apr 23, 2024
4dd8bdc
style changes
eugene123tw Apr 23, 2024
a6de9b4
merge upstream
eugene123tw Apr 23, 2024
1b199a0
solve conflicts
eugene123tw Apr 23, 2024
c332430
update instance_segmentation/maskrcnn.py
eugene123tw Apr 23, 2024
6d92199
update nms
eugene123tw Apr 23, 2024
4d75001
fix xai
eugene123tw Apr 24, 2024
7b4b43e
change rotate detection recipe
eugene123tw Apr 24, 2024
0b2a326
fix swint recipe
eugene123tw Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/otx/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
action_classification,
classification,
detection,
instance_segmentation,
plugins,
segmentation,
strategies,
Expand All @@ -23,4 +24,5 @@
"strategies",
"accelerators",
"plugins",
"instance_segmentation",
]
19 changes: 19 additions & 0 deletions src/otx/algo/detection/deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Functions for mmdeploy adapters."""
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import importlib


def is_mmdeploy_enabled() -> bool:
"""Checks if the 'mmdeploy' Python module is installed and available for use.

Returns:
bool: True if 'mmdeploy' is installed, False otherwise.

Example:
>>> is_mmdeploy_enabled()
True
"""
return importlib.util.find_spec("mmdeploy") is not None
2 changes: 1 addition & 1 deletion src/otx/algo/detection/heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class AnchorHead(BaseDenseHead):
def __init__(
self,
num_classes: int,
in_channels: tuple[int, ...],
in_channels: tuple[int, ...] | int,
anchor_generator: dict,
bbox_coder: dict,
loss_cls: dict,
Expand Down
9 changes: 4 additions & 5 deletions src/otx/algo/detection/heads/base_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
from typing import TYPE_CHECKING

import torch
from mmcv.ops import batched_nms
from mmengine.model import constant_init
from mmengine.model import BaseModule, constant_init
from mmengine.structures import InstanceData
from torch import Tensor, nn
from torch import Tensor

from otx.algo.detection.ops.nms import multiclass_nms
from otx.algo.detection.ops.nms import batched_nms, multiclass_nms
from otx.algo.detection.utils.utils import filter_scores_and_topk, select_single_mlvl, unpack_gt_instances

if TYPE_CHECKING:
Expand All @@ -24,7 +23,7 @@

# This class and its supporting functions below lightly adapted from the mmdet BaseDenseHead available at:
# https://github.com/open-mmlab/mmdetection/blob/fe3f809a0a514189baf889aa358c498d51ee36cd/mmdet/models/dense_heads/base_dense_head.py
class BaseDenseHead(nn.Module):
class BaseDenseHead(BaseModule):
"""Base class for DenseHeads.

1. The ``init_weights`` method is used to initialize densehead's
Expand Down
9 changes: 5 additions & 4 deletions src/otx/algo/detection/heads/class_incremental_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from typing import TYPE_CHECKING

import torch
from mmdet.models.utils.misc import images_to_levels, multi_apply
from mmdet.registry import MODELS
from torch import Tensor

from otx.algo.detection.utils.utils import images_to_levels, multi_apply

if TYPE_CHECKING:
from mmdet.utils import InstanceList, OptInstanceList
from mmengine.structures import InstanceData


@MODELS.register_module()
Expand All @@ -24,9 +25,9 @@ def get_atss_targets(
self,
anchor_list: list,
valid_flag_list: list[list[Tensor]],
batch_gt_instances: InstanceList,
batch_gt_instances: list[InstanceData],
batch_img_metas: list[dict],
batch_gt_instances_ignore: OptInstanceList = None,
batch_gt_instances_ignore: list[InstanceData] | None = None,
unmap_outputs: bool = True,
) -> tuple:
"""Get targets for ATSS head.
Expand Down
3 changes: 2 additions & 1 deletion src/otx/algo/detection/heads/custom_anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

import numpy as np
import torch
from mmdet.registry import TASK_UTILS
from mmengine.registry import TASK_UTILS
from torch.nn.modules.utils import _pair


# This class and its supporting functions below lightly adapted from the mmdet AnchorGenerator available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/prior_generators/anchor_generator.py
@TASK_UTILS.register_module()
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved
class AnchorGenerator:
"""Standard anchor generator for 2D anchor-based detectors.

Expand Down
5 changes: 4 additions & 1 deletion src/otx/algo/detection/heads/custom_ssd_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict],
train_cfg: ConfigDict | dict,
num_classes: int = 80,
in_channels: tuple[int, ...] = (512, 1024, 512, 256, 256, 256),
in_channels: tuple[int, ...] | int = (512, 1024, 512, 256, 256, 256),
stacked_convs: int = 0,
feat_channels: int = 256,
use_depthwise: bool = False,
Expand Down Expand Up @@ -274,6 +274,9 @@ def _init_layers(self) -> None:
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()

if isinstance(self.in_channels, int):
self.in_channels = (self.in_channels,)

for in_channel, num_base_priors in zip(self.in_channels, self.num_base_priors):
if self.use_depthwise:
activation_layer = nn.ReLU(inplace=True)
Expand Down
73 changes: 58 additions & 15 deletions src/otx/algo/detection/heads/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

import numpy as np
import torch
from mmengine.registry import TASK_UTILS
from torch import Tensor

from otx.algo.detection.deployment import is_mmdeploy_enabled


# This class and its supporting functions below lightly adapted from the mmdet DeltaXYWHBBoxCoder available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py
@TASK_UTILS.register_module()
class DeltaXYWHBBoxCoder:
"""Delta XYWH BBox coder.

Expand Down Expand Up @@ -236,21 +240,6 @@ def delta2bbox(

References:
.. [1] https://arxiv.org/abs/1311.2524

Example:
>>> rois = torch.Tensor([[ 0., 0., 1., 1.],
>>> [ 0., 0., 1., 1.],
>>> [ 0., 0., 1., 1.],
>>> [ 5., 5., 5., 5.]])
>>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
>>> [ 1., 1., 1., 1.],
>>> [ 0., 0., 2., -1.],
>>> [ 0.7, -1.9, -0.5, 0.3]])
>>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
tensor([[0.0000, 0.0000, 1.0000, 1.0000],
[0.1409, 0.1409, 2.8591, 2.8591],
[0.0000, 0.3161, 4.1945, 0.6839],
[5.0000, 5.0000, 5.0000, 5.0000]])
"""
num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4
if num_bboxes == 0:
Expand Down Expand Up @@ -426,3 +415,57 @@ def clip_bboxes(
x2 = torch.clamp(x2, 0, max_shape[1])
y2 = torch.clamp(y2, 0, max_shape[0])
return x1, y1, x2, y2


if is_mmdeploy_enabled():
from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
func_name="otx.algo.detection.heads.delta_xywh_bbox_coder.DeltaXYWHBBoxCoder.decode",
backend="default",
)
def deltaxywhbboxcoder__decode(
self: DeltaXYWHBBoxCoder,
bboxes: Tensor,
pred_bboxes: Tensor,
max_shape: Tensor | None = None,
wh_ratio_clip: float = 16 / 1000,
) -> Tensor:
"""Rewrite `decode` of `DeltaXYWHBBoxCoder` for default backend.

Rewrite this func to call `delta2bbox` directly.

Args:
bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
pred_bboxes (Tensor): Encoded offsets with respect to each roi.
Has shape (B, N, num_classes * 4) or (B, N, 4) or
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
when rois is a grid of anchors.Offset encoding follows [1]_.
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
wh_ratio_clip (float, optional): The allowed ratio between
width and height.

Returns:
torch.Tensor: Decoded boxes.
"""
if pred_bboxes.size(0) != bboxes.size(0):
msg = "The batch size of pred_bboxes and bboxes should be equal."
raise ValueError(msg)
if pred_bboxes.ndim == 3 and pred_bboxes.size(1) != bboxes.size(1):
msg = "The number of bboxes should be equal."
raise ValueError(msg)
return delta2bbox_export(
bboxes,
pred_bboxes,
self.means,
self.stds,
max_shape,
wh_ratio_clip,
self.clip_border,
self.add_ctr_clamp,
self.ctr_clamp,
)
2 changes: 2 additions & 0 deletions src/otx/algo/detection/heads/iou2d_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from __future__ import annotations

import torch
from mmengine.registry import TASK_UTILS

from otx.algo.detection.utils.bbox_overlaps import bbox_overlaps


# This class and its supporting functions below lightly adapted from the mmdet BboxOverlaps2D available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/assigners/iou2d_calculator.py
@TASK_UTILS.register_module()
class BboxOverlaps2D:
"""2D Overlaps (e.g. IoUs, GIoUs) Calculator."""

Expand Down
2 changes: 2 additions & 0 deletions src/otx/algo/detection/heads/max_iou_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Callable

import torch
from mmengine.registry import TASK_UTILS
from torch import Tensor

from otx.algo.detection.heads.iou2d_calculator import BboxOverlaps2D
Expand All @@ -19,6 +20,7 @@

# This class and its supporting functions below lightly adapted from the mmdet MaxIoUAssigner available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/assigners/max_iou_assigner.py
@TASK_UTILS.register_module()
class MaxIoUAssigner:
"""Assign a corresponding gt bbox or background to each bbox.

Expand Down
12 changes: 10 additions & 2 deletions src/otx/algo/detection/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
# SPDX-License-Identifier: Apache-2.0
#
"""Custom OTX Losses for Object Detection."""
from .accuracy import accuracy
from .cross_entropy_loss import CrossEntropyLoss
from .cross_focal_loss import CrossSigmoidFocalLoss
from .smooth_l1_loss import L1Loss


__all__ = ["CrossSigmoidFocalLoss, OrdinaryFocalLoss"]
__all__ = [
"CrossEntropyLoss",
"CrossSigmoidFocalLoss",
"accuracy",
"L1Loss",
]
73 changes: 73 additions & 0 deletions src/otx/algo/detection/losses/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# This class and its supporting functions are adapted from the mmdet.
# Please refer to https://github.com/open-mmlab/mmdetection/

"""MMDet Accuracy."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from torch import Tensor


def accuracy(
pred: Tensor,
target: Tensor,
topk: int | tuple[int] = 1,
thresh: float | None = None,
) -> list[Tensor] | Tensor:
"""Calculate accuracy according to the prediction and target.

Args:
pred (torch.Tensor): The model prediction, shape (N, num_class)
target (torch.Tensor): The target of each prediction, shape (N, )
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thresh (float, optional): If not None, predictions with scores under
this threshold are considered incorrect. Default to None.

Returns:
float | tuple[float]: If the input ``topk`` is a single integer,
the function will return a single float as accuracy. If
``topk`` is a tuple containing multiple integers, the
function will return a tuple containing accuracies of
each ``topk`` number.
"""
if not isinstance(topk, (int, tuple)):
msg = f"topk must be int or tuple of int, got {type(topk)}"
raise TypeError(msg)
if isinstance(topk, int):
topk = (topk,)
return_single = True
else:
return_single = False

maxk = max(topk)
if pred.size(0) == 0:
accu = [pred.new_tensor(0.0) for i in range(len(topk))]
return accu[0] if return_single else accu
if pred.ndim != 2 or target.ndim != 1:
msg = "Input tensors must have 2 dims for pred and 1 dim for target"
raise ValueError(msg)
if pred.size(0) != target.size(0):
msg = "Input tensors must have the same size along the 0th dim"
raise ValueError(msg)
if maxk > pred.size(1):
msg = f"maxk {maxk} exceeds pred dimension {pred.size(1)}"
raise ValueError(msg)
pred_value, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t() # transpose to shape (maxk, N)
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
if thresh is not None:
# Only prediction values larger than thresh are counted as correct
correct = correct & (pred_value > thresh).t()
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res
2 changes: 2 additions & 0 deletions src/otx/algo/detection/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import torch
from mmengine.registry import MODELS
from torch import nn

from otx.algo.detection.losses.weighted_loss import weight_reduce_loss
Expand Down Expand Up @@ -181,6 +182,7 @@ def mask_cross_entropy(
)[None]


@MODELS.register_module()
class CrossEntropyLoss(nn.Module):
"""Base Cross Entropy Loss implementation from mmdet."""

Expand Down
Loading
Loading