Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix label_to_idx for hierarchical classification #2906

Merged
merged 7 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@

self.model = Model.create_model(model_adapter, "otx_classification", self.configuration, preload=True)

self.converter = ClassificationToAnnotationConverter(self.label_schema)
if self.model.hierarchical:
hierarchical_info = self.model.hierarchical_info["cls_heads_info"]
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved
else:
hierarchical_info = None

Check warning on line 132 in src/otx/algorithms/classification/adapters/openvino/task.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/classification/adapters/openvino/task.py#L132

Added line #L132 was not covered by tests
self.converter = ClassificationToAnnotationConverter(self.label_schema, hierarchical_info)
self.callback_exceptions: List[Exception] = []
self.model.inference_adapter.set_callback(self._async_callback)

Expand Down
39 changes: 19 additions & 20 deletions src/otx/algorithms/classification/utils/cls_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

import json
from operator import itemgetter
from typing import Any, Dict
from typing import Any, Dict, List

from otx.api.entities.label import LabelEntity
from otx.api.entities.label_schema import LabelSchemaEntity
from otx.api.serialization.label_mapper import LabelSchemaMapper

Expand Down Expand Up @@ -51,8 +52,8 @@ def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disabl
for j, group in enumerate(single_label_groups):
class_to_idx[group[0]] = (len(exclusive_groups), j)

all_labels = label_schema.get_labels(include_empty=False)
label_to_idx = {lbl.name: i for i, lbl in enumerate(all_labels)}
# Idx of label corresponds to model output
label_to_idx = {lbl: i for i, lbl in enumerate(class_to_idx.keys())}

mixed_cls_heads_info = {
"num_multiclass_heads": len(exclusive_groups),
Expand Down Expand Up @@ -104,9 +105,13 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
mapi_config[("model_info", "hierarchical")] = str(inference_config["hierarchical"])
mapi_config[("model_info", "output_raw_scores")] = str(True)

label_entities = label_schema.get_labels(include_empty=False)
if inference_config["hierarchical"]:
label_entities = get_hierarchical_label_list(inference_config["multihead_class_info"], label_entities)

all_labels = ""
all_label_ids = ""
for lbl in label_schema.get_labels(include_empty=False):
for lbl in label_entities:
all_labels += lbl.name.replace(" ", "_") + " "
all_label_ids += f"{lbl.id_} "

Expand All @@ -123,22 +128,16 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
return mapi_config


def get_hierarchical_label_list(hierarchical_info, labels):
def get_hierarchical_label_list(hierarchical_info: Dict, labels: List) -> List[LabelEntity]:
"""Return hierarchical labels list which is adjusted to model outputs classes."""

# Create the list of Label Entities (took from "labels")
# corresponding to names and order in "label_to_idx"
label_to_idx = hierarchical_info["label_to_idx"]
hierarchical_labels = []
for head_idx in range(hierarchical_info["num_multiclass_heads"]):
logits_begin, logits_end = hierarchical_info["head_idx_to_logits_range"][str(head_idx)]
for logit in range(0, logits_end - logits_begin):
label_str = hierarchical_info["all_groups"][head_idx][logit]
label_idx = hierarchical_info["label_to_idx"][label_str]
hierarchical_labels.append(labels[label_idx])

if hierarchical_info["num_multilabel_classes"]:
logits_begin = hierarchical_info["num_single_label_classes"]
logits_end = len(labels)
for logit_idx, logit in enumerate(range(0, logits_end - logits_begin)):
label_str_idx = hierarchical_info["num_multiclass_heads"] + logit_idx
label_str = hierarchical_info["all_groups"][label_str_idx][0]
label_idx = hierarchical_info["label_to_idx"][label_str]
hierarchical_labels.append(labels[label_idx])
for label_str, _ in label_to_idx.items():
for label_entity in labels:
if label_entity.name == label_str:
hierarchical_labels.append(label_entity)
break
return hierarchical_labels
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from openvino.model_api.models import utils
from openvino.model_api.models.utils import AnomalyResult

from otx.algorithms.classification.utils import get_hierarchical_label_list
from otx.api.entities.annotation import (
Annotation,
AnnotationSceneEntity,
Expand Down Expand Up @@ -171,7 +172,11 @@
elif converter_type == Domain.SEGMENTATION:
converter = SegmentationToAnnotationConverter(labels)
elif converter_type == Domain.CLASSIFICATION:
converter = ClassificationToAnnotationConverter(labels)
if configuration is not None and configuration.get("hierarchical", False):
hierarchical_info = configuration["multihead_class_info"]

Check warning on line 176 in src/otx/api/usecases/exportable_code/prediction_to_annotation_converter.py

View check run for this annotation

Codecov / codecov/patch

src/otx/api/usecases/exportable_code/prediction_to_annotation_converter.py#L176

Added line #L176 was not covered by tests
else:
hierarchical_info = None
converter = ClassificationToAnnotationConverter(labels, hierarchical_info)
elif converter_type == Domain.ANOMALY_CLASSIFICATION:
converter = AnomalyClassificationToAnnotationConverter(labels)
elif converter_type == Domain.ANOMALY_DETECTION:
Expand Down Expand Up @@ -268,9 +273,10 @@

Args:
labels (LabelSchemaEntity): Label Schema containing the label info of the task
hierarchical_info (Dict): Info from model.hierarchical_info["cls_heads_info"]
"""

def __init__(self, label_schema: LabelSchemaEntity):
def __init__(self, label_schema: LabelSchemaEntity, hierarchical_info: Optional[Dict] = None):
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved
if len(label_schema.get_labels(False)) == 1:
self.labels = label_schema.get_labels(include_empty=True)
else:
Expand All @@ -284,6 +290,9 @@

self.label_schema = label_schema

if self.hierarchical:
self.labels = get_hierarchical_label_list(hierarchical_info, self.labels)

def convert_to_annotation(
self, predictions: List[Tuple[int, float]], metadata: Optional[Dict] = None
) -> AnnotationSceneEntity:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_post_process(self):
}
fake_metadata = {"original_shape": (254, 320, 3), "resized_shape": (224, 224, 3)}
self.cls_ov_inferencer.model.postprocess.return_value = [[0, 0.87], [1, 0.13]]
self.cls_ov_inferencer.model.hierarchical = False
returned_value = self.cls_ov_inferencer.post_process(fake_prediction, fake_metadata)

assert len(returned_value.annotations[0].get_labels()) > 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,10 @@ def test_classification_to_annotation_init(self):
labels=other_non_empty_labels,
)
label_schema = LabelSchemaEntity(label_groups=[label_group, other_label_group])
converter = ClassificationToAnnotationConverter(label_schema=label_schema)
hierarchical_info = {"label_to_idx": {label_0_1.name: 0, label_0_1_1.name: 1, label_0_2.name: 2}}
converter = ClassificationToAnnotationConverter(
label_schema=label_schema, hierarchical_info=hierarchical_info
)
assert not converter.empty_label
assert converter.label_schema == label_schema
assert converter.hierarchical
Expand Down Expand Up @@ -840,7 +843,10 @@ def check_annotation(actual_annotation: Annotation, expected_labels: list):
label_schema = LabelSchemaEntity(label_groups=[label_group, other_label_group])

label_schema.add_child(parent=label_0_1, child=label_0_1_1)
converter = ClassificationToAnnotationConverter(label_schema=label_schema)
hierarchical_info = {"label_to_idx": {label_0_1.name: 0, label_0_1_1.name: 1, label_0_2.name: 2}}
converter = ClassificationToAnnotationConverter(
label_schema=label_schema, hierarchical_info=hierarchical_info
)
predictions = [(2, 0.9), (1, 0.8)]
predictions_to_annotations = converter.convert_to_annotation(predictions)
check_annotation_scene(annotation_scene=predictions_to_annotations, expected_length=1)
Expand Down
Loading