Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring detection task #3860

Merged
merged 28 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4d421d1
Rename directory from `base_models` to `detectors`
sungchul2 Aug 16, 2024
e1d156d
Update pytorchcv version
sungchul2 Aug 16, 2024
7a3b01d
Create criterion modules
sungchul2 Aug 16, 2024
ef0c1c2
Fix unit tests
sungchul2 Aug 21, 2024
b0abc18
Update comment
sungchul2 Aug 21, 2024
32b5a2a
Merge branch 'develop' into refactoring-detection
sungchul2 Aug 21, 2024
a13ad86
FIx iseg unit test
sungchul2 Aug 22, 2024
8098edd
Merge branch 'develop' into refactoring-detection
sungchul2 Aug 22, 2024
9cf1e24
Enable factory for ATSS
sungchul2 Aug 22, 2024
b4ee9f1
Enable factory for YOLOX and update
sungchul2 Aug 22, 2024
2d595df
Enable factory for RTMDet
sungchul2 Aug 22, 2024
93f3f5c
Update recipes
sungchul2 Aug 22, 2024
d56b1ce
Update
sungchul2 Aug 22, 2024
055fe84
Enable factory for SSD
sungchul2 Aug 22, 2024
3d7187b
Fix unit tests
sungchul2 Aug 22, 2024
bbb8116
Enable factory for RTDETR
sungchul2 Aug 23, 2024
6428882
Reduce default parameters
sungchul2 Aug 23, 2024
f469b90
Update huggingface, keypoint, and iseg
sungchul2 Aug 23, 2024
1846d0a
Revert default `input_size` argument
sungchul2 Aug 23, 2024
0a2471e
Fix unit test
sungchul2 Aug 23, 2024
cf6485a
Fix integration test
sungchul2 Aug 23, 2024
c3dfef0
Add ABC
sungchul2 Aug 23, 2024
4277eba
Fix
sungchul2 Aug 23, 2024
49f1383
Merge branch 'develop' into refactoring-detection
sungchul2 Aug 23, 2024
6287f82
Change `model_version` to `model_name`
sungchul2 Aug 23, 2024
68abb0a
precommit
sungchul2 Aug 23, 2024
f38b613
Remove `DetectionBackboneFactory`
sungchul2 Aug 26, 2024
c3f5668
Use `Literal` for `model_name`
sungchul2 Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions src/otx/algo/common/backbones/cspnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import math
from functools import partial
from typing import Callable, ClassVar
from typing import Any, Callable, ClassVar

from otx.algo.common.layers import SPPBottleneck
from otx.algo.detection.layers import CSPLayer
Expand All @@ -22,7 +22,7 @@
from torch.nn.modules.batchnorm import _BatchNorm


class CSPNeXt(BaseModule):
class CSPNeXtModule(BaseModule):
"""CSPNeXt backbone used in RTMDet.

Args:
Expand Down Expand Up @@ -225,3 +225,43 @@ def forward(self, x: tuple[Tensor, ...]) -> tuple[Tensor, ...]:
if i in self.out_indices:
outs.append(x)
return tuple(outs)


class CSPNeXt:
"""CSPNeXt factory for detection."""

CSPNEXT_CFG: ClassVar[dict[str, Any]] = {
"rtmdet_tiny": {
"deepen_factor": 0.167,
"widen_factor": 0.375,
"normalization": nn.BatchNorm2d,
"activation": partial(nn.SiLU, inplace=True),
},
"rtmpose_tiny": {
"arch": "P5",
"expand_ratio": 0.5,
"deepen_factor": 0.167,
"widen_factor": 0.375,
"out_indices": (4,),
"channel_attention": True,
"normalization": nn.BatchNorm2d,
"activation": partial(nn.SiLU, inplace=True),
},
"rtmdet_inst_tiny": {
"arch": "P5",
"expand_ratio": 0.5,
"deepen_factor": 0.167,
"widen_factor": 0.375,
"channel_attention": True,
"normalization": nn.BatchNorm2d,
"activation": partial(nn.SiLU, inplace=True),
},
}

def __new__(cls, model_name: str) -> CSPNeXtModule:
"""Constructor for CSPNeXt."""
if model_name not in cls.CSPNEXT_CFG:
msg = f"model type '{model_name}' is not supported"
raise KeyError(msg)

return CSPNeXtModule(**cls.CSPNEXT_CFG[model_name])
3 changes: 2 additions & 1 deletion src/otx/algo/common/utils/coders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#
"""Custom coder implementations."""

from .base_bbox_coder import BaseBBoxCoder
from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from .distance_point_bbox_coder import DistancePointBBoxCoder

__all__ = ["DeltaXYWHBBoxCoder", "DistancePointBBoxCoder"]
__all__ = ["BaseBBoxCoder", "DeltaXYWHBBoxCoder", "DistancePointBBoxCoder"]
26 changes: 26 additions & 0 deletions src/otx/algo/common/utils/coders/base_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Base bounding box coder."""

from abc import ABCMeta, abstractmethod

from torch import Tensor


class BaseBBoxCoder(metaclass=ABCMeta):
"""Base class for bounding box coder."""

encode_size: int

@abstractmethod
def encode(self, *args, **kwargs) -> Tensor:
"""Encode bounding boxes."""

@abstractmethod
def decode(self, *args, **kwargs) -> Tensor:
"""Decode bounding boxes."""

@abstractmethod
def decode_export(self, *args, **kwargs) -> Tensor:
"""Decode bounding boxes for export."""
4 changes: 3 additions & 1 deletion src/otx/algo/common/utils/coders/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from otx.algo.detection.utils.utils import clip_bboxes
from torch import Tensor

from .base_bbox_coder import BaseBBoxCoder

class DeltaXYWHBBoxCoder:

class DeltaXYWHBBoxCoder(BaseBBoxCoder):
"""Delta XYWH BBox coder.

Following the practice in `R-CNN <https://arxiv.org/abs/1311.2524>`_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
from otx.algo.common.utils.utils import bbox2distance, distance2bbox
from otx.algo.detection.utils.utils import distance2bbox_export

from .base_bbox_coder import BaseBBoxCoder

if TYPE_CHECKING:
from torch import Tensor


class DistancePointBBoxCoder:
class DistancePointBBoxCoder(BaseBBoxCoder):
"""Distance Point BBox coder.

This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
Expand Down
3 changes: 2 additions & 1 deletion src/otx/algo/common/utils/prior_generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Anchor generators for detection task."""

from .anchor_generator import AnchorGenerator, SSDAnchorGeneratorClustered
from .base_prior_generator import BasePriorGenerator
from .point_generator import MlvlPointGenerator

__all__ = ["AnchorGenerator", "SSDAnchorGeneratorClustered", "MlvlPointGenerator"]
__all__ = ["AnchorGenerator", "SSDAnchorGeneratorClustered", "BasePriorGenerator", "MlvlPointGenerator"]
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from torch import Tensor
from torch.nn.modules.utils import _pair

from .base_prior_generator import BasePriorGenerator

class AnchorGenerator:

class AnchorGenerator(BasePriorGenerator):
"""Standard anchor generator for 2D anchor-based detectors.

# TODO (sungchul): change strides format from (w, h) to (h, w)
Expand Down Expand Up @@ -72,7 +74,7 @@ def __init__(
raise ValueError(msg)

# calculate base sizes of anchors
self.strides = [_pair(stride) for stride in strides]
self.strides: list[tuple[int, int]] = [_pair(stride) for stride in strides]
self.base_sizes = [min(stride) for stride in self.strides] if base_sizes is None else base_sizes

if scales is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Base prior generator."""

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Callable

if TYPE_CHECKING:
from torch import Tensor


class BasePriorGenerator(metaclass=ABCMeta):
"""Base class for prior generator."""

strides: list[tuple[int, int]]
grid_anchors: Callable[..., list[Tensor]]

@property
@abstractmethod
def num_base_priors(self) -> list[int]:
"""Return the number of priors (anchors/points) at a point on the feature grid."""

@property
@abstractmethod
def num_levels(self) -> int:
"""int: number of feature levels that the generator will be applied."""

@abstractmethod
def grid_priors(self, *args, **kwargs) -> list[Tensor]:
"""Generate grid anchors/points of multiple feature levels."""

@abstractmethod
def valid_flags(self, *args, **kwargs) -> list[Tensor]:
"""Generate valid flags of anchors/points of multiple feature levels."""
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from torch import Tensor
from torch.nn.modules.utils import _pair

from .base_prior_generator import BasePriorGenerator

DeviceType = Union[str, torch.device]


class MlvlPointGenerator:
class MlvlPointGenerator(BasePriorGenerator):
"""Standard points generator for multi-level (Mlvl) feature maps in 2D points-based detectors.

# TODO (sungchul): change strides format from (w, h) to (h, w)
Expand All @@ -31,7 +33,7 @@ class MlvlPointGenerator:
"""

def __init__(self, strides: list[int] | list[tuple[int, int]], offset: float = 0.5) -> None:
self.strides = [_pair(stride) for stride in strides]
self.strides: list[tuple[int, int]] = [_pair(stride) for stride in strides]
self.offset = offset

@property
Expand Down
Loading
Loading