Skip to content

Commit

Permalink
Introduce the accuracy and change meta_info -> label_info (#2994)
Browse files Browse the repository at this point in the history
* Introduce the accuracy and change label_info -> meta_info

* Edit the accuracy class to remove the label_info from the init_args

* Add condition for the ovmodel and add tests to check the update of metric

* Add reset_prediction_layers to some models

* Add label_info to OVModel

* Change all meta related variables

* Enhance the docstring

* Edit the NamedConfusionMatrix

* Change the name of accuracy
  • Loading branch information
sungmanc authored Feb 28, 2024
1 parent 8ac9f73 commit f9aeb73
Show file tree
Hide file tree
Showing 36 changed files with 672 additions and 215 deletions.
26 changes: 13 additions & 13 deletions src/otx/algo/classification/heads/custom_hlabel_cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import nn

if TYPE_CHECKING:
from otx.core.data.entity.classification import HLabelInfo
from otx.core.data.entity.classification import HLabelData


@MODELS.register_module()
Expand Down Expand Up @@ -74,19 +74,19 @@ def forward(self, feats: tuple[torch.Tensor]) -> torch.Tensor:
pre_logits = self.pre_logits(feats)
return self.fc(pre_logits)

def set_hlabel_info(self, hlabel_info: HLabelInfo) -> None:
def set_hlabel_data(self, hlabel_data: HLabelData) -> None:
"""Set hlabel information."""
self.hlabel_info = hlabel_info
self.hlabel_data = hlabel_data

def _get_gt_label(self, data_samples: list[DataSample]) -> torch.Tensor:
"""Get gt labels from data samples."""
return torch.stack([data_sample.gt_label for data_sample in data_samples])

def _get_head_idx_to_logits_range(self, hlabel_info: HLabelInfo, idx: int) -> tuple[int, int]:
def _get_head_idx_to_logits_range(self, hlabel_data: HLabelData, idx: int) -> tuple[int, int]:
"""Get head_idx_to_logits_range information from hlabel information."""
return (
hlabel_info.head_idx_to_logits_range[str(idx)][0],
hlabel_info.head_idx_to_logits_range[str(idx)][1],
hlabel_data.head_idx_to_logits_range[str(idx)][0],
hlabel_data.head_idx_to_logits_range[str(idx)][1],
)

def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwargs) -> dict:
Expand All @@ -113,9 +113,9 @@ def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwa
# Multiclass loss
num_effective_heads_in_batch = 0 # consider the label removal case
for i in range(self.num_multiclass_heads):
if i not in self.hlabel_info.empty_multiclass_head_indices:
if i not in self.hlabel_data.empty_multiclass_head_indices:
head_gt = gt_labels[:, i]
logit_range = self._get_head_idx_to_logits_range(self.hlabel_info, i)
logit_range = self._get_head_idx_to_logits_range(self.hlabel_data, i)
head_logits = cls_scores[:, logit_range[0] : logit_range[1]]
valid_mask = head_gt >= 0

Expand All @@ -130,15 +130,15 @@ def loss(self, feats: tuple[torch.Tensor], data_samples: list[DataSample], **kwa

# Multilabel loss
if self.num_multilabel_classes > 0:
head_gt = gt_labels[:, self.hlabel_info.num_multiclass_heads :]
head_logits = cls_scores[:, self.hlabel_info.num_single_label_classes :]
head_gt = gt_labels[:, self.hlabel_data.num_multiclass_heads :]
head_logits = cls_scores[:, self.hlabel_data.num_single_label_classes :]
valid_mask = head_gt > 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0:
img_metas = [data_sample.metainfo for data_sample in data_samples]
head_logits = head_logits[valid_mask]
valid_label_mask = self.get_valid_label_mask(img_metas).to(head_logits.device)
valid_label_mask = valid_label_mask[:, self.hlabel_info.num_single_label_classes :]
valid_label_mask = valid_label_mask[:, self.hlabel_data.num_single_label_classes :]
valid_label_mask = valid_label_mask[valid_mask]
losses["loss"] += self.multilabel_loss(head_logits, head_gt, valid_label_mask=valid_label_mask)

Expand Down Expand Up @@ -183,7 +183,7 @@ def _get_predictions(
multiclass_pred_scores = []
multiclass_pred_labels = []
for i in range(self.num_multiclass_heads):
logit_range = self._get_head_idx_to_logits_range(self.hlabel_info, i)
logit_range = self._get_head_idx_to_logits_range(self.hlabel_data, i)
multiclass_logit = cls_scores[:, logit_range[0] : logit_range[1]]
multiclass_pred = torch.softmax(multiclass_logit, dim=1)
multiclass_pred_score, multiclass_pred_label = torch.max(multiclass_pred, dim=1)
Expand All @@ -195,7 +195,7 @@ def _get_predictions(
multiclass_pred_labels = torch.cat(multiclass_pred_labels, dim=1)

if self.num_multilabel_classes > 0:
multilabel_logits = cls_scores[:, self.hlabel_info.num_single_label_classes :]
multilabel_logits = cls_scores[:, self.hlabel_data.num_single_label_classes :]

multilabel_pred_scores = torch.sigmoid(multilabel_logits)
multilabel_pred_labels = (multilabel_pred_scores >= self.thr).int()
Expand Down
2 changes: 1 addition & 1 deletion src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def instantiate_model(self, model_config: Namespace) -> tuple:

# Update num_classes
if not self.get_config_value(self.config_init, "disable_infer_num_classes", False):
num_classes = self.datamodule.meta_info.num_classes
num_classes = self.datamodule.label_info.num_classes
if num_classes != model_config.init_args.num_classes:
warning_msg = (
f"The `num_classes` in dataset is {num_classes} "
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
self.max_refetch = max_refetch
self.image_color_channel = image_color_channel
self.stack_images = stack_images
self.meta_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])
self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])

def __len__(self) -> int:
return len(self.ids)
Expand Down
26 changes: 13 additions & 13 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from otx.core.data.entity.classification import (
HlabelClsBatchDataEntity,
HlabelClsDataEntity,
HLabelInfo,
HLabelData,
MulticlassClsBatchDataEntity,
MulticlassClsDataEntity,
MultilabelClsBatchDataEntity,
Expand All @@ -28,10 +28,10 @@


@dataclass
class HLabelMetaInfo(LabelInfo):
class HLabelInfo(LabelInfo):
"""Meta information of hlabel classification."""

hlabel_info: HLabelInfo
hlabel_data: HLabelData


class OTXMulticlassClsDataset(OTXDataset[MulticlassClsDataEntity]):
Expand Down Expand Up @@ -117,14 +117,14 @@ def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.dm_categories = self.dm_subset.categories()[AnnotationType.label]

# Hlabel classification used HLabelMetaInfo to insert the HLabelInfo.
self.meta_info = HLabelMetaInfo(
# Hlabel classification used HLabelInfo to insert the HLabelData.
self.label_info = HLabelInfo(
label_names=[category.name for category in self.dm_categories],
label_groups=[label_group.labels for label_group in self.dm_categories.label_groups],
hlabel_info=HLabelInfo.from_dm_label_groups(self.dm_categories),
hlabel_data=HLabelData.from_dm_label_groups(self.dm_categories),
)

if self.meta_info.hlabel_info.num_multiclass_heads == 0:
if self.label_info.hlabel_data.num_multiclass_heads == 0:
msg = "The number of multiclass heads should be larger than 0."
raise ValueError(msg)

Expand Down Expand Up @@ -157,7 +157,7 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label
It converts the label format to h-label format.
i.e.
Let's assume that we used the same dataset with example of the definition of HLabelInfo
Let's assume that we used the same dataset with example of the definition of HLabelData
and the original labels are ["Rigid", "Panda", "Lion"].
Then, h-label format will be [1, -1, 0, 1, 1].
Expand All @@ -171,20 +171,20 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label
[Multilabel Head: [0, 1, 1]]
2, 3, 4 indices = [0, 1, 1] -> ["Circle"(X), "Lion"(O), "Panda"(O)]
"""
if not isinstance(self.meta_info, HLabelMetaInfo):
msg = f"The type of meta_info should be HLabelMetaInfo, got {type(self.meta_info)}."
if not isinstance(self.label_info, HLabelInfo):
msg = f"The type of label_info should be HLabelInfo, got {type(self.label_info)}."
raise TypeError(msg)

num_multiclass_heads = self.meta_info.hlabel_info.num_multiclass_heads
num_multilabel_classes = self.meta_info.hlabel_info.num_multilabel_classes
num_multiclass_heads = self.label_info.hlabel_data.num_multiclass_heads
num_multilabel_classes = self.label_info.hlabel_data.num_multilabel_classes

class_indices = [0] * (num_multiclass_heads + num_multilabel_classes)
for i in range(num_multiclass_heads):
class_indices[i] = -1

for ann in label_anns:
ann_name = self.dm_categories.items[ann.label].name
group_idx, in_group_idx = self.meta_info.hlabel_info.class_to_group_idx[ann_name]
group_idx, in_group_idx = self.label_info.hlabel_data.class_to_group_idx[ann_name]

if group_idx < num_multiclass_heads:
class_indices[group_idx] = in_group_idx
Expand Down
10 changes: 5 additions & 5 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@dataclass
class SegMetaInfo(LabelInfo):
class SegLabelInfo(LabelInfo):
"""Meta information of Semantic Segmentation."""

def __init__(self, label_names: list[str], label_groups: list[list[str]]) -> None:
Expand Down Expand Up @@ -64,15 +64,15 @@ def __init__(
image_color_channel,
stack_images,
)
self.meta_info = SegMetaInfo(
label_names=self.meta_info.label_names,
label_groups=self.meta_info.label_groups,
self.label_info = SegLabelInfo(
label_names=self.label_info.label_names,
label_groups=self.label_info.label_groups,
)

def _get_item_impl(self, index: int) -> SegDataEntity | None:
item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name)
img = item.media_as(Image)
num_classes = self.meta_info.num_classes
num_classes = self.label_info.num_classes
ignored_labels: list[int] = []
img_data, img_shape = self._get_img_data_and_shape(img)

Expand Down
11 changes: 4 additions & 7 deletions src/otx/core/data/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

from dataclasses import dataclass
from operator import itemgetter
from typing import TYPE_CHECKING, Any

import torch
Expand Down Expand Up @@ -179,7 +178,7 @@ class MultilabelClsBatchPredEntityWithXAI(MultilabelClsBatchDataEntity, OTXBatch


@dataclass
class HLabelInfo:
class HLabelData:
"""The label information represents the hierarchy.
All params should be kept since they're also used at the Model API side.
Expand Down Expand Up @@ -245,8 +244,8 @@ class HLabelInfo:
empty_multiclass_head_indices: list[int]

@classmethod
def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> HLabelInfo:
"""Generate HLabelInfo from the Datumaro LabelCategories.
def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> HLabelData:
"""Generate HLabelData from the Datumaro LabelCategories.
Args:
dm_label_categories (LabelCategories): the label categories of datumaro.
Expand All @@ -255,7 +254,6 @@ def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> HLabelInf
def get_exclusive_group_info(all_groups: list[Label | list[Label]]) -> dict[str, Any]:
"""Get exclusive group information."""
exclusive_groups = [g for g in all_groups if len(g) > 1]
exclusive_groups.sort(key=itemgetter(0))

last_logits_pos = 0
num_single_label_classes = 0
Expand All @@ -282,7 +280,6 @@ def get_single_label_group_info(
) -> dict[str, Any]:
"""Get single label group information."""
single_label_groups = [g for g in all_groups if len(g) == 1]
single_label_groups.sort(key=itemgetter(0))

class_to_idx = {}

Expand Down Expand Up @@ -324,7 +321,7 @@ def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str
single_label_group_info["class_to_idx"],
)

return HLabelInfo(
return HLabelData(
num_multiclass_heads=exclusive_group_info["num_multiclass_heads"],
num_multilabel_classes=single_label_group_info["num_multilabel_classes"],
head_idx_to_logits_range=exclusive_group_info["head_idx_to_logits_range"],
Expand Down
12 changes: 6 additions & 6 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
mem_size=mem_size,
)

meta_infos: list[LabelInfo] = []
label_infos: list[LabelInfo] = []
for name, dm_subset in dataset.subsets().items():
if name not in config_mapping:
log.warning(f"{name} is not available. Skip it")
Expand All @@ -114,18 +114,18 @@ def __init__(
)
self.subsets[name] = dataset

meta_infos += [self.subsets[name].meta_info]
label_infos += [self.subsets[name].label_info]
log.info(f"Add name: {name}, self.subsets: {self.subsets}")

if self._is_meta_info_valid(meta_infos) is False:
if self._is_meta_info_valid(label_infos) is False:
msg = "All data meta infos of subsets should be the same."
raise ValueError(msg)

self.meta_info = next(iter(meta_infos))
self.label_info = next(iter(label_infos))

def _is_meta_info_valid(self, meta_infos: list[LabelInfo]) -> bool:
def _is_meta_info_valid(self, label_infos: list[LabelInfo]) -> bool:
"""Check whether there are mismatches in the metainfo for the all subsets."""
if all(meta_info == meta_infos[0] for meta_info in meta_infos):
if all(label_info == label_infos[0] for label_info in label_infos):
return True
return False

Expand Down
4 changes: 0 additions & 4 deletions src/otx/core/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,4 @@

from torchmetrics import Metric

from .accuracy import HLabelAccuracy

MetricCallable = Union[Callable[[], Metric], Callable[[int], Metric]]

__all__ = ["HLabelAccuracy"]
Loading

0 comments on commit f9aeb73

Please sign in to comment.