From 7a3b01d8eae348db309530596695935c0c43635b Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Fri, 16 Aug 2024 18:02:03 +0900 Subject: [PATCH] Create criterion modules --- src/otx/algo/detection/atss.py | 53 +++- .../detectors/single_stage_detector.py | 10 +- src/otx/algo/detection/heads/anchor_head.py | 10 +- src/otx/algo/detection/heads/atss_head.py | 211 +--------------- src/otx/algo/detection/heads/rtmdet_head.py | 100 +------- src/otx/algo/detection/heads/ssd_head.py | 106 +------- src/otx/algo/detection/heads/yolox_head.py | 58 ++--- src/otx/algo/detection/losses/__init__.py | 6 +- src/otx/algo/detection/losses/atss_loss.py | 234 ++++++++++++++++++ src/otx/algo/detection/losses/rtmdet_loss.py | 114 +++++++++ src/otx/algo/detection/losses/ssd_loss.py | 107 ++++++++ src/otx/algo/detection/losses/yolox_loss.py | 108 ++++++++ src/otx/algo/detection/rtmdet.py | 14 +- src/otx/algo/detection/ssd.py | 22 +- src/otx/algo/detection/yolox.py | 50 +++- .../heads/rtmdet_ins_head.py | 80 ++---- .../instance_segmentation/losses/__init__.py | 3 +- .../losses/rtmdet_inst_loss.py | 89 +++++++ .../algo/instance_segmentation/rtmdet_inst.py | 10 +- 19 files changed, 871 insertions(+), 514 deletions(-) create mode 100644 src/otx/algo/detection/losses/atss_loss.py create mode 100644 src/otx/algo/detection/losses/rtmdet_loss.py create mode 100644 src/otx/algo/detection/losses/ssd_loss.py create mode 100644 src/otx/algo/detection/losses/yolox_loss.py create mode 100644 src/otx/algo/instance_segmentation/losses/rtmdet_inst_loss.py diff --git a/src/otx/algo/detection/atss.py b/src/otx/algo/detection/atss.py index 20dda84c364..bc49d4f0834 100644 --- a/src/otx/algo/detection/atss.py +++ b/src/otx/algo/detection/atss.py @@ -14,6 +14,7 @@ from otx.algo.common.utils.samplers import PseudoSampler from otx.algo.detection.detectors import SingleStageDetector from otx.algo.detection.heads import ATSSHead +from otx.algo.detection.losses import ATSSCriterion from otx.algo.detection.necks import FPN from otx.algo.detection.utils.assigners import ATSSAssigner from otx.algo.utils.support_otx_v1 import OTXv1Helper @@ -145,6 +146,23 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(0.1, 0.1, 0.2, 0.2), ), + feat_channels=64, + train_cfg=train_cfg, + test_cfg=test_cfg, + loss_cls=CrossSigmoidFocalLoss( # TODO (eugene): deprecated + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_bbox=GIoULoss(loss_weight=2.0), # TODO (eugene): deprecated + ) + criterion = ATSSCriterion( + num_classes=num_classes, + bbox_coder=DeltaXYWHBBoxCoder( + target_means=(0.0, 0.0, 0.0, 0.0), + target_stds=(0.1, 0.1, 0.2, 0.2), + ), loss_cls=CrossSigmoidFocalLoss( use_sigmoid=True, gamma=2.0, @@ -153,11 +171,15 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: ), loss_bbox=GIoULoss(loss_weight=2.0), loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0), - feat_channels=64, + ) + return SingleStageDetector( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + criterion=criterion, train_cfg=train_cfg, test_cfg=test_cfg, ) - return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg) class ResNeXt101ATSS(ATSS): @@ -210,6 +232,24 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(0.1, 0.1, 0.2, 0.2), ), + num_classes=num_classes, + in_channels=256, + train_cfg=train_cfg, + test_cfg=test_cfg, + loss_cls=CrossSigmoidFocalLoss( # TODO (eugene): deprecated + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_bbox=GIoULoss(loss_weight=2.0), # TODO (eugene): deprecated + ) + criterion = ATSSCriterion( + num_classes=num_classes, + bbox_coder=DeltaXYWHBBoxCoder( + target_means=(0.0, 0.0, 0.0, 0.0), + target_stds=(0.1, 0.1, 0.2, 0.2), + ), loss_cls=CrossSigmoidFocalLoss( use_sigmoid=True, gamma=2.0, @@ -218,12 +258,15 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: ), loss_bbox=GIoULoss(loss_weight=2.0), loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0), - num_classes=num_classes, - in_channels=256, + ) + return SingleStageDetector( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + criterion=criterion, train_cfg=train_cfg, test_cfg=test_cfg, ) - return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg) def to(self, *args, **kwargs) -> Self: """Return a model with specified device.""" diff --git a/src/otx/algo/detection/detectors/single_stage_detector.py b/src/otx/algo/detection/detectors/single_stage_detector.py index c83626aa29c..b90d813e46a 100644 --- a/src/otx/algo/detection/detectors/single_stage_detector.py +++ b/src/otx/algo/detection/detectors/single_stage_detector.py @@ -26,6 +26,7 @@ class SingleStageDetector(BaseModule): Args: backbone (nn.Module): Backbone module. bbox_head (nn.Module): Bbox head module. + criterion (nn.Module | None, optional): Criterion module. neck (nn.Module | None, optional): Neck module. Defaults to None. train_cfg (dict | None, optional): Training config. Defaults to None. test_cfg (dict | None, optional): Test config. Defaults to None. @@ -36,6 +37,7 @@ def __init__( self, backbone: nn.Module, bbox_head: nn.Module, + criterion: nn.Module, neck: nn.Module | None = None, train_cfg: dict | None = None, test_cfg: dict | None = None, @@ -46,6 +48,7 @@ def __init__( self.backbone = backbone self.bbox_head = bbox_head self.neck = neck + self.criterion = criterion self.init_cfg = init_cfg self.train_cfg = train_cfg self.test_cfg = test_cfg @@ -129,7 +132,7 @@ def forward( def loss( self, entity: DetBatchDataEntity, - ) -> dict | list: + ) -> dict: """Calculate losses from a batch of inputs and data samples. Args: @@ -143,7 +146,10 @@ def loss( dict: A dictionary of loss components. """ x = self.extract_feat(entity.images) - return self.bbox_head.loss(x, entity) + # TODO (sungchul): compare .loss with other forwards and remove duplicated code + outputs = self.bbox_head.loss(x, entity) + + return self.criterion(outputs) def predict( self, diff --git a/src/otx/algo/detection/heads/anchor_head.py b/src/otx/algo/detection/heads/anchor_head.py index 115741c7619..42dbe247399 100644 --- a/src/otx/algo/detection/heads/anchor_head.py +++ b/src/otx/algo/detection/heads/anchor_head.py @@ -31,7 +31,9 @@ class AnchorHead(BaseDenseHead): anchor_generator (nn.Module): Module for anchor generator bbox_coder (nn.Module): Module of bounding box coder. loss_cls (nn.Module): Module of classification loss. + It is related to RPNHead for iseg, will be deprecated. loss_bbox (nn.Module): Module of localization loss. + It is related to RPNHead for iseg, will be deprecated. train_cfg (dict): Training config of anchor head. test_cfg (dict, optional): Testing config of anchor head. feat_channels (int): Number of hidden channels. Used in child classes. @@ -49,8 +51,8 @@ def __init__( in_channels: tuple[int, ...] | int, anchor_generator: nn.Module, bbox_coder: nn.Module, - loss_cls: nn.Module, - loss_bbox: nn.Module, + loss_cls: nn.Module, # TODO (eugene): deprecated + loss_bbox: nn.Module, # TODO (eugene): deprecated train_cfg: dict, test_cfg: dict | None = None, feat_channels: int = 256, @@ -410,6 +412,8 @@ def loss_by_feat_single( ) -> tuple: """Calculate the loss of a single scale level based on the features extracted by the detection head. + TODO (eugene): it is related to RPNHead for iseg, will be deprecated + Args: cls_score (Tensor): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W). @@ -459,6 +463,8 @@ def loss_by_feat( ) -> dict: """Calculate the loss based on the features extracted by the detection head. + TODO (eugene): it is related to RPNHead for iseg, will be deprecated + Args: cls_scores (list[Tensor]): Box scores for each scale level has shape (N, num_anchors * num_classes, H, W). diff --git a/src/otx/algo/detection/heads/atss_head.py b/src/otx/algo/detection/heads/atss_head.py index 9d85dbf0b77..67fb7a06175 100644 --- a/src/otx/algo/detection/heads/atss_head.py +++ b/src/otx/algo/detection/heads/atss_head.py @@ -11,8 +11,6 @@ import torch from torch import Tensor, nn -from otx.algo.common.losses import CrossEntropyLoss, CrossSigmoidFocalLoss -from otx.algo.common.utils.bbox_overlaps import bbox_overlaps from otx.algo.common.utils.utils import multi_apply, reduce_mean from otx.algo.detection.heads.anchor_head import AnchorHead from otx.algo.detection.heads.class_incremental_mixin import ( @@ -46,7 +44,6 @@ class ATSSHead(ClassIncrementalMixin, AnchorHead): the predicted boxes and regression targets to absolute coordinates format. Defaults to False. It should be `True` when using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. - loss_centerness (nn.Module, optinoal): Module of centerness loss. Defaults to None. init_cfg (dict, list[dict], optional): Initialization config dict. """ @@ -58,11 +55,7 @@ def __init__( stacked_convs: int = 4, norm_cfg: dict | None = None, reg_decoded_bbox: bool = True, - loss_centerness: nn.Module | None = None, init_cfg: dict | None = None, - bg_loss_weight: float = -1.0, - use_qfl: bool = False, - qfl_cfg: dict | None = None, **kwargs, ) -> None: self.pred_kernel_size = pred_kernel_size @@ -83,21 +76,6 @@ def __init__( ) self.sampling = False - self.loss_centerness = loss_centerness or CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0) - - if use_qfl: - kwargs["loss_cls"] = ( - qfl_cfg - if qfl_cfg - else { - "type": "QualityFocalLoss", - "use_sigmoid": True, - "beta": 2.0, - "loss_weight": 1.0, - } - ) - self.bg_loss_weight = bg_loss_weight - self.use_qfl = use_qfl def _init_layers(self) -> None: """Initialize layers of the head.""" @@ -221,7 +199,7 @@ def loss_by_feat( # type: ignore[override] Defaults to None. Returns: - dict[str, Tensor]: A dictionary of loss components. + dict[str, Tensor]: A dictionary of raw outputs. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] if len(featmap_sizes) != self.prior_generator.num_levels: @@ -250,182 +228,17 @@ def loss_by_feat( # type: ignore[override] ) = cls_reg_targets avg_factor = reduce_mean(torch.tensor(avg_factor, dtype=torch.float, device=device)).item() - losses_cls, losses_bbox, loss_centerness, bbox_avg_factor = multi_apply( - self.loss_by_feat_single, - anchor_list, - cls_scores, - bbox_preds, - centernesses, - labels_list, - label_weights_list, - bbox_targets_list, - valid_label_mask, - avg_factor=avg_factor, - ) - - bbox_avg_factor = sum(bbox_avg_factor) - bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item() - losses_bbox = [loss_bbox / bbox_avg_factor for loss_bbox in losses_bbox] - return {"loss_cls": losses_cls, "loss_bbox": losses_bbox, "loss_centerness": loss_centerness} - - def loss_by_feat_single( # type: ignore[override] - self, - anchors: Tensor, - cls_score: Tensor, - bbox_pred: Tensor, - centerness: Tensor, - labels: Tensor, - label_weights: Tensor, - bbox_targets: Tensor, - valid_label_mask: Tensor, - avg_factor: float, - ) -> tuple: - """Compute loss of a single scale level. - - Args: - anchors (Tensor): Box reference for each scale level with shape - (N, num_total_anchors, 4). - cls_score (Tensor): Box scores for each scale level - Has shape (N, num_anchors * num_classes, H, W). - bbox_pred (Tensor): Box energies / deltas for each scale - level with shape (N, num_anchors * 4, H, W). - centerness(Tensor): Centerness scores for each scale level. - labels (Tensor): Labels of each anchors with shape - (N, num_total_anchors). - label_weights (Tensor): Label weights of each anchor with shape - (N, num_total_anchors) - bbox_targets (Tensor): BBox regression targets of each anchor with - shape (N, num_total_anchors, 4). - valid_label_mask (Tensor): Label mask for consideration of ignored - label with shape (N, num_total_anchors, 1). - avg_factor (float): Average factor that is used to average - the loss. When using sampling method, avg_factor is usually - the sum of positive and negative priors. When using - `PseudoSampler`, `avg_factor` is usually equal to the number - of positive priors. - - Returns: - tuple[Tensor]: A tuple of loss components. - """ - anchors = anchors.reshape(-1, 4) - cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels).contiguous() - bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) - centerness = centerness.permute(0, 2, 3, 1).reshape(-1) - bbox_targets = bbox_targets.reshape(-1, 4) - labels = labels.reshape(-1) - label_weights = label_weights.reshape(-1) - valid_label_mask = valid_label_mask.reshape(-1, self.cls_out_channels) - - # FG cat_id: [0, num_classes -1], BG cat_id: num_classes - pos_inds = self._get_pos_inds(labels) - - if self.use_qfl: - quality = label_weights.new_zeros(labels.shape) - - if len(pos_inds) > 0: - pos_bbox_targets = bbox_targets[pos_inds] - pos_bbox_pred = bbox_pred[pos_inds] - pos_anchors = anchors[pos_inds] - pos_centerness = centerness[pos_inds] - - centerness_targets = self.centerness_target(pos_anchors, pos_bbox_targets) - if self.reg_decoded_bbox: - pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred) - - if self.use_qfl: - quality[pos_inds] = bbox_overlaps(pos_bbox_pred.detach(), pos_bbox_targets, is_aligned=True).clamp( - min=1e-6, - ) - - # regression loss - loss_bbox = self._get_loss_bbox(pos_bbox_targets, pos_bbox_pred, centerness_targets) - - # centerness loss - loss_centerness = self._get_loss_centerness(avg_factor, pos_centerness, centerness_targets) - - else: - loss_bbox = bbox_pred.sum() * 0 - loss_centerness = centerness.sum() * 0 - centerness_targets = bbox_targets.new_tensor(0.0) - - # Re-weigting BG loss - if self.bg_loss_weight >= 0.0: - neg_indices = labels == self.num_classes - label_weights[neg_indices] = self.bg_loss_weight - - if self.use_qfl: - labels = (labels, quality) # For quality focal loss arg spec - - # classification loss - loss_cls = self._get_loss_cls(cls_score, labels, label_weights, valid_label_mask, avg_factor) - - return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum() - - def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor: - """Calculate the centerness between anchors and gts. - - Only calculate pos centerness targets, otherwise there may be nan. - - Args: - anchors (Tensor): Anchors with shape (N, 4), "xyxy" format. - gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format. - - Returns: - Tensor: Centerness between anchors and gts. - """ - anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 - anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 - l_ = anchors_cx - gts[:, 0] - t_ = anchors_cy - gts[:, 1] - r_ = gts[:, 2] - anchors_cx - b_ = gts[:, 3] - anchors_cy - - left_right = torch.stack([l_, r_], dim=1) - top_bottom = torch.stack([t_, b_], dim=1) - return torch.sqrt( - (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) - * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]), - ) - - def _get_pos_inds(self, labels: Tensor) -> Tensor: - bg_class_ind = self.num_classes - return ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) - - def _get_loss_cls( - self, - cls_score: Tensor, - labels: Tensor, - label_weights: Tensor, - valid_label_mask: Tensor, - avg_factor: Tensor, - ) -> Tensor: - if isinstance(self.loss_cls, CrossSigmoidFocalLoss): - loss_cls = self.loss_cls( - cls_score, - labels, - label_weights, - avg_factor=avg_factor, - valid_label_mask=valid_label_mask, - ) - else: - loss_cls = self.loss_cls(cls_score, labels, label_weights, avg_factor=avg_factor) - return loss_cls - - def _get_loss_centerness( - self, - avg_factor: Tensor, - pos_centerness: Tensor, - centerness_targets: Tensor, - ) -> Tensor: - return self.loss_centerness(pos_centerness, centerness_targets, avg_factor=avg_factor) - - def _get_loss_bbox( - self, - pos_bbox_targets: Tensor, - pos_bbox_pred: Tensor, - centerness_targets: Tensor, - ) -> Tensor: - return self.loss_bbox(pos_bbox_pred, pos_bbox_targets, weight=centerness_targets, avg_factor=1.0) + return { + "anchors": anchor_list, + "cls_score": cls_scores, + "bbox_pred": bbox_preds, + "centerness": centernesses, + "labels": labels_list, + "label_weights": label_weights_list, + "bbox_targets": bbox_targets_list, + "valid_label_mask": valid_label_mask, + "avg_factor": avg_factor, + } def get_targets( self, diff --git a/src/otx/algo/detection/heads/rtmdet_head.py b/src/otx/algo/detection/heads/rtmdet_head.py index 71623aad9e7..c8590c8053f 100644 --- a/src/otx/algo/detection/heads/rtmdet_head.py +++ b/src/otx/algo/detection/heads/rtmdet_head.py @@ -14,7 +14,7 @@ from torch import Tensor, nn from otx.algo.common.utils.nms import multiclass_nms -from otx.algo.common.utils.utils import distance2bbox, inverse_sigmoid, multi_apply, reduce_mean +from otx.algo.common.utils.utils import distance2bbox, inverse_sigmoid, multi_apply from otx.algo.detection.heads import ATSSHead from otx.algo.detection.utils.prior_generators.utils import anchor_inside_flags from otx.algo.detection.utils.utils import ( @@ -153,75 +153,6 @@ def forward(self, feats: tuple[Tensor, ...]) -> tuple: bbox_preds.append(reg_dist) return tuple(cls_scores), tuple(bbox_preds) - def loss_by_feat_single( # type: ignore[override] - self, - cls_score: Tensor, - bbox_pred: Tensor, - labels: Tensor, - label_weights: Tensor, - bbox_targets: Tensor, - assign_metrics: Tensor, - stride: list[int], - ) -> tuple[Tensor, ...]: - """Compute loss of a single scale level. - - Args: - cls_score (Tensor): Box scores for each scale level - Has shape (N, num_anchors * num_classes, H, W). - bbox_pred (Tensor): Decoded bboxes for each scale - level with shape (N, num_anchors * 4, H, W). - labels (Tensor): Labels of each anchors with shape - (N, num_total_anchors). - label_weights (Tensor): Label weights of each anchor with shape - (N, num_total_anchors). - bbox_targets (Tensor): BBox regression targets of each anchor with - shape (N, num_total_anchors, 4). - assign_metrics (Tensor): Assign metrics with shape - (N, num_total_anchors). - stride (list[int]): Downsample stride of the feature map. - - Returns: - dict[str, Tensor]: A dictionary of loss components. - """ - if stride[0] != stride[1]: - msg = "h stride is not equal to w stride!" - raise ValueError(msg) - cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels).contiguous() - bbox_pred = bbox_pred.reshape(-1, 4) - bbox_targets = bbox_targets.reshape(-1, 4) - labels = labels.reshape(-1) - assign_metrics = assign_metrics.reshape(-1) - label_weights = label_weights.reshape(-1) - targets = (labels, assign_metrics) - - loss_cls = self.loss_cls(cls_score, targets, label_weights, avg_factor=1.0) - - # FG cat_id: [0, num_classes -1], BG cat_id: num_classes - bg_class_ind = self.num_classes - pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) - - if len(pos_inds) > 0: - pos_bbox_targets = bbox_targets[pos_inds] - pos_bbox_pred = bbox_pred[pos_inds] - - pos_decode_bbox_pred = pos_bbox_pred - pos_decode_bbox_targets = pos_bbox_targets - - # regression loss - pos_bbox_weight = assign_metrics[pos_inds] - - loss_bbox = self.loss_bbox( - pos_decode_bbox_pred, - pos_decode_bbox_targets, - weight=pos_bbox_weight, - avg_factor=1.0, - ) - else: - loss_bbox = bbox_pred.sum() * 0 - pos_bbox_weight = bbox_targets.new_tensor(0.0) - - return loss_cls, loss_bbox, assign_metrics.sum(), pos_bbox_weight.sum() - def loss_by_feat( # type: ignore[override] self, cls_scores: list[Tensor], @@ -249,7 +180,7 @@ def loss_by_feat( # type: ignore[override] Defaults to None. Returns: - dict[str, Tensor]: A dictionary of loss components. + dict[str, Tensor]: A dictionary of raw outputs. """ num_imgs = len(batch_img_metas) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] @@ -292,23 +223,16 @@ def loss_by_feat( # type: ignore[override] batch_gt_instances_ignore=batch_gt_instances_ignore, ) - losses_cls, losses_bbox, cls_avg_factors, bbox_avg_factors = multi_apply( - self.loss_by_feat_single, - cls_scores, - decoded_bboxes, - labels_list, - label_weights_list, - bbox_targets_list, - assign_metrics_list, - self.prior_generator.strides, - ) - - cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() - losses_cls = [x / cls_avg_factor for x in losses_cls] - - bbox_avg_factor = reduce_mean(sum(bbox_avg_factors)).clamp_(min=1).item() - losses_bbox = [x / bbox_avg_factor for x in losses_bbox] - return {"loss_cls": losses_cls, "loss_bbox": losses_bbox} + return { + "cls_score": cls_scores, + "bbox_pred": decoded_bboxes, + "labels": labels_list, + "label_weights": label_weights_list, + "bbox_targets": bbox_targets_list, + "assign_metrics": assign_metrics_list, + "stride": self.prior_generator.strides, + "sampling_results_list": sampling_results_list, + } def export_by_feat( # type: ignore[override] self, diff --git a/src/otx/algo/detection/heads/ssd_head.py b/src/otx/algo/detection/heads/ssd_head.py index fb6aa316153..4d7532af3a8 100644 --- a/src/otx/algo/detection/heads/ssd_head.py +++ b/src/otx/algo/detection/heads/ssd_head.py @@ -13,9 +13,7 @@ import torch from torch import Tensor, nn -from otx.algo.common.losses import CrossEntropyLoss, smooth_l1_loss from otx.algo.common.utils.samplers import PseudoSampler -from otx.algo.common.utils.utils import multi_apply from otx.algo.detection.heads.anchor_head import AnchorHead if TYPE_CHECKING: @@ -27,7 +25,6 @@ class SSDHead(AnchorHead): Args: anchor_generator (nn.Module): Config dict for anchor generator. - bbox_coder (nn.Module): Config of bounding box coder. init_cfg (dict, list[dict]): Initialization config dict. train_cfg (dict): Training config of anchor head. num_classes (int): Number of categories excluding the background category. @@ -38,18 +35,12 @@ class SSDHead(AnchorHead): Defaults to 256. use_depthwise (bool): Whether to use DepthwiseSeparableConv. Defaults to False. - reg_decoded_bbox (bool): If true, the regression loss would be - applied directly on decoded bounding boxes, converting both - the predicted boxes and regression targets to absolute - coordinates format. Defaults to False. It should be `True` when - using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. test_cfg (dict, Optional): Testing config of anchor head. """ def __init__( self, anchor_generator: nn.Module, - bbox_coder: nn.Module, init_cfg: dict | list[dict], train_cfg: dict, num_classes: int = 80, @@ -57,7 +48,6 @@ def __init__( stacked_convs: int = 0, feat_channels: int = 256, use_depthwise: bool = False, - reg_decoded_bbox: bool = False, test_cfg: dict | None = None, ) -> None: super(AnchorHead, self).__init__(init_cfg=init_cfg) @@ -75,12 +65,8 @@ def __init__( # heads but a list of int in SSDHead self.num_base_priors = self.prior_generator.num_base_priors - self.loss_cls = CrossEntropyLoss(use_sigmoid=False, reduction="none", loss_weight=1.0) - self._init_layers() - self.bbox_coder = bbox_coder - self.reg_decoded_bbox = reg_decoded_bbox self.use_sigmoid_cls = False self.cls_focal_loss = False self.train_cfg = train_cfg @@ -113,66 +99,6 @@ def forward(self, x: tuple[Tensor]) -> tuple[list[Tensor], list[Tensor]]: bbox_preds.append(reg_conv(feat)) return cls_scores, bbox_preds - def loss_by_feat_single( - self, - cls_score: Tensor, - bbox_pred: Tensor, - anchor: Tensor, - labels: Tensor, - label_weights: Tensor, - bbox_targets: Tensor, - bbox_weights: Tensor, - avg_factor: int, - ) -> tuple[Tensor, Tensor]: - """Compute loss of a single image. - - Args: - cls_score (Tensor): Box scores for each image has shape (num_total_anchors, num_classes). - bbox_pred (Tensor): Box energies / deltas for each image level with shape (num_total_anchors, 4). - anchors (Tensor): Box reference for each scale level with shape (num_total_anchors, 4). - labels (Tensor): Labels of each anchors with shape (num_total_anchors,). - label_weights (Tensor): Label weights of each anchor with shape (num_total_anchors,) - bbox_targets (Tensor): BBox regression targets of each anchor with shape (num_total_anchors, 4). - bbox_weights (Tensor): BBox regression loss weights of each anchor with shape (num_total_anchors, 4). - avg_factor (int): Average factor that is used to average - the loss. When using sampling method, avg_factor is usually - the sum of positive and negative priors. When using - `PseudoSampler`, `avg_factor` is usually equal to the number - of positive priors. - - Returns: - tuple[Tensor, Tensor]: A tuple of cls loss and bbox loss of one - feature map. - """ - loss_cls_all = nn.functional.cross_entropy(cls_score, labels, reduction="none") * label_weights - # FG cat_id: [0, num_classes -1], BG cat_id: num_classes - pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(as_tuple=False).reshape(-1) - neg_inds = (labels == self.num_classes).nonzero(as_tuple=False).view(-1) - - num_pos_samples = pos_inds.size(0) - num_neg_samples = self.train_cfg["neg_pos_ratio"] * num_pos_samples - if num_neg_samples > neg_inds.size(0): - num_neg_samples = neg_inds.size(0) - topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples) - loss_cls_pos = loss_cls_all[pos_inds].sum() - loss_cls_neg = topk_loss_cls_neg.sum() - loss_cls = (loss_cls_pos + loss_cls_neg) / avg_factor - - if self.reg_decoded_bbox: - # When the regression loss (e.g. `IouLoss`, `GIouLoss`) - # is applied directly on the decoded bounding boxes, it - # decodes the already encoded coordinates to absolute format. - bbox_pred = self.bbox_coder.decode(anchor, bbox_pred) - - loss_bbox = smooth_l1_loss( - bbox_pred, - bbox_targets, - bbox_weights, - beta=self.train_cfg["smoothl1_beta"], - avg_factor=avg_factor, - ) - return loss_cls[None], loss_bbox - def loss_by_feat( self, cls_scores: list[Tensor], @@ -180,7 +106,7 @@ def loss_by_feat( batch_gt_instances: list[InstanceData], batch_img_metas: list[dict], batch_gt_instances_ignore: list[InstanceData] | None = None, - ) -> dict[str, list[Tensor]]: + ) -> dict[str, Tensor]: """Compute losses of the head. Args: @@ -195,13 +121,7 @@ def loss_by_feat( Defaults to None. Returns: - dict[str, list[Tensor]]: A dictionary of loss components. the dict - has components below: - - - loss_cls (list[Tensor]): A list containing each feature map \ - classification loss. - - loss_bbox (list[Tensor]): A list containing each feature map \ - regression loss. + dict[str, Tensor]: A dictionary of raw outputs. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] @@ -232,18 +152,16 @@ def loss_by_feat( # concat all level anchors to a single tensor all_anchors = [torch.cat(anchor) for anchor in anchor_list] - losses_cls, losses_bbox = multi_apply( - self.loss_by_feat_single, - all_cls_scores, - all_bbox_preds, - all_anchors, - all_labels, - all_label_weights, - all_bbox_targets, - all_bbox_weights, - avg_factor=avg_factor, - ) - return {"loss_cls": losses_cls, "loss_bbox": losses_bbox} + return { + "all_cls_scores": all_cls_scores, + "all_bbox_preds": all_bbox_preds, + "all_anchors": all_anchors, + "all_labels": all_labels, + "all_label_weights": all_label_weights, + "all_bbox_targets": all_bbox_targets, + "all_bbox_weights": all_bbox_weights, + "avg_factor": avg_factor, + } def _init_layers(self) -> None: """Initialize layers of the head. diff --git a/src/otx/algo/detection/heads/yolox_head.py b/src/otx/algo/detection/heads/yolox_head.py index af5dd677af0..d8830ecb1aa 100644 --- a/src/otx/algo/detection/heads/yolox_head.py +++ b/src/otx/algo/detection/heads/yolox_head.py @@ -17,13 +17,11 @@ from torch import Tensor, nn from torchvision.ops import box_convert -from otx.algo.common.losses import CrossEntropyLoss, L1Loss from otx.algo.common.utils.nms import batched_nms, multiclass_nms from otx.algo.common.utils.prior_generators import MlvlPointGenerator from otx.algo.common.utils.samplers import PseudoSampler from otx.algo.common.utils.utils import multi_apply, reduce_mean from otx.algo.detection.heads.base_head import BaseDenseHead -from otx.algo.detection.losses import IoULoss from otx.algo.modules.activation import Swish from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule from otx.algo.utils.mmengine_utils import InstanceData @@ -54,10 +52,6 @@ class YOLOXHead(BaseDenseHead): Defaults to dict(type='BN', momentum=0.03, eps=0.001). activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to `Swish`. - loss_cls (nn.Module, optional): Module of classification loss. - loss_bbox (nn.Module, optional): Module of localization loss. - loss_obj (nn.Module, optional): Module of objectness loss. - loss_l1 (nn.Module, optional): Module of L1 loss. train_cfg (dict, optional): Training config of anchor head. Defaults to None. test_cfg (dict, optional): Testing config of anchor head. @@ -78,10 +72,6 @@ def __init__( conv_bias: bool | str = "auto", norm_cfg: dict | None = None, activation_callable: Callable[..., nn.Module] = Swish, - loss_cls: nn.Module | None = None, - loss_bbox: nn.Module | None = None, - loss_obj: nn.Module | None = None, - loss_l1: nn.Module | None = None, train_cfg: dict | None = None, test_cfg: dict | None = None, init_cfg: dict | list[dict] | None = None, @@ -118,12 +108,7 @@ def __init__( self.norm_cfg = norm_cfg self.activation_callable = activation_callable - self.loss_cls = loss_cls or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0) - self.loss_bbox = loss_bbox or IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0) - self.loss_obj = loss_obj or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0) - self.use_l1 = False # This flag will be modified by hooks. - self.loss_l1 = loss_l1 or L1Loss(reduction="sum", loss_weight=1.0) self.prior_generator = MlvlPointGenerator(strides, offset=0) # type: ignore[arg-type] @@ -491,7 +476,7 @@ def loss_by_feat( # type: ignore[override] Defaults to None. Returns: - dict[str, Tensor]: A dictionary of losses. + dict[str, Tensor]: A dictionary of raw outputs. """ num_imgs = len(batch_img_metas) if batch_gt_instances_ignore is None: @@ -543,34 +528,19 @@ def loss_by_feat( # type: ignore[override] if self.use_l1: l1_targets = torch.cat(l1_targets, 0) - loss_obj = self.loss_obj(flatten_objectness.view(-1, 1), obj_targets) / num_total_samples - if num_pos > 0: - loss_cls = ( - self.loss_cls(flatten_cls_preds.view(-1, self.num_classes)[pos_masks], cls_targets) / num_total_samples - ) - loss_bbox = self.loss_bbox(flatten_bboxes.view(-1, 4)[pos_masks], bbox_targets) / num_total_samples - else: - # Avoid cls and reg branch not participating in the gradient - # propagation when there is no ground-truth in the images. - # For more details, please refer to - # https://github.com/open-mmlab/mmdetection/issues/7298 - loss_cls = flatten_cls_preds.sum() * 0 - loss_bbox = flatten_bboxes.sum() * 0 - - loss_dict = {"loss_cls": loss_cls, "loss_bbox": loss_bbox, "loss_obj": loss_obj} - - if self.use_l1: - if num_pos > 0: - loss_l1 = self.loss_l1(flatten_bbox_preds.view(-1, 4)[pos_masks], l1_targets) / num_total_samples - else: - # Avoid cls and reg branch not participating in the gradient - # propagation when there is no ground-truth in the images. - # For more details, please refer to - # https://github.com/open-mmlab/mmdetection/issues/7298 - loss_l1 = flatten_bbox_preds.sum() * 0 - loss_dict.update(loss_l1=loss_l1) - - return loss_dict + return { + "flatten_objectness": flatten_objectness, + "flatten_cls_preds": flatten_cls_preds, + "flatten_bbox_preds": flatten_bbox_preds, + "flatten_bboxes": flatten_bboxes, + "obj_targets": obj_targets, + "cls_targets": cls_targets, + "bbox_targets": bbox_targets, + "l1_targets": l1_targets, + "num_total_samples": num_total_samples, + "num_pos": num_pos, + "pos_masks": pos_masks, + } @torch.no_grad() def _get_targets_single( diff --git a/src/otx/algo/detection/losses/__init__.py b/src/otx/algo/detection/losses/__init__.py index 768be6e9778..91b4ad733a4 100644 --- a/src/otx/algo/detection/losses/__init__.py +++ b/src/otx/algo/detection/losses/__init__.py @@ -3,7 +3,11 @@ # """Custom OTX Losses for Object Detection.""" +from .atss_loss import ATSSCriterion from .iou_loss import IoULoss from .rtdetr_loss import DetrCriterion +from .rtmdet_loss import RTMDetCriterion +from .ssd_loss import SSDCriterion +from .yolox_loss import YOLOXCriterion -__all__ = ["IoULoss", "DetrCriterion"] +__all__ = ["ATSSCriterion", "IoULoss", "DetrCriterion", "RTMDetCriterion", "SSDCriterion", "YOLOXCriterion"] diff --git a/src/otx/algo/detection/losses/atss_loss.py b/src/otx/algo/detection/losses/atss_loss.py new file mode 100644 index 00000000000..f62fcfef634 --- /dev/null +++ b/src/otx/algo/detection/losses/atss_loss.py @@ -0,0 +1,234 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) OpenMMLab. All rights reserved. +# +"""ATSS criterion.""" + +from __future__ import annotations + +import torch +from torch import Tensor, nn + +from otx.algo.common.losses import CrossEntropyLoss, CrossSigmoidFocalLoss, QualityFocalLoss +from otx.algo.common.utils.bbox_overlaps import bbox_overlaps +from otx.algo.common.utils.utils import reduce_mean + + +class ATSSCriterion(nn.Module): + """ATSSCriterion is a loss criterion used in the Adaptive Training Sample Selection (ATSS) algorithm. + + Args: + num_classes (int): The number of object classes. + bbox_coder (nn.Module): The module used for encoding and decoding bounding box coordinates. + loss_cls (nn.Module): The module used for calculating the classification loss. + loss_bbox (nn.Module): The module used for calculating the bounding box regression loss. + loss_centerness (nn.Module | None, optional): The module used for calculating the centerness loss. + Defaults to None. + use_qfl (bool, optional): Whether to use the Quality Focal Loss (QFL). + Defaults to ``CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0)``. + reg_decoded_bbox (bool, optional): Whether to use the decoded bounding box coordinates + for regression loss calculation. Defaults to True. + bg_loss_weight (float, optional): The weight for the background loss. + Defaults to -1.0. + """ + + def __init__( + self, + num_classes: int, + bbox_coder: nn.Module, + loss_cls: nn.Module, + loss_bbox: nn.Module, + loss_centerness: nn.Module | None = None, + use_qfl: bool = False, + qfl_cfg: dict | None = None, + reg_decoded_bbox: bool = True, + bg_loss_weight: float = -1.0, + ) -> None: + super().__init__() + self.num_classes = num_classes + self.bbox_coder = bbox_coder + self.use_qfl = use_qfl + self.reg_decoded_bbox = reg_decoded_bbox + self.bg_loss_weight = bg_loss_weight + + self.loss_bbox = loss_bbox + self.loss_centerness = loss_centerness or CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0) + + if use_qfl: + loss_cls = qfl_cfg or QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0) + + self.loss_cls = loss_cls + + self.use_sigmoid_cls = loss_cls.use_sigmoid + if self.use_sigmoid_cls: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + + if self.cls_out_channels <= 0: + msg = f"num_classes={num_classes} is too small" + raise ValueError(msg) + + def forward( + self, + anchors: Tensor, + cls_score: Tensor, + bbox_pred: Tensor, + centerness: Tensor, + labels: Tensor, + label_weights: Tensor, + bbox_targets: Tensor, + valid_label_mask: Tensor, + avg_factor: float, + ) -> dict[str, Tensor]: + """Compute loss of a single scale level. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + centerness(Tensor): Centerness scores for each scale level. + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + valid_label_mask (Tensor): Label mask for consideration of ignored + label with shape (N, num_total_anchors, 1). + avg_factor (float): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels).contiguous() + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + centerness = centerness.permute(0, 2, 3, 1).reshape(-1) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + valid_label_mask = valid_label_mask.reshape(-1, self.cls_out_channels) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + pos_inds = self._get_pos_inds(labels) + + if self.use_qfl: + quality = label_weights.new_zeros(labels.shape) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_centerness = centerness[pos_inds] + + centerness_targets = self.centerness_target(pos_anchors, pos_bbox_targets) + if self.reg_decoded_bbox: + pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred) + + if self.use_qfl: + quality[pos_inds] = bbox_overlaps(pos_bbox_pred.detach(), pos_bbox_targets, is_aligned=True).clamp( + min=1e-6, + ) + + # regression loss + loss_bbox = self._get_loss_bbox(pos_bbox_targets, pos_bbox_pred, centerness_targets) + + # centerness loss + loss_centerness = self._get_loss_centerness(avg_factor, pos_centerness, centerness_targets) + + else: + loss_bbox = bbox_pred.sum() * 0 + loss_centerness = centerness.sum() * 0 + centerness_targets = bbox_targets.new_tensor(0.0) + + # Re-weigting BG loss + if self.bg_loss_weight >= 0.0: + neg_indices = labels == self.num_classes + label_weights[neg_indices] = self.bg_loss_weight + + if self.use_qfl: + labels = (labels, quality) # For quality focal loss arg spec + + # classification loss + loss_cls = self._get_loss_cls(cls_score, labels, label_weights, valid_label_mask, avg_factor) + + bbox_avg_factor = centerness_targets.sum() + + bbox_avg_factor = sum(bbox_avg_factor) + bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item() + loss_bbox = [x / bbox_avg_factor for x in loss_bbox] + return {"loss_cls": loss_cls, "loss_bbox": loss_bbox, "loss_centerness": loss_centerness} + + def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor: + """Calculate the centerness between anchors and gts. + + Only calculate pos centerness targets, otherwise there may be nan. + + Args: + anchors (Tensor): Anchors with shape (N, 4), "xyxy" format. + gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format. + + Returns: + Tensor: Centerness between anchors and gts. + """ + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + l_ = anchors_cx - gts[:, 0] + t_ = anchors_cy - gts[:, 1] + r_ = gts[:, 2] - anchors_cx + b_ = gts[:, 3] - anchors_cy + + left_right = torch.stack([l_, r_], dim=1) + top_bottom = torch.stack([t_, b_], dim=1) + return torch.sqrt( + (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) + * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]), + ) + + def _get_pos_inds(self, labels: Tensor) -> Tensor: + bg_class_ind = self.num_classes + return ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) + + def _get_loss_bbox( + self, + pos_bbox_targets: Tensor, + pos_bbox_pred: Tensor, + centerness_targets: Tensor, + ) -> Tensor: + return self.loss_bbox(pos_bbox_pred, pos_bbox_targets, weight=centerness_targets, avg_factor=1.0) + + def _get_loss_centerness( + self, + avg_factor: Tensor, + pos_centerness: Tensor, + centerness_targets: Tensor, + ) -> Tensor: + return self.loss_centerness(pos_centerness, centerness_targets, avg_factor=avg_factor) + + def _get_loss_cls( + self, + cls_score: Tensor, + labels: Tensor, + label_weights: Tensor, + valid_label_mask: Tensor, + avg_factor: Tensor, + ) -> Tensor: + if isinstance(self.loss_cls, CrossSigmoidFocalLoss): + loss_cls = self.loss_cls( + cls_score, + labels, + label_weights, + avg_factor=avg_factor, + valid_label_mask=valid_label_mask, + ) + else: + loss_cls = self.loss_cls(cls_score, labels, label_weights, avg_factor=avg_factor) + return loss_cls diff --git a/src/otx/algo/detection/losses/rtmdet_loss.py b/src/otx/algo/detection/losses/rtmdet_loss.py new file mode 100644 index 00000000000..691b810079a --- /dev/null +++ b/src/otx/algo/detection/losses/rtmdet_loss.py @@ -0,0 +1,114 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) OpenMMLab. All rights reserved. +# +"""RTMDet criterion.""" + +from __future__ import annotations + +from torch import Tensor, nn + +from otx.algo.common.utils.utils import reduce_mean + + +class RTMDetCriterion(nn.Module): + """RTMDetCriterion is a criterion module for RTM-based object detection. + + Args: + num_classes (int): Number of object classes. + loss_cls (nn.Module): Classification loss module. + loss_bbox (nn.Module): Bounding box regression loss module. + """ + + def __init__(self, num_classes: int, loss_cls: nn.Module, loss_bbox: nn.Module) -> None: + super().__init__() + self.num_classes = num_classes + self.loss_cls = loss_cls + self.loss_bbox = loss_bbox + self.use_sigmoid_cls = loss_cls.use_sigmoid + if self.use_sigmoid_cls: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + + if self.cls_out_channels <= 0: + msg = f"num_classes={num_classes} is too small" + raise ValueError(msg) + + def forward( # type: ignore[override] + self, + cls_score: Tensor, + bbox_pred: Tensor, + labels: Tensor, + label_weights: Tensor, + bbox_targets: Tensor, + assign_metrics: Tensor, + stride: list[int], + **kwargs, + ) -> dict[str, Tensor]: + """Compute loss of a single scale level. + + Args: + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Decoded bboxes for each scale + level with shape (N, num_anchors * 4, H, W). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors). + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + assign_metrics (Tensor): Assign metrics with shape + (N, num_total_anchors). + stride (list[int]): Downsample stride of the feature map. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + if stride[0] != stride[1]: + msg = "h stride is not equal to w stride!" + raise ValueError(msg) + cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels).contiguous() + bbox_pred = bbox_pred.reshape(-1, 4) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + assign_metrics = assign_metrics.reshape(-1) + label_weights = label_weights.reshape(-1) + targets = (labels, assign_metrics) + + loss_cls = self.loss_cls(cls_score, targets, label_weights, avg_factor=1.0) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + + pos_decode_bbox_pred = pos_bbox_pred + pos_decode_bbox_targets = pos_bbox_targets + + # regression loss + pos_bbox_weight = assign_metrics[pos_inds] + + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=pos_bbox_weight, + avg_factor=1.0, + ) + else: + loss_bbox = bbox_pred.sum() * 0 + pos_bbox_weight = bbox_targets.new_tensor(0.0) + + cls_avg_factors = assign_metrics.sum() + bbox_avg_factors = pos_bbox_weight.sum() + + cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() + loss_cls = [x / cls_avg_factor for x in loss_cls] + + bbox_avg_factor = reduce_mean(sum(bbox_avg_factors)).clamp_(min=1).item() + loss_bbox = [x / bbox_avg_factor for x in loss_bbox] + return {"loss_cls": loss_cls, "loss_bbox": loss_bbox} diff --git a/src/otx/algo/detection/losses/ssd_loss.py b/src/otx/algo/detection/losses/ssd_loss.py new file mode 100644 index 00000000000..3b6d2c7b27c --- /dev/null +++ b/src/otx/algo/detection/losses/ssd_loss.py @@ -0,0 +1,107 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) OpenMMLab. All rights reserved. +# +"""SSD criterion.""" + +from __future__ import annotations + +from torch import Tensor, nn + +from otx.algo.common.losses import smooth_l1_loss + + +class SSDCriterion(nn.Module): + """SSDCriterion is a loss criterion for Single Shot MultiBox Detector (SSD). + + Args: + num_classes (int): Number of classes including the background class. + bbox_coder (nn.Module): Bounding box coder module. Defaults to None. + neg_pos_ratio (int, optional): Ratio of negative to positive samples. Defaults to 3. + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Defaults to False. It should be `True` when + using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. + smoothl1_beta (float, optional): Beta parameter for the smooth L1 loss. Defaults to 1.0. + """ + + def __init__( + self, + num_classes: int, + bbox_coder: nn.Module | None = None, + neg_pos_ratio: int = 3, + reg_decoded_bbox: bool = False, + smoothl1_beta: float = 1.0, + ) -> None: + super().__init__() + self.num_classes = num_classes + self.bbox_coder = bbox_coder + self.neg_pos_ratio = neg_pos_ratio + self.reg_decoded_bbox = reg_decoded_bbox + self.smoothl1_beta = smoothl1_beta + + def forward( + self, + cls_score: Tensor, + bbox_pred: Tensor, + anchor: Tensor, + labels: Tensor, + label_weights: Tensor, + bbox_targets: Tensor, + bbox_weights: Tensor, + avg_factor: int, + ) -> dict[str, Tensor]: + """Compute loss of a single image. + + Args: + cls_score (Tensor): Box scores for each image has shape (num_total_anchors, num_classes). + bbox_pred (Tensor): Box energies / deltas for each image level with shape (num_total_anchors, 4). + anchors (Tensor): Box reference for each scale level with shape (num_total_anchors, 4). + labels (Tensor): Labels of each anchors with shape (num_total_anchors,). + label_weights (Tensor): Label weights of each anchor with shape (num_total_anchors,) + bbox_targets (Tensor): BBox regression targets of each anchor with shape (num_total_anchors, 4). + bbox_weights (Tensor): BBox regression loss weights of each anchor with shape (num_total_anchors, 4). + avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + + Returns: + dict[str, Tensor]: A dictionary of loss components. the dict + has components below: + + - loss_cls (list[Tensor]): A list containing each feature map \ + classification loss. + - loss_bbox (list[Tensor]): A list containing each feature map \ + regression loss. + """ + loss_cls_all = nn.functional.cross_entropy(cls_score, labels, reduction="none") * label_weights + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(as_tuple=False).reshape(-1) + neg_inds = (labels == self.num_classes).nonzero(as_tuple=False).view(-1) + + num_pos_samples = pos_inds.size(0) + num_neg_samples = self.neg_pos_ratio * num_pos_samples + if num_neg_samples > neg_inds.size(0): + num_neg_samples = neg_inds.size(0) + topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples) + loss_cls_pos = loss_cls_all[pos_inds].sum() + loss_cls_neg = topk_loss_cls_neg.sum() + loss_cls = (loss_cls_pos + loss_cls_neg) / avg_factor + + if self.reg_decoded_bbox and self.bbox_coder is not None: + # When the regression loss (e.g. `IouLoss`, `GIouLoss`) + # is applied directly on the decoded bounding boxes, it + # decodes the already encoded coordinates to absolute format. + bbox_pred = self.bbox_coder.decode(anchor, bbox_pred) + + loss_bbox = smooth_l1_loss( + bbox_pred, + bbox_targets, + bbox_weights, + beta=self.smoothl1_beta, + avg_factor=avg_factor, + ) + return {"loss_cls": loss_cls[None], "loss_bbox": loss_bbox} diff --git a/src/otx/algo/detection/losses/yolox_loss.py b/src/otx/algo/detection/losses/yolox_loss.py new file mode 100644 index 00000000000..350688b3b4b --- /dev/null +++ b/src/otx/algo/detection/losses/yolox_loss.py @@ -0,0 +1,108 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) OpenMMLab. All rights reserved. +# +"""YOLOX criterion.""" + +from __future__ import annotations + +from torch import Tensor, nn + +from otx.algo.common.losses.cross_entropy_loss import CrossEntropyLoss +from otx.algo.common.losses.smooth_l1_loss import L1Loss +from otx.algo.detection.losses.iou_loss import IoULoss + + +class YOLOXCriterion(nn.Module): + """YOLOX criterion module. + + This module calculates the loss for YOLOX object detection model. + + Args: + num_classes (int): The number of classes. + loss_cls (nn.Module | None): The classification loss module. Defaults to None. + loss_bbox (nn.Module | None): The bounding box regression loss module. Defaults to None. + loss_obj (nn.Module | None): The objectness loss module. Defaults to None. + loss_l1 (nn.Module | None): The L1 loss module. Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary containing the calculated losses. + + """ + + def __init__( + self, + num_classes: int, + loss_cls: nn.Module | None = None, + loss_bbox: nn.Module | None = None, + loss_obj: nn.Module | None = None, + loss_l1: nn.Module | None = None, + ) -> None: + super().__init__() + self.num_classes = num_classes + self.loss_cls = loss_cls or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0) + self.loss_bbox = loss_bbox or IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0) + self.loss_obj = loss_obj or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0) + self.loss_l1 = loss_l1 or L1Loss(reduction="sum", loss_weight=1.0) + + def forward( + self, + flatten_objectness: Tensor, + flatten_cls_preds: Tensor, + flatten_bbox_preds: Tensor, + flatten_bboxes: Tensor, + obj_targets: Tensor, + cls_targets: Tensor, + bbox_targets: Tensor, + l1_targets: Tensor, + num_total_samples: Tensor, + num_pos: Tensor, + pos_masks: Tensor, + ) -> dict[str, Tensor]: + """Forward pass of the YOLOX criterion module. + + Args: + flatten_objectness (Tensor): Flattened objectness predictions. + flatten_cls_preds (Tensor): Flattened class predictions. + flatten_bbox_preds (Tensor): Flattened bounding box predictions. + flatten_bboxes (Tensor): Flattened ground truth bounding boxes. + obj_targets (Tensor): Objectness targets. + cls_targets (Tensor): Class targets. + bbox_targets (Tensor): Bounding box targets. + l1_targets (Tensor): L1 targets. + num_total_samples (Tensor): Total number of samples. + num_pos (Tensor): Number of positive samples. + pos_masks (Tensor): Positive masks. + + Returns: + dict[str, Tensor]: A dictionary containing the calculated losses. + + """ + loss_obj = self.loss_obj(flatten_objectness.view(-1, 1), obj_targets) / num_total_samples + if num_pos > 0: + loss_cls = ( + self.loss_cls(flatten_cls_preds.view(-1, self.num_classes)[pos_masks], cls_targets) / num_total_samples + ) + loss_bbox = self.loss_bbox(flatten_bboxes.view(-1, 4)[pos_masks], bbox_targets) / num_total_samples + else: + # Avoid cls and reg branch not participating in the gradient + # propagation when there is no ground-truth in the images. + # For more details, please refer to + # https://github.com/open-mmlab/mmdetection/issues/7298 + loss_cls = flatten_cls_preds.sum() * 0 + loss_bbox = flatten_bboxes.sum() * 0 + + loss_dict = {"loss_cls": loss_cls, "loss_bbox": loss_bbox, "loss_obj": loss_obj} + + if self.use_l1: + if num_pos > 0: + loss_l1 = self.loss_l1(flatten_bbox_preds.view(-1, 4)[pos_masks], l1_targets) / num_total_samples + else: + # Avoid cls and reg branch not participating in the gradient + # propagation when there is no ground-truth in the images. + # For more details, please refer to + # https://github.com/open-mmlab/mmdetection/issues/7298 + loss_l1 = flatten_bbox_preds.sum() * 0 + loss_dict.update(loss_l1=loss_l1) + + return loss_dict diff --git a/src/otx/algo/detection/rtmdet.py b/src/otx/algo/detection/rtmdet.py index 791eb863697..855d9828404 100644 --- a/src/otx/algo/detection/rtmdet.py +++ b/src/otx/algo/detection/rtmdet.py @@ -12,13 +12,13 @@ from otx.algo.common.backbones import CSPNeXt from otx.algo.common.losses import GIoULoss, QualityFocalLoss -from otx.algo.common.losses.cross_entropy_loss import CrossEntropyLoss from otx.algo.common.utils.assigners import DynamicSoftLabelAssigner from otx.algo.common.utils.coders import DistancePointBBoxCoder from otx.algo.common.utils.prior_generators import MlvlPointGenerator from otx.algo.common.utils.samplers import PseudoSampler from otx.algo.detection.detectors import SingleStageDetector from otx.algo.detection.heads import RTMDetSepBNHead +from otx.algo.detection.losses import RTMDetCriterion from otx.algo.detection.necks import CSPNeXtPAFPN from otx.core.config.data import TileConfig from otx.core.exporter.base import OTXModelExporter @@ -144,19 +144,25 @@ def _build_model(self, num_classes: int) -> RTMDet: with_objectness=False, anchor_generator=MlvlPointGenerator(offset=0, strides=[8, 16, 32]), bbox_coder=DistancePointBBoxCoder(), - loss_cls=QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0), - loss_bbox=GIoULoss(loss_weight=2.0), - loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0), norm_cfg={"type": "BN"}, activation_callable=partial(nn.SiLU, inplace=True), train_cfg=train_cfg, test_cfg=test_cfg, + loss_cls=QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0), # TODO (eugene): deprecated + loss_bbox=GIoULoss(loss_weight=2.0), # TODO (eugene): deprecated + ) + + criterion = RTMDetCriterion( + num_classes=num_classes, + loss_cls=QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0), + loss_bbox=GIoULoss(loss_weight=2.0), ) return SingleStageDetector( backbone=backbone, neck=neck, bbox_head=bbox_head, + criterion=criterion, train_cfg=train_cfg, test_cfg=test_cfg, ) diff --git a/src/otx/algo/detection/ssd.py b/src/otx/algo/detection/ssd.py index 775a45a2aa2..df765f58f98 100644 --- a/src/otx/algo/detection/ssd.py +++ b/src/otx/algo/detection/ssd.py @@ -20,6 +20,7 @@ from otx.algo.common.utils.coders import DeltaXYWHBBoxCoder from otx.algo.detection.detectors import SingleStageDetector from otx.algo.detection.heads import SSDHead +from otx.algo.detection.losses import SSDCriterion from otx.algo.detection.utils.prior_generators import SSDAnchorGeneratorClustered from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.config.data import TileConfig @@ -81,10 +82,8 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: pos_iou_thr=0.4, neg_iou_thr=0.4, ), - "smoothl1_beta": 1.0, "allowed_border": -1, "pos_weight": -1, - "neg_pos_ratio": 3, "debug": False, "use_giou": False, "use_focal": False, @@ -116,10 +115,6 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: [587.6216059488938, 381.60024152086544, 323.5988913027747, 702.7486097568518, 741.4865860938451], ], ), - bbox_coder=DeltaXYWHBBoxCoder( - target_means=(0.0, 0.0, 0.0, 0.0), - target_stds=(0.1, 0.1, 0.2, 0.2), - ), num_classes=num_classes, in_channels=(96, 320), use_depthwise=True, @@ -127,7 +122,20 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: train_cfg=train_cfg, test_cfg=test_cfg, ) - return SingleStageDetector(backbone, bbox_head, train_cfg=train_cfg, test_cfg=test_cfg) + criterion = SSDCriterion( + num_classes=num_classes, + bbox_coder=DeltaXYWHBBoxCoder( + target_means=(0.0, 0.0, 0.0, 0.0), + target_stds=(0.1, 0.1, 0.2, 0.2), + ), + ) + return SingleStageDetector( + backbone=backbone, + bbox_head=bbox_head, + criterion=criterion, + train_cfg=train_cfg, + test_cfg=test_cfg, + ) def setup(self, stage: str) -> None: """Callback for setup OTX SSD Model. diff --git a/src/otx/algo/detection/yolox.py b/src/otx/algo/detection/yolox.py index a5e1583d3ca..a5760cb29c3 100644 --- a/src/otx/algo/detection/yolox.py +++ b/src/otx/algo/detection/yolox.py @@ -11,7 +11,7 @@ from otx.algo.detection.backbones import CSPDarknet from otx.algo.detection.detectors import SingleStageDetector from otx.algo.detection.heads import YOLOXHead -from otx.algo.detection.losses import IoULoss +from otx.algo.detection.losses import IoULoss, YOLOXCriterion from otx.algo.detection.necks import YOLOXPAFPN from otx.algo.detection.utils.assigners import SimOTAAssigner from otx.algo.utils.support_otx_v1 import OTXv1Helper @@ -184,14 +184,24 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: num_classes=num_classes, in_channels=96, feat_channels=96, + train_cfg=train_cfg, + test_cfg=test_cfg, + ) + criterion = YOLOXCriterion( + num_classes=num_classes, loss_cls=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0), loss_bbox=IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0), loss_obj=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0), loss_l1=L1Loss(reduction="sum", loss_weight=1.0), + ) + return SingleStageDetector( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + criterion=criterion, train_cfg=train_cfg, test_cfg=test_cfg, ) - return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg) class YOLOXS(YOLOX): @@ -221,14 +231,24 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: num_classes=num_classes, in_channels=128, feat_channels=128, + train_cfg=train_cfg, + test_cfg=test_cfg, + ) + criterion = YOLOXCriterion( + num_classes=num_classes, loss_cls=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0), loss_bbox=IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0), loss_obj=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0), loss_l1=L1Loss(reduction="sum", loss_weight=1.0), + ) + return SingleStageDetector( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + criterion=criterion, train_cfg=train_cfg, test_cfg=test_cfg, ) - return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg) class YOLOXL(YOLOX): @@ -253,14 +273,24 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: bbox_head = YOLOXHead( num_classes=num_classes, in_channels=256, + train_cfg=train_cfg, + test_cfg=test_cfg, + ) + criterion = YOLOXCriterion( + num_classes=num_classes, loss_cls=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0), loss_bbox=IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0), loss_obj=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0), loss_l1=L1Loss(reduction="sum", loss_weight=1.0), + ) + return SingleStageDetector( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + criterion=criterion, train_cfg=train_cfg, test_cfg=test_cfg, ) - return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg) class YOLOXX(YOLOX): @@ -290,11 +320,21 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: num_classes=num_classes, in_channels=320, feat_channels=320, + train_cfg=train_cfg, + test_cfg=test_cfg, + ) + criterion = YOLOXCriterion( + num_classes=num_classes, loss_cls=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0), loss_bbox=IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0), loss_obj=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0), loss_l1=L1Loss(reduction="sum", loss_weight=1.0), + ) + return SingleStageDetector( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + criterion=criterion, train_cfg=train_cfg, test_cfg=test_cfg, ) - return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg) diff --git a/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py b/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py index 87c5c3dbc99..e8db7790366 100644 --- a/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py +++ b/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py @@ -21,10 +21,8 @@ from otx.algo.common.utils.nms import batched_nms, multiclass_nms from otx.algo.common.utils.utils import ( - distance2bbox, filter_scores_and_topk, inverse_sigmoid, - multi_apply, reduce_mean, select_single_mlvl, ) @@ -49,7 +47,6 @@ class RTMDetInsHead(RTMDetHead): """Detection Head of RTMDet-Ins. Args: - loss_mask (nn.Module): A module for mask loss. num_prototypes (int): Number of mask prototype features extracted from the mask head. Defaults to 8. dyconv_channels (int): Channel of the dynamic conv layers. @@ -63,7 +60,6 @@ class RTMDetInsHead(RTMDetHead): def __init__( self, *args, - loss_mask: nn.Module, num_prototypes: int = 8, dyconv_channels: int = 8, num_dyconvs: int = 3, @@ -75,7 +71,6 @@ def __init__( self.dyconv_channels = dyconv_channels self.mask_loss_stride = mask_loss_stride super().__init__(*args, **kwargs) - self.loss_mask = loss_mask def _init_layers(self) -> None: """Initialize layers of the head.""" @@ -540,7 +535,7 @@ def loss_mask_by_feat( flatten_kernels: Tensor, sampling_results_list: list, batch_gt_instances: list[InstanceData], - ) -> Tensor: + ) -> dict[str, Tensor]: """Compute instance segmentation loss. Args: @@ -555,7 +550,7 @@ def loss_mask_by_feat( attributes. Returns: - Tensor: The mask loss tensor. + dict[str, Tensor]: A dictionary of raw outputs. """ batch_pos_mask_logits = [] pos_gt_masks = [] @@ -611,7 +606,11 @@ def loss_mask_by_feat( self.mask_loss_stride // 2 :: self.mask_loss_stride, ] - return self.loss_mask(batch_pos_mask_logits, pos_gt_masks, weight=None, avg_factor=num_pos) + return { + "batch_pos_mask_logits": batch_pos_mask_logits, + "pos_gt_masks": pos_gt_masks, + "num_pos": num_pos, + } def loss_by_feat( self, @@ -625,17 +624,15 @@ def loss_by_feat( ) -> dict[str, Tensor]: """Compute losses of the head.""" num_imgs = len(batch_img_metas) - featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - if len(featmap_sizes) != self.prior_generator.num_levels: - msg = "The number of featmap sizes should be equal to the number of levels." - raise ValueError(msg) - device = cls_scores[0].device - anchor_list, valid_flag_list = self.get_anchors(featmap_sizes, batch_img_metas, device=device) - flatten_cls_scores = torch.cat( - [cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels) for cls_score in cls_scores], - 1, + raw_outputs = super().loss_by_feat( + cls_scores=cls_scores, + bbox_preds=bbox_preds, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, ) + flatten_kernels = torch.cat( [ kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_gen_params) @@ -643,14 +640,7 @@ def loss_by_feat( ], 1, ) - decoded_bboxes = [] - for anchor, bbox_pred in zip(anchor_list[0], bbox_preds): - anchor = anchor.reshape(-1, 4) # noqa: PLW2901 - bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) # noqa: PLW2901 - bbox_pred = distance2bbox(anchor, bbox_pred) # noqa: PLW2901 - decoded_bboxes.append(bbox_pred) - - flatten_bboxes = torch.cat(decoded_bboxes, 1) + # Convert polygon masks to bitmap masks if isinstance(batch_gt_instances[0].masks[0], Polygon): for gt_instances, img_meta in zip(batch_gt_instances, batch_img_metas): @@ -659,43 +649,15 @@ def loss_by_feat( ndarray_masks = np.empty((0, *img_meta["img_shape"]), dtype=np.uint8) gt_instances.masks = torch.tensor(ndarray_masks, dtype=torch.bool, device=device) - cls_reg_targets = self.get_targets( - flatten_cls_scores, - flatten_bboxes, - anchor_list, - valid_flag_list, + raw_iseg_outputs = self.loss_mask_by_feat( + mask_feat, + flatten_kernels, + raw_outputs["sampling_results_list"], batch_gt_instances, - batch_img_metas, - batch_gt_instances_ignore=batch_gt_instances_ignore, ) - ( - anchor_list, - labels_list, - label_weights_list, - bbox_targets_list, - assign_metrics_list, - sampling_results_list, - ) = cls_reg_targets - - losses_cls, losses_bbox, cls_avg_factors, bbox_avg_factors = multi_apply( - self.loss_by_feat_single, - cls_scores, - decoded_bboxes, - labels_list, - label_weights_list, - bbox_targets_list, - assign_metrics_list, - self.prior_generator.strides, - ) - - cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() - losses_cls = [x / cls_avg_factor for x in losses_cls] - - bbox_avg_factor = reduce_mean(sum(bbox_avg_factors)).clamp_(min=1).item() - losses_bbox = [x / bbox_avg_factor for x in losses_bbox] + raw_outputs.update(raw_iseg_outputs) - loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels, sampling_results_list, batch_gt_instances) - return {"loss_cls": losses_cls, "loss_bbox": losses_bbox, "loss_mask": loss_mask} + return raw_outputs class MaskFeatModule(BaseModule): diff --git a/src/otx/algo/instance_segmentation/losses/__init__.py b/src/otx/algo/instance_segmentation/losses/__init__.py index 903ea4c077e..bcd3bf03a51 100644 --- a/src/otx/algo/instance_segmentation/losses/__init__.py +++ b/src/otx/algo/instance_segmentation/losses/__init__.py @@ -5,5 +5,6 @@ from .accuracy import accuracy from .dice_loss import DiceLoss +from .rtmdet_inst_loss import RTMDetInstCriterion -__all__ = ["accuracy", "DiceLoss"] +__all__ = ["accuracy", "DiceLoss", "RTMDetInstCriterion"] diff --git a/src/otx/algo/instance_segmentation/losses/rtmdet_inst_loss.py b/src/otx/algo/instance_segmentation/losses/rtmdet_inst_loss.py new file mode 100644 index 00000000000..3935c0990db --- /dev/null +++ b/src/otx/algo/instance_segmentation/losses/rtmdet_inst_loss.py @@ -0,0 +1,89 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) OpenMMLab. All rights reserved. +# +"""RTMDet for instance segmentation criterion.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from otx.algo.detection.losses import RTMDetCriterion + +if TYPE_CHECKING: + from torch import Tensor, nn + + +class RTMDetInstCriterion(RTMDetCriterion): + """Criterion of RTMDet for instance segmentation. + + Args: + num_classes (int): Number of object classes. + loss_cls (nn.Module): Classification loss module. + loss_bbox (nn.Module): Bounding box regression loss module. + loss_mask (nn.Module): Mask loss module. + """ + + def __init__( + self, + num_classes: int, + loss_cls: nn.Module, + loss_bbox: nn.Module, + loss_mask: nn.Module, + ) -> None: + super().__init__(num_classes, loss_cls, loss_bbox) + self.loss_mask = loss_mask + + def forward( + self, + cls_score: Tensor, + bbox_pred: Tensor, + labels: Tensor, + label_weights: Tensor, + bbox_targets: Tensor, + assign_metrics: Tensor, + stride: list[int], + **kwargs, + ) -> dict[str, Tensor]: + """Compute loss of a single scale level. + + Args: + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Decoded bboxes for each scale + level with shape (N, num_anchors * 4, H, W). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors). + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + assign_metrics (Tensor): Assign metrics with shape + (N, num_total_anchors). + stride (list[int]): Downsample stride of the feature map. + batch_pos_mask_logits (Tensor): The prediction, has a shape (n, *). + pos_gt_masks (Tensor): The label of the prediction, + shape (n, *), same shape of pred. + num_pos (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + loss_dict = super().forward( + cls_score=cls_score, + bbox_pred=bbox_pred, + labels=labels, + label_weights=label_weights, + bbox_targets=bbox_targets, + assign_metrics=assign_metrics, + stride=stride, + ) + + batch_pos_mask_logits: Tensor = kwargs.pop("batch_pos_mask_logits") + pos_gt_masks: Tensor = kwargs.pop("pos_gt_masks") + num_pos: int = kwargs.pop("num_pos") + + loss_mask = self.loss_mask(batch_pos_mask_logits, pos_gt_masks, weight=None, avg_factor=num_pos) + loss_dict.update({"loss_mask": loss_mask}) + return loss_dict diff --git a/src/otx/algo/instance_segmentation/rtmdet_inst.py b/src/otx/algo/instance_segmentation/rtmdet_inst.py index c3ff45ce5f6..f2d610488e5 100644 --- a/src/otx/algo/instance_segmentation/rtmdet_inst.py +++ b/src/otx/algo/instance_segmentation/rtmdet_inst.py @@ -19,7 +19,7 @@ from otx.algo.detection.detectors import SingleStageDetector from otx.algo.detection.necks import CSPNeXtPAFPN from otx.algo.instance_segmentation.heads import RTMDetInsSepBNHead -from otx.algo.instance_segmentation.losses import DiceLoss +from otx.algo.instance_segmentation.losses import DiceLoss, RTMDetInstCriterion from otx.core.config.data import TileConfig from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter @@ -168,6 +168,11 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: ), bbox_coder=DistancePointBBoxCoder(), loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0), + train_cfg=train_cfg, + test_cfg=test_cfg, + ) + criterion = RTMDetInstCriterion( + num_classes=num_classes, loss_cls=QualityFocalLoss( use_sigmoid=True, beta=2.0, @@ -179,14 +184,13 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: eps=5.0e-06, reduction="mean", ), - train_cfg=train_cfg, - test_cfg=test_cfg, ) return SingleStageDetector( backbone=backbone, neck=neck, bbox_head=bbox_head, + criterion=criterion, train_cfg=train_cfg, test_cfg=test_cfg, )