Skip to content

Commit

Permalink
Add training loss dynamics exportation feature for detection task (#2109
Browse files Browse the repository at this point in the history
)

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim authored May 12, 2023
1 parent 5a3fed1 commit 1ece09e
Show file tree
Hide file tree
Showing 22 changed files with 940 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
# SPDX-License-Identifier: Apache-2.0
#

from collections import defaultdict
from typing import Any, Dict, List

import datumaro as dm
import numpy as np
import pandas as pd

from otx.algorithms.common.utils.logger import get_logger
from otx.api.entities.dataset_item import DatasetItemEntityWithID
from otx.api.entities.datasets import DatasetEntity
from otx.core.data.noisy_label_detection import LossDynamicsTracker, LossDynamicsTrackingMixin
from otx.core.data.noisy_label_detection import (
LossDynamicsTracker,
LossDynamicsTrackingMixin,
)

logger = get_logger()

Expand All @@ -27,42 +32,19 @@ def train_step(self, data, optimizer=None, **kwargs):
class MultiClassClsLossDynamicsTracker(LossDynamicsTracker):
"""Loss dynamics tracker for multi-class classification task."""

TASK_NAME = "OTX-MultiClassCls"

def __init__(self) -> None:
super().__init__()

def init_with_otx_dataset(self, otx_dataset: DatasetEntity[DatasetItemEntityWithID]) -> None:
"""DatasetEntity should be injected to the tracker for the initialization."""
otx_labels = otx_dataset.get_labels()
label_categories = dm.LabelCategories.from_iterable([label_entity.name for label_entity in otx_labels])
self.otx_label_map = {label_entity.id_: idx for idx, label_entity in enumerate(otx_labels)}

def _convert_anns(item: DatasetItemEntityWithID):
labels = [
dm.Label(label=self.otx_label_map[label.id_])
for ann in item.get_annotations()
for label in ann.get_labels()
]
return labels

self._export_dataset = dm.Dataset.from_iterable(
[
dm.DatasetItem(
id=item.id_,
subset="train",
media=dm.Image.from_file(path=item.media.path, size=(item.media.height, item.media.width))
if item.media.path
else dm.Image.from_numpy(
data=getattr(item.media, "_Image__data"), size=(item.media.height, item.media.width)
),
annotations=_convert_anns(item),
)
for item in otx_dataset
],
infos={"purpose": "noisy_label_detection", "task": "OTX-MultiClassCls"},
categories={dm.AnnotationType.label: label_categories},
)

super().init_with_otx_dataset(otx_dataset)
self._loss_dynamics: Dict[Any, List] = defaultdict(list)

def _convert_anns(self, item: DatasetItemEntityWithID):
labels = [
dm.Label(label=self.otx_label_map[label.id_])
for ann in item.get_annotations()
for label in ann.get_labels()
]
return labels

def accumulate(self, outputs, iter) -> None:
"""Accumulate training loss dynamics for each training step."""
Expand Down
8 changes: 6 additions & 2 deletions otx/algorithms/detection/adapters/mmdet/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def get_annotation_mmdet_format(
gt_bboxes = []
gt_labels = []
gt_polygons = []
gt_ann_ids = []

label_idx = {label.id: i for i, label in enumerate(labels)}

for annotation in dataset_item.get_annotations(labels=labels, include_empty=False):

for annotation in dataset_item.get_annotations(labels=labels, include_empty=False, preserve_id=True):
box = ShapeFactory.shape_as_rectangle(annotation.shape)

if min(box.width * width, box.height * height) < min_size:
Expand All @@ -80,18 +80,22 @@ def get_annotation_mmdet_format(
polygon = np.array([p for point in polygon.points for p in [point.x * width, point.y * height]])
gt_polygons.extend([[polygon] for _ in range(n)])
gt_labels.extend(class_indices)
item_id = getattr(dataset_item, "id_", None)
gt_ann_ids.append((item_id, annotation.id_))

if len(gt_bboxes) > 0:
ann_info = dict(
bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
labels=np.array(gt_labels, dtype=int),
masks=PolygonMasks(gt_polygons, height=height, width=width) if gt_polygons else [],
ann_ids=gt_ann_ids,
)
else:
ann_info = dict(
bboxes=np.zeros((0, 4), dtype=np.float32),
labels=np.array([], dtype=int),
masks=[],
ann_ids=[],
)
return ann_info

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
def _load_bboxes(results, ann_info):
results["bbox_fields"].append("gt_bboxes")
results["gt_bboxes"] = copy.deepcopy(ann_info["bboxes"])
results["gt_ann_ids"] = copy.deepcopy(ann_info["ann_ids"])
return results

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)
from otx.algorithms.detection.adapters.mmdet.models.loss_dyns import TrackingLossType

from .l2sp_detector_mixin import L2SPDetectorMixin
from .loss_dynamics_mixin import DetLossDynamicsTrackingMixin
from .sam_detector_mixin import SAMDetectorMixin

logger = get_logger()
Expand All @@ -29,9 +31,11 @@


@DETECTORS.register_module()
class CustomATSS(SAMDetectorMixin, L2SPDetectorMixin, ATSS):
class CustomATSS(SAMDetectorMixin, DetLossDynamicsTrackingMixin, L2SPDetectorMixin, ATSS):
"""SAM optimizer & L2SP regularizer enabled custom ATSS."""

TRACKING_LOSS_TYPE = (TrackingLossType.cls, TrackingLossType.bbox, TrackingLossType.centerness)

def __init__(self, *args, task_adapt=None, **kwargs):
super().__init__(*args, **kwargs)

Expand All @@ -46,10 +50,6 @@ def __init__(self, *args, task_adapt=None, **kwargs):
)
)

def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None, **kwargs):
"""Forward function for CustomATSS."""
return super().forward_train(img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=gt_bboxes_ignore)

@staticmethod
def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, prefix, *args, **kwargs):
"""Modify input state_dict according to class name matching before weight loading."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""LossDynamics Mix-in for detection tasks."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from collections import defaultdict
from typing import Dict, Sequence, Tuple

import datumaro as dm
import numpy as np
import pandas as pd

from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.detection.adapters.mmdet.models.loss_dyns import TrackingLossType
from otx.api.entities.dataset_item import DatasetItemEntityWithID
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.shapes.rectangle import Rectangle
from otx.core.data.noisy_label_detection import (
LossDynamicsTracker,
LossDynamicsTrackingMixin,
)

logger = get_logger()


class DetLossDynamicsTracker(LossDynamicsTracker):
"""Loss dynamics tracker for detection tasks."""

TASK_NAME = "OTX-Det"

def __init__(self, tracking_loss_types: Sequence[TrackingLossType]) -> None:
super().__init__()
self._loss_dynamics: Dict[TrackingLossType, Dict] = {
loss_type: defaultdict(list) for loss_type in tracking_loss_types
}

def _convert_anns(self, item: DatasetItemEntityWithID):
labels = []

cnt = 0
for ann in item.get_annotations(preserve_id=True):
if isinstance(ann.shape, Rectangle):
for label in ann.get_labels():
bbox = dm.Bbox(
x=ann.shape.x1 * item.width,
y=ann.shape.y1 * item.height,
w=ann.shape.width * item.width,
h=ann.shape.height * item.height,
label=self.otx_label_map[label.id_],
id=cnt,
)
labels.append(bbox)
self.otx_ann_id_to_dm_ann_map[(item.id_, ann.id_)] = bbox
cnt += 1

return labels

def init_with_otx_dataset(self, otx_dataset: DatasetEntity[DatasetItemEntityWithID]) -> None:
"""DatasetEntity should be injected to the tracker for the initialization."""
self.otx_ann_id_to_dm_ann_map: Dict[Tuple[str, str], dm.Bbox] = {}
super().init_with_otx_dataset(otx_dataset)

def accumulate(self, outputs, iter) -> None:
"""Accumulate training loss dynamics for each training step."""
for key, loss_dyns in outputs.items():
if isinstance(key, TrackingLossType):
for (entity_id, ann_id), value in loss_dyns.items():
self._loss_dynamics[key][(entity_id, ann_id)].append((iter, value))

def export(self, output_path: str) -> None:
"""Export loss dynamics statistics to Datumaro format."""
dfs = [
pd.DataFrame.from_dict(
{
k: (np.array([iter for iter, _ in arr]), np.array([value for _, value in arr]))
for k, arr in loss_dyns.items()
},
orient="index",
columns=["iters", f"loss_dynamics_{key.name}"],
)
for key, loss_dyns in self._loss_dynamics.items()
]
df = pd.concat(dfs, axis=1)
df = df.loc[:, ~df.columns.duplicated()]

for (entity_id, ann_id), row in df.iterrows():
ann = self.otx_ann_id_to_dm_ann_map.get((entity_id, ann_id), None)
if ann:
ann.attributes = row.to_dict()

self._export_dataset.export(output_path, format="datumaro")


class DetLossDynamicsTrackingMixin(LossDynamicsTrackingMixin):
"""Mix-in to track loss dynamics during training for classification tasks."""

TRACKING_LOSS_TYPE: Tuple[TrackingLossType, ...] = ()

def __init__(self, track_loss_dynamics: bool = False, **kwargs):
if track_loss_dynamics:
head_cfg = kwargs.get("bbox_head", None)
head_type = head_cfg.get("type", None)
assert head_type is not None, "head_type should be specified from the config."
new_head_type = head_type + "TrackingLossDynamics"
head_cfg["type"] = new_head_type
logger.info(f"Replace head_type from {head_type} to {new_head_type}.")

super().__init__(**kwargs)

# This should be called after super().__init__(),
# since LossDynamicsTrackingMixin.__init__() creates self._loss_dyns_tracker
self._loss_dyns_tracker = DetLossDynamicsTracker(self.TRACKING_LOSS_TYPE)

def train_step(self, data, optimizer):
"""The iteration step during training."""

outputs = super().train_step(data, optimizer)

if self.loss_dyns_tracker.initialized:
gt_ann_ids = [item["gt_ann_ids"] for item in data["img_metas"]]

to_update = {}
for key, loss_dyns in self.bbox_head.loss_dyns.items():
to_update[key] = {}
for (batch_idx, bbox_idx), value in loss_dyns.items():
entity_id, ann_id = gt_ann_ids[batch_idx][bbox_idx]
to_update[key][(entity_id, ann_id)] = value.mean

outputs.update(to_update)

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .cross_dataset_detector_head import CrossDatasetDetectorHead
from .custom_anchor_generator import SSDAnchorGeneratorClustered
from .custom_atss_head import CustomATSSHead
from .custom_atss_head import CustomATSSHead, CustomATSSHeadTrackingLossDynamics
from .custom_retina_head import CustomRetinaHead
from .custom_roi_head import CustomRoIHead
from .custom_ssd_head import CustomSSDHead
Expand All @@ -21,4 +21,6 @@
"CustomRoIHead",
"CustomVFNetHead",
"CustomYOLOXHead",
# Loss dynamics tracking
"CustomATSSHeadTrackingLossDynamics",
]
Loading

0 comments on commit 1ece09e

Please sign in to comment.