Skip to content

Commit

Permalink
Split export function and _export_model function for detection task
Browse files Browse the repository at this point in the history
  • Loading branch information
jaegukhyun committed Apr 11, 2023
1 parent b63980e commit c26c241
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 56 deletions.
58 changes: 6 additions & 52 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from copy import deepcopy
from typing import Any, Dict, Optional, Union

import numpy as np
import torch
from mmcv.runner import wrap_fp16_model
from mmcv.utils import Config, ConfigDict, get_git_hash
Expand Down Expand Up @@ -52,7 +51,6 @@
from otx.algorithms.common.utils import set_random_seed
from otx.algorithms.common.utils.callback import InferenceProgressCallback
from otx.algorithms.common.utils.data import get_dataset
from otx.algorithms.common.utils.ir import embed_ir_model_data
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.detection.adapters.mmdet.configurer import (
DetectionConfigurer,
Expand All @@ -69,24 +67,20 @@
)
from otx.algorithms.detection.adapters.mmdet.utils.exporter import DetectionExporter
from otx.algorithms.detection.task import OTXDetectionTask
from otx.algorithms.detection.utils import get_det_model_api_configuration
from otx.algorithms.detection.utils.data import adaptive_tile_params
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.model import (
ModelEntity,
ModelFormat,
ModelOptimizationType,
ModelPrecision,
)
from otx.api.entities.subset import Subset
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.entities.train_parameters import default_progress_callback
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
from otx.core.data import caching

logger = get_logger()
Expand Down Expand Up @@ -461,21 +455,12 @@ def hook(module, inp, outp):
return prediction_results, metric

# pylint: disable=too-many-statements
def export(
def _export_model(
self,
export_type: ExportType,
output_model: ModelEntity,
precision: ModelPrecision = ModelPrecision.FP32,
dump_features: bool = True,
precision: ModelPrecision,
dump_features: bool,
):
"""Export function of OTX Detection Task."""
# copied from OTX inference_task.py
logger.info("Exporting the model")
if export_type != ExportType.OPENVINO:
raise RuntimeError(f"not supported export type {export_type}")
output_model.model_format = ModelFormat.OPENVINO
output_model.optimization_type = ModelOptimizationType.MO

"""Main export function of OTX MMDetection Task."""
self._init_task(export=True)

cfg = self.configure(False, "test", None)
Expand Down Expand Up @@ -506,38 +491,7 @@ def export(
**export_options,
)

outputs = results.get("outputs")
logger.debug(f"results of run_task = {outputs}")
if outputs is None:
raise RuntimeError(results.get("msg"))

bin_file = outputs.get("bin")
xml_file = outputs.get("xml")

ir_extra_data = get_det_model_api_configuration(
self._task_environment.label_schema, self._task_type, self.confidence_threshold
)
embed_ir_model_data(xml_file, ir_extra_data)

if xml_file is None or bin_file is None:
raise RuntimeError("invalid status of exporting. bin and xml should not be None")
with open(bin_file, "rb") as f:
output_model.set_data("openvino.bin", f.read())
with open(xml_file, "rb") as f:
output_model.set_data("openvino.xml", f.read())
output_model.set_data(
"confidence_threshold",
np.array([self.confidence_threshold], dtype=np.float32).tobytes(),
)
output_model.set_data("config.json", config_to_bytes(self._hyperparams))
output_model.precision = self._precision
output_model.optimization_methods = self._optimization_methods
output_model.has_xai = dump_features
output_model.set_data(
"label_schema.json",
label_schema_to_bytes(self._task_environment.label_schema),
)
logger.info("Exporting completed")
return results

def explain(
self,
Expand Down
57 changes: 53 additions & 4 deletions otx/algorithms/detection/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
InferenceProgressCallback,
TrainingProgressCallback,
)
from otx.algorithms.common.utils.ir import embed_ir_model_data
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.detection.configs.base import DetectionConfig
from otx.algorithms.detection.utils import get_det_model_api_configuration
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings
from otx.api.entities.annotation import Annotation
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
Expand All @@ -50,7 +52,12 @@
ScoreMetric,
VisualizationType,
)
from otx.api.entities.model import ModelEntity, ModelPrecision
from otx.api.entities.model import (
ModelEntity,
ModelFormat,
ModelOptimizationType,
ModelPrecision,
)
from otx.api.entities.model_template import TaskType
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.scored_label import ScoredLabel
Expand Down Expand Up @@ -245,15 +252,57 @@ def _infer_model(
"""Get inference results from dataset."""
raise NotImplementedError

@abstractmethod
def export(
self,
export_type: ExportType,
output_model: ModelEntity,
precision: ModelPrecision = ModelPrecision.FP32,
dump_features: bool = True,
):
"""Export function of OTX Task."""
"""Export function of OTX Detection Task."""
logger.info("Exporting the model")
if export_type != ExportType.OPENVINO:
raise RuntimeError(f"not supported export type {export_type}")
output_model.model_format = ModelFormat.OPENVINO
output_model.optimization_type = ModelOptimizationType.MO

results = self._export_model(precision, dump_features)
outputs = results.get("outputs")
logger.debug(f"results of run_task = {outputs}")
if outputs is None:
raise RuntimeError(results.get("msg"))

bin_file = outputs.get("bin")
xml_file = outputs.get("xml")

ir_extra_data = get_det_model_api_configuration(
self._task_environment.label_schema, self._task_type, self.confidence_threshold
)
embed_ir_model_data(xml_file, ir_extra_data)

if xml_file is None or bin_file is None:
raise RuntimeError("invalid status of exporting. bin and xml should not be None")
with open(bin_file, "rb") as f:
output_model.set_data("openvino.bin", f.read())
with open(xml_file, "rb") as f:
output_model.set_data("openvino.xml", f.read())
output_model.set_data(
"confidence_threshold",
np.array([self.confidence_threshold], dtype=np.float32).tobytes(),
)
output_model.set_data("config.json", config_to_bytes(self._hyperparams))
output_model.precision = self._precision
output_model.optimization_methods = self._optimization_methods
output_model.has_xai = dump_features
output_model.set_data(
"label_schema.json",
label_schema_to_bytes(self._task_environment.label_schema),
)
logger.info("Exporting completed")

@abstractmethod
def _export_model(self, precision: ModelPrecision, dump_features: bool):
"""Main export function using training backend."""
raise NotImplementedError

@abstractmethod
Expand Down

0 comments on commit c26c241

Please sign in to comment.