Skip to content

Commit

Permalink
Update ModelAPI configuration (#2564 from 1.4) (#2568)
Browse files Browse the repository at this point in the history
Update ModelAPI configuration (#2564)

* Update MAPI rt infor for detection

* Upadte export info for cls, det and seg

* Update unit tests
  • Loading branch information
sovrasov authored Oct 24, 2023
1 parent 65ddbfa commit 0981035
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 7 deletions.
8 changes: 6 additions & 2 deletions src/otx/algorithms/classification/utils/cls_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,20 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
"""Get ModelAPI config."""
mapi_config = {}
mapi_config[("model_info", "model_type")] = "Classification"
mapi_config[("model_info", "task_type")] = "classification"
mapi_config[("model_info", "confidence_threshold")] = str(inference_config["confidence_threshold"])
mapi_config[("model_info", "multilabel")] = str(inference_config["multilabel"])
mapi_config[("model_info", "hierarchical")] = str(inference_config["hierarchical"])
mapi_config[("model_info", "output_raw_scores")] = str(True)

all_labels = ""
all_label_ids = ""
for lbl in label_schema.get_labels(include_empty=False):
all_labels += lbl.name.replace(" ", "_") + " "
all_labels = all_labels.strip()
mapi_config[("model_info", "labels")] = all_labels
all_label_ids += f"{lbl.id_} "

mapi_config[("model_info", "labels")] = all_labels.strip()
mapi_config[("model_info", "label_ids")] = all_label_ids.strip()

hierarchical_config = {}
hierarchical_config["cls_heads_info"] = get_multihead_class_info(label_schema)
Expand Down
13 changes: 10 additions & 3 deletions src/otx/algorithms/detection/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,22 @@ def get_det_model_api_configuration(
"""Get ModelAPI config."""
omz_config = {}
all_labels = ""
all_label_ids = ""
if task_type == TaskType.DETECTION:
omz_config[("model_info", "model_type")] = "ssd"
omz_config[("model_info", "task_type")] = "detection"
if task_type == TaskType.INSTANCE_SEGMENTATION:
omz_config[("model_info", "model_type")] = "MaskRCNN"
omz_config[("model_info", "task_type")] = "instance_segmentation"
all_labels = "otx_empty_lbl "
all_label_ids = "None "
if tiling_parameters.enable_tiling:
omz_config[("model_info", "resize_type")] = "fit_to_window_letterbox"
if task_type == TaskType.ROTATED_DETECTION:
omz_config[("model_info", "model_type")] = "rotated_detection"
omz_config[("model_info", "model_type")] = "MaskRCNN"
omz_config[("model_info", "task_type")] = "rotated_detection"
all_labels = "otx_empty_lbl "
all_label_ids = "None "
if tiling_parameters.enable_tiling:
omz_config[("model_info", "resize_type")] = "fit_to_window_letterbox"

Expand All @@ -137,9 +143,10 @@ def get_det_model_api_configuration(

for lbl in label_schema.get_labels(include_empty=False):
all_labels += lbl.name.replace(" ", "_") + " "
all_labels = all_labels.strip()
all_label_ids += f"{lbl.id_} "

omz_config[("model_info", "labels")] = all_labels
omz_config[("model_info", "labels")] = all_labels.strip()
omz_config[("model_info", "label_ids")] = all_label_ids.strip()

return omz_config

Expand Down
7 changes: 5 additions & 2 deletions src/otx/algorithms/segmentation/utils/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
def get_seg_model_api_configuration(label_schema: LabelSchemaEntity, hyperparams: ConfigDict):
"""Get ModelAPI config."""
all_labels = ""
all_label_ids = ""
for lbl in label_schema.get_labels(include_empty=False):
all_labels += lbl.name.replace(" ", "_") + " "
all_labels = all_labels.strip()
all_label_ids += f"{lbl.id_} "

return {
("model_info", "model_type"): "Segmentation",
("model_info", "soft_threshold"): str(hyperparams.postprocessing.soft_threshold),
("model_info", "blur_strength"): str(hyperparams.postprocessing.blur_strength),
("model_info", "labels"): all_labels,
("model_info", "return_soft_prediction"): "True",
("model_info", "labels"): all_labels.strip(),
("model_info", "label_ids"): all_label_ids.strip(),
("model_info", "task_type"): "segmentation",
}
4 changes: 4 additions & 0 deletions tests/unit/algorithms/classification/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,7 @@ def test_get_cls_model_api_configuration(default_hierarchical_data):
assert len(model_api_cfg) > 0
assert model_api_cfg[("model_info", "confidence_threshold")] == str(config["confidence_threshold"])
assert ("model_info", "hierarchical_config") in model_api_cfg
assert ("model_info", "labels") in model_api_cfg
assert ("model_info", "label_ids") in model_api_cfg
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "labels")].split())
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "label_ids")].split())
4 changes: 4 additions & 0 deletions tests/unit/algorithms/detection/utils/test_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ def test_get_det_model_api_configuration():
tiling_parameters.tile_overlap / tiling_parameters.tile_ir_scale_factor
)
assert model_api_cfg[("model_info", "max_pred_number")] == str(tiling_parameters.tile_max_number)
assert ("model_info", "labels") in model_api_cfg
assert ("model_info", "label_ids") in model_api_cfg
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "labels")].split())
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "label_ids")].split())

0 comments on commit 0981035

Please sign in to comment.