From e45e8d36136bd88cd28b1777c4e0a0e4d727718c Mon Sep 17 00:00:00 2001 From: Jaeguk Hyun Date: Tue, 16 Apr 2024 09:41:43 +0900 Subject: [PATCH] Decoupling mmdet structures part2 (#3315) * Decouple anchor generator * Decouple base head * Decouple SSD class * Fix pre-commit --- src/otx/algo/detection/heads/base_head.py | 56 ++++---- .../heads/custom_anchor_generator.py | 11 +- src/otx/algo/detection/ssd.py | 61 +++++---- src/otx/algo/detection/utils/utils.py | 120 +++++++++++++++++- 4 files changed, 187 insertions(+), 61 deletions(-) diff --git a/src/otx/algo/detection/heads/base_head.py b/src/otx/algo/detection/heads/base_head.py index 89a883bb620..b50ddcb0236 100644 --- a/src/otx/algo/detection/heads/base_head.py +++ b/src/otx/algo/detection/heads/base_head.py @@ -10,18 +10,15 @@ from typing import TYPE_CHECKING import torch -from mmdet.models.utils import filter_scores_and_topk, select_single_mlvl, unpack_gt_instances -from mmdet.structures.bbox import cat_boxes, get_box_tensor, get_box_wh, scale_boxes +from mmcv.ops import batched_nms from mmengine.model import constant_init from mmengine.structures import InstanceData from torch import Tensor, nn -from otx.algo.detection.ops.nms import batched_nms +from otx.algo.detection.utils.utils import filter_scores_and_topk, select_single_mlvl, unpack_gt_instances if TYPE_CHECKING: - from mmdet.structures import SampleList - from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig - from mmengine.config import ConfigDict + from mmengine import ConfigDict # This class and its supporting functions below lightly adapted from the mmdet BaseDenseHead available at: @@ -63,7 +60,7 @@ class BaseDenseHead(nn.Module): loss_and_predict(): forward() -> loss_by_feat() -> predict_by_feat() """ - def __init__(self, init_cfg: OptMultiConfig = None) -> None: + def __init__(self, init_cfg: ConfigDict | list[ConfigDict] | dict | list[dict] | None = None) -> None: super().__init__() self._is_init = False @@ -83,7 +80,7 @@ def init_weights(self) -> None: if hasattr(m, "conv_offset"): constant_init(m.conv_offset, 0) - def get_positive_infos(self) -> InstanceList: + def get_positive_infos(self) -> list[InstanceData] | None: """Get positive information from sampling results. Returns: @@ -106,7 +103,7 @@ def get_positive_infos(self) -> InstanceList: positive_infos.append(pos_info) return positive_infos - def loss(self, x: tuple[Tensor], batch_data_samples: SampleList) -> dict: + def loss(self, x: tuple[Tensor], batch_data_samples: list[InstanceData]) -> dict: """Perform forward propagation and loss calculation of the detection head. Args: @@ -132,18 +129,18 @@ def loss_by_feat( self, cls_scores: list[Tensor], bbox_preds: list[Tensor], - batch_gt_instances: InstanceList, + batch_gt_instances: list[InstanceData], batch_img_metas: list[dict], - batch_gt_instances_ignore: OptInstanceList = None, + batch_gt_instances_ignore: list[InstanceData] | None = None, ) -> dict: """Calculate the loss based on the features extracted by the detection head.""" def loss_and_predict( self, x: tuple[Tensor], - batch_data_samples: SampleList, + batch_data_samples: list[InstanceData], proposal_cfg: ConfigDict | None = None, - ) -> tuple[dict, InstanceList]: + ) -> tuple[dict, list[InstanceData]]: """Perform forward propagation of the head, then calculate loss and predictions. Args: @@ -173,7 +170,12 @@ def loss_and_predict( predictions = self.predict_by_feat(cls_scores, bbox_preds, batch_img_metas=batch_img_metas, cfg=proposal_cfg) return losses, predictions - def predict(self, x: tuple[Tensor], batch_data_samples: SampleList, rescale: bool = False) -> InstanceList: + def predict( + self, + x: tuple[Tensor], + batch_data_samples: list[InstanceData], + rescale: bool = False, + ) -> list[InstanceData]: """Perform forward propagation of the detection head and predict detection results. Args: @@ -204,7 +206,7 @@ def predict_by_feat( cfg: ConfigDict | None = None, rescale: bool = False, with_nms: bool = True, - ) -> InstanceList: + ) -> list[InstanceData]: """Transform a batch of output features extracted from the head into bbox results. Note: When score_factors is not None, the cls_scores are @@ -242,8 +244,6 @@ def predict_by_feat( - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ - with_score_factors = score_factors is not None - num_levels = len(cls_scores) featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] @@ -259,7 +259,7 @@ def predict_by_feat( img_meta = batch_img_metas[img_id] cls_score_list = select_single_mlvl(cls_scores, img_id, detach=True) bbox_pred_list = select_single_mlvl(bbox_preds, img_id, detach=True) - if with_score_factors: + if score_factors is not None: score_factor_list = select_single_mlvl(score_factors, img_id, detach=True) else: score_factor_list = [None for _ in range(num_levels)] @@ -370,8 +370,13 @@ def _predict_by_feat_single( # `nms_pre` than before. score_thr = cfg.get("score_thr", 0) - results = filter_scores_and_topk(scores, score_thr, nms_pre, {"bbox_pred": bbox_pred, "priors": priors}) - scores, labels, keep_idxs, filtered_results = results + filtered_results: dict + scores, labels, keep_idxs, filtered_results = filter_scores_and_topk( # type: ignore[assignment] + scores, + score_thr, + nms_pre, + {"bbox_pred": bbox_pred, "priors": priors}, + ) bbox_pred = filtered_results["bbox_pred"] # noqa: PLW2901 priors = filtered_results["priors"] # noqa: PLW2901 @@ -388,7 +393,7 @@ def _predict_by_feat_single( mlvl_score_factors.append(score_factor) bbox_pred = torch.cat(mlvl_bbox_preds) - priors = cat_boxes(mlvl_valid_priors) + priors = torch.cat(mlvl_valid_priors) bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) results = InstanceData() @@ -438,7 +443,9 @@ def _bbox_post_process( """ if rescale: scale_factor = [1 / s for s in img_meta["scale_factor"]] - results.bboxes = scale_boxes(results.bboxes, scale_factor) + results.bboxes = results.bboxes * results.bboxes.new_tensor(scale_factor).repeat( + (1, int(results.bboxes.size(-1) / 2)), + ) if hasattr(results, "score_factors"): score_factors = results.pop("score_factors") @@ -446,13 +453,14 @@ def _bbox_post_process( # filter small size bboxes if cfg.get("min_bbox_size", -1) >= 0: - w, h = get_box_wh(results.bboxes) + w = results.bboxes[:, 2] - results.bboxes[:, 0] + h = results.bboxes[:, 3] - results.bboxes[:, 1] valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) if not valid_mask.all(): results = results[valid_mask] if with_nms and results.bboxes.numel() > 0: - bboxes = get_box_tensor(results.bboxes) + bboxes = results.bboxes det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, results.labels, cfg.nms) results = results[keep_idxs] # some nms would reweight the score, such as softnms diff --git a/src/otx/algo/detection/heads/custom_anchor_generator.py b/src/otx/algo/detection/heads/custom_anchor_generator.py index 8d706477a4c..bff9aa7d59b 100644 --- a/src/otx/algo/detection/heads/custom_anchor_generator.py +++ b/src/otx/algo/detection/heads/custom_anchor_generator.py @@ -10,7 +10,6 @@ import numpy as np import torch from mmdet.registry import TASK_UTILS -from mmdet.structures.bbox import HorizontalBoxes from torch.nn.modules.utils import _pair @@ -44,8 +43,6 @@ class AnchorGenerator: float is given, they will be used to shift the centers of anchors. center_offset (float): The offset of center in proportion to anchors' width and height. By default it is 0 in V2.0. - use_box_type (bool): Whether to warp anchors with the box type data - structure. Defaults to False. Examples: >>> from mmdet.models.task_modules. @@ -78,7 +75,6 @@ def __init__( scales_per_octave: int | None = None, centers: list[tuple[float, float]] | None = None, center_offset: float = 0.0, - use_box_type: bool = False, ) -> None: # check center and center_offset if center_offset != 0 and centers is None: @@ -112,7 +108,6 @@ def __init__( self.centers = centers self.center_offset = center_offset self.base_anchors = self.gen_base_anchors() - self.use_box_type = use_box_type @property def num_base_anchors(self) -> list[int]: @@ -278,12 +273,9 @@ def single_level_grid_priors( # shifted anchors (K, A, 4), reshape to (K*A, 4) all_anchors = base_anchors[None, :, :] + shifts[:, None, :] - all_anchors = all_anchors.view(-1, 4) # first A rows correspond to A anchors of (0, 0) in feature map, # then (0, 1), (0, 2), ... - if self.use_box_type: - all_anchors = HorizontalBoxes(all_anchors) - return all_anchors + return all_anchors.view(-1, 4) def sparse_priors( self, @@ -506,7 +498,6 @@ def __init__( self.center_offset = 0 self.gen_base_anchors() - self.use_box_type = False def gen_base_anchors(self) -> None: # type: ignore[override] """Generate base anchor for SSD.""" diff --git a/src/otx/algo/detection/ssd.py b/src/otx/algo/detection/ssd.py index 6f37dfa5e21..ec7c74b4296 100644 --- a/src/otx/algo/detection/ssd.py +++ b/src/otx/algo/detection/ssd.py @@ -31,8 +31,8 @@ if TYPE_CHECKING: import torch from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable - from mmdet.structures import DetDataSample, OptSampleList, SampleList - from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig + from mmengine import ConfigDict + from mmengine.structures import InstanceData from omegaconf import DictConfig from torch import Tensor, device @@ -51,12 +51,12 @@ class SingleStageDetector(nn.Module): def __init__( self, - backbone: ConfigType, - bbox_head: OptConfigType = None, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, - init_cfg: OptMultiConfig = None, + backbone: ConfigDict | dict, + bbox_head: ConfigDict | dict, + train_cfg: ConfigDict | dict | None = None, + test_cfg: ConfigDict | dict | None = None, + data_preprocessor: ConfigDict | dict | None = None, + init_cfg: ConfigDict | list[ConfigDict] | dict | list[dict] = None, ) -> None: super().__init__() self._is_init = False @@ -156,9 +156,9 @@ def init_weights(self) -> None: def forward( self, inputs: torch.Tensor, - data_samples: OptSampleList = None, + data_samples: list[InstanceData], mode: str = "tensor", - ) -> dict[str, torch.Tensor] | list[DetDataSample] | tuple[torch.Tensor] | torch.Tensor: + ) -> dict[str, torch.Tensor] | list[InstanceData] | tuple[torch.Tensor] | torch.Tensor: """The unified entry for a forward process in both training and test. The method should accept three modes: "tensor", "predict" and "loss": @@ -166,7 +166,7 @@ def forward( - "tensor": Forward the whole network and return tensor or tuple of tensor without any post-processing, same as a common nn.Module. - "predict": Forward and return the predictions, which are fully - processed to a list of :obj:`DetDataSample`. + processed to a list of :obj:`InstanceData`. - "loss": Forward and return a dict of losses according to the given inputs and data samples. @@ -176,7 +176,7 @@ def forward( Args: inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general. - data_samples (list[:obj:`DetDataSample`], optional): A batch of + data_samples (list[:obj:`InstanceData`], optional): A batch of data samples that contain annotations and predictions. Defaults to None. mode (str): Return what kind of value. Defaults to 'tensor'. @@ -185,7 +185,7 @@ def forward( The return type depends on ``mode``. - If ``mode="tensor"``, return a tensor or a tuple of tensor. - - If ``mode="predict"``, return a list of :obj:`DetDataSample`. + - If ``mode="predict"``, return a list of :obj:`InstanceData`. - If ``mode="loss"``, return a dict of tensor. """ if mode == "loss": @@ -201,14 +201,14 @@ def forward( def loss( self, batch_inputs: Tensor, - batch_data_samples: SampleList, + batch_data_samples: list[InstanceData], ) -> dict | list: """Calculate losses from a batch of inputs and data samples. Args: batch_inputs (Tensor): Input images of shape (N, C, H, W). These should usually be mean centered and std scaled. - batch_data_samples (list[:obj:`DetDataSample`]): The batch + batch_data_samples (list[:obj:`InstanceData`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. @@ -218,20 +218,25 @@ def loss( x = self.extract_feat(batch_inputs) return self.bbox_head.loss(x, batch_data_samples) - def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: bool = True) -> SampleList: + def predict( + self, + batch_inputs: Tensor, + batch_data_samples: list[InstanceData], + rescale: bool = True, + ) -> list[InstanceData]: """Predict results from a batch of inputs and data samples with post-processing. Args: batch_inputs (Tensor): Inputs with shape (N, C, H, W). - batch_data_samples (List[:obj:`DetDataSample`]): The Data + batch_data_samples (List[:obj:`InstanceData`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool): Whether to rescale the results. Defaults to True. Returns: - list[:obj:`DetDataSample`]: Detection results of the - input images. Each DetDataSample usually contain + list[:obj:`InstanceData`]: Detection results of the + input images. Each InstanceData usually contain 'pred_instances'. And the ``pred_instances`` usually contains following keys. @@ -249,13 +254,13 @@ def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: def _forward( self, batch_inputs: Tensor, - batch_data_samples: OptSampleList = None, + batch_data_samples: list[InstanceData] | None = None, ) -> tuple[list[Tensor], list[Tensor]]: """Network forward process. Args: batch_inputs (Tensor): Inputs with shape (N, C, H, W). - batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + batch_data_samples (list[:obj:`InstanceData`]): Each item contains the meta information of each image and corresponding annotations. @@ -280,18 +285,22 @@ def extract_feat(self, batch_inputs: Tensor) -> tuple[Tensor]: x = self.neck(x) return x - def add_pred_to_datasample(self, data_samples: SampleList, results_list: InstanceList) -> SampleList: - """Add predictions to `DetDataSample`. + def add_pred_to_datasample( + self, + data_samples: list[InstanceData], + results_list: list[InstanceData], + ) -> list[InstanceData]: + """Add predictions to `InstanceData`. Args: - data_samples (list[:obj:`DetDataSample`], optional): A batch of + data_samples (list[:obj:`InstanceData`], optional): A batch of data samples that contain annotations and predictions. results_list (list[:obj:`InstanceData`]): Detection results of each image. Returns: - list[:obj:`DetDataSample`]: Detection results of the - input images. Each DetDataSample usually contain + list[:obj:`InstanceData`]: Detection results of the + input images. Each InstanceData usually contain 'pred_instances'. And the ``pred_instances`` usually contains following keys. diff --git a/src/otx/algo/detection/utils/utils.py b/src/otx/algo/detection/utils/utils.py index 3abcb1178dc..5a869cde8e6 100644 --- a/src/otx/algo/detection/utils/utils.py +++ b/src/otx/algo/detection/utils/utils.py @@ -5,12 +5,17 @@ from __future__ import annotations from functools import partial -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from torch import Tensor +if TYPE_CHECKING: + from mmengine.structures import InstanceData + +# Methods below come from mmdet.utils and slightly modified. +# https://github.com/open-mmlab/mmdetection/blob/3.x/mmdet/models/utils/misc.py def multi_apply(func: Callable, *args, **kwargs) -> tuple: """Apply function to a list of arguments. @@ -92,3 +97,116 @@ def unmap(data: Tensor, count: int, inds: Tensor, fill: int = 0) -> Tensor: ret = data.new_full(new_size, fill) ret[inds.type(torch.bool), :] = data return ret + + +def filter_scores_and_topk( + scores: Tensor, + score_thr: float, + topk: int, + results: dict | list | Tensor | None = None, +) -> tuple[Tensor, Tensor, Tensor, dict | list | Tensor | None]: + """Filter results using score threshold and topk candidates. + + Args: + scores (Tensor): The scores, shape (num_bboxes, K). + score_thr (float): The score filter threshold. + topk (int): The number of topk candidates. + results (dict or list or Tensor, Optional): The results to + which the filtering rule is to be applied. The shape + of each item is (num_bboxes, N). + + Returns: + tuple: Filtered results + - scores (Tensor): The scores after being filtered, \ + shape (num_bboxes_filtered, ). + - labels (Tensor): The class labels, shape \ + (num_bboxes_filtered, ). + - anchor_idxs (Tensor): The anchor indexes, shape \ + (num_bboxes_filtered, ). + - filtered_results (dict or list or Tensor, Optional): \ + The filtered results. The shape of each item is \ + (num_bboxes_filtered, N). + """ + valid_mask = scores > score_thr + scores = scores[valid_mask] + valid_idxs = torch.nonzero(valid_mask) + + num_topk = min(topk, valid_idxs.size(0)) + # torch.sort is actually faster than .topk (at least on GPUs) + scores, idxs = scores.sort(descending=True) + scores = scores[:num_topk] + topk_idxs = valid_idxs[idxs[:num_topk]] + keep_idxs, labels = topk_idxs.unbind(dim=1) + + filtered_results: dict | list | Tensor | None = None + if results is not None: + if isinstance(results, dict): + filtered_results = {k: v[keep_idxs] for k, v in results.items()} + elif isinstance(results, list): + filtered_results = [result[keep_idxs] for result in results] + elif isinstance(results, torch.Tensor): + filtered_results = results[keep_idxs] + else: + msg = f"Only supports dict or list or Tensor, but get {type(results)}." + raise NotImplementedError(msg) + return scores, labels, keep_idxs, filtered_results + + +def select_single_mlvl(mlvl_tensors: list[Tensor], batch_id: int, detach: bool = True) -> list[Tensor]: + """Extract a multi-scale single image tensor from a multi-scale batch tensor based on batch index. + + Note: The default value of detach is True, because the proposal gradient + needs to be detached during the training of the two-stage model. E.g + Cascade Mask R-CNN. + + Args: + mlvl_tensors (list[Tensor]): Batch tensor for all scale levels, + each is a 4D-tensor. + batch_id (int): Batch index. + detach (bool): Whether detach gradient. Default True. + + Returns: + list[Tensor]: Multi-scale single image tensor. + """ + num_levels = len(mlvl_tensors) + + if detach: + mlvl_tensor_list = [mlvl_tensors[i][batch_id].detach() for i in range(num_levels)] + else: + mlvl_tensor_list = [mlvl_tensors[i][batch_id] for i in range(num_levels)] + return mlvl_tensor_list + + +def unpack_gt_instances(batch_data_samples: list[InstanceData]) -> tuple: + """Unpack gt_instances, gt_instances_ignore and img_metas based on batch_data_samples. + + Args: + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple: + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + - batch_gt_instances_ignore (list[:obj:`InstanceData`]): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + - batch_img_metas (list[dict]): Meta information of each image, + e.g., image size, scaling factor, etc. + """ + batch_gt_instances = [] + batch_gt_instances_ignore = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + if "ignored_instances" in data_sample: + batch_gt_instances_ignore.append(data_sample.ignored_instances) + else: + batch_gt_instances_ignore.append(None) + + return batch_gt_instances, batch_gt_instances_ignore, batch_img_metas