From df131920ea356a405b51351e84c4016e35c88106 Mon Sep 17 00:00:00 2001 From: sooahleex Date: Thu, 5 Sep 2024 14:29:30 +0900 Subject: [PATCH] Modify label_groups for dm_label_categories with id/key of label --- src/otx/core/types/label.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/otx/core/types/label.py b/src/otx/core/types/label.py index cd472965336..7f00aa0b496 100644 --- a/src/otx/core/types/label.py +++ b/src/otx/core/types/label.py @@ -229,7 +229,32 @@ def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str """Get label tree edges information. Each edges represent [child, parent].""" return [[item.name, item.parent] for item in dm_label_items if item.parent != ""] - all_groups = [label_group.labels for label_group in dm_label_categories.label_groups] + def convert_labels_if_needed( + dm_label_categories: LabelCategories, + label_names: list[str], + ) -> list[list[str]]: + # Check if the labels need conversion and create name to ID mapping if required + name_to_id_mapping = None + for label_group in dm_label_categories.label_groups: + if label_group.labels and label_group.labels[0] not in label_names: + name_to_id_mapping = { + attr[len("__name__") :]: category.name + for category in dm_label_categories.items + for attr in category.attributes + if attr.startswith("__name__") + } + break + + # If mapping exists, update the labels + if name_to_id_mapping: + for label_group in dm_label_categories.label_groups: + label_group.labels = [name_to_id_mapping.get(label, label) for label in label_group.labels] + + # Retrieve all label groups after conversion + return [group.labels for group in dm_label_categories.label_groups] + + label_names = [item.name for item in dm_label_categories.items] + all_groups = convert_labels_if_needed(dm_label_categories, label_names) exclusive_group_info = get_exclusive_group_info(all_groups) single_label_group_info = get_single_label_group_info(all_groups, exclusive_group_info["num_multiclass_heads"]) @@ -240,7 +265,7 @@ def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str ) return HLabelInfo( - label_names=[item.name for item in dm_label_categories.items], + label_names=label_names, label_groups=all_groups, num_multiclass_heads=exclusive_group_info["num_multiclass_heads"], num_multilabel_classes=single_label_group_info["num_multilabel_classes"],