From 1ece09e82f62d9f4350f21ed7ee855fd5c36dbaf Mon Sep 17 00:00:00 2001 From: Vinnam Kim Date: Fri, 12 May 2023 16:15:32 +0900 Subject: [PATCH] Add training loss dynamics exportation feature for detection task (#2109) Signed-off-by: Kim, Vinnam --- .../mmcls/models/classifiers/mixin.py | 54 ++-- .../adapters/mmdet/datasets/dataset.py | 8 +- .../datasets/pipelines/load_pipelines.py | 1 + .../models/detectors/custom_atss_detector.py | 10 +- .../models/detectors/loss_dynamics_mixin.py | 131 ++++++++ .../adapters/mmdet/models/heads/__init__.py | 4 +- .../mmdet/models/heads/custom_atss_head.py | 293 +++++++++++++++++- .../adapters/mmdet/models/loss_dyns.py | 41 +++ .../detection/adapters/mmdet/task.py | 7 +- .../adapters/mmdet/utils/config_utils.py | 2 +- .../detection/configs/base/configuration.py | 11 + .../configs/detection/configuration.yaml | 15 + otx/api/entities/dataset_item.py | 10 +- .../data/adapter/detection_dataset_adapter.py | 12 +- otx/core/data/noisy_label_detection/base.py | 34 +- .../mmdet/models/dense_heads/conftest.py | 90 ++++++ .../test_loss_dynamics_tracking_heads.py | 53 ++++ .../mmdet/models/detectors/conftest.py | 95 ++++++ .../detectors/test_custom_atss_detector.py | 65 +--- .../detectors/test_loss_dynamics_tracking.py | 117 +++++++ tests/unit/algorithms/detection/conftest.py | 8 + .../unit/algorithms/detection/test_helpers.py | 4 +- 22 files changed, 940 insertions(+), 125 deletions(-) create mode 100644 otx/algorithms/detection/adapters/mmdet/models/detectors/loss_dynamics_mixin.py create mode 100644 otx/algorithms/detection/adapters/mmdet/models/loss_dyns.py create mode 100644 tests/unit/algorithms/detection/adapters/mmdet/models/dense_heads/conftest.py create mode 100644 tests/unit/algorithms/detection/adapters/mmdet/models/dense_heads/test_loss_dynamics_tracking_heads.py create mode 100644 tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py create mode 100644 tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_loss_dynamics_tracking.py diff --git a/otx/algorithms/classification/adapters/mmcls/models/classifiers/mixin.py b/otx/algorithms/classification/adapters/mmcls/models/classifiers/mixin.py index 2d6e3e733fa..1674a2d182b 100644 --- a/otx/algorithms/classification/adapters/mmcls/models/classifiers/mixin.py +++ b/otx/algorithms/classification/adapters/mmcls/models/classifiers/mixin.py @@ -3,14 +3,19 @@ # SPDX-License-Identifier: Apache-2.0 # +from collections import defaultdict +from typing import Any, Dict, List + import datumaro as dm import numpy as np import pandas as pd from otx.algorithms.common.utils.logger import get_logger from otx.api.entities.dataset_item import DatasetItemEntityWithID -from otx.api.entities.datasets import DatasetEntity -from otx.core.data.noisy_label_detection import LossDynamicsTracker, LossDynamicsTrackingMixin +from otx.core.data.noisy_label_detection import ( + LossDynamicsTracker, + LossDynamicsTrackingMixin, +) logger = get_logger() @@ -27,42 +32,19 @@ def train_step(self, data, optimizer=None, **kwargs): class MultiClassClsLossDynamicsTracker(LossDynamicsTracker): """Loss dynamics tracker for multi-class classification task.""" + TASK_NAME = "OTX-MultiClassCls" + def __init__(self) -> None: super().__init__() - - def init_with_otx_dataset(self, otx_dataset: DatasetEntity[DatasetItemEntityWithID]) -> None: - """DatasetEntity should be injected to the tracker for the initialization.""" - otx_labels = otx_dataset.get_labels() - label_categories = dm.LabelCategories.from_iterable([label_entity.name for label_entity in otx_labels]) - self.otx_label_map = {label_entity.id_: idx for idx, label_entity in enumerate(otx_labels)} - - def _convert_anns(item: DatasetItemEntityWithID): - labels = [ - dm.Label(label=self.otx_label_map[label.id_]) - for ann in item.get_annotations() - for label in ann.get_labels() - ] - return labels - - self._export_dataset = dm.Dataset.from_iterable( - [ - dm.DatasetItem( - id=item.id_, - subset="train", - media=dm.Image.from_file(path=item.media.path, size=(item.media.height, item.media.width)) - if item.media.path - else dm.Image.from_numpy( - data=getattr(item.media, "_Image__data"), size=(item.media.height, item.media.width) - ), - annotations=_convert_anns(item), - ) - for item in otx_dataset - ], - infos={"purpose": "noisy_label_detection", "task": "OTX-MultiClassCls"}, - categories={dm.AnnotationType.label: label_categories}, - ) - - super().init_with_otx_dataset(otx_dataset) + self._loss_dynamics: Dict[Any, List] = defaultdict(list) + + def _convert_anns(self, item: DatasetItemEntityWithID): + labels = [ + dm.Label(label=self.otx_label_map[label.id_]) + for ann in item.get_annotations() + for label in ann.get_labels() + ] + return labels def accumulate(self, outputs, iter) -> None: """Accumulate training loss dynamics for each training step.""" diff --git a/otx/algorithms/detection/adapters/mmdet/datasets/dataset.py b/otx/algorithms/detection/adapters/mmdet/datasets/dataset.py index de3714b5f70..b76d46c4940 100644 --- a/otx/algorithms/detection/adapters/mmdet/datasets/dataset.py +++ b/otx/algorithms/detection/adapters/mmdet/datasets/dataset.py @@ -59,11 +59,11 @@ def get_annotation_mmdet_format( gt_bboxes = [] gt_labels = [] gt_polygons = [] + gt_ann_ids = [] label_idx = {label.id: i for i, label in enumerate(labels)} - for annotation in dataset_item.get_annotations(labels=labels, include_empty=False): - + for annotation in dataset_item.get_annotations(labels=labels, include_empty=False, preserve_id=True): box = ShapeFactory.shape_as_rectangle(annotation.shape) if min(box.width * width, box.height * height) < min_size: @@ -80,18 +80,22 @@ def get_annotation_mmdet_format( polygon = np.array([p for point in polygon.points for p in [point.x * width, point.y * height]]) gt_polygons.extend([[polygon] for _ in range(n)]) gt_labels.extend(class_indices) + item_id = getattr(dataset_item, "id_", None) + gt_ann_ids.append((item_id, annotation.id_)) if len(gt_bboxes) > 0: ann_info = dict( bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4), labels=np.array(gt_labels, dtype=int), masks=PolygonMasks(gt_polygons, height=height, width=width) if gt_polygons else [], + ann_ids=gt_ann_ids, ) else: ann_info = dict( bboxes=np.zeros((0, 4), dtype=np.float32), labels=np.array([], dtype=int), masks=[], + ann_ids=[], ) return ann_info diff --git a/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/load_pipelines.py b/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/load_pipelines.py index c92a252fe4f..0986a4b1b70 100644 --- a/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/load_pipelines.py +++ b/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/load_pipelines.py @@ -63,6 +63,7 @@ def __init__( def _load_bboxes(results, ann_info): results["bbox_fields"].append("gt_bboxes") results["gt_bboxes"] = copy.deepcopy(ann_info["bboxes"]) + results["gt_ann_ids"] = copy.deepcopy(ann_info["ann_ids"]) return results @staticmethod diff --git a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_atss_detector.py b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_atss_detector.py index bed97ef6a63..87ae474cd89 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_atss_detector.py +++ b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_atss_detector.py @@ -18,8 +18,10 @@ from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import ( DetClassProbabilityMapHook, ) +from otx.algorithms.detection.adapters.mmdet.models.loss_dyns import TrackingLossType from .l2sp_detector_mixin import L2SPDetectorMixin +from .loss_dynamics_mixin import DetLossDynamicsTrackingMixin from .sam_detector_mixin import SAMDetectorMixin logger = get_logger() @@ -29,9 +31,11 @@ @DETECTORS.register_module() -class CustomATSS(SAMDetectorMixin, L2SPDetectorMixin, ATSS): +class CustomATSS(SAMDetectorMixin, DetLossDynamicsTrackingMixin, L2SPDetectorMixin, ATSS): """SAM optimizer & L2SP regularizer enabled custom ATSS.""" + TRACKING_LOSS_TYPE = (TrackingLossType.cls, TrackingLossType.bbox, TrackingLossType.centerness) + def __init__(self, *args, task_adapt=None, **kwargs): super().__init__(*args, **kwargs) @@ -46,10 +50,6 @@ def __init__(self, *args, task_adapt=None, **kwargs): ) ) - def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None, **kwargs): - """Forward function for CustomATSS.""" - return super().forward_train(img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=gt_bboxes_ignore) - @staticmethod def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, prefix, *args, **kwargs): """Modify input state_dict according to class name matching before weight loading.""" diff --git a/otx/algorithms/detection/adapters/mmdet/models/detectors/loss_dynamics_mixin.py b/otx/algorithms/detection/adapters/mmdet/models/detectors/loss_dynamics_mixin.py new file mode 100644 index 00000000000..5463409cfe4 --- /dev/null +++ b/otx/algorithms/detection/adapters/mmdet/models/detectors/loss_dynamics_mixin.py @@ -0,0 +1,131 @@ +"""LossDynamics Mix-in for detection tasks.""" +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from collections import defaultdict +from typing import Dict, Sequence, Tuple + +import datumaro as dm +import numpy as np +import pandas as pd + +from otx.algorithms.common.utils.logger import get_logger +from otx.algorithms.detection.adapters.mmdet.models.loss_dyns import TrackingLossType +from otx.api.entities.dataset_item import DatasetItemEntityWithID +from otx.api.entities.datasets import DatasetEntity +from otx.api.entities.shapes.rectangle import Rectangle +from otx.core.data.noisy_label_detection import ( + LossDynamicsTracker, + LossDynamicsTrackingMixin, +) + +logger = get_logger() + + +class DetLossDynamicsTracker(LossDynamicsTracker): + """Loss dynamics tracker for detection tasks.""" + + TASK_NAME = "OTX-Det" + + def __init__(self, tracking_loss_types: Sequence[TrackingLossType]) -> None: + super().__init__() + self._loss_dynamics: Dict[TrackingLossType, Dict] = { + loss_type: defaultdict(list) for loss_type in tracking_loss_types + } + + def _convert_anns(self, item: DatasetItemEntityWithID): + labels = [] + + cnt = 0 + for ann in item.get_annotations(preserve_id=True): + if isinstance(ann.shape, Rectangle): + for label in ann.get_labels(): + bbox = dm.Bbox( + x=ann.shape.x1 * item.width, + y=ann.shape.y1 * item.height, + w=ann.shape.width * item.width, + h=ann.shape.height * item.height, + label=self.otx_label_map[label.id_], + id=cnt, + ) + labels.append(bbox) + self.otx_ann_id_to_dm_ann_map[(item.id_, ann.id_)] = bbox + cnt += 1 + + return labels + + def init_with_otx_dataset(self, otx_dataset: DatasetEntity[DatasetItemEntityWithID]) -> None: + """DatasetEntity should be injected to the tracker for the initialization.""" + self.otx_ann_id_to_dm_ann_map: Dict[Tuple[str, str], dm.Bbox] = {} + super().init_with_otx_dataset(otx_dataset) + + def accumulate(self, outputs, iter) -> None: + """Accumulate training loss dynamics for each training step.""" + for key, loss_dyns in outputs.items(): + if isinstance(key, TrackingLossType): + for (entity_id, ann_id), value in loss_dyns.items(): + self._loss_dynamics[key][(entity_id, ann_id)].append((iter, value)) + + def export(self, output_path: str) -> None: + """Export loss dynamics statistics to Datumaro format.""" + dfs = [ + pd.DataFrame.from_dict( + { + k: (np.array([iter for iter, _ in arr]), np.array([value for _, value in arr])) + for k, arr in loss_dyns.items() + }, + orient="index", + columns=["iters", f"loss_dynamics_{key.name}"], + ) + for key, loss_dyns in self._loss_dynamics.items() + ] + df = pd.concat(dfs, axis=1) + df = df.loc[:, ~df.columns.duplicated()] + + for (entity_id, ann_id), row in df.iterrows(): + ann = self.otx_ann_id_to_dm_ann_map.get((entity_id, ann_id), None) + if ann: + ann.attributes = row.to_dict() + + self._export_dataset.export(output_path, format="datumaro") + + +class DetLossDynamicsTrackingMixin(LossDynamicsTrackingMixin): + """Mix-in to track loss dynamics during training for classification tasks.""" + + TRACKING_LOSS_TYPE: Tuple[TrackingLossType, ...] = () + + def __init__(self, track_loss_dynamics: bool = False, **kwargs): + if track_loss_dynamics: + head_cfg = kwargs.get("bbox_head", None) + head_type = head_cfg.get("type", None) + assert head_type is not None, "head_type should be specified from the config." + new_head_type = head_type + "TrackingLossDynamics" + head_cfg["type"] = new_head_type + logger.info(f"Replace head_type from {head_type} to {new_head_type}.") + + super().__init__(**kwargs) + + # This should be called after super().__init__(), + # since LossDynamicsTrackingMixin.__init__() creates self._loss_dyns_tracker + self._loss_dyns_tracker = DetLossDynamicsTracker(self.TRACKING_LOSS_TYPE) + + def train_step(self, data, optimizer): + """The iteration step during training.""" + + outputs = super().train_step(data, optimizer) + + if self.loss_dyns_tracker.initialized: + gt_ann_ids = [item["gt_ann_ids"] for item in data["img_metas"]] + + to_update = {} + for key, loss_dyns in self.bbox_head.loss_dyns.items(): + to_update[key] = {} + for (batch_idx, bbox_idx), value in loss_dyns.items(): + entity_id, ann_id = gt_ann_ids[batch_idx][bbox_idx] + to_update[key][(entity_id, ann_id)] = value.mean + + outputs.update(to_update) + + return outputs diff --git a/otx/algorithms/detection/adapters/mmdet/models/heads/__init__.py b/otx/algorithms/detection/adapters/mmdet/models/heads/__init__.py index 02b82f763ce..c531b25265d 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/heads/__init__.py +++ b/otx/algorithms/detection/adapters/mmdet/models/heads/__init__.py @@ -5,7 +5,7 @@ from .cross_dataset_detector_head import CrossDatasetDetectorHead from .custom_anchor_generator import SSDAnchorGeneratorClustered -from .custom_atss_head import CustomATSSHead +from .custom_atss_head import CustomATSSHead, CustomATSSHeadTrackingLossDynamics from .custom_retina_head import CustomRetinaHead from .custom_roi_head import CustomRoIHead from .custom_ssd_head import CustomSSDHead @@ -21,4 +21,6 @@ "CustomRoIHead", "CustomVFNetHead", "CustomYOLOXHead", + # Loss dynamics tracking + "CustomATSSHeadTrackingLossDynamics", ] diff --git a/otx/algorithms/detection/adapters/mmdet/models/heads/custom_atss_head.py b/otx/algorithms/detection/adapters/mmdet/models/heads/custom_atss_head.py index 4e3489351c3..3ff6649bb5c 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/heads/custom_atss_head.py +++ b/otx/algorithms/detection/adapters/mmdet/models/heads/custom_atss_head.py @@ -2,15 +2,26 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # +from collections import defaultdict + import torch from mmcv.runner import force_fp32 -from mmdet.core import bbox_overlaps, multi_apply, reduce_mean +from mmdet.core import ( + anchor_inside_flags, + bbox_overlaps, + images_to_levels, + multi_apply, + reduce_mean, + unmap, +) from mmdet.models.builder import HEADS from mmdet.models.dense_heads.atss_head import ATSSHead +from mmdet.models.losses.utils import weight_reduce_loss from otx.algorithms.detection.adapters.mmdet.models.heads.cross_dataset_detector_head import ( CrossDatasetDetectorHead, ) +from otx.algorithms.detection.adapters.mmdet.models.loss_dyns import LossAccumulator, TrackingLossType from otx.algorithms.detection.adapters.mmdet.models.losses.cross_focal_loss import ( CrossSigmoidFocalLoss, ) @@ -163,8 +174,7 @@ def loss_single( valid_label_mask = valid_label_mask.reshape(-1, self.cls_out_channels) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes - bg_class_ind = self.num_classes - pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) + pos_inds = self._get_pos_inds(labels) if self.use_qfl: quality = label_weights.new_zeros(labels.shape) @@ -185,10 +195,10 @@ def loss_single( ) # regression loss - loss_bbox = self.loss_bbox(pos_bbox_pred, pos_bbox_targets, weight=centerness_targets, avg_factor=1.0) + loss_bbox = self._get_loss_bbox(pos_bbox_targets, pos_bbox_pred, centerness_targets) # centerness loss - loss_centerness = self.loss_centerness(pos_centerness, centerness_targets, avg_factor=num_total_samples) + loss_centerness = self._get_loss_centerness(num_total_samples, pos_centerness, centerness_targets) else: loss_bbox = bbox_pred.sum() * 0 @@ -204,14 +214,29 @@ def loss_single( labels = (labels, quality) # For quality focal loss arg spec # classification loss + loss_cls = self._get_loss_cls(cls_score, labels, label_weights, valid_label_mask, num_total_samples) + + return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum() + + def _get_pos_inds(self, labels): + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) + return pos_inds + + def _get_loss_cls(self, cls_score, labels, label_weights, valid_label_mask, num_total_samples): if isinstance(self.loss_cls, CrossSigmoidFocalLoss): loss_cls = self.loss_cls( cls_score, labels, label_weights, avg_factor=num_total_samples, valid_label_mask=valid_label_mask ) else: loss_cls = self.loss_cls(cls_score, labels, label_weights, avg_factor=num_total_samples) + return loss_cls - return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum() + def _get_loss_centerness(self, num_total_samples, pos_centerness, centerness_targets): + return self.loss_centerness(pos_centerness, centerness_targets, avg_factor=num_total_samples) + + def _get_loss_bbox(self, pos_bbox_targets, pos_bbox_pred, centerness_targets): + return self.loss_bbox(pos_bbox_pred, pos_bbox_targets, weight=centerness_targets, avg_factor=1.0) def get_targets( self, @@ -242,3 +267,259 @@ def get_targets( label_channels, unmap_outputs, ) + + +@HEADS.register_module() +class CustomATSSHeadTrackingLossDynamics(CustomATSSHead): + """CustomATSSHead which supports tracking loss dynamics.""" + + def __init__(self, *args, bg_loss_weight=-1, use_qfl=False, qfl_cfg=None, **kwargs): + super().__init__(*args, bg_loss_weight=bg_loss_weight, use_qfl=use_qfl, qfl_cfg=qfl_cfg, **kwargs) + + def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + centernesses (list[Tensor]): Centerness for each scale + level with shape (N, num_anchors * 1, H, W) + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor] | None): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + self.cur_loss_idx = 0 + self.loss_dyns = { + TrackingLossType.cls: defaultdict(LossAccumulator), + TrackingLossType.bbox: defaultdict(LossAccumulator), + TrackingLossType.centerness: defaultdict(LossAccumulator), + } + losses = super().loss(cls_scores, bbox_preds, centernesses, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore) + return losses + + def _get_pos_inds(self, labels): + pos_inds = super()._get_pos_inds(labels) + + if len(pos_inds) > 0: + pos_assigned_gt_inds = self.all_pos_assigned_gt_inds[self.cur_loss_idx].reshape(-1) + + gt_inds = pos_assigned_gt_inds[pos_inds].cpu() + + self.batch_inds = gt_inds // self.max_gt_bboxes_len + self.bbox_inds = gt_inds % self.max_gt_bboxes_len + + self.pos_inds = pos_inds + return pos_inds + + def _store_loss_dyns(self, losses: torch.Tensor, key: TrackingLossType) -> None: + loss_dyns = self.loss_dyns[key] + for batch_idx, bbox_idx, loss_item in zip(self.batch_inds, self.bbox_inds, losses.detach().cpu()): + loss_dyns[(batch_idx.item(), bbox_idx.item())].add(loss_item.item()) + + def _postprocess_loss(self, losses: torch.Tensor, reduction: str, avg_factor: float) -> torch.Tensor: + return weight_reduce_loss(losses, reduction=reduction, avg_factor=avg_factor) + + def _get_loss_cls(self, cls_score, labels, label_weights, valid_label_mask, num_total_samples): + if isinstance(self.loss_cls, CrossSigmoidFocalLoss): + loss_cls = self.loss_cls( + cls_score, + labels, + label_weights, + avg_factor=num_total_samples, + valid_label_mask=valid_label_mask, + reduction_override="none", + ) + else: + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=num_total_samples, reduction_override="none" + ) + + if len(self.pos_inds) > 0: + self._store_loss_dyns(loss_cls[self.pos_inds].detach().mean(-1), TrackingLossType.cls) + return self._postprocess_loss(loss_cls, self.loss_cls.reduction, avg_factor=num_total_samples) + + def _get_loss_centerness(self, num_total_samples, pos_centerness, centerness_targets): + loss_centerness = self.loss_centerness( + pos_centerness, centerness_targets, avg_factor=num_total_samples, reduction_override="none" + ) + self._store_loss_dyns(loss_centerness, TrackingLossType.centerness) + return self._postprocess_loss(loss_centerness, self.loss_centerness.reduction, avg_factor=num_total_samples) + + def _get_loss_bbox(self, pos_bbox_targets, pos_bbox_pred, centerness_targets): + loss_bbox = self.loss_bbox( + pos_bbox_pred, pos_bbox_targets, weight=centerness_targets, avg_factor=1.0, reduction_override="none" + ) + self._store_loss_dyns(loss_bbox, TrackingLossType.bbox) + return self._postprocess_loss(loss_bbox, self.loss_centerness.reduction, avg_factor=1.0) + + def loss_single( + self, + anchors, + cls_score, + bbox_pred, + centerness, + labels, + label_weights, + bbox_targets, + valid_label_mask, + num_total_samples, + ): + """Compute loss of a single scale level. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + centerness (list[Tensor]): Centerness for each scale + level with shape (N, num_anchors * num_classes, H, W) + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor wight + shape (N, num_total_anchors, 4). + valid_label_mask (Tensor): Label mask for consideration of ignored + label with shape (N, num_total_anchors, 1). + num_total_samples (int): Number of positive samples that is + reduced over all GPUs. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + losses = super().loss_single( + anchors, + cls_score, + bbox_pred, + centerness, + labels, + label_weights, + bbox_targets, + valid_label_mask, + num_total_samples, + ) + self.cur_loss_idx += 1 + return losses + + def get_targets( + self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True, + ): + """Get targets for Detection head.""" + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + self.batch_size = len(gt_bboxes_list) + self.max_gt_bboxes_len = max([len(gt_bboxes) for gt_bboxes in gt_bboxes_list]) + self.cur_batch_idx = 0 + self.pos_assigned_gt_inds_list = [] + targets = super().get_targets( + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list, + gt_labels_list, + label_channels, + unmap_outputs, + ) + self.all_pos_assigned_gt_inds = images_to_levels(self.pos_assigned_gt_inds_list, num_level_anchors) + return targets + + def _get_target_single( + self, + flat_anchors, + valid_flags, + num_level_anchors, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + label_channels=1, + unmap_outputs=True, + ): + """Compute regression, classification targets for anchors in a single image.""" + inside_flags = anchor_inside_flags( + flat_anchors, valid_flags, img_meta["img_shape"][:2], self.train_cfg.allowed_border + ) + if not inside_flags.any(): + return (None,) * 7 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + + num_level_anchors_inside = self.get_num_level_anchors_inside(num_level_anchors, inside_flags) + assign_result = self.assigner.assign(anchors, num_level_anchors_inside, gt_bboxes, gt_bboxes_ignore, gt_labels) + + sampling_result = self.sampler.sample(assign_result, anchors, gt_bboxes) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + labels = anchors.new_full((num_valid_anchors,), self.num_classes, dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + if self.reg_decoded_bbox: + pos_bbox_targets = sampling_result.pos_gt_bboxes + else: + pos_bbox_targets = self.bbox_coder.encode(sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) + + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + if gt_labels is None: + # Only rpn gives gt_labels as None + # Foreground is the first class since v2.5.0 + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + if self.train_cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg.pos_weight + + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap(labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + + ########## What we changed from the original mmdet code ############### + # Store all_pos_assigned_gt_inds to member variable + # to look up training loss dynamics for each gt_bboxes afterwards + pos_assigned_gt_inds = anchors.new_full((num_valid_anchors,), -1, dtype=torch.long) + if len(pos_inds) > 0: + pos_assigned_gt_inds[pos_inds] = ( + self.cur_batch_idx * self.max_gt_bboxes_len + sampling_result.pos_assigned_gt_inds + ) + if unmap_outputs: + pos_assigned_gt_inds = unmap(pos_assigned_gt_inds, num_total_anchors, inside_flags, fill=-1) + self.pos_assigned_gt_inds_list += [pos_assigned_gt_inds] + self.cur_batch_idx += 1 + ######################################################################## + + return (anchors, labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds) diff --git a/otx/algorithms/detection/adapters/mmdet/models/loss_dyns.py b/otx/algorithms/detection/adapters/mmdet/models/loss_dyns.py new file mode 100644 index 00000000000..d194c9eb0f2 --- /dev/null +++ b/otx/algorithms/detection/adapters/mmdet/models/loss_dyns.py @@ -0,0 +1,41 @@ +"""Utililty classes for tracking loss dynamics.""" +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from enum import IntEnum + + +class TrackingLossType(IntEnum): + """Type of loss functions to track.""" + + cls = 0 + bbox = 1 + centerness = 2 + + +class LossAccumulator: + """Accumulate for tracking loss dynamics.""" + + def __init__(self): + self.sum = 0.0 + self.cnt = 0 + + def add(self, value): + """Add loss value to itself.""" + if isinstance(value, float): + self.sum += value + self.cnt += 1 + elif isinstance(value, LossAccumulator): + self.sum += value.sum + self.cnt += value.cnt + else: + raise NotImplementedError() + + @property + def mean(self): + """Obtain mean from the accumulated values.""" + if self.cnt == 0: + return 0.0 + + return self.sum / self.cnt diff --git a/otx/algorithms/detection/adapters/mmdet/task.py b/otx/algorithms/detection/adapters/mmdet/task.py index 892651902bd..34b0857f9a6 100644 --- a/otx/algorithms/detection/adapters/mmdet/task.py +++ b/otx/algorithms/detection/adapters/mmdet/task.py @@ -86,6 +86,7 @@ from otx.api.entities.task_environment import TaskEnvironment from otx.api.serialization.label_mapper import label_schema_to_bytes from otx.core.data import caching +from otx.core.data.noisy_label_detection import LossDynamicsTrackingHook logger = get_logger() @@ -149,6 +150,10 @@ def _init_task(self, dataset: Optional[DatasetEntity] = None, export: bool = Fal # Update recipe with caching modules self._update_caching_modules(self._recipe_cfg.data) + # Loss dynamics tracking + if getattr(self._hyperparams.algo_backend, "enable_noisy_label_detection", False): + LossDynamicsTrackingHook.configure_recipe(self._recipe_cfg, self._output_path) + logger.info("initialized.") def build_model( @@ -367,7 +372,6 @@ def _infer_model( time_monitor = [hook.time_monitor for hook in cfg.custom_hooks if hook.type == "OTXProgressHook"] time_monitor = time_monitor[0] if time_monitor else None if time_monitor is not None: - # pylint: disable=unused-argument def pre_hook(module, inp): time_monitor.on_test_batch_begin(None, None) @@ -564,7 +568,6 @@ def _explain_model( time_monitor = [hook.time_monitor for hook in cfg.custom_hooks if hook.type == "OTXProgressHook"] time_monitor = time_monitor[0] if time_monitor else None if time_monitor is not None: - # pylint: disable=unused-argument def pre_hook(module, inp): time_monitor.on_test_batch_begin(None, None) diff --git a/otx/algorithms/detection/adapters/mmdet/utils/config_utils.py b/otx/algorithms/detection/adapters/mmdet/utils/config_utils.py index fe3e77c3592..791de1c8c74 100644 --- a/otx/algorithms/detection/adapters/mmdet/utils/config_utils.py +++ b/otx/algorithms/detection/adapters/mmdet/utils/config_utils.py @@ -220,7 +220,7 @@ def patch_datasets( def update_pipeline(cfg): if subset == "train": for collect_cfg in get_configs_by_pairs(cfg, dict(type="Collect")): - get_meta_keys(collect_cfg) + get_meta_keys(collect_cfg, ["gt_ann_ids"]) for cfg_ in get_configs_by_pairs(cfg, dict(type="LoadImageFromFile")): cfg_.type = "LoadImageFromOTXDataset" for cfg_ in get_configs_by_pairs(cfg, dict(type="LoadAnnotations")): diff --git a/otx/algorithms/detection/configs/base/configuration.py b/otx/algorithms/detection/configs/base/configuration.py index 403a969f914..0e258f7ed8c 100644 --- a/otx/algorithms/detection/configs/base/configuration.py +++ b/otx/algorithms/detection/configs/base/configuration.py @@ -23,6 +23,8 @@ selectable, string_attribute, ) +from otx.api.configuration.elements.primitive_parameters import configurable_boolean +from otx.api.configuration.enums.model_lifecycle import ModelLifecycle # pylint: disable=invalid-name @@ -74,6 +76,15 @@ class __AlgoBackend(BaseConfig.BaseAlgoBackendParameters): header = string_attribute("Parameters for the MPA algo-backend") description = header + enable_noisy_label_detection = configurable_boolean( + default_value=False, + header="Enable loss dynamics tracking for noisy label detection", + description="Set to True to enable loss dynamics tracking for each sample to detect noisy labeled samples.", + editable=False, + visible_in_ui=False, + affects_outcome_of=ModelLifecycle.TRAINING, + ) + @attrs class __TilingParameters(BaseConfig.BaseTilingParameters): header = string_attribute("Tiling Parameters") diff --git a/otx/algorithms/detection/configs/detection/configuration.yaml b/otx/algorithms/detection/configs/detection/configuration.yaml index fd982489929..79db40f6299 100644 --- a/otx/algorithms/detection/configs/detection/configuration.yaml +++ b/otx/algorithms/detection/configs/detection/configuration.yaml @@ -317,6 +317,21 @@ algo_backend: type: UI_RULES visible_in_ui: false warning: null + enable_noisy_label_detection: + affects_outcome_of: TRAINING + default_value: false + description: Set to True to enable loss dynamics tracking for each sample to detect noisy labeled samples. + editable: true + header: Enable loss dynamics tracking for noisy label detection + type: BOOLEAN + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + value: true + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/api/entities/dataset_item.py b/otx/api/entities/dataset_item.py index 51e47edab2a..7975a6a5436 100644 --- a/otx/api/entities/dataset_item.py +++ b/otx/api/entities/dataset_item.py @@ -249,6 +249,7 @@ def get_annotations( labels: Optional[List[LabelEntity]] = None, include_empty: bool = False, include_ignored: bool = False, + preserve_id: bool = False, ) -> List[Annotation]: """Returns a list of annotations that exist in the dataset item (wrt. ROI). @@ -259,6 +260,7 @@ def get_annotations( the ROI are returned. include_empty (bool): if True, returns both empty and non-empty labels include_ignored (bool): if True, includes the labels in ignored_labels + preserve_id (bool): if True, preserve the annotation id when copying Returns: List[Annotation]: The intersection of the input label set and those present within the ROI @@ -300,7 +302,13 @@ def get_annotations( # without tampering with the original shape. shape = copy.deepcopy(annotation.shape) - annotations.append(Annotation(shape=shape, labels=shape_labels)) + annotations.append( + Annotation( + shape=shape, + labels=shape_labels, + id=annotation.id_ if preserve_id else None, + ) + ) return annotations def append_annotations(self, annotations: Sequence[Annotation]): diff --git a/otx/core/data/adapter/detection_dataset_adapter.py b/otx/core/data/adapter/detection_dataset_adapter.py index cd9cbf290ff..a6ce1b2bce5 100644 --- a/otx/core/data/adapter/detection_dataset_adapter.py +++ b/otx/core/data/adapter/detection_dataset_adapter.py @@ -9,7 +9,7 @@ from datumaro.components.annotation import AnnotationType as DatumAnnotationType -from otx.api.entities.dataset_item import DatasetItemEntity +from otx.api.entities.dataset_item import DatasetItemEntityWithID from otx.api.entities.datasets import DatasetEntity from otx.api.entities.image import Image from otx.api.entities.model_template import TaskType @@ -28,7 +28,7 @@ def get_otx_dataset(self) -> DatasetEntity: # Prepare label information label_information = self._prepare_label_information(self.dataset) self.label_entities = label_information["label_entities"] - dataset_items: List[DatasetItemEntity] = [] + dataset_items: List[DatasetItemEntityWithID] = [] used_labels: List[int] = [] for subset, subset_data in self.dataset.items(): for _, datumaro_items in subset_data.subsets().items(): @@ -37,7 +37,6 @@ def get_otx_dataset(self) -> DatasetEntity: assert isinstance(image, Image) shapes = [] for ann in datumaro_item.annotations: - if ( self.task_type in (TaskType.INSTANCE_SEGMENTATION, TaskType.ROTATED_DETECTION) and ann.type == DatumAnnotationType.polygon @@ -56,7 +55,12 @@ def get_otx_dataset(self) -> DatasetEntity: or subset == Subset.UNLABELED or (subset != Subset.TRAINING and len(datumaro_item.annotations) == 0) ): - dataset_item = DatasetItemEntity(image, self._get_ann_scene_entity(shapes), subset=subset) + dataset_item = DatasetItemEntityWithID( + image, + self._get_ann_scene_entity(shapes), + subset=subset, + id_=datumaro_item.id, + ) dataset_items.append(dataset_item) self.remove_unused_label_entities(used_labels) return DatasetEntity(items=dataset_items) diff --git a/otx/core/data/noisy_label_detection/base.py b/otx/core/data/noisy_label_detection/base.py index 55df59f2b77..114ad728854 100644 --- a/otx/core/data/noisy_label_detection/base.py +++ b/otx/core/data/noisy_label_detection/base.py @@ -3,9 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 # -from collections import defaultdict -from typing import Any, Dict, List +from typing import List, Optional +import datumaro as dm + +from otx.api.entities.dataset_item import DatasetItemEntityWithID from otx.api.entities.datasets import DatasetEntity __all__ = ["LossDynamicsTracker", "LossDynamicsTrackingMixin"] @@ -14,14 +16,40 @@ class LossDynamicsTracker: """Class to track loss dynamics and export it to Datumaro format.""" + TASK_NAME: Optional[str] = None + def __init__(self) -> None: self.initialized = False def init_with_otx_dataset(self, otx_dataset: DatasetEntity) -> None: """DatasetEntity should be injected to the tracker for the initialization.""" - self._loss_dynamics: Dict[Any, List] = defaultdict(list) + otx_labels = otx_dataset.get_labels() + label_categories = dm.LabelCategories.from_iterable([label_entity.name for label_entity in otx_labels]) + self.otx_label_map = {label_entity.id_: idx for idx, label_entity in enumerate(otx_labels)} + + self._export_dataset = dm.Dataset.from_iterable( + [ + dm.DatasetItem( + id=item.id_, + subset="train", + media=dm.Image.from_file(path=item.media.path, size=(item.media.height, item.media.width)) + if item.media.path + else dm.Image.from_numpy( + data=getattr(item.media, "_Image__data"), size=(item.media.height, item.media.width) + ), + annotations=self._convert_anns(item), + ) + for item in otx_dataset + ], + infos={"purpose": "noisy_label_detection", "task": self.TASK_NAME}, + categories={dm.AnnotationType.label: label_categories}, + ) + self.initialized = True + def _convert_anns(self, item: DatasetItemEntityWithID) -> List[dm.Annotation]: + raise NotImplementedError() + def accumulate(self, outputs, iter) -> None: """Accumulate training loss dynamics for each training step.""" raise NotImplementedError() diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/dense_heads/conftest.py b/tests/unit/algorithms/detection/adapters/mmdet/models/dense_heads/conftest.py new file mode 100644 index 00000000000..46fb0d73bd1 --- /dev/null +++ b/tests/unit/algorithms/detection/adapters/mmdet/models/dense_heads/conftest.py @@ -0,0 +1,90 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from typing import Dict + +import mmcv +import pytest +import torch +from mmdet.models.builder import build_head + +from otx.algorithms.detection.adapters.mmdet.models.heads import * + + +@pytest.fixture +def fxt_head_input( + img_size=256, + n_bboxes=3, + n_classes=4, + batch_size=2, + n_channels=64, +): + img_metas = [ + {"img_shape": (img_size, img_size, 3), "scale_factor": 1, "pad_shape": (img_size, img_size, 3)} + for _ in range(batch_size) + ] + + def _gen_gt_bboxes(): + gt_bboxes = torch.rand(size=[n_bboxes, 4]) + gt_bboxes[:, :2] = img_size * 0.5 * gt_bboxes[:, :2] + gt_bboxes[:, 2:] = img_size * (0.5 * gt_bboxes[:, 2:] + 0.5) + return gt_bboxes.clamp(0, img_size) + + feat = [ + torch.rand(batch_size, n_channels, img_size // feat_size, img_size // feat_size) + for feat_size in [4, 8, 16, 32, 64] + ] + gt_bboxes = [_gen_gt_bboxes() for _ in range(batch_size)] + gt_labels = [torch.randint(0, n_classes, size=(n_bboxes,)) for _ in range(batch_size)] + gt_bboxes_ignore = None + return feat, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore + + +@pytest.fixture +def fxt_cfg_atss_head(n_classes=4, n_channels=64) -> Dict: + train_cfg = mmcv.Config( + dict( + assigner=dict(type="ATSSAssigner", topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False, + ) + ) + + head_cfg = dict( + type="CustomATSSHead", + num_classes=n_classes, + in_channels=n_channels, + feat_channels=n_channels, + anchor_generator=dict( + type="AnchorGenerator", + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128], + ), + bbox_coder=dict( + type="DeltaXYWHBBoxCoder", + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[0.1, 0.1, 0.2, 0.2], + ), + loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), + loss_bbox=dict(type="GIoULoss", loss_weight=2.0), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), + use_qfl=False, + qfl_cfg=dict( + type="QualityFocalLoss", + use_sigmoid=True, + beta=2.0, + loss_weight=1.0, + ), + train_cfg=train_cfg, + ) + + return head_cfg + + +@pytest.fixture +def fxt_atss_head(fxt_cfg_atss_head: Dict) -> CustomATSSHead: + return build_head(fxt_cfg_atss_head) diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/dense_heads/test_loss_dynamics_tracking_heads.py b/tests/unit/algorithms/detection/adapters/mmdet/models/dense_heads/test_loss_dynamics_tracking_heads.py new file mode 100644 index 00000000000..35bb7241e26 --- /dev/null +++ b/tests/unit/algorithms/detection/adapters/mmdet/models/dense_heads/test_loss_dynamics_tracking_heads.py @@ -0,0 +1,53 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + + +from typing import Dict, Tuple + +import pytest +import torch +from mmdet.models.builder import build_head + +from otx.algorithms.detection.adapters.mmdet.models.heads import ( + CustomATSSHead, + CustomATSSHeadTrackingLossDynamics, +) + + +class TestLossDynamicsTrackingHeads: + @pytest.fixture(scope="class", autouse=True) + def set_seed(self): + torch.random.manual_seed(3003) + + @pytest.fixture + def fxt_atss_head_with_tracking_loss( + self, fxt_atss_head: CustomATSSHead, fxt_cfg_atss_head: Dict + ) -> Tuple[CustomATSSHead, CustomATSSHeadTrackingLossDynamics]: + fxt_cfg_atss_head["type"] = fxt_cfg_atss_head["type"] + "TrackingLossDynamics" + + atss_head_with_tracking_loss = build_head(fxt_cfg_atss_head) + # Copy-paste atss_head's weights + atss_head_with_tracking_loss.load_state_dict(fxt_atss_head.state_dict()) + return fxt_atss_head, atss_head_with_tracking_loss + + @torch.no_grad() + def test_output_equivalance(self, fxt_atss_head_with_tracking_loss, fxt_head_input): + atss_head, atss_head_with_tracking_loss = fxt_atss_head_with_tracking_loss + feat, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore = fxt_head_input + + scores = atss_head_with_tracking_loss.forward(feat) + expected_scores = atss_head.forward(feat) + + for actual, expected in zip(scores, expected_scores): + # actual, expected are list (# of feature pyramid level) + for a, e in zip(actual, expected): + assert torch.allclose(a, e) + + losses = atss_head_with_tracking_loss.loss(*scores, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore) + expected_losses = atss_head.loss(*expected_scores, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore) + + for actual, expected in zip(losses.values(), expected_losses.values()): + # actual, expected are list (# of feature pyramid level) + for a, e in zip(actual, expected): + assert torch.allclose(a, e) diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py new file mode 100644 index 00000000000..1b0dfda47c3 --- /dev/null +++ b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py @@ -0,0 +1,95 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import os.path as osp +import uuid +from typing import Dict + +import datumaro as dm +import mmcv +import numpy as np +import pytest +from mmdet.datasets import build_dataloader as mmdet_build_dataloader +from mmdet.datasets import build_dataset as mmdet_build_dataset + +from otx.algorithms.common.adapters.mmcv.utils.builder import ( + build_dataloader, + build_dataset, +) +from otx.api.entities.datasets import DatasetEntity +from otx.api.entities.model_template import TaskType +from otx.core.data.adapter.detection_dataset_adapter import DetectionDatasetAdapter + + +@pytest.fixture +def fxt_cfg_custom_atss(num_classes: int = 2) -> Dict: + train_cfg = mmcv.Config( + dict( + assigner=dict(type="ATSSAssigner", topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False, + ) + ) + cfg = dict( + type="CustomATSS", + backbone=dict( + avg_down=False, + base_channels=64, + conv_cfg=None, + dcn=None, + deep_stem=False, + depth=18, + dilations=(1, 1, 1, 1), + frozen_stages=-1, + in_channels=3, + init_cfg=None, + norm_cfg=dict(requires_grad=True, type="BN"), + norm_eval=True, + num_stages=4, + out_indices=(0, 1, 2, 3), + plugins=None, + pretrained=None, + stage_with_dcn=(False, False, False, False), + stem_channels=None, + strides=(1, 2, 2, 2), + style="pytorch", + type="mmdet.ResNet", + with_cp=False, + zero_init_residual=True, + ), + neck=dict( + type="FPN", + in_channels=[64, 128, 256, 512], + out_channels=64, + start_level=1, + add_extra_convs="on_output", + num_outs=5, + relu_before_extra_convs=True, + ), + bbox_head=dict( + type="CustomATSSHead", + num_classes=num_classes, + in_channels=64, + stacked_convs=4, + feat_channels=64, + anchor_generator=dict( + type="AnchorGenerator", + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128], + ), + bbox_coder=dict( + type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2] + ), + loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), + loss_bbox=dict(type="GIoULoss", loss_weight=2.0), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), + use_qfl=False, + qfl_cfg=dict(type="QualityFocalLoss", use_sigmoid=True, beta=2.0, loss_weight=1.0), + ), + train_cfg=train_cfg, + ) + return cfg diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_atss_detector.py b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_atss_detector.py index 7d6bd9cbcc5..55bd05f200a 100644 --- a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_atss_detector.py +++ b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_atss_detector.py @@ -1,3 +1,4 @@ +from typing import Dict import torch from mmdet.models.builder import build_detector @@ -9,68 +10,8 @@ class TestCustomATSS: @e2e_pytest_unit - def test_custom_atss_build(self): - model_cfg = dict( - type="CustomATSS", - backbone=dict( - avg_down=False, - base_channels=64, - conv_cfg=None, - dcn=None, - deep_stem=False, - depth=18, - dilations=(1, 1, 1, 1), - frozen_stages=-1, - in_channels=3, - init_cfg=None, - norm_cfg=dict(requires_grad=True, type="BN"), - norm_eval=True, - num_stages=4, - out_indices=(0, 1, 2, 3), - plugins=None, - pretrained=None, - stage_with_dcn=(False, False, False, False), - stem_channels=None, - strides=(1, 2, 2, 2), - style="pytorch", - type="mmdet.ResNet", - with_cp=False, - zero_init_residual=True, - ), - neck=dict( - type="FPN", - in_channels=[64, 128, 256, 512], - out_channels=64, - start_level=1, - add_extra_convs="on_output", - num_outs=5, - relu_before_extra_convs=True, - ), - bbox_head=dict( - type="CustomATSSHead", - num_classes=2, - in_channels=64, - stacked_convs=4, - feat_channels=64, - anchor_generator=dict( - type="AnchorGenerator", - ratios=[1.0], - octave_base_scale=8, - scales_per_octave=1, - strides=[8, 16, 32, 64, 128], - ), - bbox_coder=dict( - type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2] - ), - loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), - loss_bbox=dict(type="GIoULoss", loss_weight=2.0), - loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), - use_qfl=False, - qfl_cfg=dict(type="QualityFocalLoss", use_sigmoid=True, beta=2.0, loss_weight=1.0), - ), - ) - - model = build_detector(model_cfg) + def test_custom_atss_build(self, fxt_cfg_custom_atss: Dict): + model = build_detector(fxt_cfg_custom_atss) assert isinstance(model, CustomATSS) @e2e_pytest_unit diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_loss_dynamics_tracking.py b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_loss_dynamics_tracking.py new file mode 100644 index 00000000000..c8889c55c2a --- /dev/null +++ b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_loss_dynamics_tracking.py @@ -0,0 +1,117 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +import os.path as osp +from typing import Any, Dict, Type + +import datumaro as dm +import numpy as np +import pytest +import torch +from mmcv import ConfigDict +from mmdet.datasets import build_dataloader, build_dataset +from mmdet.models.builder import build_detector + +from otx.algorithms.detection.adapters.mmdet.models.detectors import CustomATSS +from otx.algorithms.detection.adapters.mmdet.models.loss_dyns import TrackingLossType +from otx.api.entities.datasets import DatasetEntity +from otx.api.entities.label import Domain + + +class TestLossDynamicsTrackingMixin: + @pytest.fixture() + def dataloader(self, fxt_det_dataset_entity: DatasetEntity): + img_size = 256 + dataloader_cfg = dict(samples_per_gpu=len(fxt_det_dataset_entity), workers_per_gpu=1) + dataset_cfg = ConfigDict( + dict( + type="OTXDetDataset", + pipeline=[ + dict(type="LoadImageFromOTXDataset"), + dict( + type="LoadAnnotationFromOTXDataset", + with_bbox=True, + with_mask=False, + domain=Domain.DETECTION, + min_size=-1, + ), + dict(type="RandomFlip", flip_ratio=0.5), + dict(type="DefaultFormatBundle"), + dict( + type="Collect", + keys=["img", "gt_bboxes", "gt_labels"], + meta_keys=( + "filename", + "ori_shape", + "img_shape", + "pad_shape", + "scale_factor", + "flip", + "img_norm_cfg", + "gt_ann_ids", + ), + ), + ], + otx_dataset=fxt_det_dataset_entity, + labels=fxt_det_dataset_entity.get_labels(), + domain=Domain.DETECTION, + ) + ) + + dataset = build_dataset(dataset_cfg) + dataloader = build_dataloader(dataset, **dataloader_cfg) + + return dataloader + + @pytest.fixture + def fxt_custom_atss(self, fxt_cfg_custom_atss: Dict, fxt_det_dataset_entity: DatasetEntity) -> CustomATSS: + fxt_cfg_custom_atss["track_loss_dynamics"] = True + + detector = build_detector(fxt_cfg_custom_atss) + detector.loss_dyns_tracker.init_with_otx_dataset(fxt_det_dataset_entity) + return detector + + @pytest.fixture() + def detector(self, request: Type[pytest.FixtureRequest]): + return request.getfixturevalue(request.param) + + TESTCASE = ["fxt_custom_atss"] + + @torch.no_grad() + @pytest.mark.parametrize("detector", TESTCASE, indirect=True) + def test_train_step(self, detector, dataloader: Dict[str, Any], tmp_dir_path: str): + for data in dataloader: + outputs = detector.train_step({k: v.data[0] for k, v in data.items()}, None) + + output_keys = {key for key in outputs.keys()} + for loss_type in detector.TRACKING_LOSS_TYPE: + assert loss_type in output_keys + + n_steps = 3 + for iter in range(n_steps): + detector.loss_dyns_tracker.accumulate(outputs, iter) + + export_dir = osp.join(tmp_dir_path, "noisy_label_detection") + detector.loss_dyns_tracker.export(export_dir) + + dataset = dm.Dataset.import_from(export_dir, format="datumaro") + + cnt = 0 + for item in dataset: + for ann in item.annotations: + has_attrs = False + for v in ann.attributes.values(): + assert set(list(ann.attributes.keys())) == { + "iters", + *[f"loss_dynamics_{loss_type.name}" for loss_type in detector.TRACKING_LOSS_TYPE], + } + assert len(v) == n_steps + has_attrs = True + if has_attrs: + cnt += 1 + + for loss_type, values in outputs.items(): + if loss_type in detector.TRACKING_LOSS_TYPE: + assert cnt == len( + values + ), "The number of accumulated statistics is equal to the number of Datumaro items which have attirbutes." diff --git a/tests/unit/algorithms/detection/conftest.py b/tests/unit/algorithms/detection/conftest.py index 2ea32f0969c..a036f17d07b 100644 --- a/tests/unit/algorithms/detection/conftest.py +++ b/tests/unit/algorithms/detection/conftest.py @@ -4,6 +4,8 @@ from otx.api.entities.datasets import DatasetEntity from otx.api.entities.label_schema import LabelSchemaEntity from otx.api.entities.model import ModelConfiguration, ModelEntity +from otx.api.entities.model_template import TaskType +from .test_helpers import generate_det_dataset @pytest.fixture @@ -13,3 +15,9 @@ def otx_model(): label_schema=LabelSchemaEntity(), ) return ModelEntity(train_dataset=DatasetEntity(), configuration=model_configuration) + + +@pytest.fixture(scope="session") +def fxt_det_dataset_entity(number_of_images: int = 8) -> DatasetEntity: + dataset, _ = generate_det_dataset(TaskType.DETECTION, number_of_images) + return dataset diff --git a/tests/unit/algorithms/detection/test_helpers.py b/tests/unit/algorithms/detection/test_helpers.py index 4b3bc73434f..0249bb218dd 100644 --- a/tests/unit/algorithms/detection/test_helpers.py +++ b/tests/unit/algorithms/detection/test_helpers.py @@ -12,7 +12,7 @@ from otx.algorithms.detection.utils import generate_label_schema from otx.api.entities.annotation import AnnotationSceneEntity, AnnotationSceneKind -from otx.api.entities.dataset_item import DatasetItemEntity +from otx.api.entities.dataset_item import DatasetItemEntityWithID from otx.api.entities.datasets import DatasetEntity from otx.api.entities.id import ID from otx.api.entities.image import Image @@ -91,7 +91,7 @@ def generate_det_dataset(task_type, number_of_images=1): anno.shape = ShapeFactory.shape_as_polygon(anno.shape) image = Image(data=image_numpy) annotation_scene = AnnotationSceneEntity(kind=AnnotationSceneKind.ANNOTATION, annotations=annos) - items.append(DatasetItemEntity(media=image, annotation_scene=annotation_scene, subset=subset)) + items.append(DatasetItemEntityWithID(media=image, annotation_scene=annotation_scene, subset=subset)) dataset = DatasetEntity(items) return dataset, dataset.get_labels()