Skip to content

Commit

Permalink
Decoupling mmdet structures Part 1. (#3301)
Browse files Browse the repository at this point in the history
* Remove BaseBoxes

* Migrate losses

* Decouple ssd_head

* Decoupling anchor head

* Fix unit tests
  • Loading branch information
jaegukhyun authored Apr 12, 2024
1 parent 3266f24 commit babae3b
Show file tree
Hide file tree
Showing 14 changed files with 944 additions and 111 deletions.
32 changes: 14 additions & 18 deletions src/otx/algo/detection/heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,17 @@
from typing import TYPE_CHECKING

import torch
from mmdet.models.task_modules.prior_generators import anchor_inside_flags
from mmdet.models.utils import images_to_levels, multi_apply, unmap
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures.bbox import BaseBoxes, cat_boxes, get_box_tensor
from mmengine.structures import InstanceData
from torch import Tensor, nn

from otx.algo.detection.heads.base_head import BaseDenseHead
from otx.algo.detection.heads.base_sampler import PseudoSampler
from otx.algo.detection.heads.custom_anchor_generator import AnchorGenerator
from otx.algo.detection.utils.utils import anchor_inside_flags, images_to_levels, multi_apply, unmap

if TYPE_CHECKING:
from mmdet.utils import InstanceList, OptConfigType, OptInstanceList, OptMultiConfig
from mmengine import ConfigDict


# This class and its supporting functions below lightly adapted from the mmdet AnchorHead available at:
Expand Down Expand Up @@ -56,11 +54,11 @@ def __init__(
bbox_coder: dict,
loss_cls: dict,
loss_bbox: dict,
train_cfg: ConfigDict | dict,
feat_channels: int = 256,
reg_decoded_bbox: bool = False,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None,
test_cfg: ConfigDict | dict | None = None,
init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict] | None = None,
) -> None:
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
Expand Down Expand Up @@ -142,7 +140,7 @@ def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]:
bbox_pred = self.conv_reg(x)
return cls_score, bbox_pred

def forward(self, x: tuple[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
def forward(self, x: tuple[Tensor]) -> tuple:
"""Forward features from the upstream network.
Args:
Expand Down Expand Up @@ -199,7 +197,7 @@ def get_anchors(

def _get_targets_single(
self,
flat_anchors: Tensor | BaseBoxes,
flat_anchors: Tensor,
valid_flags: Tensor,
gt_instances: InstanceData,
img_meta: dict,
Expand All @@ -209,7 +207,7 @@ def _get_targets_single(
"""Compute regression and classification targets for anchors in a single image.
Args:
flat_anchors (Tensor or :obj:`BaseBoxes`): Multi-level anchors
flat_anchors (Tensor): Multi-level anchors
of the image, which are concatenated into a single tensor
or box type of shape (num_anchors, 4)
valid_flags (Tensor): Multi level valid flags of the image,
Expand Down Expand Up @@ -277,7 +275,6 @@ def _get_targets_single(
pos_bbox_targets = self.bbox_coder.encode(sampling_result.pos_priors, sampling_result.pos_gt_bboxes)
else:
pos_bbox_targets = sampling_result.pos_gt_bboxes
pos_bbox_targets = get_box_tensor(pos_bbox_targets)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0

Expand All @@ -303,9 +300,9 @@ def get_targets(
self,
anchor_list: list[list[Tensor]],
valid_flag_list: list[list[Tensor]],
batch_gt_instances: InstanceList,
batch_gt_instances: list[InstanceData],
batch_img_metas: list[dict],
batch_gt_instances_ignore: OptInstanceList = None,
batch_gt_instances_ignore: list[InstanceData] | None = None,
unmap_outputs: bool = True,
) -> tuple:
"""Compute regression and classification targets for anchors in multiple images.
Expand Down Expand Up @@ -364,7 +361,7 @@ def get_targets(
concat_anchor_list = []
concat_valid_flag_list = []
for i in range(num_imgs):
concat_anchor_list.append(cat_boxes(anchor_list[i]))
concat_anchor_list.append(torch.cat(anchor_list[i]))
concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))

# compute targets for each image
Expand Down Expand Up @@ -455,17 +452,16 @@ def loss_by_feat_single(
# decodes the already encoded coordinates to absolute format.
anchors = anchors.reshape(-1, anchors.size(-1))
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
bbox_pred = get_box_tensor(bbox_pred)
loss_bbox = self.loss_bbox(bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor)
return loss_cls, loss_bbox

def loss_by_feat(
self,
cls_scores: list[Tensor],
bbox_preds: list[Tensor],
batch_gt_instances: InstanceList,
batch_gt_instances: list[InstanceData],
batch_img_metas: list[dict],
batch_gt_instances_ignore: OptInstanceList = None,
batch_gt_instances_ignore: list[InstanceData] | None = None,
) -> dict:
"""Calculate the loss based on the features extracted by the detection head.
Expand Down Expand Up @@ -504,7 +500,7 @@ def loss_by_feat(
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
concat_anchor_list = [cat_boxes(anchor) for anchor in anchor_list]
concat_anchor_list = [torch.cat(anchor) for anchor in anchor_list]
all_anchor_list = images_to_levels(concat_anchor_list, num_level_anchors)

losses_cls, losses_bbox = multi_apply(
Expand Down
34 changes: 4 additions & 30 deletions src/otx/algo/detection/heads/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from abc import ABCMeta, abstractmethod

import torch
from mmdet.models.task_modules.assigners import AssignResult
from mmdet.models.task_modules.samplers.sampling_result import SamplingResult
from mmdet.structures.bbox import BaseBoxes, cat_boxes
from mmengine.structures import InstanceData

from otx.algo.detection.utils.structures import AssignResult, SamplingResult


class BaseSampler(metaclass=ABCMeta):
"""Base class of samplers.
Expand Down Expand Up @@ -72,26 +71,6 @@ def sample(
Returns:
:obj:`SamplingResult`: Sampling result.
Example:
>>> from mmengine.structures import InstanceData
>>> from mmdet.models.task_modules.samplers import RandomSampler,
>>> from mmdet.models.task_modules.assigners import AssignResult
>>> from mmdet.models.task_modules.samplers.
... sampling_result import ensure_rng, random_boxes
>>> rng = ensure_rng(None)
>>> assign_result = AssignResult.random(rng=rng)
>>> pred_instances = InstanceData()
>>> pred_instances.priors = random_boxes(assign_result.num_preds,
... rng=rng)
>>> gt_instances = InstanceData()
>>> gt_instances.bboxes = random_boxes(assign_result.num_gts,
... rng=rng)
>>> gt_instances.labels = torch.randint(
... 0, 5, (assign_result.num_gts,), dtype=torch.long)
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
>>> add_gt_as_proposals=False)
>>> self = self.sample(assign_result, pred_instances, gt_instances)
"""
gt_bboxes = gt_instances.bboxes
priors = pred_instances.priors
Expand All @@ -101,13 +80,8 @@ def sample(

gt_flags = priors.new_zeros((priors.shape[0],), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
# When `gt_bboxes` and `priors` are all box type, convert
# `gt_bboxes` type to `priors` type.
if isinstance(gt_bboxes, BaseBoxes) and isinstance(priors, BaseBoxes):
gt_bboxes_ = gt_bboxes.convert_to(type(priors))
else:
gt_bboxes_ = gt_bboxes
priors = cat_boxes([gt_bboxes_, priors], dim=0)
gt_bboxes_ = gt_bboxes
priors = torch.cat([gt_bboxes_, priors], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = priors.new_ones(gt_bboxes_.shape[0], dtype=torch.uint8)
gt_flags = torch.cat([gt_ones, gt_flags])
Expand Down
51 changes: 15 additions & 36 deletions src/otx/algo/detection/heads/custom_ssd_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,19 @@
from typing import TYPE_CHECKING

import torch
from mmdet.models.losses import smooth_l1_loss
from mmdet.models.utils import multi_apply
from mmdet.registry import MODELS
from torch import Tensor, nn

from otx.algo.detection.heads.anchor_head import AnchorHead
from otx.algo.detection.heads.base_sampler import PseudoSampler
from otx.algo.detection.heads.custom_anchor_generator import SSDAnchorGeneratorClustered
from otx.algo.detection.heads.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from otx.algo.detection.heads.max_iou_assigner import MaxIoUAssigner
from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss
from otx.algo.detection.losses.weighted_loss import smooth_l1_loss
from otx.algo.detection.utils.utils import multi_apply

if TYPE_CHECKING:
from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList
from mmengine.config import Config
from mmengine.config import ConfigDict, InstanceData


# This class and its supporting functions below lightly adapted from the mmdet SSDHead available at:
Expand All @@ -39,12 +38,6 @@ class SSDHead(AnchorHead):
> 0. Defaults to 256.
use_depthwise (bool): Whether to use DepthwiseSeparableConv.
Defaults to False.
conv_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
and config conv layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
and config norm layer. Defaults to None.
act_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
and config activation layer. Defaults to None.
anchor_generator (:obj:`ConfigDict` or dict): Config dict for anchor
generator.
bbox_coder (:obj:`ConfigDict` or dict): Config of bounding box coder.
Expand All @@ -63,31 +56,26 @@ class SSDHead(AnchorHead):

def __init__(
self,
anchor_generator: ConfigType,
bbox_coder: ConfigType,
init_cfg: MultiConfig,
act_cfg: ConfigType,
anchor_generator: ConfigDict | dict,
bbox_coder: ConfigDict | dict,
init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict],
act_cfg: ConfigDict | dict,
train_cfg: ConfigDict | dict,
num_classes: int = 80,
in_channels: tuple[int, ...] = (512, 1024, 512, 256, 256, 256),
stacked_convs: int = 0,
feat_channels: int = 256,
use_depthwise: bool = False,
conv_cfg: ConfigType | None = None,
norm_cfg: ConfigType | None = None,
reg_decoded_bbox: bool = False,
train_cfg: ConfigType | None = None,
test_cfg: ConfigType | None = None,
loss_cls: Config | dict | None = None,
test_cfg: ConfigDict | dict | None = None,
) -> None:
super(AnchorHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.in_channels = in_channels
self.stacked_convs = stacked_convs
self.feat_channels = feat_channels
self.use_depthwise = use_depthwise
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.act_cfg = act_cfg # TODO(Jaeguk): act_cfg will be deprecated after implementing export.

self.cls_out_channels = num_classes + 1 # add background class
anchor_generator.pop("type")
Expand All @@ -98,14 +86,7 @@ def __init__(
# heads but a list of int in SSDHead
self.num_base_priors = self.prior_generator.num_base_priors

if loss_cls is None:
loss_cls = {
"type": "CrossEntropyLoss",
"use_sigmoid": False,
"reduction": "none",
"loss_weight": 1.0,
}
self.loss_cls = MODELS.build(loss_cls)
self.loss_cls = CrossEntropyLoss(use_sigmoid=False, reduction="none", loss_weight=1.0)

self._init_layers()

Expand Down Expand Up @@ -218,9 +199,9 @@ def loss_by_feat(
self,
cls_scores: list[Tensor],
bbox_preds: list[Tensor],
batch_gt_instances: InstanceList,
batch_gt_instances: list[InstanceData],
batch_img_metas: list[dict],
batch_gt_instances_ignore: OptInstanceList = None,
batch_gt_instances_ignore: list[InstanceData] | None = None,
) -> dict[str, list[Tensor]]:
"""Compute losses of the head.
Expand Down Expand Up @@ -298,11 +279,9 @@ def _init_layers(self) -> None:
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()

activation_config = self.act_cfg.copy()
activation_config.setdefault("inplace", True)
for in_channel, num_base_priors in zip(self.in_channels, self.num_base_priors):
if self.use_depthwise:
activation_layer = MODELS.build(activation_config)
activation_layer = nn.ReLU(inplace=True)

self.reg_convs.append(
nn.Sequential(
Expand Down
25 changes: 9 additions & 16 deletions src/otx/algo/detection/heads/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch
from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor
from torch import Tensor


Expand Down Expand Up @@ -44,40 +43,39 @@ def __init__(
ctr_clamp: int = 32,
) -> None:
self.encode_size = encode_size
# TODO(Jaeguk): use_box_type should be deprecated.
self.use_box_type = use_box_type
self.means = target_means
self.stds = target_stds
self.clip_border = clip_border
self.add_ctr_clamp = add_ctr_clamp
self.ctr_clamp = ctr_clamp

def encode(self, bboxes: Tensor | BaseBoxes, gt_bboxes: Tensor | BaseBoxes) -> Tensor:
def encode(self, bboxes: Tensor, gt_bboxes: Tensor) -> Tensor:
"""Get box regression transformation deltas that can be used to transform the bboxes into the gt_bboxes.
Args:
bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes,
bboxes (torch.Tensor): Source boxes,
e.g., object proposals.
gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the
gt_bboxes (torch.Tensor): Target of the
transformation, e.g., ground-truth boxes.
Returns:
torch.Tensor: Box transformation deltas
"""
bboxes = get_box_tensor(bboxes)
gt_bboxes = get_box_tensor(gt_bboxes)
return bbox2delta(bboxes, gt_bboxes, self.means, self.stds)

def decode(
self,
bboxes: Tensor | BaseBoxes,
bboxes: Tensor,
pred_bboxes: Tensor,
max_shape: tuple[int, ...] | Tensor | tuple[tuple[int, ...], ...] | None = None,
wh_ratio_clip: float = 16 / 1000,
) -> Tensor | BaseBoxes:
) -> Tensor:
"""Apply transformation `pred_bboxes` to `boxes`.
Args:
bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes. Shape
bboxes (torch.Tensor): Basic boxes. Shape
(B, N, 4) or (N, 4)
pred_bboxes (Tensor): Encoded offsets with respect to each roi.
Has shape (B, N, num_classes * 4) or (B, N, 4) or
Expand All @@ -92,10 +90,9 @@ def decode(
width and height.
Returns:
Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
torch.Tensor: Decoded boxes.
"""
bboxes = get_box_tensor(bboxes)
decoded_bboxes = delta2bbox(
return delta2bbox(
bboxes,
pred_bboxes,
self.means,
Expand All @@ -107,10 +104,6 @@ def decode(
self.ctr_clamp,
)

if self.use_box_type:
decoded_bboxes = HorizontalBoxes(decoded_bboxes)
return decoded_bboxes


def bbox2delta(
proposals: Tensor,
Expand Down
Loading

0 comments on commit babae3b

Please sign in to comment.