Skip to content

Commit

Permalink
Fix label list order for h-label classification (#2440)
Browse files Browse the repository at this point in the history
* Fix label list for h-label cls
* Fix unit tests
  • Loading branch information
GalyaZalesskaya authored Aug 21, 2023
1 parent 6b09e65 commit 4f9c2f1
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 4 deletions.
17 changes: 15 additions & 2 deletions src/otx/algorithms/classification/adapters/openvino/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from otx.algorithms.classification.utils import (
get_cls_deploy_config,
get_cls_inferencer_configuration,
get_hierarchical_label_list,
)
from otx.algorithms.common.utils import OTXOpenVinoDataLoader
from otx.algorithms.common.utils.ir import check_if_quantized
Expand Down Expand Up @@ -228,12 +229,18 @@ def add_prediction(id: int, predicted_scene: AnnotationSceneEntity, aux_data: tu
if saliency_map is not None and repr_vector is not None:
feature_vec_media = TensorEntity(name="representation_vector", numpy=repr_vector.reshape(-1))
dataset_item.append_metadata_item(feature_vec_media, model=self.model)
label_list = self.task_environment.get_labels()
# Fix the order for hierarchical labels to adjust classes with model outputs
if self.inferencer.model.hierarchical:
label_list = get_hierarchical_label_list(
self.inferencer.model.hierarchical_info["cls_heads_info"], label_list
)

add_saliency_maps_to_dataset_item(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self.model,
labels=self.task_environment.get_labels(),
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
Expand Down Expand Up @@ -284,6 +291,12 @@ def explain(
explain_predicted_classes = explain_parameters.explain_predicted_classes

dataset_size = len(dataset)
label_list = self.task_environment.get_labels()
# Fix the order for hierarchical labels to adjust classes with model outputs
if self.inferencer.model.hierarchical:
label_list = get_hierarchical_label_list(
self.inferencer.model.hierarchical_info["cls_heads_info"], label_list
)
for i, dataset_item in enumerate(dataset, 1):
predicted_scene, _, saliency_map, _, _ = self.inferencer.predict(dataset_item.numpy)
if saliency_map is None:
Expand All @@ -298,7 +311,7 @@ def explain(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self.model,
labels=self.task_environment.get_labels(),
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
Expand Down
13 changes: 11 additions & 2 deletions src/otx/algorithms/classification/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
get_cls_deploy_config,
get_cls_inferencer_configuration,
get_cls_model_api_configuration,
get_hierarchical_label_list,
)
from otx.algorithms.classification.utils import (
get_multihead_class_info as get_hierarchical_info,
Expand Down Expand Up @@ -345,6 +346,10 @@ def _add_predictions_to_dataset(

dataset_size = len(dataset)
pos_thr = 0.5
label_list = self._labels
# Fix the order for hierarchical labels to adjust classes with model outputs
if self._hierarchical:
label_list = get_hierarchical_label_list(self._hierarchical_info, label_list)
for i, (dataset_item, prediction_items) in enumerate(zip(dataset, prediction_results)):
prediction_item, feature_vector, saliency_map = prediction_items
if any(np.isnan(prediction_item)):
Expand Down Expand Up @@ -373,7 +378,7 @@ def _add_predictions_to_dataset(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self._task_environment.model,
labels=self._labels,
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
Expand Down Expand Up @@ -436,13 +441,17 @@ def _add_explanations_to_dataset(
):
"""Loop over dataset again and assign saliency maps."""
dataset_size = len(dataset)
label_list = self._labels
# Fix the order for hierarchical labels to adjust classes with model outputs
if self._hierarchical:
label_list = get_hierarchical_label_list(self._hierarchical_info, label_list)
for i, (dataset_item, prediction_item, saliency_map) in enumerate(zip(dataset, predictions, saliency_maps)):
item_labels = self._get_item_labels(prediction_item, pos_thr=0.5)
add_saliency_maps_to_dataset_item(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self._task_environment.model,
labels=self._labels,
labels=label_list,
predicted_scored_labels=item_labels,
explain_predicted_classes=explain_predicted_classes,
process_saliency_maps=process_saliency_maps,
Expand Down
2 changes: 2 additions & 0 deletions src/otx/algorithms/classification/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
get_cls_deploy_config,
get_cls_inferencer_configuration,
get_cls_model_api_configuration,
get_hierarchical_label_list,
get_multihead_class_info,
)

__all__ = [
"get_hierarchical_label_list",
"get_multihead_class_info",
"get_cls_inferencer_configuration",
"get_cls_deploy_config",
Expand Down
21 changes: 21 additions & 0 deletions src/otx/algorithms/classification/utils/cls_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,24 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c

mapi_config[("model_info", "hierarchical_config")] = json.dumps(hierarchical_config)
return mapi_config


def get_hierarchical_label_list(hierarchical_info, labels):
"""Return hierarchical labels list which is adjusted to model outputs classes."""
hierarchical_labels = []
for head_idx in range(hierarchical_info["num_multiclass_heads"]):
logits_begin, logits_end = hierarchical_info["head_idx_to_logits_range"][str(head_idx)]
for logit in range(0, logits_end - logits_begin):
label_str = hierarchical_info["all_groups"][head_idx][logit]
label_idx = hierarchical_info["label_to_idx"][label_str]
hierarchical_labels.append(labels[label_idx])

if hierarchical_info["num_multilabel_classes"]:
logits_begin = hierarchical_info["num_single_label_classes"]
logits_end = len(labels)
for logit_idx, logit in enumerate(range(0, logits_end - logits_begin)):
label_str_idx = hierarchical_info["num_multiclass_heads"] + logit_idx
label_str = hierarchical_info["all_groups"][label_str_idx][0]
label_idx = hierarchical_info["label_to_idx"][label_str]
hierarchical_labels.append(labels[label_idx])
return hierarchical_labels
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def test_explain(self, mocker):
self.fake_input,
),
)
self.cls_ov_task.inferencer.model.hierarchical = False
updpated_dataset = self.cls_ov_task.explain(self.dataset)

assert updpated_dataset is not None
Expand Down

0 comments on commit 4f9c2f1

Please sign in to comment.