diff --git a/CHANGELOG.md b/CHANGELOG.md index 965a0b761c..bb157675bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,10 @@ All notable changes to this project will be documented in this file. () - Fix SupCon flag (https://github.com/openvinotoolkit/training_extensions/pull/4076) +- Add h-cls label info normalization + () +- Fix arrow support for semantic segmentation task + () ## \[2.2.2\] diff --git a/src/otx/core/types/export.py b/src/otx/core/types/export.py index a8a511355a..6343c64fff 100644 --- a/src/otx/core/types/export.py +++ b/src/otx/core/types/export.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # """OTX export-related types definition.""" @@ -123,10 +123,11 @@ def to_metadata(self) -> dict[tuple[str, str], str]: } if isinstance(self.label_info, HLabelInfo): + dict_info = self.label_info.as_dict(normalize_label_names=True) metadata[("model_info", "hierarchical_config")] = json.dumps( { - "cls_heads_info": self.label_info.as_dict(), - "label_tree_edges": self.label_info.label_tree_edges, + "cls_heads_info": dict_info, + "label_tree_edges": dict_info["label_tree_edges"], }, ) diff --git a/src/otx/core/types/label.py b/src/otx/core/types/label.py index 19c3ece3bb..f140d2b39b 100644 --- a/src/otx/core/types/label.py +++ b/src/otx/core/types/label.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # """Dataclasses for label information.""" @@ -119,9 +119,28 @@ def from_dm_label_groups_arrow(cls, dm_label_categories: LabelCategories) -> Lab label_ids=label_ids, ) - def as_dict(self) -> dict[str, Any]: + def as_dict(self, normalize_label_names: bool = False) -> dict[str, Any]: """Return a dictionary including all params.""" - return asdict(self) + result = asdict(self) + + if normalize_label_names: + + def normalize_fn(node: str | list | tuple | dict | int) -> str | list | tuple | dict | int: + """Normalizes the label names stored in various nested structures.""" + if isinstance(node, str): + return node.replace(" ", "_") + if isinstance(node, list): + return [normalize_fn(item) for item in node] + if isinstance(node, tuple): + return tuple(normalize_fn(item) for item in node) + if isinstance(node, dict): + return {normalize_fn(key): normalize_fn(value) for key, value in node.items()} + return node + + for k in result: + result[k] = normalize_fn(result[k]) + + return result def to_json(self) -> str: """Return JSON serialized string.""" diff --git a/tests/unit/core/types/test_label.py b/tests/unit/core/types/test_label.py index 7c6d2359b7..e1115e71b8 100644 --- a/tests/unit/core/types/test_label.py +++ b/tests/unit/core/types/test_label.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -74,6 +74,11 @@ def test_hlabel_info(): hlabel_info = HLabelInfo.from_dm_label_groups(dm_label_categories) + # check if label info can be normalized on export + dict_label_info = hlabel_info.as_dict(normalize_label_names=True) + for lbl in dict_label_info["label_names"]: + assert " " not in lbl + # Check if class_to_group_idx and label_to_idx have the same keys assert list(hlabel_info.class_to_group_idx.keys()) == list( hlabel_info.label_to_idx.keys(),