Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wrong indices setting in HLabelInfo #4044

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4016>)
- Fix multilabel_accuracy of MixedHLabelAccuracy
(<https://github.com/openvinotoolkit/training_extensions/pull/4042>)
- Fix wrong indices setting in HLabelInfo
(<https://github.com/openvinotoolkit/training_extensions/pull/4044>)

## \[v2.1.0\]

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ In addition to the examples above, please refer to the documentation for tutoria
- Fix num_trials calculation on dataset length less than num_class
- Fix out_features in HierarchicalCBAMClsHead
- Fix multilabel_accuracy of MixedHLabelAccuracy
- Fix wrong indices setting in HLabelInfo

### Known issues

Expand Down
1 change: 1 addition & 0 deletions docs/source/guide/release_notes/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Bug fixes
- Fix num_trials calculation on dataset length less than num_class
- Fix out_features in HierarchicalCBAMClsHead
- Fix multilabel_accuracy of MixedHLabelAccuracy
- Fix wrong indices setting in HLabelInfo

v2.1.0 (2024.07)
----------------
Expand Down
2 changes: 1 addition & 1 deletion src/otx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

__version__ = "2.2.0rc8"
__version__ = "2.2.0rc9"

import os
from pathlib import Path
Expand Down
4 changes: 3 additions & 1 deletion src/otx/core/types/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def convert_labels_if_needed(
single_label_group_info["class_to_idx"],
)

label_to_idx = {lbl: i for i, lbl in enumerate(merged_class_to_idx.keys())}

return HLabelInfo(
label_names=label_names,
label_groups=all_groups,
Expand All @@ -273,7 +275,7 @@ def convert_labels_if_needed(
num_single_label_classes=exclusive_group_info["num_single_label_classes"],
class_to_group_idx=merged_class_to_idx,
all_groups=all_groups,
label_to_idx=dm_label_categories._indices, # noqa: SLF001
label_to_idx=label_to_idx,
label_tree_edges=get_label_tree_edges(dm_label_categories.items),
empty_multiclass_head_indices=[], # consider the label removing case
)
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/core/types/test_label.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from otx.core.types.label import NullLabelInfo, SegLabelInfo
from datumaro import LabelCategories
from datumaro.components.annotation import GroupType
from otx.core.types.label import HLabelInfo, NullLabelInfo, SegLabelInfo


def test_as_json(fxt_label_info):
Expand All @@ -18,3 +21,34 @@ def test_seg_label_info():
)
assert SegLabelInfo.from_num_classes(1) == SegLabelInfo(["background", "label_0"], [["background", "label_0"]])
assert SegLabelInfo.from_num_classes(0) == NullLabelInfo()


# Unit test
def test_hlabel_info():
labels = [
LabelCategories.Category(name="car", parent="vehicle"),
LabelCategories.Category(name="truck", parent="vehicle"),
LabelCategories.Category(name="plush toy", parent="plush toy"),
LabelCategories.Category(name="No class"),
]
label_groups = [
LabelCategories.LabelGroup(
name="Detection labels___vehicle",
labels=["car", "truck"],
group_type=GroupType.EXCLUSIVE,
),
LabelCategories.LabelGroup(
name="Detection labels___plush toy",
labels=["plush toy"],
group_type=GroupType.EXCLUSIVE,
),
LabelCategories.LabelGroup(name="No class", labels=["No class"], group_type=GroupType.RESTRICTED),
]
dm_label_categories = LabelCategories(items=labels, label_groups=label_groups)

hlabel_info = HLabelInfo.from_dm_label_groups(dm_label_categories)

# Check if class_to_group_idx and label_to_idx have the same keys
assert list(hlabel_info.class_to_group_idx.keys()) == list(
hlabel_info.label_to_idx.keys(),
), "class_to_group_idx and label_to_idx keys do not match"
Loading