diff --git a/src/otx/algo/detection/atss.py b/src/otx/algo/detection/atss.py
index 20dda84c364..bc49d4f0834 100644
--- a/src/otx/algo/detection/atss.py
+++ b/src/otx/algo/detection/atss.py
@@ -14,6 +14,7 @@
 from otx.algo.common.utils.samplers import PseudoSampler
 from otx.algo.detection.detectors import SingleStageDetector
 from otx.algo.detection.heads import ATSSHead
+from otx.algo.detection.losses import ATSSCriterion
 from otx.algo.detection.necks import FPN
 from otx.algo.detection.utils.assigners import ATSSAssigner
 from otx.algo.utils.support_otx_v1 import OTXv1Helper
@@ -145,6 +146,23 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
                 target_means=(0.0, 0.0, 0.0, 0.0),
                 target_stds=(0.1, 0.1, 0.2, 0.2),
             ),
+            feat_channels=64,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+            loss_cls=CrossSigmoidFocalLoss(  # TODO (eugene): deprecated
+                use_sigmoid=True,
+                gamma=2.0,
+                alpha=0.25,
+                loss_weight=1.0,
+            ),
+            loss_bbox=GIoULoss(loss_weight=2.0),  # TODO (eugene): deprecated
+        )
+        criterion = ATSSCriterion(
+            num_classes=num_classes,
+            bbox_coder=DeltaXYWHBBoxCoder(
+                target_means=(0.0, 0.0, 0.0, 0.0),
+                target_stds=(0.1, 0.1, 0.2, 0.2),
+            ),
             loss_cls=CrossSigmoidFocalLoss(
                 use_sigmoid=True,
                 gamma=2.0,
@@ -153,11 +171,15 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
             ),
             loss_bbox=GIoULoss(loss_weight=2.0),
             loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0),
-            feat_channels=64,
+        )
+        return SingleStageDetector(
+            backbone=backbone,
+            neck=neck,
+            bbox_head=bbox_head,
+            criterion=criterion,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
         )
-        return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg)
 
 
 class ResNeXt101ATSS(ATSS):
@@ -210,6 +232,24 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
                 target_means=(0.0, 0.0, 0.0, 0.0),
                 target_stds=(0.1, 0.1, 0.2, 0.2),
             ),
+            num_classes=num_classes,
+            in_channels=256,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+            loss_cls=CrossSigmoidFocalLoss(  # TODO (eugene): deprecated
+                use_sigmoid=True,
+                gamma=2.0,
+                alpha=0.25,
+                loss_weight=1.0,
+            ),
+            loss_bbox=GIoULoss(loss_weight=2.0),  # TODO (eugene): deprecated
+        )
+        criterion = ATSSCriterion(
+            num_classes=num_classes,
+            bbox_coder=DeltaXYWHBBoxCoder(
+                target_means=(0.0, 0.0, 0.0, 0.0),
+                target_stds=(0.1, 0.1, 0.2, 0.2),
+            ),
             loss_cls=CrossSigmoidFocalLoss(
                 use_sigmoid=True,
                 gamma=2.0,
@@ -218,12 +258,15 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
             ),
             loss_bbox=GIoULoss(loss_weight=2.0),
             loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0),
-            num_classes=num_classes,
-            in_channels=256,
+        )
+        return SingleStageDetector(
+            backbone=backbone,
+            neck=neck,
+            bbox_head=bbox_head,
+            criterion=criterion,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
         )
-        return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg)
 
     def to(self, *args, **kwargs) -> Self:
         """Return a model with specified device."""
diff --git a/src/otx/algo/detection/detectors/single_stage_detector.py b/src/otx/algo/detection/detectors/single_stage_detector.py
index c83626aa29c..b90d813e46a 100644
--- a/src/otx/algo/detection/detectors/single_stage_detector.py
+++ b/src/otx/algo/detection/detectors/single_stage_detector.py
@@ -26,6 +26,7 @@ class SingleStageDetector(BaseModule):
     Args:
         backbone (nn.Module): Backbone module.
         bbox_head (nn.Module): Bbox head module.
+        criterion (nn.Module | None, optional): Criterion module.
         neck (nn.Module | None, optional): Neck module. Defaults to None.
         train_cfg (dict | None, optional): Training config. Defaults to None.
         test_cfg (dict | None, optional): Test config. Defaults to None.
@@ -36,6 +37,7 @@ def __init__(
         self,
         backbone: nn.Module,
         bbox_head: nn.Module,
+        criterion: nn.Module,
         neck: nn.Module | None = None,
         train_cfg: dict | None = None,
         test_cfg: dict | None = None,
@@ -46,6 +48,7 @@ def __init__(
         self.backbone = backbone
         self.bbox_head = bbox_head
         self.neck = neck
+        self.criterion = criterion
         self.init_cfg = init_cfg
         self.train_cfg = train_cfg
         self.test_cfg = test_cfg
@@ -129,7 +132,7 @@ def forward(
     def loss(
         self,
         entity: DetBatchDataEntity,
-    ) -> dict | list:
+    ) -> dict:
         """Calculate losses from a batch of inputs and data samples.
 
         Args:
@@ -143,7 +146,10 @@ def loss(
             dict: A dictionary of loss components.
         """
         x = self.extract_feat(entity.images)
-        return self.bbox_head.loss(x, entity)
+        # TODO (sungchul): compare .loss with other forwards and remove duplicated code
+        outputs = self.bbox_head.loss(x, entity)
+
+        return self.criterion(outputs)
 
     def predict(
         self,
diff --git a/src/otx/algo/detection/heads/anchor_head.py b/src/otx/algo/detection/heads/anchor_head.py
index 115741c7619..42dbe247399 100644
--- a/src/otx/algo/detection/heads/anchor_head.py
+++ b/src/otx/algo/detection/heads/anchor_head.py
@@ -31,7 +31,9 @@ class AnchorHead(BaseDenseHead):
         anchor_generator (nn.Module): Module for anchor generator
         bbox_coder (nn.Module): Module of bounding box coder.
         loss_cls (nn.Module): Module of classification loss.
+            It is related to RPNHead for iseg, will be deprecated.
         loss_bbox (nn.Module): Module of localization loss.
+            It is related to RPNHead for iseg, will be deprecated.
         train_cfg (dict): Training config of anchor head.
         test_cfg (dict, optional): Testing config of anchor head.
         feat_channels (int): Number of hidden channels. Used in child classes.
@@ -49,8 +51,8 @@ def __init__(
         in_channels: tuple[int, ...] | int,
         anchor_generator: nn.Module,
         bbox_coder: nn.Module,
-        loss_cls: nn.Module,
-        loss_bbox: nn.Module,
+        loss_cls: nn.Module,  # TODO (eugene): deprecated
+        loss_bbox: nn.Module,  # TODO (eugene): deprecated
         train_cfg: dict,
         test_cfg: dict | None = None,
         feat_channels: int = 256,
@@ -410,6 +412,8 @@ def loss_by_feat_single(
     ) -> tuple:
         """Calculate the loss of a single scale level based on the features extracted by the detection head.
 
+        TODO (eugene): it is related to RPNHead for iseg, will be deprecated
+
         Args:
             cls_score (Tensor): Box scores for each scale level
                 Has shape (N, num_anchors * num_classes, H, W).
@@ -459,6 +463,8 @@ def loss_by_feat(
     ) -> dict:
         """Calculate the loss based on the features extracted by the detection head.
 
+        TODO (eugene): it is related to RPNHead for iseg, will be deprecated
+
         Args:
             cls_scores (list[Tensor]): Box scores for each scale level
                 has shape (N, num_anchors * num_classes, H, W).
diff --git a/src/otx/algo/detection/heads/atss_head.py b/src/otx/algo/detection/heads/atss_head.py
index 9d85dbf0b77..67fb7a06175 100644
--- a/src/otx/algo/detection/heads/atss_head.py
+++ b/src/otx/algo/detection/heads/atss_head.py
@@ -11,8 +11,6 @@
 import torch
 from torch import Tensor, nn
 
-from otx.algo.common.losses import CrossEntropyLoss, CrossSigmoidFocalLoss
-from otx.algo.common.utils.bbox_overlaps import bbox_overlaps
 from otx.algo.common.utils.utils import multi_apply, reduce_mean
 from otx.algo.detection.heads.anchor_head import AnchorHead
 from otx.algo.detection.heads.class_incremental_mixin import (
@@ -46,7 +44,6 @@ class ATSSHead(ClassIncrementalMixin, AnchorHead):
             the predicted boxes and regression targets to absolute
             coordinates format. Defaults to False. It should be `True` when
             using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
-        loss_centerness (nn.Module, optinoal): Module of centerness loss. Defaults to None.
         init_cfg (dict, list[dict], optional): Initialization config dict.
     """
 
@@ -58,11 +55,7 @@ def __init__(
         stacked_convs: int = 4,
         norm_cfg: dict | None = None,
         reg_decoded_bbox: bool = True,
-        loss_centerness: nn.Module | None = None,
         init_cfg: dict | None = None,
-        bg_loss_weight: float = -1.0,
-        use_qfl: bool = False,
-        qfl_cfg: dict | None = None,
         **kwargs,
     ) -> None:
         self.pred_kernel_size = pred_kernel_size
@@ -83,21 +76,6 @@ def __init__(
         )
 
         self.sampling = False
-        self.loss_centerness = loss_centerness or CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0)
-
-        if use_qfl:
-            kwargs["loss_cls"] = (
-                qfl_cfg
-                if qfl_cfg
-                else {
-                    "type": "QualityFocalLoss",
-                    "use_sigmoid": True,
-                    "beta": 2.0,
-                    "loss_weight": 1.0,
-                }
-            )
-        self.bg_loss_weight = bg_loss_weight
-        self.use_qfl = use_qfl
 
     def _init_layers(self) -> None:
         """Initialize layers of the head."""
@@ -221,7 +199,7 @@ def loss_by_feat(  # type: ignore[override]
                 Defaults to None.
 
         Returns:
-            dict[str, Tensor]: A dictionary of loss components.
+            dict[str, Tensor]: A dictionary of raw outputs.
         """
         featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
         if len(featmap_sizes) != self.prior_generator.num_levels:
@@ -250,182 +228,17 @@ def loss_by_feat(  # type: ignore[override]
         ) = cls_reg_targets
         avg_factor = reduce_mean(torch.tensor(avg_factor, dtype=torch.float, device=device)).item()
 
-        losses_cls, losses_bbox, loss_centerness, bbox_avg_factor = multi_apply(
-            self.loss_by_feat_single,
-            anchor_list,
-            cls_scores,
-            bbox_preds,
-            centernesses,
-            labels_list,
-            label_weights_list,
-            bbox_targets_list,
-            valid_label_mask,
-            avg_factor=avg_factor,
-        )
-
-        bbox_avg_factor = sum(bbox_avg_factor)
-        bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item()
-        losses_bbox = [loss_bbox / bbox_avg_factor for loss_bbox in losses_bbox]
-        return {"loss_cls": losses_cls, "loss_bbox": losses_bbox, "loss_centerness": loss_centerness}
-
-    def loss_by_feat_single(  # type: ignore[override]
-        self,
-        anchors: Tensor,
-        cls_score: Tensor,
-        bbox_pred: Tensor,
-        centerness: Tensor,
-        labels: Tensor,
-        label_weights: Tensor,
-        bbox_targets: Tensor,
-        valid_label_mask: Tensor,
-        avg_factor: float,
-    ) -> tuple:
-        """Compute loss of a single scale level.
-
-        Args:
-            anchors (Tensor): Box reference for each scale level with shape
-                (N, num_total_anchors, 4).
-            cls_score (Tensor): Box scores for each scale level
-                Has shape (N, num_anchors * num_classes, H, W).
-            bbox_pred (Tensor): Box energies / deltas for each scale
-                level with shape (N, num_anchors * 4, H, W).
-            centerness(Tensor): Centerness scores for each scale level.
-            labels (Tensor): Labels of each anchors with shape
-                (N, num_total_anchors).
-            label_weights (Tensor): Label weights of each anchor with shape
-                (N, num_total_anchors)
-            bbox_targets (Tensor): BBox regression targets of each anchor with
-                shape (N, num_total_anchors, 4).
-            valid_label_mask (Tensor): Label mask for consideration of ignored
-                label with shape (N, num_total_anchors, 1).
-            avg_factor (float): Average factor that is used to average
-                the loss. When using sampling method, avg_factor is usually
-                the sum of positive and negative priors. When using
-                `PseudoSampler`, `avg_factor` is usually equal to the number
-                of positive priors.
-
-        Returns:
-            tuple[Tensor]: A tuple of loss components.
-        """
-        anchors = anchors.reshape(-1, 4)
-        cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels).contiguous()
-        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
-        centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
-        bbox_targets = bbox_targets.reshape(-1, 4)
-        labels = labels.reshape(-1)
-        label_weights = label_weights.reshape(-1)
-        valid_label_mask = valid_label_mask.reshape(-1, self.cls_out_channels)
-
-        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
-        pos_inds = self._get_pos_inds(labels)
-
-        if self.use_qfl:
-            quality = label_weights.new_zeros(labels.shape)
-
-        if len(pos_inds) > 0:
-            pos_bbox_targets = bbox_targets[pos_inds]
-            pos_bbox_pred = bbox_pred[pos_inds]
-            pos_anchors = anchors[pos_inds]
-            pos_centerness = centerness[pos_inds]
-
-            centerness_targets = self.centerness_target(pos_anchors, pos_bbox_targets)
-            if self.reg_decoded_bbox:
-                pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred)
-
-            if self.use_qfl:
-                quality[pos_inds] = bbox_overlaps(pos_bbox_pred.detach(), pos_bbox_targets, is_aligned=True).clamp(
-                    min=1e-6,
-                )
-
-            # regression loss
-            loss_bbox = self._get_loss_bbox(pos_bbox_targets, pos_bbox_pred, centerness_targets)
-
-            # centerness loss
-            loss_centerness = self._get_loss_centerness(avg_factor, pos_centerness, centerness_targets)
-
-        else:
-            loss_bbox = bbox_pred.sum() * 0
-            loss_centerness = centerness.sum() * 0
-            centerness_targets = bbox_targets.new_tensor(0.0)
-
-        # Re-weigting BG loss
-        if self.bg_loss_weight >= 0.0:
-            neg_indices = labels == self.num_classes
-            label_weights[neg_indices] = self.bg_loss_weight
-
-        if self.use_qfl:
-            labels = (labels, quality)  # For quality focal loss arg spec
-
-        # classification loss
-        loss_cls = self._get_loss_cls(cls_score, labels, label_weights, valid_label_mask, avg_factor)
-
-        return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
-
-    def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor:
-        """Calculate the centerness between anchors and gts.
-
-        Only calculate pos centerness targets, otherwise there may be nan.
-
-        Args:
-            anchors (Tensor): Anchors with shape (N, 4), "xyxy" format.
-            gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format.
-
-        Returns:
-            Tensor: Centerness between anchors and gts.
-        """
-        anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
-        anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
-        l_ = anchors_cx - gts[:, 0]
-        t_ = anchors_cy - gts[:, 1]
-        r_ = gts[:, 2] - anchors_cx
-        b_ = gts[:, 3] - anchors_cy
-
-        left_right = torch.stack([l_, r_], dim=1)
-        top_bottom = torch.stack([t_, b_], dim=1)
-        return torch.sqrt(
-            (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
-            * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]),
-        )
-
-    def _get_pos_inds(self, labels: Tensor) -> Tensor:
-        bg_class_ind = self.num_classes
-        return ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1)
-
-    def _get_loss_cls(
-        self,
-        cls_score: Tensor,
-        labels: Tensor,
-        label_weights: Tensor,
-        valid_label_mask: Tensor,
-        avg_factor: Tensor,
-    ) -> Tensor:
-        if isinstance(self.loss_cls, CrossSigmoidFocalLoss):
-            loss_cls = self.loss_cls(
-                cls_score,
-                labels,
-                label_weights,
-                avg_factor=avg_factor,
-                valid_label_mask=valid_label_mask,
-            )
-        else:
-            loss_cls = self.loss_cls(cls_score, labels, label_weights, avg_factor=avg_factor)
-        return loss_cls
-
-    def _get_loss_centerness(
-        self,
-        avg_factor: Tensor,
-        pos_centerness: Tensor,
-        centerness_targets: Tensor,
-    ) -> Tensor:
-        return self.loss_centerness(pos_centerness, centerness_targets, avg_factor=avg_factor)
-
-    def _get_loss_bbox(
-        self,
-        pos_bbox_targets: Tensor,
-        pos_bbox_pred: Tensor,
-        centerness_targets: Tensor,
-    ) -> Tensor:
-        return self.loss_bbox(pos_bbox_pred, pos_bbox_targets, weight=centerness_targets, avg_factor=1.0)
+        return {
+            "anchors": anchor_list,
+            "cls_score": cls_scores,
+            "bbox_pred": bbox_preds,
+            "centerness": centernesses,
+            "labels": labels_list,
+            "label_weights": label_weights_list,
+            "bbox_targets": bbox_targets_list,
+            "valid_label_mask": valid_label_mask,
+            "avg_factor": avg_factor,
+        }
 
     def get_targets(
         self,
diff --git a/src/otx/algo/detection/heads/rtmdet_head.py b/src/otx/algo/detection/heads/rtmdet_head.py
index 71623aad9e7..c8590c8053f 100644
--- a/src/otx/algo/detection/heads/rtmdet_head.py
+++ b/src/otx/algo/detection/heads/rtmdet_head.py
@@ -14,7 +14,7 @@
 from torch import Tensor, nn
 
 from otx.algo.common.utils.nms import multiclass_nms
-from otx.algo.common.utils.utils import distance2bbox, inverse_sigmoid, multi_apply, reduce_mean
+from otx.algo.common.utils.utils import distance2bbox, inverse_sigmoid, multi_apply
 from otx.algo.detection.heads import ATSSHead
 from otx.algo.detection.utils.prior_generators.utils import anchor_inside_flags
 from otx.algo.detection.utils.utils import (
@@ -153,75 +153,6 @@ def forward(self, feats: tuple[Tensor, ...]) -> tuple:
             bbox_preds.append(reg_dist)
         return tuple(cls_scores), tuple(bbox_preds)
 
-    def loss_by_feat_single(  # type: ignore[override]
-        self,
-        cls_score: Tensor,
-        bbox_pred: Tensor,
-        labels: Tensor,
-        label_weights: Tensor,
-        bbox_targets: Tensor,
-        assign_metrics: Tensor,
-        stride: list[int],
-    ) -> tuple[Tensor, ...]:
-        """Compute loss of a single scale level.
-
-        Args:
-            cls_score (Tensor): Box scores for each scale level
-                Has shape (N, num_anchors * num_classes, H, W).
-            bbox_pred (Tensor): Decoded bboxes for each scale
-                level with shape (N, num_anchors * 4, H, W).
-            labels (Tensor): Labels of each anchors with shape
-                (N, num_total_anchors).
-            label_weights (Tensor): Label weights of each anchor with shape
-                (N, num_total_anchors).
-            bbox_targets (Tensor): BBox regression targets of each anchor with
-                shape (N, num_total_anchors, 4).
-            assign_metrics (Tensor): Assign metrics with shape
-                (N, num_total_anchors).
-            stride (list[int]): Downsample stride of the feature map.
-
-        Returns:
-            dict[str, Tensor]: A dictionary of loss components.
-        """
-        if stride[0] != stride[1]:
-            msg = "h stride is not equal to w stride!"
-            raise ValueError(msg)
-        cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels).contiguous()
-        bbox_pred = bbox_pred.reshape(-1, 4)
-        bbox_targets = bbox_targets.reshape(-1, 4)
-        labels = labels.reshape(-1)
-        assign_metrics = assign_metrics.reshape(-1)
-        label_weights = label_weights.reshape(-1)
-        targets = (labels, assign_metrics)
-
-        loss_cls = self.loss_cls(cls_score, targets, label_weights, avg_factor=1.0)
-
-        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
-        bg_class_ind = self.num_classes
-        pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1)
-
-        if len(pos_inds) > 0:
-            pos_bbox_targets = bbox_targets[pos_inds]
-            pos_bbox_pred = bbox_pred[pos_inds]
-
-            pos_decode_bbox_pred = pos_bbox_pred
-            pos_decode_bbox_targets = pos_bbox_targets
-
-            # regression loss
-            pos_bbox_weight = assign_metrics[pos_inds]
-
-            loss_bbox = self.loss_bbox(
-                pos_decode_bbox_pred,
-                pos_decode_bbox_targets,
-                weight=pos_bbox_weight,
-                avg_factor=1.0,
-            )
-        else:
-            loss_bbox = bbox_pred.sum() * 0
-            pos_bbox_weight = bbox_targets.new_tensor(0.0)
-
-        return loss_cls, loss_bbox, assign_metrics.sum(), pos_bbox_weight.sum()
-
     def loss_by_feat(  # type: ignore[override]
         self,
         cls_scores: list[Tensor],
@@ -249,7 +180,7 @@ def loss_by_feat(  # type: ignore[override]
                 Defaults to None.
 
         Returns:
-            dict[str, Tensor]: A dictionary of loss components.
+            dict[str, Tensor]: A dictionary of raw outputs.
         """
         num_imgs = len(batch_img_metas)
         featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
@@ -292,23 +223,16 @@ def loss_by_feat(  # type: ignore[override]
             batch_gt_instances_ignore=batch_gt_instances_ignore,
         )
 
-        losses_cls, losses_bbox, cls_avg_factors, bbox_avg_factors = multi_apply(
-            self.loss_by_feat_single,
-            cls_scores,
-            decoded_bboxes,
-            labels_list,
-            label_weights_list,
-            bbox_targets_list,
-            assign_metrics_list,
-            self.prior_generator.strides,
-        )
-
-        cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
-        losses_cls = [x / cls_avg_factor for x in losses_cls]
-
-        bbox_avg_factor = reduce_mean(sum(bbox_avg_factors)).clamp_(min=1).item()
-        losses_bbox = [x / bbox_avg_factor for x in losses_bbox]
-        return {"loss_cls": losses_cls, "loss_bbox": losses_bbox}
+        return {
+            "cls_score": cls_scores,
+            "bbox_pred": decoded_bboxes,
+            "labels": labels_list,
+            "label_weights": label_weights_list,
+            "bbox_targets": bbox_targets_list,
+            "assign_metrics": assign_metrics_list,
+            "stride": self.prior_generator.strides,
+            "sampling_results_list": sampling_results_list,
+        }
 
     def export_by_feat(  # type: ignore[override]
         self,
diff --git a/src/otx/algo/detection/heads/ssd_head.py b/src/otx/algo/detection/heads/ssd_head.py
index fb6aa316153..4d7532af3a8 100644
--- a/src/otx/algo/detection/heads/ssd_head.py
+++ b/src/otx/algo/detection/heads/ssd_head.py
@@ -13,9 +13,7 @@
 import torch
 from torch import Tensor, nn
 
-from otx.algo.common.losses import CrossEntropyLoss, smooth_l1_loss
 from otx.algo.common.utils.samplers import PseudoSampler
-from otx.algo.common.utils.utils import multi_apply
 from otx.algo.detection.heads.anchor_head import AnchorHead
 
 if TYPE_CHECKING:
@@ -27,7 +25,6 @@ class SSDHead(AnchorHead):
 
     Args:
         anchor_generator (nn.Module): Config dict for anchor generator.
-        bbox_coder (nn.Module): Config of bounding box coder.
         init_cfg (dict, list[dict]): Initialization config dict.
         train_cfg (dict): Training config of anchor head.
         num_classes (int): Number of categories excluding the background category.
@@ -38,18 +35,12 @@ class SSDHead(AnchorHead):
             Defaults to 256.
         use_depthwise (bool): Whether to use DepthwiseSeparableConv.
             Defaults to False.
-        reg_decoded_bbox (bool): If true, the regression loss would be
-            applied directly on decoded bounding boxes, converting both
-            the predicted boxes and regression targets to absolute
-            coordinates format. Defaults to False. It should be `True` when
-            using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
         test_cfg (dict, Optional): Testing config of anchor head.
     """
 
     def __init__(
         self,
         anchor_generator: nn.Module,
-        bbox_coder: nn.Module,
         init_cfg: dict | list[dict],
         train_cfg: dict,
         num_classes: int = 80,
@@ -57,7 +48,6 @@ def __init__(
         stacked_convs: int = 0,
         feat_channels: int = 256,
         use_depthwise: bool = False,
-        reg_decoded_bbox: bool = False,
         test_cfg: dict | None = None,
     ) -> None:
         super(AnchorHead, self).__init__(init_cfg=init_cfg)
@@ -75,12 +65,8 @@ def __init__(
         # heads but a list of int in SSDHead
         self.num_base_priors = self.prior_generator.num_base_priors
 
-        self.loss_cls = CrossEntropyLoss(use_sigmoid=False, reduction="none", loss_weight=1.0)
-
         self._init_layers()
 
-        self.bbox_coder = bbox_coder
-        self.reg_decoded_bbox = reg_decoded_bbox
         self.use_sigmoid_cls = False
         self.cls_focal_loss = False
         self.train_cfg = train_cfg
@@ -113,66 +99,6 @@ def forward(self, x: tuple[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
             bbox_preds.append(reg_conv(feat))
         return cls_scores, bbox_preds
 
-    def loss_by_feat_single(
-        self,
-        cls_score: Tensor,
-        bbox_pred: Tensor,
-        anchor: Tensor,
-        labels: Tensor,
-        label_weights: Tensor,
-        bbox_targets: Tensor,
-        bbox_weights: Tensor,
-        avg_factor: int,
-    ) -> tuple[Tensor, Tensor]:
-        """Compute loss of a single image.
-
-        Args:
-            cls_score (Tensor): Box scores for each image has shape (num_total_anchors, num_classes).
-            bbox_pred (Tensor): Box energies / deltas for each image level with shape (num_total_anchors, 4).
-            anchors (Tensor): Box reference for each scale level with shape (num_total_anchors, 4).
-            labels (Tensor): Labels of each anchors with shape (num_total_anchors,).
-            label_weights (Tensor): Label weights of each anchor with shape (num_total_anchors,)
-            bbox_targets (Tensor): BBox regression targets of each anchor with shape (num_total_anchors, 4).
-            bbox_weights (Tensor): BBox regression loss weights of each anchor with shape (num_total_anchors, 4).
-            avg_factor (int): Average factor that is used to average
-                the loss. When using sampling method, avg_factor is usually
-                the sum of positive and negative priors. When using
-                `PseudoSampler`, `avg_factor` is usually equal to the number
-                of positive priors.
-
-        Returns:
-            tuple[Tensor, Tensor]: A tuple of cls loss and bbox loss of one
-            feature map.
-        """
-        loss_cls_all = nn.functional.cross_entropy(cls_score, labels, reduction="none") * label_weights
-        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
-        pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(as_tuple=False).reshape(-1)
-        neg_inds = (labels == self.num_classes).nonzero(as_tuple=False).view(-1)
-
-        num_pos_samples = pos_inds.size(0)
-        num_neg_samples = self.train_cfg["neg_pos_ratio"] * num_pos_samples
-        if num_neg_samples > neg_inds.size(0):
-            num_neg_samples = neg_inds.size(0)
-        topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
-        loss_cls_pos = loss_cls_all[pos_inds].sum()
-        loss_cls_neg = topk_loss_cls_neg.sum()
-        loss_cls = (loss_cls_pos + loss_cls_neg) / avg_factor
-
-        if self.reg_decoded_bbox:
-            # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
-            # is applied directly on the decoded bounding boxes, it
-            # decodes the already encoded coordinates to absolute format.
-            bbox_pred = self.bbox_coder.decode(anchor, bbox_pred)
-
-        loss_bbox = smooth_l1_loss(
-            bbox_pred,
-            bbox_targets,
-            bbox_weights,
-            beta=self.train_cfg["smoothl1_beta"],
-            avg_factor=avg_factor,
-        )
-        return loss_cls[None], loss_bbox
-
     def loss_by_feat(
         self,
         cls_scores: list[Tensor],
@@ -180,7 +106,7 @@ def loss_by_feat(
         batch_gt_instances: list[InstanceData],
         batch_img_metas: list[dict],
         batch_gt_instances_ignore: list[InstanceData] | None = None,
-    ) -> dict[str, list[Tensor]]:
+    ) -> dict[str, Tensor]:
         """Compute losses of the head.
 
         Args:
@@ -195,13 +121,7 @@ def loss_by_feat(
                 Defaults to None.
 
         Returns:
-            dict[str, list[Tensor]]: A dictionary of loss components. the dict
-            has components below:
-
-            - loss_cls (list[Tensor]): A list containing each feature map \
-            classification loss.
-            - loss_bbox (list[Tensor]): A list containing each feature map \
-            regression loss.
+            dict[str, Tensor]: A dictionary of raw outputs.
         """
         featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
 
@@ -232,18 +152,16 @@ def loss_by_feat(
         # concat all level anchors to a single tensor
         all_anchors = [torch.cat(anchor) for anchor in anchor_list]
 
-        losses_cls, losses_bbox = multi_apply(
-            self.loss_by_feat_single,
-            all_cls_scores,
-            all_bbox_preds,
-            all_anchors,
-            all_labels,
-            all_label_weights,
-            all_bbox_targets,
-            all_bbox_weights,
-            avg_factor=avg_factor,
-        )
-        return {"loss_cls": losses_cls, "loss_bbox": losses_bbox}
+        return {
+            "all_cls_scores": all_cls_scores,
+            "all_bbox_preds": all_bbox_preds,
+            "all_anchors": all_anchors,
+            "all_labels": all_labels,
+            "all_label_weights": all_label_weights,
+            "all_bbox_targets": all_bbox_targets,
+            "all_bbox_weights": all_bbox_weights,
+            "avg_factor": avg_factor,
+        }
 
     def _init_layers(self) -> None:
         """Initialize layers of the head.
diff --git a/src/otx/algo/detection/heads/yolox_head.py b/src/otx/algo/detection/heads/yolox_head.py
index af5dd677af0..d8830ecb1aa 100644
--- a/src/otx/algo/detection/heads/yolox_head.py
+++ b/src/otx/algo/detection/heads/yolox_head.py
@@ -17,13 +17,11 @@
 from torch import Tensor, nn
 from torchvision.ops import box_convert
 
-from otx.algo.common.losses import CrossEntropyLoss, L1Loss
 from otx.algo.common.utils.nms import batched_nms, multiclass_nms
 from otx.algo.common.utils.prior_generators import MlvlPointGenerator
 from otx.algo.common.utils.samplers import PseudoSampler
 from otx.algo.common.utils.utils import multi_apply, reduce_mean
 from otx.algo.detection.heads.base_head import BaseDenseHead
-from otx.algo.detection.losses import IoULoss
 from otx.algo.modules.activation import Swish
 from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule
 from otx.algo.utils.mmengine_utils import InstanceData
@@ -54,10 +52,6 @@ class YOLOXHead(BaseDenseHead):
             Defaults to dict(type='BN', momentum=0.03, eps=0.001).
         activation_callable (Callable[..., nn.Module]): Activation layer module.
             Defaults to `Swish`.
-        loss_cls (nn.Module, optional): Module of classification loss.
-        loss_bbox (nn.Module, optional): Module of localization loss.
-        loss_obj (nn.Module, optional): Module of objectness loss.
-        loss_l1 (nn.Module, optional): Module of L1 loss.
         train_cfg (dict, optional): Training config of anchor head.
             Defaults to None.
         test_cfg (dict, optional): Testing config of anchor head.
@@ -78,10 +72,6 @@ def __init__(
         conv_bias: bool | str = "auto",
         norm_cfg: dict | None = None,
         activation_callable: Callable[..., nn.Module] = Swish,
-        loss_cls: nn.Module | None = None,
-        loss_bbox: nn.Module | None = None,
-        loss_obj: nn.Module | None = None,
-        loss_l1: nn.Module | None = None,
         train_cfg: dict | None = None,
         test_cfg: dict | None = None,
         init_cfg: dict | list[dict] | None = None,
@@ -118,12 +108,7 @@ def __init__(
         self.norm_cfg = norm_cfg
         self.activation_callable = activation_callable
 
-        self.loss_cls = loss_cls or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0)
-        self.loss_bbox = loss_bbox or IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0)
-        self.loss_obj = loss_obj or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0)
-
         self.use_l1 = False  # This flag will be modified by hooks.
-        self.loss_l1 = loss_l1 or L1Loss(reduction="sum", loss_weight=1.0)
 
         self.prior_generator = MlvlPointGenerator(strides, offset=0)  # type: ignore[arg-type]
 
@@ -491,7 +476,7 @@ def loss_by_feat(  # type: ignore[override]
                 Defaults to None.
 
         Returns:
-            dict[str, Tensor]: A dictionary of losses.
+            dict[str, Tensor]: A dictionary of raw outputs.
         """
         num_imgs = len(batch_img_metas)
         if batch_gt_instances_ignore is None:
@@ -543,34 +528,19 @@ def loss_by_feat(  # type: ignore[override]
         if self.use_l1:
             l1_targets = torch.cat(l1_targets, 0)
 
-        loss_obj = self.loss_obj(flatten_objectness.view(-1, 1), obj_targets) / num_total_samples
-        if num_pos > 0:
-            loss_cls = (
-                self.loss_cls(flatten_cls_preds.view(-1, self.num_classes)[pos_masks], cls_targets) / num_total_samples
-            )
-            loss_bbox = self.loss_bbox(flatten_bboxes.view(-1, 4)[pos_masks], bbox_targets) / num_total_samples
-        else:
-            # Avoid cls and reg branch not participating in the gradient
-            # propagation when there is no ground-truth in the images.
-            # For more details, please refer to
-            # https://github.com/open-mmlab/mmdetection/issues/7298
-            loss_cls = flatten_cls_preds.sum() * 0
-            loss_bbox = flatten_bboxes.sum() * 0
-
-        loss_dict = {"loss_cls": loss_cls, "loss_bbox": loss_bbox, "loss_obj": loss_obj}
-
-        if self.use_l1:
-            if num_pos > 0:
-                loss_l1 = self.loss_l1(flatten_bbox_preds.view(-1, 4)[pos_masks], l1_targets) / num_total_samples
-            else:
-                # Avoid cls and reg branch not participating in the gradient
-                # propagation when there is no ground-truth in the images.
-                # For more details, please refer to
-                # https://github.com/open-mmlab/mmdetection/issues/7298
-                loss_l1 = flatten_bbox_preds.sum() * 0
-            loss_dict.update(loss_l1=loss_l1)
-
-        return loss_dict
+        return {
+            "flatten_objectness": flatten_objectness,
+            "flatten_cls_preds": flatten_cls_preds,
+            "flatten_bbox_preds": flatten_bbox_preds,
+            "flatten_bboxes": flatten_bboxes,
+            "obj_targets": obj_targets,
+            "cls_targets": cls_targets,
+            "bbox_targets": bbox_targets,
+            "l1_targets": l1_targets,
+            "num_total_samples": num_total_samples,
+            "num_pos": num_pos,
+            "pos_masks": pos_masks,
+        }
 
     @torch.no_grad()
     def _get_targets_single(
diff --git a/src/otx/algo/detection/losses/__init__.py b/src/otx/algo/detection/losses/__init__.py
index 768be6e9778..91b4ad733a4 100644
--- a/src/otx/algo/detection/losses/__init__.py
+++ b/src/otx/algo/detection/losses/__init__.py
@@ -3,7 +3,11 @@
 #
 """Custom OTX Losses for Object Detection."""
 
+from .atss_loss import ATSSCriterion
 from .iou_loss import IoULoss
 from .rtdetr_loss import DetrCriterion
+from .rtmdet_loss import RTMDetCriterion
+from .ssd_loss import SSDCriterion
+from .yolox_loss import YOLOXCriterion
 
-__all__ = ["IoULoss", "DetrCriterion"]
+__all__ = ["ATSSCriterion", "IoULoss", "DetrCriterion", "RTMDetCriterion", "SSDCriterion", "YOLOXCriterion"]
diff --git a/src/otx/algo/detection/losses/atss_loss.py b/src/otx/algo/detection/losses/atss_loss.py
new file mode 100644
index 00000000000..f62fcfef634
--- /dev/null
+++ b/src/otx/algo/detection/losses/atss_loss.py
@@ -0,0 +1,234 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (c) OpenMMLab. All rights reserved.
+#
+"""ATSS criterion."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor, nn
+
+from otx.algo.common.losses import CrossEntropyLoss, CrossSigmoidFocalLoss, QualityFocalLoss
+from otx.algo.common.utils.bbox_overlaps import bbox_overlaps
+from otx.algo.common.utils.utils import reduce_mean
+
+
+class ATSSCriterion(nn.Module):
+    """ATSSCriterion is a loss criterion used in the Adaptive Training Sample Selection (ATSS) algorithm.
+
+    Args:
+        num_classes (int): The number of object classes.
+        bbox_coder (nn.Module): The module used for encoding and decoding bounding box coordinates.
+        loss_cls (nn.Module): The module used for calculating the classification loss.
+        loss_bbox (nn.Module): The module used for calculating the bounding box regression loss.
+        loss_centerness (nn.Module | None, optional): The module used for calculating the centerness loss.
+            Defaults to None.
+        use_qfl (bool, optional): Whether to use the Quality Focal Loss (QFL).
+            Defaults to ``CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0)``.
+        reg_decoded_bbox (bool, optional): Whether to use the decoded bounding box coordinates
+            for regression loss calculation. Defaults to True.
+        bg_loss_weight (float, optional): The weight for the background loss.
+            Defaults to -1.0.
+    """
+
+    def __init__(
+        self,
+        num_classes: int,
+        bbox_coder: nn.Module,
+        loss_cls: nn.Module,
+        loss_bbox: nn.Module,
+        loss_centerness: nn.Module | None = None,
+        use_qfl: bool = False,
+        qfl_cfg: dict | None = None,
+        reg_decoded_bbox: bool = True,
+        bg_loss_weight: float = -1.0,
+    ) -> None:
+        super().__init__()
+        self.num_classes = num_classes
+        self.bbox_coder = bbox_coder
+        self.use_qfl = use_qfl
+        self.reg_decoded_bbox = reg_decoded_bbox
+        self.bg_loss_weight = bg_loss_weight
+
+        self.loss_bbox = loss_bbox
+        self.loss_centerness = loss_centerness or CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0)
+
+        if use_qfl:
+            loss_cls = qfl_cfg or QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0)
+
+        self.loss_cls = loss_cls
+
+        self.use_sigmoid_cls = loss_cls.use_sigmoid
+        if self.use_sigmoid_cls:
+            self.cls_out_channels = num_classes
+        else:
+            self.cls_out_channels = num_classes + 1
+
+        if self.cls_out_channels <= 0:
+            msg = f"num_classes={num_classes} is too small"
+            raise ValueError(msg)
+
+    def forward(
+        self,
+        anchors: Tensor,
+        cls_score: Tensor,
+        bbox_pred: Tensor,
+        centerness: Tensor,
+        labels: Tensor,
+        label_weights: Tensor,
+        bbox_targets: Tensor,
+        valid_label_mask: Tensor,
+        avg_factor: float,
+    ) -> dict[str, Tensor]:
+        """Compute loss of a single scale level.
+
+        Args:
+            anchors (Tensor): Box reference for each scale level with shape
+                (N, num_total_anchors, 4).
+            cls_score (Tensor): Box scores for each scale level
+                Has shape (N, num_anchors * num_classes, H, W).
+            bbox_pred (Tensor): Box energies / deltas for each scale
+                level with shape (N, num_anchors * 4, H, W).
+            centerness(Tensor): Centerness scores for each scale level.
+            labels (Tensor): Labels of each anchors with shape
+                (N, num_total_anchors).
+            label_weights (Tensor): Label weights of each anchor with shape
+                (N, num_total_anchors)
+            bbox_targets (Tensor): BBox regression targets of each anchor with
+                shape (N, num_total_anchors, 4).
+            valid_label_mask (Tensor): Label mask for consideration of ignored
+                label with shape (N, num_total_anchors, 1).
+            avg_factor (float): Average factor that is used to average
+                the loss. When using sampling method, avg_factor is usually
+                the sum of positive and negative priors. When using
+                `PseudoSampler`, `avg_factor` is usually equal to the number
+                of positive priors.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components.
+        """
+        anchors = anchors.reshape(-1, 4)
+        cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels).contiguous()
+        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+        centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
+        bbox_targets = bbox_targets.reshape(-1, 4)
+        labels = labels.reshape(-1)
+        label_weights = label_weights.reshape(-1)
+        valid_label_mask = valid_label_mask.reshape(-1, self.cls_out_channels)
+
+        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+        pos_inds = self._get_pos_inds(labels)
+
+        if self.use_qfl:
+            quality = label_weights.new_zeros(labels.shape)
+
+        if len(pos_inds) > 0:
+            pos_bbox_targets = bbox_targets[pos_inds]
+            pos_bbox_pred = bbox_pred[pos_inds]
+            pos_anchors = anchors[pos_inds]
+            pos_centerness = centerness[pos_inds]
+
+            centerness_targets = self.centerness_target(pos_anchors, pos_bbox_targets)
+            if self.reg_decoded_bbox:
+                pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred)
+
+            if self.use_qfl:
+                quality[pos_inds] = bbox_overlaps(pos_bbox_pred.detach(), pos_bbox_targets, is_aligned=True).clamp(
+                    min=1e-6,
+                )
+
+            # regression loss
+            loss_bbox = self._get_loss_bbox(pos_bbox_targets, pos_bbox_pred, centerness_targets)
+
+            # centerness loss
+            loss_centerness = self._get_loss_centerness(avg_factor, pos_centerness, centerness_targets)
+
+        else:
+            loss_bbox = bbox_pred.sum() * 0
+            loss_centerness = centerness.sum() * 0
+            centerness_targets = bbox_targets.new_tensor(0.0)
+
+        # Re-weigting BG loss
+        if self.bg_loss_weight >= 0.0:
+            neg_indices = labels == self.num_classes
+            label_weights[neg_indices] = self.bg_loss_weight
+
+        if self.use_qfl:
+            labels = (labels, quality)  # For quality focal loss arg spec
+
+        # classification loss
+        loss_cls = self._get_loss_cls(cls_score, labels, label_weights, valid_label_mask, avg_factor)
+
+        bbox_avg_factor = centerness_targets.sum()
+
+        bbox_avg_factor = sum(bbox_avg_factor)
+        bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item()
+        loss_bbox = [x / bbox_avg_factor for x in loss_bbox]
+        return {"loss_cls": loss_cls, "loss_bbox": loss_bbox, "loss_centerness": loss_centerness}
+
+    def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor:
+        """Calculate the centerness between anchors and gts.
+
+        Only calculate pos centerness targets, otherwise there may be nan.
+
+        Args:
+            anchors (Tensor): Anchors with shape (N, 4), "xyxy" format.
+            gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format.
+
+        Returns:
+            Tensor: Centerness between anchors and gts.
+        """
+        anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
+        anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
+        l_ = anchors_cx - gts[:, 0]
+        t_ = anchors_cy - gts[:, 1]
+        r_ = gts[:, 2] - anchors_cx
+        b_ = gts[:, 3] - anchors_cy
+
+        left_right = torch.stack([l_, r_], dim=1)
+        top_bottom = torch.stack([t_, b_], dim=1)
+        return torch.sqrt(
+            (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
+            * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]),
+        )
+
+    def _get_pos_inds(self, labels: Tensor) -> Tensor:
+        bg_class_ind = self.num_classes
+        return ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1)
+
+    def _get_loss_bbox(
+        self,
+        pos_bbox_targets: Tensor,
+        pos_bbox_pred: Tensor,
+        centerness_targets: Tensor,
+    ) -> Tensor:
+        return self.loss_bbox(pos_bbox_pred, pos_bbox_targets, weight=centerness_targets, avg_factor=1.0)
+
+    def _get_loss_centerness(
+        self,
+        avg_factor: Tensor,
+        pos_centerness: Tensor,
+        centerness_targets: Tensor,
+    ) -> Tensor:
+        return self.loss_centerness(pos_centerness, centerness_targets, avg_factor=avg_factor)
+
+    def _get_loss_cls(
+        self,
+        cls_score: Tensor,
+        labels: Tensor,
+        label_weights: Tensor,
+        valid_label_mask: Tensor,
+        avg_factor: Tensor,
+    ) -> Tensor:
+        if isinstance(self.loss_cls, CrossSigmoidFocalLoss):
+            loss_cls = self.loss_cls(
+                cls_score,
+                labels,
+                label_weights,
+                avg_factor=avg_factor,
+                valid_label_mask=valid_label_mask,
+            )
+        else:
+            loss_cls = self.loss_cls(cls_score, labels, label_weights, avg_factor=avg_factor)
+        return loss_cls
diff --git a/src/otx/algo/detection/losses/rtmdet_loss.py b/src/otx/algo/detection/losses/rtmdet_loss.py
new file mode 100644
index 00000000000..691b810079a
--- /dev/null
+++ b/src/otx/algo/detection/losses/rtmdet_loss.py
@@ -0,0 +1,114 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (c) OpenMMLab. All rights reserved.
+#
+"""RTMDet criterion."""
+
+from __future__ import annotations
+
+from torch import Tensor, nn
+
+from otx.algo.common.utils.utils import reduce_mean
+
+
+class RTMDetCriterion(nn.Module):
+    """RTMDetCriterion is a criterion module for RTM-based object detection.
+
+    Args:
+        num_classes (int): Number of object classes.
+        loss_cls (nn.Module): Classification loss module.
+        loss_bbox (nn.Module): Bounding box regression loss module.
+    """
+
+    def __init__(self, num_classes: int, loss_cls: nn.Module, loss_bbox: nn.Module) -> None:
+        super().__init__()
+        self.num_classes = num_classes
+        self.loss_cls = loss_cls
+        self.loss_bbox = loss_bbox
+        self.use_sigmoid_cls = loss_cls.use_sigmoid
+        if self.use_sigmoid_cls:
+            self.cls_out_channels = num_classes
+        else:
+            self.cls_out_channels = num_classes + 1
+
+        if self.cls_out_channels <= 0:
+            msg = f"num_classes={num_classes} is too small"
+            raise ValueError(msg)
+
+    def forward(  # type: ignore[override]
+        self,
+        cls_score: Tensor,
+        bbox_pred: Tensor,
+        labels: Tensor,
+        label_weights: Tensor,
+        bbox_targets: Tensor,
+        assign_metrics: Tensor,
+        stride: list[int],
+        **kwargs,
+    ) -> dict[str, Tensor]:
+        """Compute loss of a single scale level.
+
+        Args:
+            cls_score (Tensor): Box scores for each scale level
+                Has shape (N, num_anchors * num_classes, H, W).
+            bbox_pred (Tensor): Decoded bboxes for each scale
+                level with shape (N, num_anchors * 4, H, W).
+            labels (Tensor): Labels of each anchors with shape
+                (N, num_total_anchors).
+            label_weights (Tensor): Label weights of each anchor with shape
+                (N, num_total_anchors).
+            bbox_targets (Tensor): BBox regression targets of each anchor with
+                shape (N, num_total_anchors, 4).
+            assign_metrics (Tensor): Assign metrics with shape
+                (N, num_total_anchors).
+            stride (list[int]): Downsample stride of the feature map.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components.
+        """
+        if stride[0] != stride[1]:
+            msg = "h stride is not equal to w stride!"
+            raise ValueError(msg)
+        cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels).contiguous()
+        bbox_pred = bbox_pred.reshape(-1, 4)
+        bbox_targets = bbox_targets.reshape(-1, 4)
+        labels = labels.reshape(-1)
+        assign_metrics = assign_metrics.reshape(-1)
+        label_weights = label_weights.reshape(-1)
+        targets = (labels, assign_metrics)
+
+        loss_cls = self.loss_cls(cls_score, targets, label_weights, avg_factor=1.0)
+
+        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+        bg_class_ind = self.num_classes
+        pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1)
+
+        if len(pos_inds) > 0:
+            pos_bbox_targets = bbox_targets[pos_inds]
+            pos_bbox_pred = bbox_pred[pos_inds]
+
+            pos_decode_bbox_pred = pos_bbox_pred
+            pos_decode_bbox_targets = pos_bbox_targets
+
+            # regression loss
+            pos_bbox_weight = assign_metrics[pos_inds]
+
+            loss_bbox = self.loss_bbox(
+                pos_decode_bbox_pred,
+                pos_decode_bbox_targets,
+                weight=pos_bbox_weight,
+                avg_factor=1.0,
+            )
+        else:
+            loss_bbox = bbox_pred.sum() * 0
+            pos_bbox_weight = bbox_targets.new_tensor(0.0)
+
+        cls_avg_factors = assign_metrics.sum()
+        bbox_avg_factors = pos_bbox_weight.sum()
+
+        cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
+        loss_cls = [x / cls_avg_factor for x in loss_cls]
+
+        bbox_avg_factor = reduce_mean(sum(bbox_avg_factors)).clamp_(min=1).item()
+        loss_bbox = [x / bbox_avg_factor for x in loss_bbox]
+        return {"loss_cls": loss_cls, "loss_bbox": loss_bbox}
diff --git a/src/otx/algo/detection/losses/ssd_loss.py b/src/otx/algo/detection/losses/ssd_loss.py
new file mode 100644
index 00000000000..3b6d2c7b27c
--- /dev/null
+++ b/src/otx/algo/detection/losses/ssd_loss.py
@@ -0,0 +1,107 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (c) OpenMMLab. All rights reserved.
+#
+"""SSD criterion."""
+
+from __future__ import annotations
+
+from torch import Tensor, nn
+
+from otx.algo.common.losses import smooth_l1_loss
+
+
+class SSDCriterion(nn.Module):
+    """SSDCriterion is a loss criterion for Single Shot MultiBox Detector (SSD).
+
+    Args:
+        num_classes (int): Number of classes including the background class.
+        bbox_coder (nn.Module): Bounding box coder module. Defaults to None.
+        neg_pos_ratio (int, optional): Ratio of negative to positive samples. Defaults to 3.
+        reg_decoded_bbox (bool): If true, the regression loss would be
+            applied directly on decoded bounding boxes, converting both
+            the predicted boxes and regression targets to absolute
+            coordinates format. Defaults to False. It should be `True` when
+            using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+        smoothl1_beta (float, optional): Beta parameter for the smooth L1 loss. Defaults to 1.0.
+    """
+
+    def __init__(
+        self,
+        num_classes: int,
+        bbox_coder: nn.Module | None = None,
+        neg_pos_ratio: int = 3,
+        reg_decoded_bbox: bool = False,
+        smoothl1_beta: float = 1.0,
+    ) -> None:
+        super().__init__()
+        self.num_classes = num_classes
+        self.bbox_coder = bbox_coder
+        self.neg_pos_ratio = neg_pos_ratio
+        self.reg_decoded_bbox = reg_decoded_bbox
+        self.smoothl1_beta = smoothl1_beta
+
+    def forward(
+        self,
+        cls_score: Tensor,
+        bbox_pred: Tensor,
+        anchor: Tensor,
+        labels: Tensor,
+        label_weights: Tensor,
+        bbox_targets: Tensor,
+        bbox_weights: Tensor,
+        avg_factor: int,
+    ) -> dict[str, Tensor]:
+        """Compute loss of a single image.
+
+        Args:
+            cls_score (Tensor): Box scores for each image has shape (num_total_anchors, num_classes).
+            bbox_pred (Tensor): Box energies / deltas for each image level with shape (num_total_anchors, 4).
+            anchors (Tensor): Box reference for each scale level with shape (num_total_anchors, 4).
+            labels (Tensor): Labels of each anchors with shape (num_total_anchors,).
+            label_weights (Tensor): Label weights of each anchor with shape (num_total_anchors,)
+            bbox_targets (Tensor): BBox regression targets of each anchor with shape (num_total_anchors, 4).
+            bbox_weights (Tensor): BBox regression loss weights of each anchor with shape (num_total_anchors, 4).
+            avg_factor (int): Average factor that is used to average
+                the loss. When using sampling method, avg_factor is usually
+                the sum of positive and negative priors. When using
+                `PseudoSampler`, `avg_factor` is usually equal to the number
+                of positive priors.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components. the dict
+            has components below:
+
+            - loss_cls (list[Tensor]): A list containing each feature map \
+            classification loss.
+            - loss_bbox (list[Tensor]): A list containing each feature map \
+            regression loss.
+        """
+        loss_cls_all = nn.functional.cross_entropy(cls_score, labels, reduction="none") * label_weights
+        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+        pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(as_tuple=False).reshape(-1)
+        neg_inds = (labels == self.num_classes).nonzero(as_tuple=False).view(-1)
+
+        num_pos_samples = pos_inds.size(0)
+        num_neg_samples = self.neg_pos_ratio * num_pos_samples
+        if num_neg_samples > neg_inds.size(0):
+            num_neg_samples = neg_inds.size(0)
+        topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
+        loss_cls_pos = loss_cls_all[pos_inds].sum()
+        loss_cls_neg = topk_loss_cls_neg.sum()
+        loss_cls = (loss_cls_pos + loss_cls_neg) / avg_factor
+
+        if self.reg_decoded_bbox and self.bbox_coder is not None:
+            # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+            # is applied directly on the decoded bounding boxes, it
+            # decodes the already encoded coordinates to absolute format.
+            bbox_pred = self.bbox_coder.decode(anchor, bbox_pred)
+
+        loss_bbox = smooth_l1_loss(
+            bbox_pred,
+            bbox_targets,
+            bbox_weights,
+            beta=self.smoothl1_beta,
+            avg_factor=avg_factor,
+        )
+        return {"loss_cls": loss_cls[None], "loss_bbox": loss_bbox}
diff --git a/src/otx/algo/detection/losses/yolox_loss.py b/src/otx/algo/detection/losses/yolox_loss.py
new file mode 100644
index 00000000000..350688b3b4b
--- /dev/null
+++ b/src/otx/algo/detection/losses/yolox_loss.py
@@ -0,0 +1,108 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (c) OpenMMLab. All rights reserved.
+#
+"""YOLOX criterion."""
+
+from __future__ import annotations
+
+from torch import Tensor, nn
+
+from otx.algo.common.losses.cross_entropy_loss import CrossEntropyLoss
+from otx.algo.common.losses.smooth_l1_loss import L1Loss
+from otx.algo.detection.losses.iou_loss import IoULoss
+
+
+class YOLOXCriterion(nn.Module):
+    """YOLOX criterion module.
+
+    This module calculates the loss for YOLOX object detection model.
+
+    Args:
+        num_classes (int): The number of classes.
+        loss_cls (nn.Module | None): The classification loss module. Defaults to None.
+        loss_bbox (nn.Module | None): The bounding box regression loss module. Defaults to None.
+        loss_obj (nn.Module | None): The objectness loss module. Defaults to None.
+        loss_l1 (nn.Module | None): The L1 loss module. Defaults to None.
+
+    Returns:
+        dict[str, Tensor]: A dictionary containing the calculated losses.
+
+    """
+
+    def __init__(
+        self,
+        num_classes: int,
+        loss_cls: nn.Module | None = None,
+        loss_bbox: nn.Module | None = None,
+        loss_obj: nn.Module | None = None,
+        loss_l1: nn.Module | None = None,
+    ) -> None:
+        super().__init__()
+        self.num_classes = num_classes
+        self.loss_cls = loss_cls or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0)
+        self.loss_bbox = loss_bbox or IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0)
+        self.loss_obj = loss_obj or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0)
+        self.loss_l1 = loss_l1 or L1Loss(reduction="sum", loss_weight=1.0)
+
+    def forward(
+        self,
+        flatten_objectness: Tensor,
+        flatten_cls_preds: Tensor,
+        flatten_bbox_preds: Tensor,
+        flatten_bboxes: Tensor,
+        obj_targets: Tensor,
+        cls_targets: Tensor,
+        bbox_targets: Tensor,
+        l1_targets: Tensor,
+        num_total_samples: Tensor,
+        num_pos: Tensor,
+        pos_masks: Tensor,
+    ) -> dict[str, Tensor]:
+        """Forward pass of the YOLOX criterion module.
+
+        Args:
+            flatten_objectness (Tensor): Flattened objectness predictions.
+            flatten_cls_preds (Tensor): Flattened class predictions.
+            flatten_bbox_preds (Tensor): Flattened bounding box predictions.
+            flatten_bboxes (Tensor): Flattened ground truth bounding boxes.
+            obj_targets (Tensor): Objectness targets.
+            cls_targets (Tensor): Class targets.
+            bbox_targets (Tensor): Bounding box targets.
+            l1_targets (Tensor): L1 targets.
+            num_total_samples (Tensor): Total number of samples.
+            num_pos (Tensor): Number of positive samples.
+            pos_masks (Tensor): Positive masks.
+
+        Returns:
+            dict[str, Tensor]: A dictionary containing the calculated losses.
+
+        """
+        loss_obj = self.loss_obj(flatten_objectness.view(-1, 1), obj_targets) / num_total_samples
+        if num_pos > 0:
+            loss_cls = (
+                self.loss_cls(flatten_cls_preds.view(-1, self.num_classes)[pos_masks], cls_targets) / num_total_samples
+            )
+            loss_bbox = self.loss_bbox(flatten_bboxes.view(-1, 4)[pos_masks], bbox_targets) / num_total_samples
+        else:
+            # Avoid cls and reg branch not participating in the gradient
+            # propagation when there is no ground-truth in the images.
+            # For more details, please refer to
+            # https://github.com/open-mmlab/mmdetection/issues/7298
+            loss_cls = flatten_cls_preds.sum() * 0
+            loss_bbox = flatten_bboxes.sum() * 0
+
+        loss_dict = {"loss_cls": loss_cls, "loss_bbox": loss_bbox, "loss_obj": loss_obj}
+
+        if self.use_l1:
+            if num_pos > 0:
+                loss_l1 = self.loss_l1(flatten_bbox_preds.view(-1, 4)[pos_masks], l1_targets) / num_total_samples
+            else:
+                # Avoid cls and reg branch not participating in the gradient
+                # propagation when there is no ground-truth in the images.
+                # For more details, please refer to
+                # https://github.com/open-mmlab/mmdetection/issues/7298
+                loss_l1 = flatten_bbox_preds.sum() * 0
+            loss_dict.update(loss_l1=loss_l1)
+
+        return loss_dict
diff --git a/src/otx/algo/detection/rtmdet.py b/src/otx/algo/detection/rtmdet.py
index 791eb863697..855d9828404 100644
--- a/src/otx/algo/detection/rtmdet.py
+++ b/src/otx/algo/detection/rtmdet.py
@@ -12,13 +12,13 @@
 
 from otx.algo.common.backbones import CSPNeXt
 from otx.algo.common.losses import GIoULoss, QualityFocalLoss
-from otx.algo.common.losses.cross_entropy_loss import CrossEntropyLoss
 from otx.algo.common.utils.assigners import DynamicSoftLabelAssigner
 from otx.algo.common.utils.coders import DistancePointBBoxCoder
 from otx.algo.common.utils.prior_generators import MlvlPointGenerator
 from otx.algo.common.utils.samplers import PseudoSampler
 from otx.algo.detection.detectors import SingleStageDetector
 from otx.algo.detection.heads import RTMDetSepBNHead
+from otx.algo.detection.losses import RTMDetCriterion
 from otx.algo.detection.necks import CSPNeXtPAFPN
 from otx.core.config.data import TileConfig
 from otx.core.exporter.base import OTXModelExporter
@@ -144,19 +144,25 @@ def _build_model(self, num_classes: int) -> RTMDet:
             with_objectness=False,
             anchor_generator=MlvlPointGenerator(offset=0, strides=[8, 16, 32]),
             bbox_coder=DistancePointBBoxCoder(),
-            loss_cls=QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0),
-            loss_bbox=GIoULoss(loss_weight=2.0),
-            loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0),
             norm_cfg={"type": "BN"},
             activation_callable=partial(nn.SiLU, inplace=True),
             train_cfg=train_cfg,
             test_cfg=test_cfg,
+            loss_cls=QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0),  # TODO (eugene): deprecated
+            loss_bbox=GIoULoss(loss_weight=2.0),  # TODO (eugene): deprecated
+        )
+
+        criterion = RTMDetCriterion(
+            num_classes=num_classes,
+            loss_cls=QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0),
+            loss_bbox=GIoULoss(loss_weight=2.0),
         )
 
         return SingleStageDetector(
             backbone=backbone,
             neck=neck,
             bbox_head=bbox_head,
+            criterion=criterion,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
         )
diff --git a/src/otx/algo/detection/ssd.py b/src/otx/algo/detection/ssd.py
index 775a45a2aa2..df765f58f98 100644
--- a/src/otx/algo/detection/ssd.py
+++ b/src/otx/algo/detection/ssd.py
@@ -20,6 +20,7 @@
 from otx.algo.common.utils.coders import DeltaXYWHBBoxCoder
 from otx.algo.detection.detectors import SingleStageDetector
 from otx.algo.detection.heads import SSDHead
+from otx.algo.detection.losses import SSDCriterion
 from otx.algo.detection.utils.prior_generators import SSDAnchorGeneratorClustered
 from otx.algo.utils.support_otx_v1 import OTXv1Helper
 from otx.core.config.data import TileConfig
@@ -81,10 +82,8 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
                 pos_iou_thr=0.4,
                 neg_iou_thr=0.4,
             ),
-            "smoothl1_beta": 1.0,
             "allowed_border": -1,
             "pos_weight": -1,
-            "neg_pos_ratio": 3,
             "debug": False,
             "use_giou": False,
             "use_focal": False,
@@ -116,10 +115,6 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
                     [587.6216059488938, 381.60024152086544, 323.5988913027747, 702.7486097568518, 741.4865860938451],
                 ],
             ),
-            bbox_coder=DeltaXYWHBBoxCoder(
-                target_means=(0.0, 0.0, 0.0, 0.0),
-                target_stds=(0.1, 0.1, 0.2, 0.2),
-            ),
             num_classes=num_classes,
             in_channels=(96, 320),
             use_depthwise=True,
@@ -127,7 +122,20 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
             train_cfg=train_cfg,
             test_cfg=test_cfg,
         )
-        return SingleStageDetector(backbone, bbox_head, train_cfg=train_cfg, test_cfg=test_cfg)
+        criterion = SSDCriterion(
+            num_classes=num_classes,
+            bbox_coder=DeltaXYWHBBoxCoder(
+                target_means=(0.0, 0.0, 0.0, 0.0),
+                target_stds=(0.1, 0.1, 0.2, 0.2),
+            ),
+        )
+        return SingleStageDetector(
+            backbone=backbone,
+            bbox_head=bbox_head,
+            criterion=criterion,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+        )
 
     def setup(self, stage: str) -> None:
         """Callback for setup OTX SSD Model.
diff --git a/src/otx/algo/detection/yolox.py b/src/otx/algo/detection/yolox.py
index a5e1583d3ca..a5760cb29c3 100644
--- a/src/otx/algo/detection/yolox.py
+++ b/src/otx/algo/detection/yolox.py
@@ -11,7 +11,7 @@
 from otx.algo.detection.backbones import CSPDarknet
 from otx.algo.detection.detectors import SingleStageDetector
 from otx.algo.detection.heads import YOLOXHead
-from otx.algo.detection.losses import IoULoss
+from otx.algo.detection.losses import IoULoss, YOLOXCriterion
 from otx.algo.detection.necks import YOLOXPAFPN
 from otx.algo.detection.utils.assigners import SimOTAAssigner
 from otx.algo.utils.support_otx_v1 import OTXv1Helper
@@ -184,14 +184,24 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
             num_classes=num_classes,
             in_channels=96,
             feat_channels=96,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+        )
+        criterion = YOLOXCriterion(
+            num_classes=num_classes,
             loss_cls=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0),
             loss_bbox=IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0),
             loss_obj=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0),
             loss_l1=L1Loss(reduction="sum", loss_weight=1.0),
+        )
+        return SingleStageDetector(
+            backbone=backbone,
+            neck=neck,
+            bbox_head=bbox_head,
+            criterion=criterion,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
         )
-        return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg)
 
 
 class YOLOXS(YOLOX):
@@ -221,14 +231,24 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
             num_classes=num_classes,
             in_channels=128,
             feat_channels=128,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+        )
+        criterion = YOLOXCriterion(
+            num_classes=num_classes,
             loss_cls=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0),
             loss_bbox=IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0),
             loss_obj=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0),
             loss_l1=L1Loss(reduction="sum", loss_weight=1.0),
+        )
+        return SingleStageDetector(
+            backbone=backbone,
+            neck=neck,
+            bbox_head=bbox_head,
+            criterion=criterion,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
         )
-        return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg)
 
 
 class YOLOXL(YOLOX):
@@ -253,14 +273,24 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
         bbox_head = YOLOXHead(
             num_classes=num_classes,
             in_channels=256,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+        )
+        criterion = YOLOXCriterion(
+            num_classes=num_classes,
             loss_cls=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0),
             loss_bbox=IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0),
             loss_obj=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0),
             loss_l1=L1Loss(reduction="sum", loss_weight=1.0),
+        )
+        return SingleStageDetector(
+            backbone=backbone,
+            neck=neck,
+            bbox_head=bbox_head,
+            criterion=criterion,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
         )
-        return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg)
 
 
 class YOLOXX(YOLOX):
@@ -290,11 +320,21 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
             num_classes=num_classes,
             in_channels=320,
             feat_channels=320,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+        )
+        criterion = YOLOXCriterion(
+            num_classes=num_classes,
             loss_cls=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0),
             loss_bbox=IoULoss(mode="square", eps=1e-16, reduction="sum", loss_weight=5.0),
             loss_obj=CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0),
             loss_l1=L1Loss(reduction="sum", loss_weight=1.0),
+        )
+        return SingleStageDetector(
+            backbone=backbone,
+            neck=neck,
+            bbox_head=bbox_head,
+            criterion=criterion,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
         )
-        return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg)
diff --git a/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py b/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py
index 87c5c3dbc99..e8db7790366 100644
--- a/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py
+++ b/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py
@@ -21,10 +21,8 @@
 
 from otx.algo.common.utils.nms import batched_nms, multiclass_nms
 from otx.algo.common.utils.utils import (
-    distance2bbox,
     filter_scores_and_topk,
     inverse_sigmoid,
-    multi_apply,
     reduce_mean,
     select_single_mlvl,
 )
@@ -49,7 +47,6 @@ class RTMDetInsHead(RTMDetHead):
     """Detection Head of RTMDet-Ins.
 
     Args:
-        loss_mask (nn.Module): A module for mask loss.
         num_prototypes (int): Number of mask prototype features extracted
             from the mask head. Defaults to 8.
         dyconv_channels (int): Channel of the dynamic conv layers.
@@ -63,7 +60,6 @@ class RTMDetInsHead(RTMDetHead):
     def __init__(
         self,
         *args,
-        loss_mask: nn.Module,
         num_prototypes: int = 8,
         dyconv_channels: int = 8,
         num_dyconvs: int = 3,
@@ -75,7 +71,6 @@ def __init__(
         self.dyconv_channels = dyconv_channels
         self.mask_loss_stride = mask_loss_stride
         super().__init__(*args, **kwargs)
-        self.loss_mask = loss_mask
 
     def _init_layers(self) -> None:
         """Initialize layers of the head."""
@@ -540,7 +535,7 @@ def loss_mask_by_feat(
         flatten_kernels: Tensor,
         sampling_results_list: list,
         batch_gt_instances: list[InstanceData],
-    ) -> Tensor:
+    ) -> dict[str, Tensor]:
         """Compute instance segmentation loss.
 
         Args:
@@ -555,7 +550,7 @@ def loss_mask_by_feat(
                 attributes.
 
         Returns:
-            Tensor: The mask loss tensor.
+            dict[str, Tensor]: A dictionary of raw outputs.
         """
         batch_pos_mask_logits = []
         pos_gt_masks = []
@@ -611,7 +606,11 @@ def loss_mask_by_feat(
             self.mask_loss_stride // 2 :: self.mask_loss_stride,
         ]
 
-        return self.loss_mask(batch_pos_mask_logits, pos_gt_masks, weight=None, avg_factor=num_pos)
+        return {
+            "batch_pos_mask_logits": batch_pos_mask_logits,
+            "pos_gt_masks": pos_gt_masks,
+            "num_pos": num_pos,
+        }
 
     def loss_by_feat(
         self,
@@ -625,17 +624,15 @@ def loss_by_feat(
     ) -> dict[str, Tensor]:
         """Compute losses of the head."""
         num_imgs = len(batch_img_metas)
-        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
-        if len(featmap_sizes) != self.prior_generator.num_levels:
-            msg = "The number of featmap sizes should be equal to the number of levels."
-            raise ValueError(msg)
-
         device = cls_scores[0].device
-        anchor_list, valid_flag_list = self.get_anchors(featmap_sizes, batch_img_metas, device=device)
-        flatten_cls_scores = torch.cat(
-            [cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels) for cls_score in cls_scores],
-            1,
+        raw_outputs = super().loss_by_feat(
+            cls_scores=cls_scores,
+            bbox_preds=bbox_preds,
+            batch_gt_instances=batch_gt_instances,
+            batch_img_metas=batch_img_metas,
+            batch_gt_instances_ignore=batch_gt_instances_ignore,
         )
+
         flatten_kernels = torch.cat(
             [
                 kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_gen_params)
@@ -643,14 +640,7 @@ def loss_by_feat(
             ],
             1,
         )
-        decoded_bboxes = []
-        for anchor, bbox_pred in zip(anchor_list[0], bbox_preds):
-            anchor = anchor.reshape(-1, 4)  # noqa: PLW2901
-            bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)  # noqa: PLW2901
-            bbox_pred = distance2bbox(anchor, bbox_pred)  # noqa: PLW2901
-            decoded_bboxes.append(bbox_pred)
-
-        flatten_bboxes = torch.cat(decoded_bboxes, 1)
+
         # Convert polygon masks to bitmap masks
         if isinstance(batch_gt_instances[0].masks[0], Polygon):
             for gt_instances, img_meta in zip(batch_gt_instances, batch_img_metas):
@@ -659,43 +649,15 @@ def loss_by_feat(
                     ndarray_masks = np.empty((0, *img_meta["img_shape"]), dtype=np.uint8)
                 gt_instances.masks = torch.tensor(ndarray_masks, dtype=torch.bool, device=device)
 
-        cls_reg_targets = self.get_targets(
-            flatten_cls_scores,
-            flatten_bboxes,
-            anchor_list,
-            valid_flag_list,
+        raw_iseg_outputs = self.loss_mask_by_feat(
+            mask_feat,
+            flatten_kernels,
+            raw_outputs["sampling_results_list"],
             batch_gt_instances,
-            batch_img_metas,
-            batch_gt_instances_ignore=batch_gt_instances_ignore,
         )
-        (
-            anchor_list,
-            labels_list,
-            label_weights_list,
-            bbox_targets_list,
-            assign_metrics_list,
-            sampling_results_list,
-        ) = cls_reg_targets
-
-        losses_cls, losses_bbox, cls_avg_factors, bbox_avg_factors = multi_apply(
-            self.loss_by_feat_single,
-            cls_scores,
-            decoded_bboxes,
-            labels_list,
-            label_weights_list,
-            bbox_targets_list,
-            assign_metrics_list,
-            self.prior_generator.strides,
-        )
-
-        cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
-        losses_cls = [x / cls_avg_factor for x in losses_cls]
-
-        bbox_avg_factor = reduce_mean(sum(bbox_avg_factors)).clamp_(min=1).item()
-        losses_bbox = [x / bbox_avg_factor for x in losses_bbox]
+        raw_outputs.update(raw_iseg_outputs)
 
-        loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels, sampling_results_list, batch_gt_instances)
-        return {"loss_cls": losses_cls, "loss_bbox": losses_bbox, "loss_mask": loss_mask}
+        return raw_outputs
 
 
 class MaskFeatModule(BaseModule):
diff --git a/src/otx/algo/instance_segmentation/losses/__init__.py b/src/otx/algo/instance_segmentation/losses/__init__.py
index 903ea4c077e..bcd3bf03a51 100644
--- a/src/otx/algo/instance_segmentation/losses/__init__.py
+++ b/src/otx/algo/instance_segmentation/losses/__init__.py
@@ -5,5 +5,6 @@
 
 from .accuracy import accuracy
 from .dice_loss import DiceLoss
+from .rtmdet_inst_loss import RTMDetInstCriterion
 
-__all__ = ["accuracy", "DiceLoss"]
+__all__ = ["accuracy", "DiceLoss", "RTMDetInstCriterion"]
diff --git a/src/otx/algo/instance_segmentation/losses/rtmdet_inst_loss.py b/src/otx/algo/instance_segmentation/losses/rtmdet_inst_loss.py
new file mode 100644
index 00000000000..3935c0990db
--- /dev/null
+++ b/src/otx/algo/instance_segmentation/losses/rtmdet_inst_loss.py
@@ -0,0 +1,89 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (c) OpenMMLab. All rights reserved.
+#
+"""RTMDet for instance segmentation criterion."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from otx.algo.detection.losses import RTMDetCriterion
+
+if TYPE_CHECKING:
+    from torch import Tensor, nn
+
+
+class RTMDetInstCriterion(RTMDetCriterion):
+    """Criterion of RTMDet for instance segmentation.
+
+    Args:
+        num_classes (int): Number of object classes.
+        loss_cls (nn.Module): Classification loss module.
+        loss_bbox (nn.Module): Bounding box regression loss module.
+        loss_mask (nn.Module): Mask loss module.
+    """
+
+    def __init__(
+        self,
+        num_classes: int,
+        loss_cls: nn.Module,
+        loss_bbox: nn.Module,
+        loss_mask: nn.Module,
+    ) -> None:
+        super().__init__(num_classes, loss_cls, loss_bbox)
+        self.loss_mask = loss_mask
+
+    def forward(
+        self,
+        cls_score: Tensor,
+        bbox_pred: Tensor,
+        labels: Tensor,
+        label_weights: Tensor,
+        bbox_targets: Tensor,
+        assign_metrics: Tensor,
+        stride: list[int],
+        **kwargs,
+    ) -> dict[str, Tensor]:
+        """Compute loss of a single scale level.
+
+        Args:
+            cls_score (Tensor): Box scores for each scale level
+                Has shape (N, num_anchors * num_classes, H, W).
+            bbox_pred (Tensor): Decoded bboxes for each scale
+                level with shape (N, num_anchors * 4, H, W).
+            labels (Tensor): Labels of each anchors with shape
+                (N, num_total_anchors).
+            label_weights (Tensor): Label weights of each anchor with shape
+                (N, num_total_anchors).
+            bbox_targets (Tensor): BBox regression targets of each anchor with
+                shape (N, num_total_anchors, 4).
+            assign_metrics (Tensor): Assign metrics with shape
+                (N, num_total_anchors).
+            stride (list[int]): Downsample stride of the feature map.
+            batch_pos_mask_logits (Tensor): The prediction, has a shape (n, *).
+            pos_gt_masks (Tensor): The label of the prediction,
+                shape (n, *), same shape of pred.
+            num_pos (int, optional): Average factor that is used to average
+                the loss. Defaults to None.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components.
+        """
+        loss_dict = super().forward(
+            cls_score=cls_score,
+            bbox_pred=bbox_pred,
+            labels=labels,
+            label_weights=label_weights,
+            bbox_targets=bbox_targets,
+            assign_metrics=assign_metrics,
+            stride=stride,
+        )
+
+        batch_pos_mask_logits: Tensor = kwargs.pop("batch_pos_mask_logits")
+        pos_gt_masks: Tensor = kwargs.pop("pos_gt_masks")
+        num_pos: int = kwargs.pop("num_pos")
+
+        loss_mask = self.loss_mask(batch_pos_mask_logits, pos_gt_masks, weight=None, avg_factor=num_pos)
+        loss_dict.update({"loss_mask": loss_mask})
+        return loss_dict
diff --git a/src/otx/algo/instance_segmentation/rtmdet_inst.py b/src/otx/algo/instance_segmentation/rtmdet_inst.py
index c3ff45ce5f6..f2d610488e5 100644
--- a/src/otx/algo/instance_segmentation/rtmdet_inst.py
+++ b/src/otx/algo/instance_segmentation/rtmdet_inst.py
@@ -19,7 +19,7 @@
 from otx.algo.detection.detectors import SingleStageDetector
 from otx.algo.detection.necks import CSPNeXtPAFPN
 from otx.algo.instance_segmentation.heads import RTMDetInsSepBNHead
-from otx.algo.instance_segmentation.losses import DiceLoss
+from otx.algo.instance_segmentation.losses import DiceLoss, RTMDetInstCriterion
 from otx.core.config.data import TileConfig
 from otx.core.exporter.base import OTXModelExporter
 from otx.core.exporter.native import OTXNativeModelExporter
@@ -168,6 +168,11 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
             ),
             bbox_coder=DistancePointBBoxCoder(),
             loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0),
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+        )
+        criterion = RTMDetInstCriterion(
+            num_classes=num_classes,
             loss_cls=QualityFocalLoss(
                 use_sigmoid=True,
                 beta=2.0,
@@ -179,14 +184,13 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
                 eps=5.0e-06,
                 reduction="mean",
             ),
-            train_cfg=train_cfg,
-            test_cfg=test_cfg,
         )
 
         return SingleStageDetector(
             backbone=backbone,
             neck=neck,
             bbox_head=bbox_head,
+            criterion=criterion,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
         )