Skip to content

Commit

Permalink
Refactoring detection task (#3860)
Browse files Browse the repository at this point in the history
* Rename directory from `base_models` to `detectors`

* Update pytorchcv version

* Create criterion modules

* Fix unit tests

* Update comment

* FIx iseg unit test

* Enable factory for ATSS

* Enable factory for YOLOX and update

* Enable factory for RTMDet

* Update recipes

* Update

* Enable factory for SSD

* Fix unit tests

* Enable factory for RTDETR

* Reduce default parameters

* Update huggingface, keypoint, and iseg

* Revert default `input_size` argument

* Fix unit test

* Fix integration test

* Add ABC

* Fix

* Change `model_version` to `model_name`

* precommit

* Remove `DetectionBackboneFactory`

* Use `Literal` for `model_name`
  • Loading branch information
sungchul2 authored Aug 27, 2024
1 parent 9833212 commit 0a395b2
Show file tree
Hide file tree
Showing 78 changed files with 2,079 additions and 1,302 deletions.
44 changes: 42 additions & 2 deletions src/otx/algo/common/backbones/cspnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import math
from functools import partial
from typing import Callable, ClassVar
from typing import Any, Callable, ClassVar

from otx.algo.common.layers import SPPBottleneck
from otx.algo.detection.layers import CSPLayer
Expand All @@ -22,7 +22,7 @@
from torch.nn.modules.batchnorm import _BatchNorm


class CSPNeXt(BaseModule):
class CSPNeXtModule(BaseModule):
"""CSPNeXt backbone used in RTMDet.
Args:
Expand Down Expand Up @@ -225,3 +225,43 @@ def forward(self, x: tuple[Tensor, ...]) -> tuple[Tensor, ...]:
if i in self.out_indices:
outs.append(x)
return tuple(outs)


class CSPNeXt:
"""CSPNeXt factory for detection."""

CSPNEXT_CFG: ClassVar[dict[str, Any]] = {
"rtmdet_tiny": {
"deepen_factor": 0.167,
"widen_factor": 0.375,
"normalization": nn.BatchNorm2d,
"activation": partial(nn.SiLU, inplace=True),
},
"rtmpose_tiny": {
"arch": "P5",
"expand_ratio": 0.5,
"deepen_factor": 0.167,
"widen_factor": 0.375,
"out_indices": (4,),
"channel_attention": True,
"normalization": nn.BatchNorm2d,
"activation": partial(nn.SiLU, inplace=True),
},
"rtmdet_inst_tiny": {
"arch": "P5",
"expand_ratio": 0.5,
"deepen_factor": 0.167,
"widen_factor": 0.375,
"channel_attention": True,
"normalization": nn.BatchNorm2d,
"activation": partial(nn.SiLU, inplace=True),
},
}

def __new__(cls, model_name: str) -> CSPNeXtModule:
"""Constructor for CSPNeXt."""
if model_name not in cls.CSPNEXT_CFG:
msg = f"model type '{model_name}' is not supported"
raise KeyError(msg)

return CSPNeXtModule(**cls.CSPNEXT_CFG[model_name])
3 changes: 2 additions & 1 deletion src/otx/algo/common/utils/coders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#
"""Custom coder implementations."""

from .base_bbox_coder import BaseBBoxCoder
from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from .distance_point_bbox_coder import DistancePointBBoxCoder

__all__ = ["DeltaXYWHBBoxCoder", "DistancePointBBoxCoder"]
__all__ = ["BaseBBoxCoder", "DeltaXYWHBBoxCoder", "DistancePointBBoxCoder"]
26 changes: 26 additions & 0 deletions src/otx/algo/common/utils/coders/base_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Base bounding box coder."""

from abc import ABCMeta, abstractmethod

from torch import Tensor


class BaseBBoxCoder(metaclass=ABCMeta):
"""Base class for bounding box coder."""

encode_size: int

@abstractmethod
def encode(self, *args, **kwargs) -> Tensor:
"""Encode bounding boxes."""

@abstractmethod
def decode(self, *args, **kwargs) -> Tensor:
"""Decode bounding boxes."""

@abstractmethod
def decode_export(self, *args, **kwargs) -> Tensor:
"""Decode bounding boxes for export."""
4 changes: 3 additions & 1 deletion src/otx/algo/common/utils/coders/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from otx.algo.detection.utils.utils import clip_bboxes
from torch import Tensor

from .base_bbox_coder import BaseBBoxCoder

class DeltaXYWHBBoxCoder:

class DeltaXYWHBBoxCoder(BaseBBoxCoder):
"""Delta XYWH BBox coder.
Following the practice in `R-CNN <https://arxiv.org/abs/1311.2524>`_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
from otx.algo.common.utils.utils import bbox2distance, distance2bbox
from otx.algo.detection.utils.utils import distance2bbox_export

from .base_bbox_coder import BaseBBoxCoder

if TYPE_CHECKING:
from torch import Tensor


class DistancePointBBoxCoder:
class DistancePointBBoxCoder(BaseBBoxCoder):
"""Distance Point BBox coder.
This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
Expand Down
3 changes: 2 additions & 1 deletion src/otx/algo/common/utils/prior_generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Anchor generators for detection task."""

from .anchor_generator import AnchorGenerator, SSDAnchorGeneratorClustered
from .base_prior_generator import BasePriorGenerator
from .point_generator import MlvlPointGenerator

__all__ = ["AnchorGenerator", "SSDAnchorGeneratorClustered", "MlvlPointGenerator"]
__all__ = ["AnchorGenerator", "SSDAnchorGeneratorClustered", "BasePriorGenerator", "MlvlPointGenerator"]
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from torch import Tensor
from torch.nn.modules.utils import _pair

from .base_prior_generator import BasePriorGenerator

class AnchorGenerator:

class AnchorGenerator(BasePriorGenerator):
"""Standard anchor generator for 2D anchor-based detectors.
# TODO (sungchul): change strides format from (w, h) to (h, w)
Expand Down Expand Up @@ -72,7 +74,7 @@ def __init__(
raise ValueError(msg)

# calculate base sizes of anchors
self.strides = [_pair(stride) for stride in strides]
self.strides: list[tuple[int, int]] = [_pair(stride) for stride in strides]
self.base_sizes = [min(stride) for stride in self.strides] if base_sizes is None else base_sizes

if scales is not None:
Expand Down
37 changes: 37 additions & 0 deletions src/otx/algo/common/utils/prior_generators/base_prior_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Base prior generator."""

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Callable

if TYPE_CHECKING:
from torch import Tensor


class BasePriorGenerator(metaclass=ABCMeta):
"""Base class for prior generator."""

strides: list[tuple[int, int]]
grid_anchors: Callable[..., list[Tensor]]

@property
@abstractmethod
def num_base_priors(self) -> list[int]:
"""Return the number of priors (anchors/points) at a point on the feature grid."""

@property
@abstractmethod
def num_levels(self) -> int:
"""int: number of feature levels that the generator will be applied."""

@abstractmethod
def grid_priors(self, *args, **kwargs) -> list[Tensor]:
"""Generate grid anchors/points of multiple feature levels."""

@abstractmethod
def valid_flags(self, *args, **kwargs) -> list[Tensor]:
"""Generate valid flags of anchors/points of multiple feature levels."""
6 changes: 4 additions & 2 deletions src/otx/algo/common/utils/prior_generators/point_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from torch import Tensor
from torch.nn.modules.utils import _pair

from .base_prior_generator import BasePriorGenerator

DeviceType = Union[str, torch.device]


class MlvlPointGenerator:
class MlvlPointGenerator(BasePriorGenerator):
"""Standard points generator for multi-level (Mlvl) feature maps in 2D points-based detectors.
# TODO (sungchul): change strides format from (w, h) to (h, w)
Expand All @@ -31,7 +33,7 @@ class MlvlPointGenerator:
"""

def __init__(self, strides: list[int] | list[tuple[int, int]], offset: float = 0.5) -> None:
self.strides = [_pair(stride) for stride in strides]
self.strides: list[tuple[int, int]] = [_pair(stride) for stride in strides]
self.offset = offset

@property
Expand Down
Loading

0 comments on commit 0a395b2

Please sign in to comment.