Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move mpa.deploy to otx.algorithms.common #1903

34 changes: 0 additions & 34 deletions docs/source/guide/reference/mpa/deploy.rst

This file was deleted.

1 change: 0 additions & 1 deletion docs/source/guide/reference/mpa/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ Model Preparation Algorithm
classification
detection
segmentation
deploy
utils
10 changes: 10 additions & 0 deletions otx/algorithms/common/adapters/mmdeploy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Adapters for mmdeploy."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from .utils.mmdeploy import is_mmdeploy_enabled

__all__ = [
"is_mmdeploy_enabled",
]
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""API of otx.algorithms.common.adapters.mmdeploy."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
Expand All @@ -12,20 +13,23 @@

import mmcv
import numpy as np
import onnx
import torch
from mmcv.parallel import collate, scatter

from .utils import numpy_2_list
from .utils.mmdeploy import (
is_mmdeploy_enabled,
mmdeploy_init_model_helper,
update_deploy_cfg,
)
from .utils.onnx import prepare_onnx_for_openvino
from .utils.utils import numpy_2_list

# pylint: disable=too-many-locals


class NaiveExporter:
"""NaiveExporter for non-mmdeploy export."""

@staticmethod
def export2openvino(
output_dir: str,
Expand All @@ -38,13 +42,15 @@ def export2openvino(
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
opset_version: int = 11,
dynamic_axes: Dict[Any, Any] = {},
dynamic_axes: Optional[Dict[Any, Any]] = None,
mo_transforms: str = "",
):
"""Function for exporting to openvino."""
input_data = scatter(collate([input_data], samples_per_gpu=1), [-1])[0]

model = model_builder(cfg)
model = model.cpu().eval()
dynamic_axes = dynamic_axes if dynamic_axes else dict()

onnx_path = NaiveExporter.torch2onnx(
output_dir,
Expand Down Expand Up @@ -108,17 +114,19 @@ def torch2onnx(
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
opset_version: int = 11,
dynamic_axes: Dict[Any, Any] = {},
dynamic_axes: Optional[Dict[Any, Any]] = None,
verbose: bool = False,
**onnx_options,
) -> str:
"""Function for torch to onnx exporting."""

img_metas = input_data.get("img_metas")
numpy_2_list(img_metas)
imgs = input_data.get("img")
model.forward = partial(model.forward, img_metas=img_metas, return_loss=False)

onnx_file_name = model_name + ".onnx"
dynamic_axes = dynamic_axes if dynamic_axes else dict()
torch.onnx.export(
model,
imgs,
Expand All @@ -143,6 +151,7 @@ def onnx2openvino(
model_name: str = "model",
**openvino_options,
) -> Tuple[str, str]:
"""Function for onnx to openvino exporting."""
from otx.mpa.utils import mo_wrapper

mo_args = {
Expand All @@ -163,17 +172,15 @@ def onnx2openvino(

if is_mmdeploy_enabled():
import mmdeploy.apis.openvino as openvino_api
from mmdeploy.apis import (
build_task_processor,
extract_model,
get_predefined_partition_cfg,
torch2onnx,
)
from mmdeploy.apis import build_task_processor, extract_model, torch2onnx
from mmdeploy.apis.openvino import get_input_info_from_cfg, get_mo_options_from_cfg
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import get_backend_config, get_ir_config, get_partition_config

# from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import get_ir_config, get_partition_config

class MMdeployExporter:
"""MMdeployExporter for mmdeploy exporting."""

@staticmethod
def export2openvino(
output_dir: str,
Expand All @@ -183,6 +190,7 @@ def export2openvino(
*,
model_name: str = "model",
):
"""Function for exporting to openvino."""

task_processor = build_task_processor(cfg, deploy_cfg, "cpu")

Expand Down Expand Up @@ -248,6 +256,7 @@ def torch2onnx(
*,
model_name: str = "model",
) -> str:
"""Function for torch to onnx exporting."""
onnx_file_name = model_name + ".onnx"
torch2onnx(
input_data,
Expand All @@ -266,6 +275,7 @@ def partition_onnx(
onnx_path: str,
partition_cfgs: Union[mmcv.ConfigDict, List[mmcv.ConfigDict]],
) -> Tuple[str, ...]:
"""Function for parition onnx."""
partitioned_paths = []

if not isinstance(partition_cfgs, list):
Expand All @@ -290,6 +300,7 @@ def onnx2openvino(
*,
model_name: Optional[str] = None,
) -> Tuple[str, str]:
"""Function for onnx to openvino exporting."""

input_info = get_input_info_from_cfg(deploy_cfg)
output_names = get_ir_config(deploy_cfg).output_names
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Init file for otx.algorithms.common.adapters.mmdeploy.utils."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
"""Functions for mmdeploy adapters."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import importlib

import onnx
from mmcv.utils import ConfigDict


def is_mmdeploy_enabled():
"""Checks if the 'mmdeploy' Python module is installed and available for use.

Returns:
bool: True if 'mmdeploy' is installed, False otherwise.

Example:
>>> is_mmdeploy_enabled()
True
"""
return importlib.util.find_spec("mmdeploy") is not None


def mmdeploy_init_model_helper(ctx, model_checkpoint=None, cfg_options=None, **kwargs):
"""Helper function for initializing a model for inference using the 'mmdeploy' library."""

model_builder = kwargs.pop("model_builder")
model = model_builder(
ctx.model_cfg,
Expand All @@ -31,12 +42,14 @@ def mmdeploy_init_model_helper(ctx, model_checkpoint=None, cfg_options=None, **k
return model


def update_deploy_cfg(onnx_path, deploy_cfg, mo_options={}):
def update_deploy_cfg(onnx_path, deploy_cfg, mo_options=None):
"""Update the 'deploy_cfg' configuration file based on the ONNX model specified by 'onnx_path'."""

from mmdeploy.utils import get_backend_config, get_ir_config

onnx_model = onnx.load(onnx_path)
ir_config = get_ir_config(deploy_cfg)
backend_config = get_backend_config(deploy_cfg)
get_backend_config(deploy_cfg)

# update input
input_names = [i.name for i in onnx_model.graph.input]
Expand All @@ -47,6 +60,7 @@ def update_deploy_cfg(onnx_path, deploy_cfg, mo_options={}):
ir_config["output_names"] = output_names

# update mo options
mo_options = mo_options if mo_options else dict()
deploy_cfg.merge_from_dict({"backend_config": {"mo_options": mo_options}})


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Functions for onnx adapters."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
Expand All @@ -6,6 +7,7 @@


def remove_nodes_by_op_type(onnx_model, op_type):
"""Remove all nodes of a specified op type from the ONNX model."""
# TODO: support more nodes

supported_op_types = ["Mark", "Conv", "Gemm"]
Expand Down Expand Up @@ -42,6 +44,7 @@ def remove_nodes_by_op_type(onnx_model, op_type):


def prepare_onnx_for_openvino(in_path, out_path):
"""Modify the specified ONNX model to be compatible with OpenVINO by removing 'Mark' op nodes."""
onnx_model = onnx.load(in_path)
onnx_model = remove_nodes_by_op_type(onnx_model, "Mark")
onnx.checker.check_model(onnx_model)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Add domain function."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
Expand All @@ -6,4 +7,5 @@


def add_domain(name_operator: str) -> str:
"""Function for adding to DOMAIN_CUSTOM_OPS_NAME."""
return DOMAIN_CUSTOM_OPS_NAME + "::" + name_operator
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Util functions of otx.algorithms.common.adapters.mmdeploy."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
Expand All @@ -9,6 +10,7 @@


def sync_batchnorm_2_batchnorm(module, dim=2):
"""Syncs the BatchNorm layers in a model to use regular BatchNorm layers."""
if dim == 1:
bn = torch.nn.BatchNorm1d
elif dim == 2:
Expand Down Expand Up @@ -48,6 +50,7 @@ def sync_batchnorm_2_batchnorm(module, dim=2):


def numpy_2_list(data):
"""Converts NumPy arrays to Python lists."""

if isinstance(data, np.ndarray):
return data.tolist()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.atss import ATSS

from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
)
from otx.mpa.deploy.utils import is_mmdeploy_enabled
from otx.mpa.modules.hooks.recording_forward_hooks import FeatureVectorHook
from otx.mpa.modules.utils.task_adapt import map_class_names
from otx.mpa.utils.logger import get_logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.mask_rcnn import MaskRCNN

from otx.mpa.deploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.mpa.modules.hooks.recording_forward_hooks import (
ActivationMapHook,
FeatureVectorHook,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.single_stage import SingleStageDetector

from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
)
from otx.mpa.deploy.utils import is_mmdeploy_enabled
from otx.mpa.modules.hooks.recording_forward_hooks import FeatureVectorHook
from otx.mpa.modules.utils.task_adapt import map_class_names
from otx.mpa.utils.logger import get_logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.yolox import YOLOX

from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
)
from otx.mpa.deploy.utils import is_mmdeploy_enabled
from otx.mpa.modules.hooks.recording_forward_hooks import FeatureVectorHook
from otx.mpa.modules.utils.task_adapt import map_class_names
from otx.mpa.utils.logger import get_logger
Expand Down
2 changes: 1 addition & 1 deletion otx/algorithms/detection/adapters/mmdet/nncf/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from mmdet.models.roi_heads.bbox_heads.sabl_head import SABLHead
from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead

from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.adapters.nncf import (
NNCF_PATCHER,
is_in_nncf_tracing,
nncf_trace_wrapper,
no_nncf_trace_wrapper,
)
from otx.algorithms.common.adapters.nncf.patches import nncf_trace_context
from otx.mpa.deploy.utils import is_mmdeploy_enabled

HEADS_TARGETS = dict(
classes=(
Expand Down
4 changes: 2 additions & 2 deletions otx/mpa/cls/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from mmcv.runner import wrap_fp16_model

from otx.mpa.deploy.utils import sync_batchnorm_2_batchnorm
from otx.algorithms.common.adapters.mmdeploy.utils import sync_batchnorm_2_batchnorm
from otx.mpa.exporter_mixin import ExporterMixin
from otx.mpa.registry import STAGES
from otx.mpa.utils.logger import get_logger
Expand Down Expand Up @@ -47,7 +47,7 @@ def model_builder_helper(*args, **kwargs):
def naive_export(output_dir, model_builder, precision, cfg, model_name="model"):
from mmcls.datasets.pipelines import Compose

from ..deploy.apis import NaiveExporter
from otx.algorithms.common.adapters.mmdeploy.apis import NaiveExporter

def get_fake_data(cfg, orig_img_shape=(128, 128, 3)):
pipeline = cfg.data.test.pipeline
Expand Down
9 changes: 0 additions & 9 deletions otx/mpa/deploy/__init__.py

This file was deleted.

Loading