Skip to content

Commit

Permalink
Add configuration to segmentation ir
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Apr 18, 2023
1 parent 6cbac46 commit d2f21cc
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 2 deletions.
1 change: 0 additions & 1 deletion otx/algorithms/segmentation/adapters/mmseg/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
from otx.algorithms.segmentation.adapters.mmseg.utils.exporter import SegmentationExporter
from otx.algorithms.segmentation.task import OTXSegmentationTask

# from otx.algorithms.segmentation.utils import get_det_model_api_configuration
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.entities.datasets import DatasetEntity
Expand Down
2 changes: 1 addition & 1 deletion otx/algorithms/segmentation/configs/base/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
class SegmentationConfig(BaseConfig):
"""Configurations of OTX Segmentation."""

header = string_attribute("Configuration for an object detection task of MPA")
header = string_attribute("Configuration for an object semantic segmentation task of OTX")
description = header

@attrs
Expand Down
7 changes: 7 additions & 0 deletions otx/algorithms/segmentation/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
InferenceProgressCallback,
TrainingProgressCallback,
)
from otx.algorithms.common.utils.ir import embed_ir_model_data
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.segmentation.adapters.openvino.model_wrappers.blur import (
get_activation_map,
)
from otx.algorithms.segmentation.configs.base import SegmentationConfig
from otx.algorithms.segmentation.utils.metadata import get_seg_model_api_configuration
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.entities.datasets import DatasetEntity
Expand Down Expand Up @@ -228,6 +230,11 @@ def export(
xml_file = outputs.get("xml")
onnx_file = outputs.get("onnx")

ir_extra_data = get_seg_model_api_configuration(
self._task_environment.label_schema, self._task_type, self._hyperparams
)
embed_ir_model_data(xml_file, ir_extra_data)

if xml_file is None or bin_file is None or onnx_file is None:
raise RuntimeError("invalid status of exporting. bin and xml or onnx should not be None")
with open(bin_file, "rb") as f:
Expand Down
3 changes: 3 additions & 0 deletions otx/algorithms/segmentation/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.


from .metadata import get_seg_model_api_configuration
37 changes: 37 additions & 0 deletions otx/algorithms/segmentation/utils/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Utils for hadnling metadata of segmentation models."""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.


from mmcv.utils import ConfigDict
from otx.api.entities.label_schema import LabelSchemaEntity


def get_seg_model_api_configuration(label_schema: LabelSchemaEntity, hyperparams: ConfigDict):
"""Get ModelAPI config."""
omz_config = {}
omz_config[("model_info", "model_type")] = "segmentation"

omz_config[("model_info", "soft_threshold")] = hyperparams.postprocessing.soft_threshold
omz_config[("model_info", "blur_strength")] = hyperparams.postprocessing.blur_strength

all_labels = ""
for lbl in label_schema.get_labels(include_empty=False):
all_labels += lbl.name.replace(" ", "_") + " "
all_labels = all_labels.strip()

omz_config[("model_info", "labels")] = all_labels

return omz_config

0 comments on commit d2f21cc

Please sign in to comment.