Skip to content

Commit

Permalink
Extract a part of duplicated code in export method
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed May 17, 2023
1 parent c4c60f4 commit b1760f8
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 64 deletions.
17 changes: 1 addition & 16 deletions otx/algorithms/action/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@
)
from otx.api.entities.model import (
ModelEntity,
ModelFormat,
ModelOptimizationType,
ModelPrecision,
)
from otx.api.entities.model_template import TaskType
Expand Down Expand Up @@ -245,17 +243,7 @@ def export(
"The saliency maps and representation vector outputs will not be dumped in the exported model."
)

if export_type == ExportType.ONNX:
output_model.model_format = ModelFormat.ONNX
output_model.optimization_type = ModelOptimizationType.ONNX
if precision == ModelPrecision.FP16:
raise RuntimeError("Export to FP16 ONNX is not supported")
elif export_type == ExportType.OPENVINO:
output_model.model_format = ModelFormat.OPENVINO
output_model.optimization_type = ModelOptimizationType.MO
else:
raise RuntimeError(f"not supported export type {export_type}")

self._update_model_export_metadata(output_model, export_type, precision, dump_features)
results = self._export_model(precision, export_type, dump_features)

outputs = results.get("outputs")
Expand All @@ -281,9 +269,6 @@ def export(
np.array([self.confidence_threshold], dtype=np.float32).tobytes(),
)
output_model.set_data("config.json", config_to_bytes(self._hyperparams))
output_model.precision = self._precision
output_model.optimization_methods = self._optimization_methods
output_model.has_xai = dump_features
output_model.set_data(
"label_schema.json",
label_schema_to_bytes(self._task_environment.label_schema),
Expand Down
18 changes: 2 additions & 16 deletions otx/algorithms/classification/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@
Performance,
ScoreMetric,
)
from otx.api.entities.model import ( # ModelStatus
from otx.api.entities.model import (
ModelEntity,
ModelFormat,
ModelOptimizationType,
ModelPrecision,
)
from otx.api.entities.resultset import ResultSetEntity
Expand Down Expand Up @@ -246,17 +244,7 @@ def export(

logger.info("Exporting the model")

if export_type == ExportType.ONNX:
output_model.model_format = ModelFormat.ONNX
output_model.optimization_type = ModelOptimizationType.ONNX
if precision == ModelPrecision.FP16:
raise RuntimeError("Export to FP16 ONNX is not supported")
elif export_type == ExportType.OPENVINO:
output_model.model_format = ModelFormat.OPENVINO
output_model.optimization_type = ModelOptimizationType.MO
else:
raise RuntimeError(f"not supported export type {export_type}")

self._update_model_export_metadata(output_model, export_type, precision, dump_features)
results = self._export_model(precision, export_type, dump_features)
outputs = results.get("outputs")
logger.debug(f"results of run_task = {outputs}")
Expand All @@ -282,8 +270,6 @@ def export(
with open(xml_file, "rb") as f:
output_model.set_data("openvino.xml", f.read())

output_model.precision = self._precision
output_model.has_xai = dump_features
output_model.set_data(
"label_schema.json",
label_schema_to_bytes(self._task_environment.label_schema),
Expand Down
21 changes: 20 additions & 1 deletion otx/algorithms/common/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.label import LabelEntity
from otx.api.entities.metrics import MetricsGroup
from otx.api.entities.model import ModelEntity, ModelPrecision, OptimizationMethod
from otx.api.entities.model import ModelEntity, ModelFormat, ModelOptimizationType, ModelPrecision, OptimizationMethod
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.entities.train_parameters import TrainParameters
Expand Down Expand Up @@ -331,3 +331,22 @@ def config(self):
@config.setter
def config(self, config: Dict[Any, Any]):
self._config = config

def _update_model_export_metadata(
self, output_model: ModelEntity, export_type: ExportType, precision: ModelPrecision, dump_features: bool
) -> None:
"""Updates a model entity with format and optimization related attributes."""
if export_type == ExportType.ONNX:
output_model.model_format = ModelFormat.ONNX
output_model.optimization_type = ModelOptimizationType.ONNX
if precision == ModelPrecision.FP16:
raise RuntimeError("Export to FP16 ONNX is not supported")
elif export_type == ExportType.OPENVINO:
output_model.model_format = ModelFormat.OPENVINO
output_model.optimization_type = ModelOptimizationType.MO
else:
raise RuntimeError(f"not supported export type {export_type}")

output_model.has_xai = dump_features
output_model.optimization_methods = self._optimization_methods
output_model.precision = [precision]
16 changes: 1 addition & 15 deletions otx/algorithms/detection/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@
)
from otx.api.entities.model import (
ModelEntity,
ModelFormat,
ModelOptimizationType,
ModelPrecision,
)
from otx.api.entities.model_template import TaskType
Expand Down Expand Up @@ -301,16 +299,7 @@ def export(
"""Export function of OTX Detection Task."""
logger.info("Exporting the model")

if export_type == ExportType.ONNX:
output_model.model_format = ModelFormat.ONNX
output_model.optimization_type = ModelOptimizationType.ONNX
if precision == ModelPrecision.FP16:
raise RuntimeError("Export to FP16 ONNX is not supported")
elif export_type == ExportType.OPENVINO:
output_model.model_format = ModelFormat.OPENVINO
output_model.optimization_type = ModelOptimizationType.MO
else:
raise RuntimeError(f"not supported export type {export_type}")
self._update_model_export_metadata(output_model, export_type, precision, dump_features)

results = self._export_model(precision, export_type, dump_features)
outputs = results.get("outputs")
Expand Down Expand Up @@ -358,9 +347,6 @@ def export(
np.array([self.confidence_threshold], dtype=np.float32).tobytes(),
)
output_model.set_data("config.json", config_to_bytes(self._hyperparams))
output_model.precision = self._precision
output_model.optimization_methods = self._optimization_methods
output_model.has_xai = dump_features
output_model.set_data(
"label_schema.json",
label_schema_to_bytes(self._task_environment.label_schema),
Expand Down
17 changes: 1 addition & 16 deletions otx/algorithms/segmentation/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@
)
from otx.api.entities.model import (
ModelEntity,
ModelFormat,
ModelOptimizationType,
ModelPrecision,
)
from otx.api.entities.result_media import ResultMediaEntity
Expand Down Expand Up @@ -223,17 +221,7 @@ def export(
"""Export function of OTX Task."""
logger.info("Exporting the model")

if export_type == ExportType.ONNX:
output_model.model_format = ModelFormat.ONNX
output_model.optimization_type = ModelOptimizationType.ONNX
if precision == ModelPrecision.FP16:
raise RuntimeError("Export to FP16 ONNX is not supported")
elif export_type == ExportType.OPENVINO:
output_model.model_format = ModelFormat.OPENVINO
output_model.optimization_type = ModelOptimizationType.MO
else:
raise RuntimeError(f"not supported export type {export_type}")

self._update_model_export_metadata(output_model, export_type, precision, dump_features)
results = self._export_model(precision, export_type, dump_features)
outputs = results.get("outputs")
logger.debug(f"results of run_task = {outputs}")
Expand All @@ -256,9 +244,6 @@ def export(
with open(xml_file, "rb") as f:
output_model.set_data("openvino.xml", f.read())

output_model.precision = self._precision
output_model.optimization_methods = self._optimization_methods
output_model.has_xai = dump_features
output_model.set_data("label_schema.json", label_schema_to_bytes(self._task_environment.label_schema))
logger.info("Exporting completed")

Expand Down

0 comments on commit b1760f8

Please sign in to comment.