Skip to content

Commit

Permalink
Add dumping of hierarchical config on export (#2868)
Browse files Browse the repository at this point in the history
Add dumping of hierarhical config on export
  • Loading branch information
sovrasov authored Feb 7, 2024
1 parent 356509f commit e74a9f4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
16 changes: 14 additions & 2 deletions src/otx/core/model/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch

from otx.core.data.dataset.classification import HLabelMetaInfo
from otx.core.data.entity.base import OTXBatchLossEntity, T_OTXBatchDataEntity, T_OTXBatchPredEntity
from otx.core.data.entity.classification import (
HlabelClsBatchDataEntity,
Expand Down Expand Up @@ -347,8 +348,19 @@ def _export_parameters(self) -> dict[str, Any]:
"""Defines parameters required to export a particular model implementation."""
parameters = super()._export_parameters
hierarchical_config: dict = {}
hierarchical_config["cls_heads_info"] = {}
hierarchical_config["label_tree_edges"] = []

label_info: HLabelMetaInfo = self.label_info # type: ignore[assignment]
hierarchical_config["cls_heads_info"] = {
"num_multiclass_heads": label_info.hlabel_info.num_multiclass_heads,
"num_multilabel_classes": label_info.hlabel_info.num_multilabel_classes,
"head_idx_to_logits_range": label_info.hlabel_info.head_idx_to_logits_range,
"num_single_label_classes": label_info.hlabel_info.num_single_label_classes,
"class_to_group_idx": label_info.hlabel_info.class_to_group_idx,
"all_groups": label_info.hlabel_info.all_groups,
"label_to_idx": label_info.hlabel_info.label_to_idx,
"empty_multiclass_head_indices": label_info.hlabel_info.empty_multiclass_head_indices,
}
hierarchical_config["label_tree_edges"] = label_info.hlabel_info.label_tree_edges

parameters["metadata"].update(
{
Expand Down
3 changes: 2 additions & 1 deletion src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ def export(
)
loaded_checkpoint = torch.load(ckpt_path)
lit_module.meta_info = loaded_checkpoint["state_dict"]["meta_info"]
# self.model.label_info = lit_module.meta_info # this doesn't work for some models yet
self.model.label_info = lit_module.meta_info

lit_module.load_state_dict(loaded_checkpoint)

return self.model.export(
Expand Down

0 comments on commit e74a9f4

Please sign in to comment.