diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b16d2b2e6c..c6a9b8bc820 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,6 +88,8 @@ All notable changes to this project will be documented in this file. () - Fix multilabel_accuracy of MixedHLabelAccuracy () +- Fix wrong indices setting in HLabelInfo + () ## \[v2.1.0\] diff --git a/README.md b/README.md index f185b9c1d96..435415abb0f 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,7 @@ In addition to the examples above, please refer to the documentation for tutoria - Fix num_trials calculation on dataset length less than num_class - Fix out_features in HierarchicalCBAMClsHead - Fix multilabel_accuracy of MixedHLabelAccuracy +- Fix wrong indices setting in HLabelInfo ### Known issues diff --git a/docs/source/guide/release_notes/index.rst b/docs/source/guide/release_notes/index.rst index 965cbb1f377..e0b8dc86383 100644 --- a/docs/source/guide/release_notes/index.rst +++ b/docs/source/guide/release_notes/index.rst @@ -54,6 +54,7 @@ Bug fixes - Fix num_trials calculation on dataset length less than num_class - Fix out_features in HierarchicalCBAMClsHead - Fix multilabel_accuracy of MixedHLabelAccuracy +- Fix wrong indices setting in HLabelInfo v2.1.0 (2024.07) ---------------- diff --git a/src/otx/__init__.py b/src/otx/__init__.py index 87cf0e846ff..6924cfdb186 100644 --- a/src/otx/__init__.py +++ b/src/otx/__init__.py @@ -3,7 +3,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -__version__ = "2.2.0rc8" +__version__ = "2.2.0rc9" import os from pathlib import Path diff --git a/src/otx/core/types/label.py b/src/otx/core/types/label.py index 7f00aa0b496..b9e26b48511 100644 --- a/src/otx/core/types/label.py +++ b/src/otx/core/types/label.py @@ -264,6 +264,8 @@ def convert_labels_if_needed( single_label_group_info["class_to_idx"], ) + label_to_idx = {lbl: i for i, lbl in enumerate(merged_class_to_idx.keys())} + return HLabelInfo( label_names=label_names, label_groups=all_groups, @@ -273,7 +275,7 @@ def convert_labels_if_needed( num_single_label_classes=exclusive_group_info["num_single_label_classes"], class_to_group_idx=merged_class_to_idx, all_groups=all_groups, - label_to_idx=dm_label_categories._indices, # noqa: SLF001 + label_to_idx=label_to_idx, label_tree_edges=get_label_tree_edges(dm_label_categories.items), empty_multiclass_head_indices=[], # consider the label removing case ) diff --git a/tests/unit/core/types/test_label.py b/tests/unit/core/types/test_label.py index 78daec6982e..3ae1ae1f463 100644 --- a/tests/unit/core/types/test_label.py +++ b/tests/unit/core/types/test_label.py @@ -1,7 +1,10 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations -from otx.core.types.label import NullLabelInfo, SegLabelInfo +from datumaro import LabelCategories +from datumaro.components.annotation import GroupType +from otx.core.types.label import HLabelInfo, NullLabelInfo, SegLabelInfo def test_as_json(fxt_label_info): @@ -18,3 +21,34 @@ def test_seg_label_info(): ) assert SegLabelInfo.from_num_classes(1) == SegLabelInfo(["background", "label_0"], [["background", "label_0"]]) assert SegLabelInfo.from_num_classes(0) == NullLabelInfo() + + +# Unit test +def test_hlabel_info(): + labels = [ + LabelCategories.Category(name="car", parent="vehicle"), + LabelCategories.Category(name="truck", parent="vehicle"), + LabelCategories.Category(name="plush toy", parent="plush toy"), + LabelCategories.Category(name="No class"), + ] + label_groups = [ + LabelCategories.LabelGroup( + name="Detection labels___vehicle", + labels=["car", "truck"], + group_type=GroupType.EXCLUSIVE, + ), + LabelCategories.LabelGroup( + name="Detection labels___plush toy", + labels=["plush toy"], + group_type=GroupType.EXCLUSIVE, + ), + LabelCategories.LabelGroup(name="No class", labels=["No class"], group_type=GroupType.RESTRICTED), + ] + dm_label_categories = LabelCategories(items=labels, label_groups=label_groups) + + hlabel_info = HLabelInfo.from_dm_label_groups(dm_label_categories) + + # Check if class_to_group_idx and label_to_idx have the same keys + assert list(hlabel_info.class_to_group_idx.keys()) == list( + hlabel_info.label_to_idx.keys(), + ), "class_to_group_idx and label_to_idx keys do not match"