Skip to content

Commit

Permalink
Create criterion modules
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 committed Aug 17, 2024
1 parent e1d156d commit 7a3b01d
Show file tree
Hide file tree
Showing 19 changed files with 871 additions and 514 deletions.
53 changes: 48 additions & 5 deletions src/otx/algo/detection/atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from otx.algo.common.utils.samplers import PseudoSampler
from otx.algo.detection.detectors import SingleStageDetector
from otx.algo.detection.heads import ATSSHead
from otx.algo.detection.losses import ATSSCriterion
from otx.algo.detection.necks import FPN
from otx.algo.detection.utils.assigners import ATSSAssigner
from otx.algo.utils.support_otx_v1 import OTXv1Helper
Expand Down Expand Up @@ -145,6 +146,23 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
target_means=(0.0, 0.0, 0.0, 0.0),
target_stds=(0.1, 0.1, 0.2, 0.2),
),
feat_channels=64,
train_cfg=train_cfg,
test_cfg=test_cfg,
loss_cls=CrossSigmoidFocalLoss( # TODO (eugene): deprecated
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0,
),
loss_bbox=GIoULoss(loss_weight=2.0), # TODO (eugene): deprecated
)
criterion = ATSSCriterion(
num_classes=num_classes,
bbox_coder=DeltaXYWHBBoxCoder(
target_means=(0.0, 0.0, 0.0, 0.0),
target_stds=(0.1, 0.1, 0.2, 0.2),
),
loss_cls=CrossSigmoidFocalLoss(
use_sigmoid=True,
gamma=2.0,
Expand All @@ -153,11 +171,15 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
),
loss_bbox=GIoULoss(loss_weight=2.0),
loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0),
feat_channels=64,
)
return SingleStageDetector(
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
criterion=criterion,
train_cfg=train_cfg,
test_cfg=test_cfg,
)
return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg)


class ResNeXt101ATSS(ATSS):
Expand Down Expand Up @@ -210,6 +232,24 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
target_means=(0.0, 0.0, 0.0, 0.0),
target_stds=(0.1, 0.1, 0.2, 0.2),
),
num_classes=num_classes,
in_channels=256,
train_cfg=train_cfg,
test_cfg=test_cfg,
loss_cls=CrossSigmoidFocalLoss( # TODO (eugene): deprecated
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0,
),
loss_bbox=GIoULoss(loss_weight=2.0), # TODO (eugene): deprecated
)
criterion = ATSSCriterion(
num_classes=num_classes,
bbox_coder=DeltaXYWHBBoxCoder(
target_means=(0.0, 0.0, 0.0, 0.0),
target_stds=(0.1, 0.1, 0.2, 0.2),
),
loss_cls=CrossSigmoidFocalLoss(
use_sigmoid=True,
gamma=2.0,
Expand All @@ -218,12 +258,15 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
),
loss_bbox=GIoULoss(loss_weight=2.0),
loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0),
num_classes=num_classes,
in_channels=256,
)
return SingleStageDetector(
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
criterion=criterion,
train_cfg=train_cfg,
test_cfg=test_cfg,
)
return SingleStageDetector(backbone, bbox_head, neck=neck, train_cfg=train_cfg, test_cfg=test_cfg)

def to(self, *args, **kwargs) -> Self:
"""Return a model with specified device."""
Expand Down
10 changes: 8 additions & 2 deletions src/otx/algo/detection/detectors/single_stage_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class SingleStageDetector(BaseModule):
Args:
backbone (nn.Module): Backbone module.
bbox_head (nn.Module): Bbox head module.
criterion (nn.Module | None, optional): Criterion module.
neck (nn.Module | None, optional): Neck module. Defaults to None.
train_cfg (dict | None, optional): Training config. Defaults to None.
test_cfg (dict | None, optional): Test config. Defaults to None.
Expand All @@ -36,6 +37,7 @@ def __init__(
self,
backbone: nn.Module,
bbox_head: nn.Module,
criterion: nn.Module,
neck: nn.Module | None = None,
train_cfg: dict | None = None,
test_cfg: dict | None = None,
Expand All @@ -46,6 +48,7 @@ def __init__(
self.backbone = backbone
self.bbox_head = bbox_head
self.neck = neck
self.criterion = criterion
self.init_cfg = init_cfg
self.train_cfg = train_cfg
self.test_cfg = test_cfg
Expand Down Expand Up @@ -129,7 +132,7 @@ def forward(
def loss(
self,
entity: DetBatchDataEntity,
) -> dict | list:
) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
Expand All @@ -143,7 +146,10 @@ def loss(
dict: A dictionary of loss components.
"""
x = self.extract_feat(entity.images)
return self.bbox_head.loss(x, entity)
# TODO (sungchul): compare .loss with other forwards and remove duplicated code
outputs = self.bbox_head.loss(x, entity)

return self.criterion(outputs)

def predict(
self,
Expand Down
10 changes: 8 additions & 2 deletions src/otx/algo/detection/heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ class AnchorHead(BaseDenseHead):
anchor_generator (nn.Module): Module for anchor generator
bbox_coder (nn.Module): Module of bounding box coder.
loss_cls (nn.Module): Module of classification loss.
It is related to RPNHead for iseg, will be deprecated.
loss_bbox (nn.Module): Module of localization loss.
It is related to RPNHead for iseg, will be deprecated.
train_cfg (dict): Training config of anchor head.
test_cfg (dict, optional): Testing config of anchor head.
feat_channels (int): Number of hidden channels. Used in child classes.
Expand All @@ -49,8 +51,8 @@ def __init__(
in_channels: tuple[int, ...] | int,
anchor_generator: nn.Module,
bbox_coder: nn.Module,
loss_cls: nn.Module,
loss_bbox: nn.Module,
loss_cls: nn.Module, # TODO (eugene): deprecated
loss_bbox: nn.Module, # TODO (eugene): deprecated
train_cfg: dict,
test_cfg: dict | None = None,
feat_channels: int = 256,
Expand Down Expand Up @@ -410,6 +412,8 @@ def loss_by_feat_single(
) -> tuple:
"""Calculate the loss of a single scale level based on the features extracted by the detection head.
TODO (eugene): it is related to RPNHead for iseg, will be deprecated
Args:
cls_score (Tensor): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W).
Expand Down Expand Up @@ -459,6 +463,8 @@ def loss_by_feat(
) -> dict:
"""Calculate the loss based on the features extracted by the detection head.
TODO (eugene): it is related to RPNHead for iseg, will be deprecated
Args:
cls_scores (list[Tensor]): Box scores for each scale level
has shape (N, num_anchors * num_classes, H, W).
Expand Down
Loading

0 comments on commit 7a3b01d

Please sign in to comment.