Skip to content

Commit

Permalink
Split export function and _export_model function for detection task (#…
Browse files Browse the repository at this point in the history
…1997)

* Split export function and _export_model function for detection task

* Add basic unit test for detection task

* Update change log

* Reflect changes from #1976

* Fix unit test failure
  • Loading branch information
jaegukhyun authored Apr 12, 2023
1 parent 5652811 commit 10fce79
Show file tree
Hide file tree
Showing 6 changed files with 431 additions and 211 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ All notable changes to this project will be documented in this file.

- Clean up and refactor the output of the OTX CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1946>)
- Enhance DetCon logic and SupCon for semantic segmentation(<https://github.com/openvinotoolkit/training_extensions/pull/1958>)
- Detection task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1955>)
- Classification task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1972>)
- Extend OTX explain CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1941>)
- Segmentation task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1977>)
Expand Down
61 changes: 6 additions & 55 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from copy import deepcopy
from typing import Any, Dict, Optional, Union

import numpy as np
import torch
from mmcv.runner import wrap_fp16_model
from mmcv.utils import Config, ConfigDict, get_git_hash
Expand Down Expand Up @@ -52,7 +51,6 @@
from otx.algorithms.common.utils import set_random_seed
from otx.algorithms.common.utils.callback import InferenceProgressCallback
from otx.algorithms.common.utils.data import get_dataset
from otx.algorithms.common.utils.ir import embed_ir_model_data
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.detection.adapters.mmdet.configurer import (
DetectionConfigurer,
Expand All @@ -69,24 +67,20 @@
)
from otx.algorithms.detection.adapters.mmdet.utils.exporter import DetectionExporter
from otx.algorithms.detection.task import OTXDetectionTask
from otx.algorithms.detection.utils import get_det_model_api_configuration
from otx.algorithms.detection.utils.data import adaptive_tile_params
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.model import (
ModelEntity,
ModelFormat,
ModelOptimizationType,
ModelPrecision,
)
from otx.api.entities.subset import Subset
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.entities.train_parameters import default_progress_callback
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
from otx.core.data import caching

logger = get_logger()
Expand Down Expand Up @@ -461,21 +455,12 @@ def hook(module, inp, outp):
return prediction_results, metric

# pylint: disable=too-many-statements
def export(
def _export_model(
self,
export_type: ExportType,
output_model: ModelEntity,
precision: ModelPrecision = ModelPrecision.FP32,
dump_features: bool = True,
precision: ModelPrecision,
dump_features: bool,
):
"""Export function of OTX Detection Task."""
# copied from OTX inference_task.py
logger.info("Exporting the model")
if export_type != ExportType.OPENVINO:
raise RuntimeError(f"not supported export type {export_type}")
output_model.model_format = ModelFormat.OPENVINO
output_model.optimization_type = ModelOptimizationType.MO

"""Main export function of OTX MMDetection Task."""
self._init_task(export=True)

cfg = self.configure(False, "test", None)
Expand Down Expand Up @@ -506,41 +491,7 @@ def export(
**export_options,
)

outputs = results.get("outputs")
logger.debug(f"results of run_task = {outputs}")
if outputs is None:
raise RuntimeError(results.get("msg"))

bin_file = outputs.get("bin")
xml_file = outputs.get("xml")
onnx_file = outputs.get("onnx")

ir_extra_data = get_det_model_api_configuration(
self._task_environment.label_schema, self._task_type, self.confidence_threshold
)
embed_ir_model_data(xml_file, ir_extra_data)

if xml_file is None or bin_file is None or onnx_file is None:
raise RuntimeError("invalid status of exporting. bin and xml or onnx should not be None")
with open(bin_file, "rb") as f:
output_model.set_data("openvino.bin", f.read())
with open(xml_file, "rb") as f:
output_model.set_data("openvino.xml", f.read())
with open(onnx_file, "rb") as f:
output_model.set_data("model.onnx", f.read())
output_model.set_data(
"confidence_threshold",
np.array([self.confidence_threshold], dtype=np.float32).tobytes(),
)
output_model.set_data("config.json", config_to_bytes(self._hyperparams))
output_model.precision = self._precision
output_model.optimization_methods = self._optimization_methods
output_model.has_xai = dump_features
output_model.set_data(
"label_schema.json",
label_schema_to_bytes(self._task_environment.label_schema),
)
logger.info("Exporting completed")
return results

def explain(
self,
Expand Down
60 changes: 56 additions & 4 deletions otx/algorithms/detection/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
InferenceProgressCallback,
TrainingProgressCallback,
)
from otx.algorithms.common.utils.ir import embed_ir_model_data
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.detection.configs.base import DetectionConfig
from otx.algorithms.detection.utils import get_det_model_api_configuration
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings
from otx.api.entities.annotation import Annotation
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
Expand All @@ -50,7 +52,12 @@
ScoreMetric,
VisualizationType,
)
from otx.api.entities.model import ModelEntity, ModelPrecision
from otx.api.entities.model import (
ModelEntity,
ModelFormat,
ModelOptimizationType,
ModelPrecision,
)
from otx.api.entities.model_template import TaskType
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.scored_label import ScoredLabel
Expand Down Expand Up @@ -246,15 +253,60 @@ def _infer_model(
"""Get inference results from dataset."""
raise NotImplementedError

@abstractmethod
def export(
self,
export_type: ExportType,
output_model: ModelEntity,
precision: ModelPrecision = ModelPrecision.FP32,
dump_features: bool = True,
):
"""Export function of OTX Task."""
"""Export function of OTX Detection Task."""
logger.info("Exporting the model")
if export_type != ExportType.OPENVINO:
raise RuntimeError(f"not supported export type {export_type}")
output_model.model_format = ModelFormat.OPENVINO
output_model.optimization_type = ModelOptimizationType.MO

results = self._export_model(precision, dump_features)
outputs = results.get("outputs")
logger.debug(f"results of run_task = {outputs}")
if outputs is None:
raise RuntimeError(results.get("msg"))

bin_file = outputs.get("bin")
xml_file = outputs.get("xml")
onnx_file = outputs.get("onnx")

ir_extra_data = get_det_model_api_configuration(
self._task_environment.label_schema, self._task_type, self.confidence_threshold
)
embed_ir_model_data(xml_file, ir_extra_data)

if xml_file is None or bin_file is None or onnx_file is None:
raise RuntimeError("invalid status of exporting. bin and xml or onnx should not be None")
with open(bin_file, "rb") as f:
output_model.set_data("openvino.bin", f.read())
with open(xml_file, "rb") as f:
output_model.set_data("openvino.xml", f.read())
with open(onnx_file, "rb") as f:
output_model.set_data("model.onnx", f.read())
output_model.set_data(
"confidence_threshold",
np.array([self.confidence_threshold], dtype=np.float32).tobytes(),
)
output_model.set_data("config.json", config_to_bytes(self._hyperparams))
output_model.precision = self._precision
output_model.optimization_methods = self._optimization_methods
output_model.has_xai = dump_features
output_model.set_data(
"label_schema.json",
label_schema_to_bytes(self._task_environment.label_schema),
)
logger.info("Exporting completed")

@abstractmethod
def _export_model(self, precision: ModelPrecision, dump_features: bool):
"""Main export function using training backend."""
raise NotImplementedError

@abstractmethod
Expand Down
Loading

0 comments on commit 10fce79

Please sign in to comment.