Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMDet RTMDet Inst decoupling #3433

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
fa0d462
migrate mmdet maskrcnn modules
eugene123tw Apr 5, 2024
763ce7d
style reformat
eugene123tw Apr 5, 2024
5ace6b0
style reformat
eugene123tw Apr 5, 2024
b2b6e0c
stype reformat
eugene123tw Apr 5, 2024
a305c7c
ignore mypy, ruff errors
eugene123tw Apr 8, 2024
6b682b7
skip mypy error
eugene123tw Apr 8, 2024
727cf2b
update
eugene123tw Apr 8, 2024
22f8f81
fix loss
eugene123tw Apr 8, 2024
08d766f
add maskrcnn
eugene123tw Apr 8, 2024
873a101
update import
eugene123tw Apr 8, 2024
bfca5f0
update import
eugene123tw Apr 8, 2024
47f3744
add necks
eugene123tw Apr 8, 2024
7bb7f39
update
eugene123tw Apr 8, 2024
d37f8a1
update
eugene123tw Apr 8, 2024
abd1f21
add cross-entropy loss
eugene123tw Apr 9, 2024
4978171
style changes
eugene123tw Apr 9, 2024
330f721
mypy changes and style changes
eugene123tw Apr 10, 2024
92dc0bf
update style
eugene123tw Apr 11, 2024
87a2415
Merge branch 'develop' into eugene/CVS-137823-mmdet-maskrcnn-decouple
eugene123tw Apr 11, 2024
c4dec94
remove box structures
eugene123tw Apr 11, 2024
5723ec2
add resnet
eugene123tw Apr 11, 2024
47de4f1
udpate
eugene123tw Apr 11, 2024
f5a51da
modify resnet
eugene123tw Apr 15, 2024
63f46c4
add annotation
eugene123tw Apr 15, 2024
46b37d5
style changes
eugene123tw Apr 15, 2024
cef1e24
update
eugene123tw Apr 15, 2024
655baea
fix all mypy issues
eugene123tw Apr 15, 2024
b1ed150
fix mypy issues
eugene123tw Apr 15, 2024
27b6a4a
style changes
eugene123tw Apr 16, 2024
e87cfa9
remove unused losses
eugene123tw Apr 16, 2024
c2c2394
remove focal_loss_pb
eugene123tw Apr 16, 2024
edd85e0
fix all rull and mypy issues
eugene123tw Apr 16, 2024
742b6fd
fix conflicts
eugene123tw Apr 16, 2024
194a6c2
style change
eugene123tw Apr 16, 2024
c1938de
update
eugene123tw Apr 16, 2024
734a459
udpate license
eugene123tw Apr 16, 2024
93fcd79
udpate
eugene123tw Apr 16, 2024
1d0926d
remove duplicates
eugene123tw Apr 17, 2024
1238d28
remove as F
eugene123tw Apr 17, 2024
55d77f5
remove as F
eugene123tw Apr 17, 2024
1598929
remove mmdet mask structures
eugene123tw Apr 17, 2024
498b750
remove duplicates
eugene123tw Apr 17, 2024
a0f52de
style changes
eugene123tw Apr 17, 2024
549d0ef
add new test
eugene123tw Apr 17, 2024
0a074ae
test style change
eugene123tw Apr 17, 2024
9106bc1
fix test
eugene123tw Apr 17, 2024
8197825
Merge branch 'develop' into eugene/CVS-137823-mmdet-maskrcnn-decouple
eugene123tw Apr 17, 2024
07cd25e
chagne device for unit test
eugene123tw Apr 17, 2024
ab07675
add deployment files
eugene123tw Apr 17, 2024
70717ce
remove deployment from inst-seg
eugene123tw Apr 17, 2024
e5027d9
update deployment
eugene123tw Apr 18, 2024
03e26d3
add mmdeploy maskrcnn opset
eugene123tw Apr 18, 2024
bbeaa32
fix linter
eugene123tw Apr 18, 2024
478158a
update test
eugene123tw Apr 18, 2024
052a582
update test
eugene123tw Apr 18, 2024
28cea6f
update test
eugene123tw Apr 18, 2024
ed7275e
Merge branch 'develop' into eugene/CVS-137823-mmdet-maskrcnn-decouple
eugene123tw Apr 22, 2024
04cd223
replace mmcv.cnn module
eugene123tw Apr 22, 2024
11ae260
remove upsample building
eugene123tw Apr 22, 2024
d8a78ca
remove upsample building
eugene123tw Apr 22, 2024
34ea7d7
use batch_nms from otx
eugene123tw Apr 22, 2024
617e1ac
add swintransformer
eugene123tw Apr 22, 2024
5a6737f
add transformers
eugene123tw Apr 22, 2024
fe34145
add swin transformer
eugene123tw Apr 23, 2024
4dd8bdc
style changes
eugene123tw Apr 23, 2024
a6de9b4
merge upstream
eugene123tw Apr 23, 2024
1b199a0
solve conflicts
eugene123tw Apr 23, 2024
c332430
update instance_segmentation/maskrcnn.py
eugene123tw Apr 23, 2024
6d92199
update nms
eugene123tw Apr 23, 2024
4d75001
fix xai
eugene123tw Apr 24, 2024
7b4b43e
change rotate detection recipe
eugene123tw Apr 24, 2024
0b2a326
fix swint recipe
eugene123tw Apr 24, 2024
ee70fe7
solve merge develop conflicts
eugene123tw Apr 25, 2024
16bd98f
remove some files
eugene123tw Apr 25, 2024
8d94815
decopule mmdeploy and replace with native exporter
eugene123tw Apr 26, 2024
9c6b3f9
solve conflicts
eugene123tw Apr 26, 2024
6734dcd
remove duplicates import
eugene123tw Apr 26, 2024
260a5d1
todo
eugene123tw Apr 26, 2024
e015d02
update
eugene123tw Apr 29, 2024
e186f8f
fix conflicts
eugene123tw Apr 29, 2024
74b963d
fix rpn_head training issue
eugene123tw Apr 29, 2024
e6eafa8
remove maskrcnn r50 mmconfigs
eugene123tw Apr 29, 2024
695084d
merge develop
eugene123tw Apr 30, 2024
f2c7bdc
fix anchor head and related fixes
eugene123tw Apr 30, 2024
88237ee
remove gather_topk
eugene123tw Apr 30, 2024
2070be4
remove maskrcnn efficientnet mmconfig
eugene123tw Apr 30, 2024
9cea2af
remove maskrcnn-swint mmconfig
eugene123tw Apr 30, 2024
50b0912
revert some changes
eugene123tw Apr 30, 2024
85127a1
update recipes
eugene123tw Apr 30, 2024
9d8028f
replace mmcv.ops.roi_align with torchvision.ops.roi_align
eugene123tw Apr 30, 2024
b4dd3d7
Merge branch 'develop' into eugene/CVS-139700-maskrcnn-naive-exporter
eugene123tw Apr 30, 2024
820f6b3
fix format issue
eugene123tw Apr 30, 2024
87368fc
update anchor head
eugene123tw Apr 30, 2024
cd8e1da
add CrossSigmoidFocalLoss back
eugene123tw May 1, 2024
e551698
remove mmdet decouple test
eugene123tw May 1, 2024
5538233
fix test
eugene123tw May 1, 2024
362e0af
skip xai test for inst-seg for now
eugene123tw May 1, 2024
6fdeffd
remove code comment
eugene123tw May 1, 2024
896b32f
add more rtmdet modules
eugene123tw May 1, 2024
20662cb
Merge branch 'develop' into eugene/CVS-139534-rtmdet-decoupling
eugene123tw May 2, 2024
7e264eb
reformat
eugene123tw May 2, 2024
1588c23
add native exporter
eugene123tw May 2, 2024
cda51b9
fix export issue
eugene123tw May 2, 2024
b3bd52e
fix format
eugene123tw May 2, 2024
70fd20d
update todo roi_align comment
eugene123tw May 2, 2024
b3b372c
add custom otx roi align
eugene123tw May 2, 2024
451b111
reformat OTXRoIAlign
eugene123tw May 2, 2024
be1e328
remove files
eugene123tw May 2, 2024
a09b5c9
remove config from MMDetInstanceSegCompatibleModel
eugene123tw May 3, 2024
aa35f00
add rtmdet inst test
eugene123tw May 3, 2024
af7b388
Merge branch 'develop' into eugene/CVS-139534-rtmdet-decoupling
eugene123tw May 3, 2024
9b1f838
add unit tests
eugene123tw May 3, 2024
2e983ea
rename unit tests
eugene123tw May 3, 2024
351fa2f
update rtmdet recipe
eugene123tw May 3, 2024
00f3531
remove RTMDetSepBNHead
eugene123tw May 3, 2024
f45be85
fix failures
eugene123tw May 7, 2024
1b68c71
Merge branch 'releases/2.0.0' into eugene/CVS-139534-rtmdet-decoupling
eugene123tw May 7, 2024
8d1da12
Merge branch 'releases/2.0.0' into eugene/CVS-139534-rtmdet-decoupling
eugene123tw May 7, 2024
f3b7677
skip xai test for rtmdet inst for now
eugene123tw May 7, 2024
cc5c1f7
update license docstring
eugene123tw May 8, 2024
a126f24
revert src/otx/core/model/instance_segmentation.py
eugene123tw May 8, 2024
62b866b
fix broken import
eugene123tw May 8, 2024
ba28ea0
Merge branch 'releases/2.0.0' into eugene/CVS-139534-rtmdet-decoupling
eugene123tw May 8, 2024
b2717d6
fix broken imports
eugene123tw May 8, 2024
b42e1d8
remove is_norm
eugene123tw May 8, 2024
ef0be7b
add back is_norm
eugene123tw May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions src/otx/algo/detection/backbones/cspnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.

"""CSPNeXt backbone used in RTMDet."""

from __future__ import annotations

import math
from typing import ClassVar

from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm

from otx.algo.detection.backbones.csp_darknet import SPPBottleneck
from otx.algo.detection.layers.csp_layer import CSPLayer
from otx.algo.modules.base_module import BaseModule
from otx.algo.modules.conv_module import ConvModule
from otx.algo.modules.depthwise_separable_conv_module import DepthwiseSeparableConvModule


class CSPNeXt(BaseModule):
"""CSPNeXt backbone used in RTMDet.

Args:
arch (str): Architecture of CSPNeXt, from {P5, P6}.
Defaults to P5.
expand_ratio (float): Ratio to adjust the number of channels of the
hidden layer. Defaults to 0.5.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
out_indices (Sequence[int]): Output from which stages.
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
use_depthwise (bool): Whether to use depthwise separable convolution.
Defaults to False.
arch_ovewrite (list): Overwrite default arch settings.
Defaults to None.
spp_kernel_sizes: (tuple[int]): Sequential of kernel sizes of SPP
layers. Defaults to (5, 9, 13).
channel_attention (bool): Whether to add channel attention in each
stage. Defaults to True.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
config norm layer. Defaults to dict(type='BN', requires_grad=True).
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Defaults to dict(type='SiLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
list[:obj:`ConfigDict`]): Initialization config dict.
"""

# From left to right:
# in_channels, out_channels, num_blocks, add_identity, use_spp
arch_settings: ClassVar = {
"P5": [
[64, 128, 3, True, False],
[128, 256, 6, True, False],
[256, 512, 6, True, False],
[512, 1024, 3, False, True],
],
"P6": [
[64, 128, 3, True, False],
[128, 256, 6, True, False],
[256, 512, 6, True, False],
[512, 768, 3, True, False],
[768, 1024, 3, False, True],
],
}

def __init__(
self,
arch: str = "P5",
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
out_indices: tuple[int, int, int] = (2, 3, 4),
frozen_stages: int = -1,
use_depthwise: bool = False,
expand_ratio: float = 0.5,
arch_ovewrite: dict | None = None,
spp_kernel_sizes: tuple[int, int, int] = (5, 9, 13),
channel_attention: bool = True,
conv_cfg: dict | None = None,
norm_cfg: dict | None = None,
act_cfg: dict | None = None,
norm_eval: bool = False,
init_cfg: dict | None = None,
) -> None:
if init_cfg is None:
init_cfg = {
"type": "Kaiming",
"layer": "Conv2d",
"a": math.sqrt(5),
"distribution": "uniform",
"mode": "fan_in",
"nonlinearity": "leaky_relu",
}

super().__init__(init_cfg=init_cfg)
arch_setting = self.arch_settings[arch]
if arch_ovewrite:
arch_setting = arch_ovewrite # type: ignore[assignment]

if not set(out_indices).issubset(i for i in range(len(arch_setting) + 1)):
msg = f"out_indices must be in range(0, len(arch_setting) + 1). But received {out_indices}"
raise ValueError(msg)

if frozen_stages not in range(-1, len(arch_setting) + 1):
msg = f"frozen_stages must be in (-1, len(arch_setting) + 1). But received {frozen_stages}"
raise ValueError(msg)

if norm_cfg is None:
norm_cfg = {"type": "BN", "momentum": 0.03, "eps": 0.001}

if act_cfg is None:
act_cfg = {"type": "SiLU"}

self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.use_depthwise = use_depthwise
self.norm_eval = norm_eval
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
self.stem = nn.Sequential(
ConvModule(
3,
int(arch_setting[0][0] * widen_factor // 2),
3,
padding=1,
stride=2,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
),
ConvModule(
int(arch_setting[0][0] * widen_factor // 2),
int(arch_setting[0][0] * widen_factor // 2),
3,
padding=1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
),
ConvModule(
int(arch_setting[0][0] * widen_factor // 2),
int(arch_setting[0][0] * widen_factor),
3,
padding=1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
),
)
self.layers = ["stem"]

for i, (in_channels, out_channels, num_blocks, add_identity, use_spp) in enumerate(arch_setting):
in_channels = int(in_channels * widen_factor) # noqa: PLW2901
out_channels = int(out_channels * widen_factor) # noqa: PLW2901
num_blocks = max(round(num_blocks * deepen_factor), 1) # noqa: PLW2901
stage = []
conv_layer = conv(
in_channels,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
stage.append(conv_layer)
if use_spp:
spp = SPPBottleneck(
out_channels,
out_channels,
kernel_sizes=spp_kernel_sizes,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
stage.append(spp)
csp_layer = CSPLayer(
out_channels,
out_channels,
num_blocks=num_blocks,
add_identity=add_identity,
use_depthwise=use_depthwise,
use_cspnext_block=True,
expand_ratio=expand_ratio,
channel_attention=channel_attention,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
stage.append(csp_layer)
self.add_module(f"stage{i + 1}", nn.Sequential(*stage))
self.layers.append(f"stage{i + 1}")

def _freeze_stages(self) -> None:
"""Freeze stages param and norm stats."""
if self.frozen_stages >= 0:
for i in range(self.frozen_stages + 1):
m = getattr(self, self.layers[i])
m.eval()
for param in m.parameters():
param.requires_grad = False

def train(self, mode: bool = True) -> None:
"""Set modules in training mode."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()

def forward(self, x: tuple[Tensor, ...]) -> tuple[Tensor, ...]:
"""Forward function."""
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
7 changes: 3 additions & 4 deletions src/otx/algo/detection/heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from otx.algo.detection.heads.anchor_generator import AnchorGenerator
from otx.algo.detection.heads.base_head import BaseDenseHead
from otx.algo.detection.heads.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from otx.algo.detection.utils.utils import anchor_inside_flags, images_to_levels, multi_apply, unmap
from otx.algo.utils.mmengine_utils import InstanceData

Expand Down Expand Up @@ -49,14 +48,14 @@ def __init__(
self,
num_classes: int,
in_channels: tuple[int, ...] | int,
anchor_generator: AnchorGenerator,
bbox_coder: DeltaXYWHBBoxCoder,
anchor_generator: nn.Module,
bbox_coder: nn.Module,
loss_cls: nn.Module,
loss_bbox: nn.Module,
train_cfg: dict,
test_cfg: DictConfig,
feat_channels: int = 256,
reg_decoded_bbox: bool = False,
test_cfg: DictConfig | None = None,
init_cfg: dict | list[dict] | None = None,
) -> None:
super().__init__(init_cfg=init_cfg)
Expand Down
86 changes: 86 additions & 0 deletions src/otx/algo/detection/heads/distance_point_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
""""Distance Point BBox coder."""

from __future__ import annotations

from typing import TYPE_CHECKING

from otx.algo.detection.utils.utils import bbox2distance, distance2bbox

if TYPE_CHECKING:
from torch import Tensor


class DistancePointBBoxCoder:
"""Distance Point BBox coder.

This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
right) and decode it back to the original.

Args:
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
"""

def __init__(
self,
clip_border: bool = True,
encode_size: int = 4,
use_box_type: bool = False,
) -> None:
self.clip_border = clip_border
self.encode_size = encode_size
self.use_box_type = use_box_type

def encode(
self,
points: Tensor,
gt_bboxes: Tensor,
max_dis: float | None = None,
eps: float = 0.1,
) -> Tensor:
"""Encode bounding box to distances.

Args:
points (Tensor): Shape (N, 2), The format is [x, y].
gt_bboxes (Tensor or :obj:`BaseBoxes`): Shape (N, 4), The format
is "xyxy"
max_dis (float): Upper bound of the distance. Default None.
eps (float): a small value to ensure target < max_dis, instead <=.
Default 0.1.

Returns:
Tensor: Box transformation deltas. The shape is (N, 4).
"""
if points.size(0) != gt_bboxes.size(0):
msg = "The number of points should be equal to the number of boxes."
raise ValueError(msg)
if points.size(-1) != 2:
msg = "The last dimension of points should be 2."
raise ValueError(msg)
if gt_bboxes.size(-1) != 4:
msg = "The last dimension of gt_bboxes should be 4."
raise ValueError(msg)
return bbox2distance(points, gt_bboxes, max_dis, eps)

def decode(
self,
points: Tensor,
pred_bboxes: Tensor,
max_shape: tuple[int, ...] | Tensor | tuple[tuple[int, ...], ...] | None = None,
) -> Tensor:
"""Decode distance prediction to bounding box."""
if points.size(0) != pred_bboxes.size(0):
msg = "The number of points should be equal to the number of boxes."
raise ValueError(msg)
if points.size(-1) != 2:
msg = "The last dimension of points should be 2."
raise ValueError(msg)
if pred_bboxes.size(-1) != 4:
msg = "The last dimension of pred_bboxes should be 4."
raise ValueError(msg)
if self.clip_border is False:
max_shape = None
return distance2bbox(points, pred_bboxes, max_shape)
Loading
Loading