Skip to content

Commit

Permalink
Add h-cls label info normalization (#4173)
Browse files Browse the repository at this point in the history
* Normalize h_cls label info on export

* Add unit test

* Upda copyright

* Update changelog
  • Loading branch information
sovrasov authored Jan 9, 2025
1 parent c3f5e02 commit b9debee
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4028>)
- Fix SupCon flag
(https://github.com/openvinotoolkit/training_extensions/pull/4076)
- Add h-cls label info normalization
(<https://github.com/openvinotoolkit/training_extensions/pull/4173>)
- Fix arrow support for semantic segmentation task
(<https://github.com/openvinotoolkit/training_extensions/pull/4172>)

## \[2.2.2\]

Expand Down
7 changes: 4 additions & 3 deletions src/otx/core/types/export.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""OTX export-related types definition."""
Expand Down Expand Up @@ -123,10 +123,11 @@ def to_metadata(self) -> dict[tuple[str, str], str]:
}

if isinstance(self.label_info, HLabelInfo):
dict_info = self.label_info.as_dict(normalize_label_names=True)
metadata[("model_info", "hierarchical_config")] = json.dumps(
{
"cls_heads_info": self.label_info.as_dict(),
"label_tree_edges": self.label_info.label_tree_edges,
"cls_heads_info": dict_info,
"label_tree_edges": dict_info["label_tree_edges"],
},
)

Expand Down
25 changes: 22 additions & 3 deletions src/otx/core/types/label.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Dataclasses for label information."""
Expand Down Expand Up @@ -119,9 +119,28 @@ def from_dm_label_groups_arrow(cls, dm_label_categories: LabelCategories) -> Lab
label_ids=label_ids,
)

def as_dict(self) -> dict[str, Any]:
def as_dict(self, normalize_label_names: bool = False) -> dict[str, Any]:
"""Return a dictionary including all params."""
return asdict(self)
result = asdict(self)

if normalize_label_names:

def normalize_fn(node: str | list | tuple | dict | int) -> str | list | tuple | dict | int:
"""Normalizes the label names stored in various nested structures."""
if isinstance(node, str):
return node.replace(" ", "_")
if isinstance(node, list):
return [normalize_fn(item) for item in node]
if isinstance(node, tuple):
return tuple(normalize_fn(item) for item in node)
if isinstance(node, dict):
return {normalize_fn(key): normalize_fn(value) for key, value in node.items()}
return node

for k in result:
result[k] = normalize_fn(result[k])

return result

def to_json(self) -> str:
"""Return JSON serialized string."""
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/core/types/test_label.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

Expand Down Expand Up @@ -74,6 +74,11 @@ def test_hlabel_info():

hlabel_info = HLabelInfo.from_dm_label_groups(dm_label_categories)

# check if label info can be normalized on export
dict_label_info = hlabel_info.as_dict(normalize_label_names=True)
for lbl in dict_label_info["label_names"]:
assert " " not in lbl

# Check if class_to_group_idx and label_to_idx have the same keys
assert list(hlabel_info.class_to_group_idx.keys()) == list(
hlabel_info.label_to_idx.keys(),
Expand Down

0 comments on commit b9debee

Please sign in to comment.