From 4b3daf9444720b905c38344198715797fee54836 Mon Sep 17 00:00:00 2001 From: "Kang, Harim" Date: Thu, 23 Mar 2023 09:34:01 +0900 Subject: [PATCH 1/6] First Refactoring --- otx/core/ov/__init__.py | 4 + otx/core/ov/graph/__init__.py | 4 + otx/core/ov/graph/parsers/__init__.py | 4 + otx/core/ov/graph/parsers/cls/__init__.py | 8 + .../ov/graph/parsers/cls/cls_base_parser.py | 18 +- otx/{mpa/modules => core}/ov/omz_wrapper.py | 251 ++++++------ otx/core/ov/ops/__init__.py | 4 + otx/core/ov/ops/activations.py | 356 ++++++++++++++++++ .../modules => core}/ov/ops/arithmetics.py | 55 ++- otx/{mpa/modules => core}/ov/ops/builder.py | 38 +- .../modules => core}/ov/ops/convolutions.py | 41 +- .../modules => core}/ov/ops/generation.py | 16 +- .../ov/ops/image_processings.py | 50 +-- .../ov/ops/infrastructures.py | 56 ++- otx/{mpa/modules => core}/ov/ops/matmuls.py | 19 +- otx/core/ov/ops/modules/__init__.py | 8 + .../ov/ops/modules/op_module.py | 39 +- otx/{mpa/modules => core}/ov/ops/movements.py | 242 ++++++++---- .../modules => core}/ov/ops/normalizations.py | 67 ++-- .../ov/ops/object_detections.py | 64 ++-- otx/{mpa/modules => core}/ov/ops/op.py | 42 ++- otx/{mpa/modules => core}/ov/ops/poolings.py | 54 ++- .../modules => core}/ov/ops/reductions.py | 61 ++- .../ov/ops/shape_manipulations.py | 72 ++-- .../ov/ops/sorting_maximization.py | 41 +- .../ov/ops/type_conversions.py | 20 +- otx/core/ov/ops/utils.py | 16 + otx/{mpa/modules => core}/ov/registry.py | 18 +- otx/mpa/modules/__init__.py | 12 +- otx/mpa/modules/ov/__init__.py | 7 +- otx/mpa/modules/ov/graph/graph.py | 5 +- otx/mpa/modules/ov/graph/parsers/__init__.py | 2 +- otx/mpa/modules/ov/graph/parsers/builder.py | 2 +- .../modules/ov/graph/parsers/cls/__init__.py | 6 - otx/mpa/modules/ov/graph/utils.py | 5 +- .../models/mmcls/backbones/mmov_backbone.py | 3 +- .../ov/models/mmcls/heads/mmov_cls_head.py | 3 +- .../ov/models/mmcls/necks/mmov_neck.py | 3 +- otx/mpa/modules/ov/models/ov_model.py | 6 +- otx/mpa/modules/ov/ops/__init__.py | 23 -- otx/mpa/modules/ov/ops/activations.py | 271 ------------- otx/mpa/modules/ov/ops/modules/__init__.py | 6 - otx/mpa/modules/ov/ops/utils.py | 48 --- otx/mpa/modules/ov/utils.py | 7 +- otx/mpa/utils/file.py | 8 - .../graph/parsers/test_ov_graph_cls_parser.py | 2 +- .../modules/ov/ops/test_ov_ops_activations.py | 2 +- .../modules/ov/ops/test_ov_ops_arithmetics.py | 8 +- .../mpa/modules/ov/ops/test_ov_ops_builder.py | 4 +- .../ov/ops/test_ov_ops_convolutions.py | 2 +- .../modules/ov/ops/test_ov_ops_generation.py | 2 +- .../ov/ops/test_ov_ops_image_processings.py | 2 +- .../ov/ops/test_ov_ops_infrastructures.py | 2 +- .../mpa/modules/ov/ops/test_ov_ops_matmuls.py | 2 +- .../mpa/modules/ov/ops/test_ov_ops_module.py | 4 +- .../modules/ov/ops/test_ov_ops_movements.py | 2 +- .../ov/ops/test_ov_ops_normalizations.py | 2 +- .../ov/ops/test_ov_ops_object_detections.py | 2 +- .../unit/mpa/modules/ov/ops/test_ov_ops_op.py | 2 +- .../modules/ov/ops/test_ov_ops_poolings.py | 2 +- .../modules/ov/ops/test_ov_ops_reductions.py | 2 +- .../ov/ops/test_ov_ops_shape_manipulations.py | 2 +- .../ops/test_ov_ops_sorting_maximization.py | 2 +- .../ov/ops/test_ov_ops_type_conversions.py | 2 +- .../mpa/modules/ov/ops/test_ov_ops_utils.py | 3 +- .../mpa/modules/ov/test_ov_omz_wrapper.py | 2 +- tests/unit/mpa/modules/ov/test_ov_registry.py | 2 +- tests/unit/mpa/modules/ov/test_ov_utils.py | 4 +- 68 files changed, 1267 insertions(+), 877 deletions(-) create mode 100644 otx/core/ov/__init__.py create mode 100644 otx/core/ov/graph/__init__.py create mode 100644 otx/core/ov/graph/parsers/__init__.py create mode 100644 otx/core/ov/graph/parsers/cls/__init__.py rename otx/{mpa/modules => core}/ov/graph/parsers/cls/cls_base_parser.py (86%) rename otx/{mpa/modules => core}/ov/omz_wrapper.py (60%) create mode 100644 otx/core/ov/ops/__init__.py create mode 100644 otx/core/ov/ops/activations.py rename otx/{mpa/modules => core}/ov/ops/arithmetics.py (70%) rename otx/{mpa/modules => core}/ov/ops/builder.py (50%) rename otx/{mpa/modules => core}/ov/ops/convolutions.py (74%) rename otx/{mpa/modules => core}/ov/ops/generation.py (59%) rename otx/{mpa/modules => core}/ov/ops/image_processings.py (78%) rename otx/{mpa/modules => core}/ov/ops/infrastructures.py (83%) rename otx/{mpa/modules => core}/ov/ops/matmuls.py (68%) create mode 100644 otx/core/ov/ops/modules/__init__.py rename otx/{mpa/modules => core}/ov/ops/modules/op_module.py (60%) rename otx/{mpa/modules => core}/ov/ops/movements.py (58%) rename otx/{mpa/modules => core}/ov/ops/normalizations.py (73%) rename otx/{mpa/modules => core}/ov/ops/object_detections.py (80%) rename otx/{mpa/modules => core}/ov/ops/op.py (64%) rename otx/{mpa/modules => core}/ov/ops/poolings.py (79%) rename otx/{mpa/modules => core}/ov/ops/reductions.py (58%) rename otx/{mpa/modules => core}/ov/ops/shape_manipulations.py (64%) rename otx/{mpa/modules => core}/ov/ops/sorting_maximization.py (76%) rename otx/{mpa/modules => core}/ov/ops/type_conversions.py (70%) create mode 100644 otx/core/ov/ops/utils.py rename otx/{mpa/modules => core}/ov/registry.py (67%) delete mode 100644 otx/mpa/modules/ov/graph/parsers/cls/__init__.py delete mode 100644 otx/mpa/modules/ov/ops/__init__.py delete mode 100644 otx/mpa/modules/ov/ops/activations.py delete mode 100644 otx/mpa/modules/ov/ops/modules/__init__.py delete mode 100644 otx/mpa/modules/ov/ops/utils.py delete mode 100644 otx/mpa/utils/file.py diff --git a/otx/core/ov/__init__.py b/otx/core/ov/__init__.py new file mode 100644 index 00000000000..9ae460d03b2 --- /dev/null +++ b/otx/core/ov/__init__.py @@ -0,0 +1,4 @@ +"""Module for otx.core.ov.""" +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT diff --git a/otx/core/ov/graph/__init__.py b/otx/core/ov/graph/__init__.py new file mode 100644 index 00000000000..9696d660171 --- /dev/null +++ b/otx/core/ov/graph/__init__.py @@ -0,0 +1,4 @@ +"""Module for otx.core.ov.graph.""" +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT diff --git a/otx/core/ov/graph/parsers/__init__.py b/otx/core/ov/graph/parsers/__init__.py new file mode 100644 index 00000000000..63ec930e665 --- /dev/null +++ b/otx/core/ov/graph/parsers/__init__.py @@ -0,0 +1,4 @@ +"""Module for otx.core.ov.graph.parser.""" +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT diff --git a/otx/core/ov/graph/parsers/cls/__init__.py b/otx/core/ov/graph/parsers/cls/__init__.py new file mode 100644 index 00000000000..74b90157a52 --- /dev/null +++ b/otx/core/ov/graph/parsers/cls/__init__.py @@ -0,0 +1,8 @@ +"""Module for otx.core.ov.graph.parsers.cls.""" +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from .cls_base_parser import cls_base_parser + +__all__ = ["cls_base_parser"] diff --git a/otx/mpa/modules/ov/graph/parsers/cls/cls_base_parser.py b/otx/core/ov/graph/parsers/cls/cls_base_parser.py similarity index 86% rename from otx/mpa/modules/ov/graph/parsers/cls/cls_base_parser.py rename to otx/core/ov/graph/parsers/cls/cls_base_parser.py index 2ca1c5d27d9..4b88ab5e447 100644 --- a/otx/mpa/modules/ov/graph/parsers/cls/cls_base_parser.py +++ b/otx/core/ov/graph/parsers/cls/cls_base_parser.py @@ -1,16 +1,18 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Class base parser for otx.core.ov.graph.parsers.cls.cls_base_parser.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from typing import Dict, List, Optional +from otx.mpa.modules.ov.graph.parsers.builder import PARSERS +from otx.mpa.modules.ov.graph.parsers.parser import parameter_parser from otx.mpa.utils.logger import get_logger -from ..builder import PARSERS -from ..parser import parameter_parser - logger = get_logger() +# pylint: disable=too-many-return-statements, too-many-branches + NECK_INPUT_TYPES = ["ReduceMean", "MaxPool", "AvgPool"] NECK_TYPES = [ @@ -27,6 +29,7 @@ @PARSERS.register() def cls_base_parser(graph, component: str = "backbone") -> Optional[Dict[str, List[str]]]: + """Class base parser for OMZ models.""" assert component in ["backbone", "neck", "head"] result_nodes = graph.get_nodes_by_types(["Result"]) @@ -73,13 +76,13 @@ def cls_base_parser(graph, component: str = "backbone") -> Optional[Dict[str, Li outputs=outputs, ) - elif component == "neck": + if component == "neck": return dict( inputs=[neck_input.name], outputs=[neck_output.name], ) - elif component == "head": + if component == "head": inputs = list(graph.successors(neck_output)) # if len(inputs) != 1: # logger.debug(f"neck_output {neck_output.name} has more than one successors.") @@ -102,3 +105,4 @@ def cls_base_parser(graph, component: str = "backbone") -> Optional[Dict[str, Li inputs=[input.name for input in inputs], outputs=[output.name for output in outputs], ) + return None diff --git a/otx/mpa/modules/ov/omz_wrapper.py b/otx/core/ov/omz_wrapper.py similarity index 60% rename from otx/mpa/modules/ov/omz_wrapper.py rename to otx/core/ov/omz_wrapper.py index ab321b04023..21e2adebed2 100644 --- a/otx/mpa/modules/ov/omz_wrapper.py +++ b/otx/core/ov/omz_wrapper.py @@ -1,6 +1,7 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""OMZ wrapper-related code for otx.core.ov.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT import hashlib import os @@ -9,6 +10,7 @@ import sys import time from pathlib import Path +from typing import Dict, List import requests from openvino.model_zoo import _common, _reporting @@ -18,13 +20,16 @@ from openvino.model_zoo.omz_converter import ModelOptimizerProperties, convert_to_onnx from requests.exceptions import HTTPError -from otx.mpa.utils.file import MPA_CACHE +# pylint: disable=too-many-locals, too-many-branches -MPA_OMZ_CACHE = os.path.join(MPA_CACHE, "omz") -os.makedirs(MPA_OMZ_CACHE, exist_ok=True) +OTX_CACHE = os.path.expanduser(os.getenv("OTX_CACHE", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "otx"))) +os.makedirs(OTX_CACHE, exist_ok=True) +OMZ_CACHE = os.path.join(OTX_CACHE, "omz") +os.makedirs(OMZ_CACHE, exist_ok=True) -OMZ_PUBLIC_MODELS = dict( + +OMZ_PUBLIC_MODELS: Dict[str, List[str]] = dict( cls=[ "alexnet", "caffenet", @@ -86,37 +91,43 @@ ) -AVAILABLE_OMZ_MODELS = [model for models in OMZ_PUBLIC_MODELS.values() for model in models] +AVAILABLE_OMZ_MODELS: List[str] = [] +for models_ in OMZ_PUBLIC_MODELS.values(): + for model_ in models_: + AVAILABLE_OMZ_MODELS.append(model_) class NameSpace: + """NameSpace class for otx.core.ov.omz_wrapper.""" + def __init__(self, **kwargs): self.__dict__.update(kwargs) def _get_etag(url): + """Getter etag function from url.""" try: - response = requests.head(url, allow_redirects=True) + response = requests.head(url, allow_redirects=True, timeout=100) if response.status_code != 200: return None - else: - return response.headers.get("ETag", None) - except HTTPError as e: + return response.headers.get("ETag", None) + except HTTPError: return None def _get_ir_path(directory): + """Getter IR path function from directory path.""" directory = Path(directory) model_path = list(directory.glob("**/*.xml")) weight_path = list(directory.glob("**/*.bin")) if model_path and weight_path: assert len(model_path) == 1 and len(weight_path) == 1 return dict(model_path=model_path[0], weight_path=weight_path[0]) - else: - return None + return None def _run_pre_convert(reporter, model, output_dir, args): + """Run pre-converting function.""" script = _common.MODEL_ROOT / model.subdirectory_ori / "pre-convert.py" if not script.exists(): return True @@ -146,14 +157,15 @@ def _run_pre_convert(reporter, model, output_dir, args): def _update_model(model): - m = hashlib.sha256() + """Update model configs for omz_wrapper.""" + m_hash = hashlib.sha256() for file in model.files: url = file.source.url etag = _get_etag(url) if etag is not None: - m.update(bytes(etag, "utf-8")) + m_hash.update(bytes(etag, "utf-8")) model.subdirectory_ori = model.subdirectory - model.subdirectory = Path(m.hexdigest()) + model.subdirectory = Path(m_hash.hexdigest()) # FIXME: a bug from openvino-dev==2022.3.0 # It has been fixed on master branch. @@ -165,16 +177,19 @@ def _update_model(model): def get_model_configuration(model_name): + """Getter function of model configuration from name.""" model_configurations = load_models(_common.MODEL_ROOT, {}) - for i, model in enumerate(model_configurations): + for model in model_configurations: if model.name == model_name: _update_model(model) return model return None -def download_model(model, download_dir=MPA_OMZ_CACHE, precisions={"FP32"}, force=False): +def download_model(model, download_dir=OMZ_CACHE, precisions=None, force=False): + """Function for downloading model from directory.""" download_dir = Path("") if download_dir is None else Path(download_dir) + precisions = precisions if precisions else {"FP32"} # TODO: need delicate cache management if not force and (download_dir / model.subdirectory).exists(): @@ -207,15 +222,108 @@ def download_model(model, download_dir=MPA_OMZ_CACHE, precisions={"FP32"}, force sys.exit(1) +def _convert(reporter, model, output_dir, namespace, mo_props, requested_precisions): + """Convert function for OMZ wrapper.""" + if model.mo_args is None: + reporter.print_section_heading("Skipping {} (no conversions defined)", model.name) + reporter.print() + return True + + model_precisions = requested_precisions & model.precisions + if not model_precisions: + reporter.print_section_heading("Skipping {} (all conversions skipped)", model.name) + reporter.print() + return True + + (output_dir / model.subdirectory).mkdir(parents=True, exist_ok=True) + + if not _run_pre_convert(reporter, model, output_dir, namespace): + return False + + model_format = model.framework + mo_extension_dir = mo_props.base_dir / "extensions" + if not mo_extension_dir.exists(): + mo_extension_dir = mo_props.base_dir + + template_variables = { + "config_dir": _common.MODEL_ROOT / model.subdirectory_ori, + "conv_dir": output_dir / model.subdirectory, + "dl_dir": namespace.download_dir / model.subdirectory, + "mo_dir": mo_props.base_dir, + "mo_ext_dir": mo_extension_dir, + } + + if model.conversion_to_onnx_args: + if not convert_to_onnx(reporter, model, output_dir, namespace, template_variables): + return False + model_format = "onnx" + + expanded_mo_args = [string.Template(arg).substitute(template_variables) for arg in model.mo_args] + + for model_precision in sorted(model_precisions): + data_type = model_precision.split("-")[0] + layout_string = ",".join(f"{input.name}({input.layout})" for input in model.input_info if input.layout) + shape_string = ",".join(str(input.shape) for input in model.input_info if input.shape) + + if layout_string: + expanded_mo_args.append(f"--layout={layout_string}") + if shape_string: + expanded_mo_args.append(f"--input_shape={shape_string}") + + mo_cmd = [ + *mo_props.cmd_prefix, + f"--framework={model_format}", + f"--data_type={data_type}", + f"--output_dir={output_dir / model.subdirectory / model_precision}", + f"--model_name={model.name}", + f"--input={','.join(input.name for input in model.input_info)}".format(), + *expanded_mo_args, + *mo_props.extra_args, + ] + + reporter.print_section_heading( + "{}Converting {} to IR ({})", + "(DRY RUN) " if namespace.dry_run else "", + model.name, + model_precision, + ) + + reporter.print("Conversion command: {}", _common.command_string(mo_cmd)) + + if not namespace.dry_run: + reporter.print(flush=True) + + if not reporter.job_context.subprocess(mo_cmd): + # NOTE: mo returns non zero return code (245) even though it successfully generate IR + cur_time = time.time() + time_threshold = 5 + xml_path = output_dir / model.subdirectory / model_precision / f"{model.name}.xml" + bin_path = output_dir / model.subdirectory / model_precision / f"{model.name}.bin" + if not ( + os.path.exists(xml_path) + and os.path.exists(bin_path) + and os.path.getmtime(xml_path) - cur_time < time_threshold + and os.path.getmtime(bin_path) - cur_time < time_threshold + ): + return False + + reporter.print() + + return True + + def convert_model( model, - download_dir=MPA_OMZ_CACHE, - output_dir=MPA_OMZ_CACHE, - precisions={"FP32"}, + download_dir=OMZ_CACHE, + output_dir=OMZ_CACHE, + precisions=None, force=False, -): + *args, +): # pylint: disable=keyword-arg-before-vararg + """Converting model for OMZ wrapping.""" download_dir = Path("") if download_dir is None else Path(download_dir) output_dir = Path("") if output_dir is None else Path(output_dir) + precisions = precisions if precisions else {"FP32"} out = _get_ir_path(output_dir / model.subdirectory) if out and not force: @@ -254,7 +362,7 @@ def convert_model( if mo_package_path is None: mo_package_path, stderr = _common.get_package_path(args.python, "mo") if mo_package_path is None: - sys.exit("Unable to load Model Optimizer. Errors occurred: {}".format(stderr)) + sys.exit(f"Unable to load Model Optimizer. Errors occurred: {stderr}") mo_dir = mo_package_path.parent reporter = _reporting.Reporter(_reporting.DirectOutputContext()) @@ -265,104 +373,14 @@ def convert_model( ) shared_convert_args = (output_dir, namespace, mo_props, precisions) - def convert(reporter, model, output_dir, namespace, mo_props, requested_precisions): - if model.mo_args is None: - reporter.print_section_heading("Skipping {} (no conversions defined)", model.name) - reporter.print() - return True - - model_precisions = requested_precisions & model.precisions - if not model_precisions: - reporter.print_section_heading("Skipping {} (all conversions skipped)", model.name) - reporter.print() - return True - - (output_dir / model.subdirectory).mkdir(parents=True, exist_ok=True) - - if not _run_pre_convert(reporter, model, output_dir, namespace): - return False - - model_format = model.framework - mo_extension_dir = mo_props.base_dir / "extensions" - if not mo_extension_dir.exists(): - mo_extension_dir = mo_props.base_dir - - template_variables = { - "config_dir": _common.MODEL_ROOT / model.subdirectory_ori, - "conv_dir": output_dir / model.subdirectory, - "dl_dir": namespace.download_dir / model.subdirectory, - "mo_dir": mo_props.base_dir, - "mo_ext_dir": mo_extension_dir, - } - - if model.conversion_to_onnx_args: - if not convert_to_onnx(reporter, model, output_dir, namespace, template_variables): - return False - model_format = "onnx" - - expanded_mo_args = [string.Template(arg).substitute(template_variables) for arg in model.mo_args] - - for model_precision in sorted(model_precisions): - data_type = model_precision.split("-")[0] - layout_string = ",".join( - "{}({})".format(input.name, input.layout) for input in model.input_info if input.layout - ) - shape_string = ",".join(str(input.shape) for input in model.input_info if input.shape) - - if layout_string: - expanded_mo_args.append("--layout={}".format(layout_string)) - if shape_string: - expanded_mo_args.append("--input_shape={}".format(shape_string)) - - mo_cmd = [ - *mo_props.cmd_prefix, - "--framework={}".format(model_format), - "--data_type={}".format(data_type), - "--output_dir={}".format(output_dir / model.subdirectory / model_precision), - "--model_name={}".format(model.name), - "--input={}".format(",".join(input.name for input in model.input_info)), - *expanded_mo_args, - *mo_props.extra_args, - ] - - reporter.print_section_heading( - "{}Converting {} to IR ({})", - "(DRY RUN) " if namespace.dry_run else "", - model.name, - model_precision, - ) - - reporter.print("Conversion command: {}", _common.command_string(mo_cmd)) - - if not namespace.dry_run: - reporter.print(flush=True) - - if not reporter.job_context.subprocess(mo_cmd): - # NOTE: mo returns non zero return code (245) even though it successfully generate IR - cur_time = time.time() - time_threshold = 5 - xml_path = output_dir / model.subdirectory / model_precision / f"{model.name}.xml" - bin_path = output_dir / model.subdirectory / model_precision / f"{model.name}.bin" - if not ( - os.path.exists(xml_path) - and os.path.exists(bin_path) - and os.path.getmtime(xml_path) - cur_time < time_threshold - and os.path.getmtime(bin_path) - cur_time < time_threshold - ): - return False - - reporter.print() - - return True - results = [] models = [] if model.model_stages: for model_stage in model.model_stages: - results.append(convert(reporter, model_stage, *shared_convert_args)) + results.append(_convert(reporter, model_stage, *shared_convert_args)) models.append(model_stage) else: - results.append(convert(reporter, model, *shared_convert_args)) + results.append(_convert(reporter, model, *shared_convert_args)) models.append(model) failed_models = [model.name for model, successful in zip(models, results) if not successful] @@ -376,7 +394,8 @@ def convert(reporter, model, output_dir, namespace, mo_props, requested_precisio return _get_ir_path(output_dir / model.subdirectory) -def get_omz_model(model_name, download_dir=MPA_OMZ_CACHE, output_dir=MPA_OMZ_CACHE, force=False): +def get_omz_model(model_name, download_dir=OMZ_CACHE, output_dir=OMZ_CACHE, force=False): + """Get OMZ model from name and download_dir.""" model = get_model_configuration(model_name) download_model(model, download_dir=download_dir, force=force) return convert_model(model, download_dir=download_dir, output_dir=output_dir, force=force) diff --git a/otx/core/ov/ops/__init__.py b/otx/core/ov/ops/__init__.py new file mode 100644 index 00000000000..df6df449c4a --- /dev/null +++ b/otx/core/ov/ops/__init__.py @@ -0,0 +1,4 @@ +"""Module of otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT diff --git a/otx/core/ov/ops/activations.py b/otx/core/ov/ops/activations.py new file mode 100644 index 00000000000..5d7f7ffc16b --- /dev/null +++ b/otx/core/ov/ops/activations.py @@ -0,0 +1,356 @@ +"""Activation-related modules for otx.core.ov.ops.activations.""" +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import math +from dataclasses import dataclass, field + +import torch +from torch.nn import functional as F + +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation + + +@dataclass +class SoftMaxV0Attribute(Attribute): + """SoftMaxV0Attribute class.""" + + axis: int = field(default=1) + + +@OPS.register() +class SoftMaxV0(Operation[SoftMaxV0Attribute]): + """SoftMaxV0 class.""" + + TYPE = "Softmax" + VERSION = 0 + ATTRIBUTE_FACTORY = SoftMaxV0Attribute + + def forward(self, inputs): + """SoftMaxV0's forward function.""" + return F.softmax(input=inputs, dim=self.attrs.axis) + + +@dataclass +class SoftMaxV1Attribute(Attribute): + """SoftMaxV1Attribute class.""" + + axis: int = field(default=1) + + +@OPS.register() +class SoftMaxV1(Operation[SoftMaxV1Attribute]): + """SoftMaxV1 class.""" + + TYPE = "Softmax" + VERSION = 1 + ATTRIBUTE_FACTORY = SoftMaxV1Attribute + + def forward(self, inputs): + """SoftMaxV1's forward function.""" + return F.softmax(input=inputs, dim=self.attrs.axis) + + +@dataclass +class ReluV0Attribute(Attribute): + """ReluV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class ReluV0(Operation[ReluV0Attribute]): + """ReluV0 class.""" + + TYPE = "Relu" + VERSION = 0 + ATTRIBUTE_FACTORY = ReluV0Attribute + + def forward(self, inputs): + """ReluV0's forward function.""" + return F.relu(inputs) + + +@dataclass +class SwishV4Attribute(Attribute): + """SwishV4Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class SwishV4(Operation[SwishV4Attribute]): + """SwishV4 class.""" + + TYPE = "Swish" + VERSION = 4 + ATTRIBUTE_FACTORY = SwishV4Attribute + + def forward(self, inputs, beta=1.0): + """SwishV4's forward function.""" + return inputs * torch.sigmoid(inputs * beta) + + +@dataclass +class SigmoidV0Attribute(Attribute): + """SigmoidV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class SigmoidV0(Operation[SigmoidV0Attribute]): + """SigmoidV0 class.""" + + TYPE = "Sigmoid" + VERSION = 0 + ATTRIBUTE_FACTORY = SigmoidV0Attribute + + def forward(self, inputs): + """SigmoidV0's forward function.""" + return torch.sigmoid(inputs) + + +@dataclass +class ClampV0Attribute(Attribute): + """ClampV0Attribute class.""" + + min: float + max: float + + +@OPS.register() +class ClampV0(Operation[ClampV0Attribute]): + """ClampV0 class.""" + + TYPE = "Clamp" + VERSION = 0 + ATTRIBUTE_FACTORY = ClampV0Attribute + + def forward(self, inputs): + """ClampV0's forward function.""" + return inputs.clamp(min=self.attrs.min, max=self.attrs.max) + + +@dataclass +class PReluV0Attribute(Attribute): + """PReluV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class PReluV0(Operation[PReluV0Attribute]): + """PReluV0 class.""" + + TYPE = "PRelu" + VERSION = 0 + ATTRIBUTE_FACTORY = PReluV0Attribute + + def forward(self, inputs, slope): + """PReluV0's forward function.""" + return F.prelu(input=inputs, weight=slope) + + +@dataclass +class TanhV0Attribute(Attribute): + """TanhV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class TanhV0(Operation[TanhV0Attribute]): + """TanhV0 class.""" + + TYPE = "Tanh" + VERSION = 0 + ATTRIBUTE_FACTORY = TanhV0Attribute + + def forward(self, inputs): + """TanhV0's forward function.""" + return F.tanh(inputs) + + +@dataclass +class EluV0Attribute(Attribute): + """EluV0Attribute class.""" + + alpha: float + + +@OPS.register() +class EluV0(Operation[EluV0Attribute]): + """EluV0 class.""" + + TYPE = "Elu" + VERSION = 0 + ATTRIBUTE_FACTORY = EluV0Attribute + + def forward(self, inputs): + """EluV0's forward function.""" + return F.elu(input=inputs, alpha=self.attrs.alpha) + + +@dataclass +class SeluV0Attribute(Attribute): + """SeluV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class SeluV0(Operation[SeluV0Attribute]): + """SeluV0 class.""" + + TYPE = "Selu" + VERSION = 0 + ATTRIBUTE_FACTORY = SeluV0Attribute + + def forward(self, inputs, alpha, lambda_): + """SeluV0's forward function.""" + return lambda_ * F.elu(input=inputs, alpha=alpha) + + +@dataclass +class MishV4Attribute(Attribute): + """MishV4Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class MishV4(Operation[MishV4Attribute]): + """MishV4 class.""" + + TYPE = "Mish" + VERSION = 4 + ATTRIBUTE_FACTORY = MishV4Attribute + + def forward(self, inputs): + """MishV4's forward function.""" + # NOTE: pytorch 1.8.2 does not have mish function + # return F.mish(input=input) + return inputs * F.tanh(F.softplus(inputs)) + + +@dataclass +class HSwishV4Attribute(Attribute): + """HSwishV4Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class HSwishV4(Operation[HSwishV4Attribute]): + """HSwishV4 class.""" + + TYPE = "HSwish" + VERSION = 4 + ATTRIBUTE_FACTORY = HSwishV4Attribute + + def forward(self, inputs): + """HSwishV4's forward function.""" + return F.hardswish(input=inputs) + + +@dataclass +class HSigmoidV5Attribute(Attribute): + """HSigmoidV5Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class HSigmoidV5(Operation[HSigmoidV5Attribute]): + """HSigmoidV5 class.""" + + TYPE = "HSigmoid" + VERSION = 5 + ATTRIBUTE_FACTORY = HSigmoidV5Attribute + + def forward(self, inputs): + """HSigmoidV5's forward function.""" + return F.hardsigmoid(input=inputs) + + +@dataclass +class ExpV0Attribute(Attribute): + """ExpV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class ExpV0(Operation[ExpV0Attribute]): + """ExpV0 class.""" + + TYPE = "Exp" + VERSION = 0 + ATTRIBUTE_FACTORY = ExpV0Attribute + + def forward(self, inputs): + """ExpV0's forward function.""" + return torch.exp(inputs) + + +@dataclass +class HardSigmoidV0Attribute(Attribute): + """HardSigmoidV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass + + +@OPS.register() +class HardSigmoidV0(Operation[HardSigmoidV0Attribute]): + """HardSigmoidV0 class.""" + + TYPE = "HardSigmoid" + VERSION = 0 + ATTRIBUTE_FACTORY = HardSigmoidV0Attribute + + def forward(self, inputs, alpha, beta): + """HardSigmoidV0's forward function.""" + return torch.maximum( + torch.zeros_like(inputs), + torch.minimum(torch.ones_like(inputs), inputs * alpha + beta), + ) + + +@dataclass +class GeluV7Attribute(Attribute): + """GeluV7Attribute class.""" + + approximation_mode: str = field(default="ERF") + + def __post_init__(self): + """GeluV7Attribute's post init function.""" + super().__post_init__() + valid_approximation_mode = ["ERF", "tanh"] + if self.approximation_mode not in valid_approximation_mode: + raise ValueError( + f"Invalid approximation_mode {self.approximation_mode}. " + f"It must be one of {valid_approximation_mode}." + ) + + +@OPS.register() +class GeluV7(Operation[GeluV7Attribute]): + """GeluV7 class.""" + + TYPE = "Gelu" + VERSION = 7 + ATTRIBUTE_FACTORY = GeluV7Attribute + + def forward(self, inputs): + """GeluV7's forward function.""" + mode = self.attrs.approximation_mode + if mode == "ERF": + return F.gelu(input=inputs) + if mode == "tanh": + return ( + inputs * 0.5 * (1 + F.tanh(torch.sqrt(2 / torch.tensor(math.pi)) * (inputs + 0.044715 * inputs**3))) + ) + return None diff --git a/otx/mpa/modules/ov/ops/arithmetics.py b/otx/core/ov/ops/arithmetics.py similarity index 70% rename from otx/mpa/modules/ov/ops/arithmetics.py rename to otx/core/ov/ops/arithmetics.py index f637fca53d3..b959037650c 100644 --- a/otx/mpa/modules/ov/ops/arithmetics.py +++ b/otx/core/ov/ops/arithmetics.py @@ -1,51 +1,61 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Arithmetics-related codes for otx.core.ov.ops.arithmetics.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass, field import torch -from .builder import OPS -from .op import Attribute, Operation +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation @dataclass class MultiplyV1Attribute(Attribute): + """MultiplyV1Attribute class.""" + auto_broadcast: str = field(default="numpy") @OPS.register() class MultiplyV1(Operation[MultiplyV1Attribute]): + """MultiplyV1 class.""" + TYPE = "Multiply" VERSION = 1 ATTRIBUTE_FACTORY = MultiplyV1Attribute def forward(self, input_0, input_1): + """MultiplyV1's forward function.""" broadcast = self.attrs.auto_broadcast if broadcast == "none": assert input_0.shape == input_1.shape return input_0 * input_1 - elif broadcast == "numpy": + if broadcast == "numpy": return input_0 * input_1 - else: - raise NotImplementedError + raise NotImplementedError @dataclass class DivideV1Attribute(Attribute): + """DivideV1Attribute class.""" + m_pythondiv: bool = field(default=True) auto_broadcast: str = field(default="numpy") @OPS.register() class DivideV1(Operation[DivideV1Attribute]): + """DivideV1 class.""" + TYPE = "Divide" VERSION = 1 ATTRIBUTE_FACTORY = DivideV1Attribute def forward(self, input_0, input_1): + """DivideV1's forward function.""" broadcast = self.attrs.auto_broadcast if broadcast == "none": @@ -65,60 +75,73 @@ def forward(self, input_0, input_1): @dataclass class AddV1Attribute(Attribute): + """AddV1Attribute class.""" + auto_broadcast: str = field(default="numpy") @OPS.register() class AddV1(Operation[AddV1Attribute]): + """AddV1 class.""" + TYPE = "Add" VERSION = 1 ATTRIBUTE_FACTORY = AddV1Attribute def forward(self, input_0, input_1): + """AddV1's forward function.""" broadcast = self.attrs.auto_broadcast if broadcast == "none": assert input_0.shape == input_1.shape return input_0 + input_1 - elif broadcast == "numpy": + if broadcast == "numpy": return input_0 + input_1 - else: - raise NotImplementedError + raise NotImplementedError @dataclass class SubtractV1Attribute(Attribute): + """SubtractV1Attribute class.""" + auto_broadcast: str = field(default="numpy") @OPS.register() class SubtractV1(Operation[SubtractV1Attribute]): + """SubtractV1 class.""" + TYPE = "Subtract" VERSION = 1 ATTRIBUTE_FACTORY = SubtractV1Attribute def forward(self, input_0, input_1): + """SubtractV1's forward function.""" broadcast = self.attrs.auto_broadcast if broadcast == "none": assert input_0.shape == input_1.shape return input_0 - input_1 - elif broadcast == "numpy": + if broadcast == "numpy": return input_0 - input_1 - else: - raise NotImplementedError + raise NotImplementedError @dataclass class TanV0Attribute(Attribute): - pass + """TanV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class TanV0(Operation[TanV0Attribute]): + """TanV0 class.""" + TYPE = "Tan" VERSION = 0 ATTRIBUTE_FACTORY = TanV0Attribute - def forward(self, input): - return torch.tan(input) + def forward(self, inputs): + """TanV0's forward function.""" + return torch.tan(inputs) diff --git a/otx/mpa/modules/ov/ops/builder.py b/otx/core/ov/ops/builder.py similarity index 50% rename from otx/mpa/modules/ov/ops/builder.py rename to otx/core/ov/ops/builder.py index 2b97d93ae5d..ace6e17777f 100644 --- a/otx/mpa/modules/ov/ops/builder.py +++ b/otx/core/ov/ops/builder.py @@ -1,18 +1,23 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""OPS (OperationRegistry) module for otx.core.ov.ops.builder.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from typing import Any, Optional -from ..registry import Registry +from otx.core.ov.registry import Registry class OperationRegistry(Registry): + """OperationRegistry class.""" + def __init__(self, name, add_name_as_attr=False): super().__init__(name, add_name_as_attr) self._registry_dict_by_type = {} def register(self, name: Optional[Any] = None): + """Register function from name.""" + def wrap(obj): layer_name = name if layer_name is None: @@ -27,23 +32,26 @@ def wrap(obj): return wrap - def _register(self, obj, name, type, version): + def _register(self, obj, name, types, version): + """Register function from obj and obj name.""" super()._register(obj, name) - if type not in self._registry_dict_by_type: - self._registry_dict_by_type[type] = {} - if version in self._registry_dict_by_type[type]: - raise KeyError(f"{version} is already registered in {type}") - self._registry_dict_by_type[type][version] = obj + if types not in self._registry_dict_by_type: + self._registry_dict_by_type[types] = {} + if version in self._registry_dict_by_type[types]: + raise KeyError(f"{version} is already registered in {types}") + self._registry_dict_by_type[types][version] = obj def get_by_name(self, name): + """Get obj from name.""" return self.get(name) - def get_by_type_version(self, type, version): - if type not in self._registry_dict_by_type: - raise KeyError(f"type {type} is not registered in {self._name}") - if version not in self._registry_dict_by_type[type]: - raise KeyError(f"version {version} is not registered in {type} of {self._name}") - return self._registry_dict_by_type[type][version] + def get_by_type_version(self, types, version): + """Get obj from type and version.""" + if types not in self._registry_dict_by_type: + raise KeyError(f"type {types} is not registered in {self._name}") + if version not in self._registry_dict_by_type[types]: + raise KeyError(f"version {version} is not registered in {types} of {self._name}") + return self._registry_dict_by_type[types][version] OPS = OperationRegistry("ov ops") diff --git a/otx/mpa/modules/ov/ops/convolutions.py b/otx/core/ov/ops/convolutions.py similarity index 74% rename from otx/mpa/modules/ov/ops/convolutions.py rename to otx/core/ov/ops/convolutions.py index 20d21dcc2f2..a551be27810 100644 --- a/otx/mpa/modules/ov/ops/convolutions.py +++ b/otx/core/ov/ops/convolutions.py @@ -1,20 +1,22 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Convolutions-related module for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass, field from typing import Callable, List -import torch from torch.nn import functional as F -from .builder import OPS -from .op import Attribute, Operation -from .utils import get_torch_padding +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.movements import get_torch_padding +from otx.core.ov.ops.op import Attribute, Operation @dataclass class ConvolutionV1Attribute(Attribute): + """ConvolutionV1Attribute class.""" + strides: List[int] pads_begin: List[int] pads_end: List[int] @@ -22,6 +24,7 @@ class ConvolutionV1Attribute(Attribute): auto_pad: str = field(default="explicit") def __post_init__(self): + """ConvolutionV1Attribute's post-init function.""" super().__post_init__() valid_auto_pad = ["explicit", "same_upper", "same_Lower", "valid"] if self.auto_pad not in valid_auto_pad: @@ -30,11 +33,14 @@ def __post_init__(self): @OPS.register() class ConvolutionV1(Operation[ConvolutionV1Attribute]): + """ConvolutionV1 class.""" + TYPE = "Convolution" VERSION = 1 ATTRIBUTE_FACTORY = ConvolutionV1Attribute - def forward(self, input, weight): + def forward(self, inputs, weight): + """ConvolutionV1's forward function.""" if weight.dim() == 3: func = F.conv1d elif weight.dim() == 4: @@ -48,17 +54,17 @@ def forward(self, input, weight): self.attrs.pads_begin, self.attrs.pads_end, self.attrs.auto_pad, - list(input.shape[2:]), + list(inputs.shape[2:]), list(weight.shape[2:]), self.attrs.strides, self.attrs.dilations, ) if isinstance(padding, Callable): - input = padding(input=input) + inputs = padding(input=inputs) padding = 0 return func( - input=input, + input=inputs, weight=weight, bias=None, stride=self.attrs.strides, @@ -69,16 +75,21 @@ def forward(self, input, weight): @dataclass class GroupConvolutionV1Attribute(ConvolutionV1Attribute): - pass + """GroupConvolutionV1Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class GroupConvolutionV1(Operation[GroupConvolutionV1Attribute]): + """GroupConvolutionV1 class.""" + TYPE = "GroupConvolution" VERSION = 1 ATTRIBUTE_FACTORY = GroupConvolutionV1Attribute - def forward(self, input, weight): + def forward(self, inputs, weight): + """GroupConvolutionV1's forward function.""" if weight.dim() == 4: func = F.conv1d elif weight.dim() == 5: @@ -96,17 +107,17 @@ def forward(self, input, weight): self.attrs.pads_begin, self.attrs.pads_end, self.attrs.auto_pad, - list(input.shape[2:]), + list(inputs.shape[2:]), list(weight.shape[2:]), self.attrs.strides, self.attrs.dilations, ) if isinstance(padding, Callable): - input = padding(input=input) + inputs = padding(input=inputs) padding = 0 output = func( - input=input, + input=inputs, weight=weight, bias=None, stride=self.attrs.strides, diff --git a/otx/mpa/modules/ov/ops/generation.py b/otx/core/ov/ops/generation.py similarity index 59% rename from otx/mpa/modules/ov/ops/generation.py rename to otx/core/ov/ops/generation.py index 5785aa8aa80..395801923bd 100644 --- a/otx/mpa/modules/ov/ops/generation.py +++ b/otx/core/ov/ops/generation.py @@ -1,28 +1,34 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Generation-related module for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass import torch -from .builder import OPS -from .op import Attribute, Operation -from .type_conversions import _ov_to_torch +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation +from otx.core.ov.ops.type_conversions import _ov_to_torch @dataclass class RangeV4Attribute(Attribute): + """RangeV4Attribute class.""" + output_type: str @OPS.register() class RangeV4(Operation[RangeV4Attribute]): + """RangeV4 class.""" + TYPE = "Range" VERSION = 4 ATTRIBUTE_FACTORY = RangeV4Attribute def forward(self, start, stop, step): + """RangeV4's forward function.""" dtype = _ov_to_torch[self.attrs.output_type] return torch.arange( start=start, diff --git a/otx/mpa/modules/ov/ops/image_processings.py b/otx/core/ov/ops/image_processings.py similarity index 78% rename from otx/mpa/modules/ov/ops/image_processings.py rename to otx/core/ov/ops/image_processings.py index 8c54357a1a1..f6da2fc53d7 100644 --- a/otx/mpa/modules/ov/ops/image_processings.py +++ b/otx/core/ov/ops/image_processings.py @@ -1,6 +1,7 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Image Processings-related code for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass, field from typing import List @@ -12,9 +13,13 @@ from .movements import PadV1 from .op import Attribute, Operation +# pylint: disable=too-many-instance-attributes, too-many-branches + @dataclass class InterpolateV4Attribute(Attribute): + """InterpolateV4Attribute class.""" + mode: str shape_calculation_mode: str coordinate_transformation_mode: str = field(default="half_pixel") @@ -25,6 +30,7 @@ class InterpolateV4Attribute(Attribute): cube_coeff: float = field(default=-0.75) def __post_init__(self): + """InterpolateV4Attribute's post-init function.""" super().__post_init__() valid_mode = ["nearest", "linear", "linear_onnx", "cubic"] if self.mode not in valid_mode: @@ -60,6 +66,8 @@ def __post_init__(self): @OPS.register() class InterpolateV4(Operation[InterpolateV4Attribute]): + """InterpolateV4 class.""" + TYPE = "Interpolate" VERSION = 4 ATTRIBUTE_FACTORY = InterpolateV4Attribute @@ -68,7 +76,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pad = PadV1("tmp", shape=self.shape, pad_mode="constant") - def forward(self, input, sizes, scales, axes=None): + def forward(self, inputs, sizes, scales, axes=None): + """InterpolateV4's forward function.""" # TODO list # - handle 'linear_onnx' mode # - coordinate_transformation_mode @@ -77,14 +86,14 @@ def forward(self, input, sizes, scales, axes=None): # - antialias if axes is None: - axes = list(range(input.dim())) + axes = list(range(inputs.dim())) else: axes = axes.detach().cpu().tolist() - output = self.pad(input, self.attrs.pads_begin, self.attrs.pads_end, 0) + output = self.pad(inputs, self.attrs.pads_begin, self.attrs.pads_end, 0) mode = self.attrs.mode - if mode == "linear" or mode == "linear_onnx": + if mode in ("linear", "linear_onnx"): align_corners = False if output.dim() == 3: pass @@ -96,13 +105,13 @@ def forward(self, input, sizes, scales, axes=None): align_corners = False if output.dim() == 3: raise NotImplementedError - elif output.dim() == 4: + if output.dim() == 4: mode = "bicubic" elif output.dim() == 5: raise NotImplementedError elif mode == "nearest": align_corners = None - pass + pass # pylint: disable=unnecessary-pass else: raise NotImplementedError @@ -119,16 +128,15 @@ def forward(self, input, sizes, scales, axes=None): mode=mode, align_corners=align_corners, ) - else: - scales = scales.detach().cpu().numpy() - scales = scales[np.argsort(axes)].tolist() - if output.dim() == len(scales): - scales = scales[2:] - - return F.interpolate( - input=output, - size=None, - scale_factor=scales, - mode=mode, - align_corners=align_corners, - ) + scales = scales.detach().cpu().numpy() + scales = scales[np.argsort(axes)].tolist() + if output.dim() == len(scales): + scales = scales[2:] + + return F.interpolate( + input=output, + size=None, + scale_factor=scales, + mode=mode, + align_corners=align_corners, + ) diff --git a/otx/mpa/modules/ov/ops/infrastructures.py b/otx/core/ov/ops/infrastructures.py similarity index 83% rename from otx/mpa/modules/ov/ops/infrastructures.py rename to otx/core/ov/ops/infrastructures.py index 05d1caf0d3c..0a24882d8aa 100644 --- a/otx/mpa/modules/ov/ops/infrastructures.py +++ b/otx/core/ov/ops/infrastructures.py @@ -1,6 +1,7 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Infrastructure-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from collections import OrderedDict from dataclasses import dataclass, field @@ -9,14 +10,13 @@ import numpy as np import torch +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation +from otx.core.ov.ops.type_conversions import ConvertV0 +from otx.core.ov.ops.utils import get_dynamic_shape +from otx.mpa.modules.ov.utils import get_op_name from otx.mpa.utils.logger import get_logger -from ..utils import get_op_name -from .builder import OPS -from .op import Attribute, Operation -from .type_conversions import ConvertV0 -from .utils import get_dynamic_shape - logger = get_logger() @@ -36,6 +36,8 @@ @dataclass class ParameterV0Attribute(Attribute): + """ParameterV0Attribute class.""" + element_type: Optional[str] = field(default=None) layout: Optional[Tuple[str]] = field(default=None) @@ -43,6 +45,7 @@ class ParameterV0Attribute(Attribute): verify_shape: bool = field(default=True) def __post_init__(self): + """ParameterV0Attribute's post-init function.""" super().__post_init__() # fmt: off valid_element_type = [ @@ -57,29 +60,33 @@ def __post_init__(self): @OPS.register() class ParameterV0(Operation[ParameterV0Attribute]): + """ParameterV0 class.""" + TYPE = "Parameter" VERSION = 0 ATTRIBUTE_FACTORY = ParameterV0Attribute - def forward(self, input): + def forward(self, inputs): + """ParameterV0's forward function.""" # TODO: validate shape # need to handle new generated op from reshaped model if self.attrs.verify_shape: assert self.shape is not None ov_shape = self.shape[0] - torch_shape = list(input.shape) + torch_shape = list(inputs.shape) for ov_shape_, torch_shape_ in zip(ov_shape, torch_shape): if ov_shape_ == -1: continue assert ov_shape_ == torch_shape_, f"input shape {torch_shape} does not match with ov shape {ov_shape}" if self.attrs.permute: - input = input.permute(self.attrs.permute) + inputs = inputs.permute(self.attrs.permute) - return input + return inputs @classmethod def from_ov(cls, ov_op): + """ParameterV0's from_ov function.""" op_type = ov_op.get_type_name() op_version = ov_op.get_version() op_name = get_op_name(ov_op) @@ -122,8 +129,8 @@ def from_ov(cls, ov_op): new_shape = [] for shape in attrs["shape"]: new_shape.append([-1 if j == i else k for j, k in enumerate(shape)]) - new_shape = tuple(tuple(shape) for shape in new_shape) - attrs["shape"] = new_shape + new_shape = [tuple(shape) for shape in new_shape] + attrs["shape"] = tuple(new_shape) # change shape and layout based on permute if "permute" in attrs and attrs["permute"] != (0, 1, 2, 3): @@ -135,28 +142,35 @@ def from_ov(cls, ov_op): for shape in attrs["shape"]: new_shape.append([shape[i] for i in permute]) attrs["shape"] = tuple(tuple(shape) for shape in new_shape) - attrs["layout"] = tuple([attrs["layout"][i] for i in permute]) + attrs["layout"] = tuple(attrs["layout"][i] for i in permute) return cls(name=op_name, **attrs) @dataclass class ResultV0Attribute(Attribute): - pass + """ResultV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class ResultV0(Operation[ResultV0Attribute]): + """ResultV0 class.""" + TYPE = "Result" VERSION = 0 ATTRIBUTE_FACTORY = ResultV0Attribute - def forward(self, input): - return input + def forward(self, inputs): + """ResultV0's forward function.""" + return inputs @dataclass class ConstantV0Attribute(Attribute): + """ConstantV0Attribute class.""" + element_type: str offset: int = field(default=0) size: int = field(default=0) @@ -164,6 +178,7 @@ class ConstantV0Attribute(Attribute): is_parameter: bool = field(default=False) def __post_init__(self): + """ConstantV0Attribute's post-init function.""" super().__post_init__() # fmt: off valid_element_type = [ @@ -177,6 +192,8 @@ def __post_init__(self): @OPS.register() class ConstantV0(Operation[ConstantV0Attribute]): + """ConstantV0 class.""" + TYPE = "Constant" VERSION = 0 ATTRIBUTE_FACTORY = ConstantV0Attribute @@ -194,10 +211,12 @@ def __init__(self, *args, **kwargs): self.register_buffer("data", data) def forward(self): + """ConstantV0's forward function.""" return self.data @classmethod def from_ov(cls, ov_op): + """ConstantV0's from_ov function.""" op_type = ov_op.get_type_name() op_version = ov_op.get_version() op_name = get_op_name(ov_op) @@ -228,6 +247,7 @@ def from_ov(cls, ov_op): # FIXME: need a better way to distinghish if it is parameter or no is_parameter = False + # pylint: disable=too-many-boolean-expressions if ( set(op_node_types).intersection(NODE_TYPES_WITH_WEIGHT) and len(in_port_indices) == 1 diff --git a/otx/mpa/modules/ov/ops/matmuls.py b/otx/core/ov/ops/matmuls.py similarity index 68% rename from otx/mpa/modules/ov/ops/matmuls.py rename to otx/core/ov/ops/matmuls.py index bc5764e66e7..0bf2e3a173b 100644 --- a/otx/mpa/modules/ov/ops/matmuls.py +++ b/otx/core/ov/ops/matmuls.py @@ -1,28 +1,34 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""MatMul-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass, field import torch -from .builder import OPS -from .op import Attribute, Operation +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation @dataclass class MatMulV0Attribute(Attribute): + """MatMulV0Attribute class.""" + transpose_a: bool = field(default=False) transpose_b: bool = field(default=False) @OPS.register() class MatMulV0(Operation[MatMulV0Attribute]): + """MatMulV0 class.""" + TYPE = "MatMul" VERSION = 0 ATTRIBUTE_FACTORY = MatMulV0Attribute def forward(self, input_a, input_b): + """MatMulV0's forward function.""" if self.attrs.transpose_a: input_a = torch.transpose(input_a, -1, -2) if self.attrs.transpose_b: @@ -32,14 +38,19 @@ def forward(self, input_a, input_b): @dataclass class EinsumV7Attribute(Attribute): + """EinsumV7Attribute class.""" + equation: str @OPS.register() class EinsumV7(Operation[EinsumV7Attribute]): + """EinsumV7 class.""" + TYPE = "Einsum" VERSION = 7 ATTRIBUTE_FACTORY = EinsumV7Attribute def forward(self, *inputs): + """EinsumV7's forward function.""" return torch.einsum(self.attrs.equation, *inputs) diff --git a/otx/core/ov/ops/modules/__init__.py b/otx/core/ov/ops/modules/__init__.py new file mode 100644 index 00000000000..6f18c9c4a33 --- /dev/null +++ b/otx/core/ov/ops/modules/__init__.py @@ -0,0 +1,8 @@ +"""Module for otx.core.ov.pos.modules.""" +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from .op_module import OperationModule + +__all__ = ["OperationModule"] diff --git a/otx/mpa/modules/ov/ops/modules/op_module.py b/otx/core/ov/ops/modules/op_module.py similarity index 60% rename from otx/mpa/modules/ov/ops/modules/op_module.py rename to otx/core/ov/ops/modules/op_module.py index 1d4ad1ba8b3..00c50c62438 100644 --- a/otx/mpa/modules/ov/ops/modules/op_module.py +++ b/otx/core/ov/ops/modules/op_module.py @@ -1,6 +1,7 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Operation module for otx.core.ov.ops.modeuls.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT import inspect from typing import Dict, List, Optional, Union @@ -11,17 +12,19 @@ class OperationModule(torch.nn.Module): + """OperationModule class.""" + def __init__( self, - op: Operation, - dependent_ops: Union[List[Optional[Operation]], Dict[str, Optional[Operation]]], + op_v: Operation, + dependent_ops: Union[List[Operation], Dict[str, Optional[Operation]]], ): super().__init__() - self.op = op + self.op_v = op_v self._dependent_ops = torch.nn.ModuleDict() - spec = inspect.getfullargspec(op.forward) + spec = inspect.getfullargspec(op_v.forward) kwargs = spec.args[1:] self._dependents_with_defaults = [] @@ -30,8 +33,8 @@ def __init__( if isinstance(dependent_ops, list): assert len(dependent_ops) == len(kwargs) - for op, kwarg in zip(dependent_ops, kwargs): - self._dependent_ops[kwarg] = op + for op_, kwarg in zip(dependent_ops, kwargs): + self._dependent_ops[kwarg] = op_ elif isinstance(dependent_ops, dict): for kwarg in kwargs: self._dependent_ops[kwarg] = dependent_ops[kwarg] @@ -39,6 +42,7 @@ def __init__( raise NotImplementedError def forward(self, *args, **kwargs): + """Operationmodule's forward function.""" inputs = {k: v() if v is not None else None for k, v in self._dependent_ops.items()} if args: @@ -53,24 +57,29 @@ def forward(self, *args, **kwargs): assert all(v is not None for v in inputs.values() if v not in self._dependents_with_defaults) - return self.op(**inputs) + return self.op_v(**inputs) @property - def type(self): - return self.op.type + def type(self): # pylint: disable=invalid-overridden-method + """Operationmodule's type property.""" + return self.op_v.type @property def version(self): - return self.op.version + """Operationmodule's version property.""" + return self.op_v.version @property def name(self): - return self.op.name + """Operationmodule's name property.""" + return self.op_v.name @property def shape(self): - return self.op.shape + """Operationmodule's shape property.""" + return self.op_v.shape @property def attrs(self): - return self.op.attrs + """Operationmodule's attrs property.""" + return self.op_v.attrs diff --git a/otx/mpa/modules/ov/ops/movements.py b/otx/core/ov/ops/movements.py similarity index 58% rename from otx/mpa/modules/ov/ops/movements.py rename to otx/core/ov/ops/movements.py index 2faf812b4d4..e529214e992 100644 --- a/otx/mpa/modules/ov/ops/movements.py +++ b/otx/core/ov/ops/movements.py @@ -1,9 +1,11 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Movement-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT import math from dataclasses import dataclass, field +from functools import partial from typing import List import torch @@ -12,12 +14,17 @@ from .builder import OPS from .op import Attribute, Operation +# pylint: disable=too-many-branches + @dataclass class PadV1Attribute(Attribute): + """PadV1Attribute class.""" + pad_mode: str def __post_init__(self): + """PadV1Attribute's post-init function.""" super().__post_init__() valid_pad_mode = ["constant", "edge", "reflect", "symmetric"] if self.pad_mode not in valid_pad_mode: @@ -26,6 +33,8 @@ def __post_init__(self): @OPS.register() class PadV1(Operation[PadV1Attribute]): + """PadV1 class.""" + TYPE = "Pad" VERSION = 1 ATTRIBUTE_FACTORY = PadV1Attribute @@ -36,76 +45,91 @@ def __init__(self, *args, **kwargs): @staticmethod def get_torch_pad_mode(pad_mode): + """PadV1's get_torch_pad_mode function.""" if pad_mode == "constant": return "constant" - elif pad_mode == "edge": + if pad_mode == "edge": return "replicate" - elif pad_mode == "reflect": + if pad_mode == "reflect": return "reflect" - elif pad_mode == "symmetric": - raise NotImplementedError - else: - raise NotImplementedError + raise NotImplementedError @staticmethod def get_torch_pad_dim(pads_begin, pads_end): + """PadV1's get_torch_pad_dim function.""" # reverse padding return [val for tup in zip(pads_begin[::-1], pads_end[::-1]) for val in tup] - def forward(self, input, pads_begin, pads_end, pad_value=0): + def forward(self, inputs, pads_begin, pads_end, pad_value=0): + """PadV1's forward function.""" pads_begin = pads_begin if isinstance(pads_begin, list) else pads_begin.detach().cpu().tolist() pads_end = pads_end if isinstance(pads_end, list) else pads_end.detach().cpu().tolist() pad = self.get_torch_pad_dim(pads_begin, pads_end) pad = list(map(math.ceil, pad)) - return F.pad(input=input, pad=pad, mode=self._pad_mode, value=pad_value) + return F.pad(input=inputs, pad=pad, mode=self._pad_mode, value=pad_value) @dataclass class ConcatV0Attribute(Attribute): + """ConcatV0Attribute class.""" + axis: int @OPS.register() class ConcatV0(Operation[ConcatV0Attribute]): + """ConcatV0 class.""" + TYPE = "Concat" VERSION = 0 ATTRIBUTE_FACTORY = ConcatV0Attribute def forward(self, *inputs): + """ConcatV0's forward function.""" return torch.cat(inputs, self.attrs.axis) @dataclass class TransposeV1Attribute(Attribute): - pass + """TransposeV1Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class TransposeV1(Operation[TransposeV1Attribute]): + """TransposeV1 class.""" + TYPE = "Transpose" VERSION = 1 ATTRIBUTE_FACTORY = TransposeV1Attribute - def forward(self, input, order): + def forward(self, inputs, order): + """TransposeV1's forward function.""" if order.numel() == 0: - order = list(range(input.dim()))[::-1] + order = list(range(inputs.dim()))[::-1] elif isinstance(order, torch.Tensor): order = order.detach().cpu().tolist() - return input.permute(order) + return inputs.permute(order) @dataclass class GatherV0Attribute(Attribute): + """GatherV0Attribute class.""" + batch_dims: int = field(default=0) @OPS.register() class GatherV0(Operation[GatherV0Attribute]): + """GatherV0 class.""" + TYPE = "Gather" VERSION = 0 ATTRIBUTE_FACTORY = GatherV0Attribute - def forward(self, input, indices, axis): + def forward(self, inputs, indices, axis): + """GatherV0's forward function.""" assert axis.numel() == 1 axis = axis.squeeze() squeeze_axis = indices.dim() == 0 @@ -119,22 +143,22 @@ def forward(self, input, indices, axis): indices = indices.reshape(*indices_shape[:batch_dims], -1) indices_shape = indices_shape[batch_dims:] - if indices.dim() != input.dim(): + if indices.dim() != inputs.dim(): if indices.dim() != 0: while indices.dim() - 1 < axis: indices = indices.unsqueeze(batch_dims) - while indices.dim() < input.dim(): + while indices.dim() < inputs.dim(): indices = indices.unsqueeze(-1) repeat = [] - for i, (j, k) in enumerate(zip(input.shape, indices.shape)): + for i, (j, k) in enumerate(zip(inputs.shape, indices.shape)): if i == axis: repeat.append(1) else: assert j % k == 0 repeat.append(j // k) indices = indices.repeat(repeat) - output = torch.gather(input=input, dim=axis, index=indices.type(torch.int64)) + output = torch.gather(input=inputs, dim=axis, index=indices.type(torch.int64)) if squeeze_axis: output = output.squeeze(axis) @@ -144,21 +168,28 @@ def forward(self, input, indices, axis): @dataclass class GatherV1Attribute(Attribute): - pass + """GatherV1Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class GatherV1(Operation[GatherV1Attribute]): + """GatherV1 class.""" + TYPE = "Gather" VERSION = 1 ATTRIBUTE_FACTORY = GatherV1Attribute - def forward(self, input, indices, axis): - return torch.gather(input=input, dim=axis, index=indices) + def forward(self, inputs, indices, axis): + """GatherV1's forward function.""" + return torch.gather(input=inputs, dim=axis, index=indices) @dataclass class StridedSliceV1Attribute(Attribute): + """StridedSliceV1Attribute class.""" + begin_mask: List[int] end_mask: List[int] new_axis_mask: List[int] = field(default_factory=lambda: [0]) @@ -168,11 +199,14 @@ class StridedSliceV1Attribute(Attribute): @OPS.register() class StridedSliceV1(Operation[StridedSliceV1Attribute]): + """StridedSliceV1 class.""" + TYPE = "StridedSlice" VERSION = 1 ATTRIBUTE_FACTORY = StridedSliceV1Attribute - def forward(self, input, begin, end, stride=None): + def forward(self, inputs, begin, end, stride=None): + """StridedSliceV1's forward function.""" if sum(self.attrs.ellipsis_mask) > 0: raise NotImplementedError @@ -181,28 +215,28 @@ def forward(self, input, begin, end, stride=None): begin[i] = 0 for i, mask in enumerate(self.attrs.end_mask): if mask == 1: - end[i] = input.size(i) + end[i] = inputs.size(i) if stride is None: stride = torch.tensor([1 for _ in begin], dtype=begin.dtype) - output = input - for i, (b, e, s) in enumerate(zip(begin, end, stride)): - length = input.size(i) + output = inputs + for i, (b, e, stride_0) in enumerate(zip(begin, end, stride)): + length = inputs.size(i) # begin index is inclusive b = torch.clamp(b, -length, length - 1) # end index is exclusive e = torch.clamp(e, -length - 1, length) - if s > 0: + if stride_0 > 0: b = b + length if b < 0 else b e = e + length if e < 0 else e - indices = torch.arange(b, e, s, device=input.device) + indices = torch.arange(b, e, stride_0, device=inputs.device) else: b = b - length if b >= 0 else b e = e - length if e >= 0 else e - indices = torch.arange(b, e, s, device=input.device) + indices = torch.arange(b, e, stride_0, device=inputs.device) indices += length output = torch.index_select(output, i, indices) @@ -224,46 +258,56 @@ def forward(self, input, begin, end, stride=None): @dataclass class SplitV1Attribute(Attribute): + """SplitV1Attribute class.""" + num_splits: int @OPS.register() class SplitV1(Operation[SplitV1Attribute]): + """SplitV1 class.""" + TYPE = "Split" VERSION = 1 ATTRIBUTE_FACTORY = SplitV1Attribute - def forward(self, input, axis): - split_size = input.shape[axis] // self.attrs.num_splits - return torch.split(tensor=input, split_size_or_sections=split_size, dim=axis) + def forward(self, inputs, axis): + """SplitV1's forward function.""" + split_size = inputs.shape[axis] // self.attrs.num_splits + return torch.split(tensor=inputs, split_size_or_sections=split_size, dim=axis) @dataclass class VariadicSplitV1Attribute(Attribute): - pass + """VariadicSplitV1Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class VariadicSplitV1(Operation[VariadicSplitV1Attribute]): + """VariadicSplitV1 class.""" + TYPE = "VariadicSplit" VERSION = 1 ATTRIBUTE_FACTORY = VariadicSplitV1Attribute - def forward(self, input, axis, split_lengths): + def forward(self, inputs, axis, split_lengths): + """VariadicSplitV1's forward function.""" idx = [i for i, j in enumerate(split_lengths) if j == -1] if idx: assert len(idx) == 1 idx = idx[0] - split_lengths[idx] = input.size(axis) - sum(split_lengths) - 1 - assert input.size(axis) == sum(split_lengths) + split_lengths[idx] = inputs.size(axis) - sum(split_lengths) - 1 + assert inputs.size(axis) == sum(split_lengths) outputs = [] start_idx = 0 for length in split_lengths: outputs.append( torch.index_select( - input, + inputs, axis, - torch.arange(start_idx, start_idx + length, device=input.device), + torch.arange(start_idx, start_idx + length, device=inputs.device), ) ) start_idx += length @@ -272,25 +316,30 @@ def forward(self, input, axis, split_lengths): @dataclass class ShuffleChannelsV0Attribute(Attribute): + """ShuffleChannelsV0Attribute class.""" + axis: int = field(default=1) group: int = field(default=1) @OPS.register() class ShuffleChannelsV0(Operation[ShuffleChannelsV0Attribute]): + """ShuffleChannelsV0 class.""" + TYPE = "ShuffleChannels" VERSION = 0 ATTRIBUTE_FACTORY = ShuffleChannelsV0Attribute - def forward(self, input): + def forward(self, inputs): + """ShuffleChannelsV0's forward function.""" # n, c, h, w = input.shape - assert input.dim() == 4 - origin_shape = input.shape - origin_dim = input.dim() + assert inputs.dim() == 4 + origin_shape = inputs.shape + origin_dim = inputs.dim() assert origin_shape[self.attrs.axis] % self.attrs.group == 0 axis = self.attrs.axis - axis = axis if axis >= 0 else axis + input.dim() + axis = axis if axis >= 0 else axis + inputs.dim() target_shape = [ 0, @@ -301,14 +350,14 @@ def forward(self, input): if axis == 0: target_shape[0] = 1 target_shape[-1] = math.prod([origin_shape[i] for i in range(axis + 1, origin_dim)]) - elif axis == input.dim() - 1: + elif axis == inputs.dim() - 1: target_shape[0] = math.prod([origin_shape[i] for i in range(0, axis)]) target_shape[-1] = 1 else: target_shape[0] = math.prod([origin_shape[i] for i in range(0, axis)]) target_shape[-1] = math.prod([origin_shape[i] for i in range(axis + 1, origin_dim)]) - output = input.reshape(target_shape) + output = inputs.reshape(target_shape) output = output.permute([0, 2, 1, 3]) output = output.reshape(origin_shape) return output @@ -316,9 +365,12 @@ def forward(self, input): @dataclass class BroadcastV3Attribute(Attribute): + """BroadcastV3Attribute class.""" + mode: str = field(default="numpy") def __post_init__(self): + """BroadcastV3Attribute's post-init function.""" super().__post_init__() valid_mode = ["numpy", "explicit", "bidirectional"] if self.mode not in valid_mode: @@ -327,40 +379,47 @@ def __post_init__(self): @OPS.register() class BroadcastV3(Operation[BroadcastV3Attribute]): + """BroadcastV3 class.""" + TYPE = "Broadcast" VERSION = 3 ATTRIBUTE_FACTORY = BroadcastV3Attribute - def forward(self, input, target_shape, axes_mapping=None): + def forward(self, inputs, target_shape, axes_mapping=None): + """BroadcastV3's forward function.""" if self.attrs.mode == "numpy": - return input.expand(*target_shape) + return inputs.expand(*target_shape) if self.attrs.mode == "bidirectional": - return torch.ones(*target_shape, device=input.device) * input - else: - assert axes_mapping is not None - prev = -1 - for axes in axes_mapping: + return torch.ones(*target_shape, device=inputs.device) * inputs + assert axes_mapping is not None + prev = -1 + for axes in axes_mapping: + prev += 1 + while axes - prev > 0: + inputs = inputs.unsqueeze(axes - 1) prev += 1 - while axes - prev > 0: - input = input.unsqueeze(axes - 1) - prev += 1 - while input.dim() < len(target_shape): - input = input.unsqueeze(-1) - return input.expand(*target_shape) + while inputs.dim() < len(target_shape): + inputs = inputs.unsqueeze(-1) + return inputs.expand(*target_shape) @dataclass class ScatterNDUpdateV3Attribute(Attribute): - pass + """ScatterNDUpdateV3Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class ScatterNDUpdateV3(Operation[ScatterNDUpdateV3Attribute]): + """ScatterNDUpdateV3 class.""" + TYPE = "ScatterNDUpdate" VERSION = 3 ATTRIBUTE_FACTORY = ScatterNDUpdateV3Attribute - def forward(self, input, indicies, updates): + def forward(self, inputs, indicies, updates): + """ScatterNDUpdateV3's forward function.""" # TODO: need to verify if updates.numel() == 1: raise NotImplementedError @@ -369,53 +428,90 @@ def forward(self, input, indicies, updates): last_dim = indicies.shape[-1] assert last_dim == 2 assert indicies[..., -2].sum() == 0 - input.shape[indicies.shape[-1] :] + inputs.shape[indicies.shape[-1] :] # pylint: disable=pointless-statement index = indicies[..., -1] - for i in input.shape[indicies.shape[-1] :]: + for i in inputs.shape[indicies.shape[-1] :]: index = index.unsqueeze(-1).tile((i,)) - output = torch.scatter(input, 1, index, updates) + output = torch.scatter(inputs, 1, index, updates) return output @dataclass class ScatterUpdateV3Attribute(Attribute): - pass + """ScatterUpdateV3Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class ScatterUpdateV3(Operation[ScatterUpdateV3Attribute]): + """ScatterUpdateV3 class.""" + TYPE = "ScatterUpdate" VERSION = 3 ATTRIBUTE_FACTORY = ScatterUpdateV3Attribute - def forward(self, input, indicies, updates, axis): + def forward(self, inputs, indicies, updates, axis): + """ScatterUpdateV3's forward function.""" # TODO: need to verify axis = axis.item() - if input.dtype != updates.dtype: - updates = updates.type(input.dtype) + if inputs.dtype != updates.dtype: + updates = updates.type(inputs.dtype) if indicies.dim() == 0: assert axis == 0 - output = input + output = inputs output[indicies] = updates - output = torch.scatter(input, axis, indicies, updates) + output = torch.scatter(inputs, axis, indicies, updates) return output @dataclass class TileV0Attribute(Attribute): - pass + """TileV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class TileV0(Operation[TileV0Attribute]): + """TileV0 class.""" + TYPE = "Tile" VERSION = 0 ATTRIBUTE_FACTORY = TileV0Attribute - def forward(self, input, repeats): - return torch.tile(input, repeats.tolist()) + def forward(self, inputs, repeats): + """TileV0's forward function.""" + return torch.tile(inputs, repeats.tolist()) + + +def get_torch_padding(pads_begin, pads_end, auto_pad, input_size, weight_size, stride, dilation=None): + """Getter function for torch padding.""" + if dilation is None: + dilation = [1 for _ in input_size] + + if auto_pad == "valid": + return 0 + if auto_pad in ("same_upper", "same_lower"): + assert len(set(dilation)) == 1 and dilation[0] == 1 + pads_begin = [] + pads_end = [] + for input_size_, weight_size_, stride_, _ in zip(input_size, weight_size, stride, dilation): + out_size = math.ceil(input_size_ / stride_) + padding_needed = max(0, (out_size - 1) * stride_ + weight_size_ - input_size_) + padding_lhs = int(padding_needed / 2) + padding_rhs = padding_needed - padding_lhs + + pads_begin.append(padding_lhs if auto_pad == "same_upper" else padding_rhs) + pads_end.append(padding_rhs if auto_pad == "same_upper" else padding_lhs) + pad = PadV1.get_torch_pad_dim(pads_begin, pads_end) + return partial(F.pad, pad=pad, mode="constant", value=0) + if auto_pad == "explicit": + pad = PadV1.get_torch_pad_dim(pads_begin, pads_end) + return partial(F.pad, pad=pad, mode="constant", value=0) + raise NotImplementedError diff --git a/otx/mpa/modules/ov/ops/normalizations.py b/otx/core/ov/ops/normalizations.py similarity index 73% rename from otx/mpa/modules/ov/ops/normalizations.py rename to otx/core/ov/ops/normalizations.py index 5d0df82d163..7098f3541f3 100644 --- a/otx/mpa/modules/ov/ops/normalizations.py +++ b/otx/core/ov/ops/normalizations.py @@ -1,25 +1,30 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Normalization-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass, field import torch from torch.nn import functional as F -from .builder import OPS -from .op import Attribute, Operation -from .poolings import AvgPoolV1 +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation +from otx.core.ov.ops.poolings import AvgPoolV1 @dataclass class BatchNormalizationV0Attribute(Attribute): + """BatchNormalizationV0Attribute class.""" + epsilon: float max_init_iter: int = field(default=2) @OPS.register() class BatchNormalizationV0(Operation[BatchNormalizationV0Attribute]): + """BatchNormalizationV0 class.""" + TYPE = "BatchNormInference" VERSION = 0 ATTRIBUTE_FACTORY = BatchNormalizationV0Attribute @@ -28,10 +33,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.register_buffer("_num_init_iter", torch.tensor(0)) - def forward(self, input, gamma, beta, mean, variance): + def forward(self, inputs, gamma, beta, mean, variance): + """BatchNormalizationV0's forward function.""" output = F.batch_norm( - input=input, + input=inputs, running_mean=mean, running_var=variance, weight=gamma, @@ -42,13 +48,13 @@ def forward(self, input, gamma, beta, mean, variance): ) if self.training and self._num_init_iter < self.attrs.max_init_iter: - n_dims = input.dim() - 2 + n_dims = inputs.dim() - 2 gamma = gamma.unsqueeze(0) beta = beta.unsqueeze(0) for _ in range(n_dims): gamma = gamma.unsqueeze(-1) beta = beta.unsqueeze(-1) - output = input * gamma + beta + output = inputs * gamma + beta self._num_init_iter += 1 if self._num_init_iter >= self.attrs.max_init_iter: # Adapt weight & bias using the first batch statistics @@ -61,6 +67,8 @@ def forward(self, input, gamma, beta, mean, variance): @dataclass class LocalResponseNormalizationV0Attribute(Attribute): + """LocalResponseNormalizationV0Attribute class.""" + alpha: float beta: float bias: float @@ -69,12 +77,15 @@ class LocalResponseNormalizationV0Attribute(Attribute): @OPS.register() class LocalResponseNormalizationV0(Operation[LocalResponseNormalizationV0Attribute]): + """LocalResponseNormalizationV0 class.""" + TYPE = "LRN" VERSION = 0 ATTRIBUTE_FACTORY = LocalResponseNormalizationV0Attribute - def forward(self, input, axes): - dim = input.dim() + def forward(self, inputs, axes): + """LocalResponseNormalizationV0's forward function.""" + dim = inputs.dim() axes = axes.detach().cpu().tolist() assert all(ax >= 1 for ax in axes) @@ -84,10 +95,10 @@ def forward(self, input, axes): stride = [1 for _ in range(dim - 1)] pads_begin = [0 for _ in range(dim - 1)] pads_end = [0 for _ in range(dim - 1)] - for ax in axes: - kernel[ax] = self.attrs.size - pads_begin[ax] = self.attrs.size // 2 - pads_end[ax] = (self.attrs.size - 1) // 2 + for axe in axes: + kernel[axe] = self.attrs.size + pads_begin[axe] = self.attrs.size // 2 + pads_end[axe] = (self.attrs.size - 1) // 2 avg_attrs = { "auto_pad": "explicit", @@ -100,20 +111,23 @@ def forward(self, input, axes): } avg_pool = AvgPoolV1("temp", **avg_attrs) - div = input.mul(input).unsqueeze(1) + div = inputs.mul(inputs).unsqueeze(1) div = avg_pool(div) div = div.squeeze(1) div = div.mul(self.attrs.alpha).add(self.attrs.bias).pow(self.attrs.beta) - output = input / div + output = inputs / div return output @dataclass class NormalizeL2V0Attribute(Attribute): + """NormalizeL2V0Attribute class.""" + eps: float eps_mode: str def __post_init__(self): + """NormalizeL2V0Attribute post-init function.""" super().__post_init__() valid_eps_mode = ["add", "max"] if self.eps_mode not in valid_eps_mode: @@ -122,11 +136,14 @@ def __post_init__(self): @OPS.register() class NormalizeL2V0(Operation[NormalizeL2V0Attribute]): + """NormalizeL2V0 class.""" + TYPE = "NormalizeL2" VERSION = 0 ATTRIBUTE_FACTORY = NormalizeL2V0Attribute - def forward(self, input, axes): + def forward(self, inputs, axes): + """NormalizeL2V0's forward function.""" eps = self.attrs.eps eps_mode = self.attrs.eps_mode @@ -136,7 +153,7 @@ def forward(self, input, axes): axes = [axes] # normalization layer convert to FP32 in FP16 training - input_float = input.float() + input_float = inputs.float() if axes: norm = input_float.pow(2).sum(axes, keepdim=True) else: @@ -147,16 +164,19 @@ def forward(self, input, axes): elif eps_mode == "max": norm = torch.clamp(norm, max=eps) - return (input_float / norm.sqrt()).type_as(input) + return (input_float / norm.sqrt()).type_as(inputs) @dataclass class MVNV6Attribute(Attribute): + """MVNV6Attribute class.""" + normalize_variance: bool eps: float eps_mode: str def __post_init__(self): + """MVNV6Attribute's post-init function.""" super().__post_init__() valid_eps_mode = ["INSIDE_SQRT", "OUTSIDE_SQRT"] if self.eps_mode not in valid_eps_mode: @@ -165,12 +185,15 @@ def __post_init__(self): @OPS.register() class MVNV6(Operation[MVNV6Attribute]): + """MVNV6 class.""" + TYPE = "MVN" VERSION = 6 ATTRIBUTE_FACTORY = MVNV6Attribute - def forward(self, input, axes): - output = input - input.mean(axes.tolist(), keepdim=True) + def forward(self, inputs, axes): + """MVNV6's forward function.""" + output = inputs - inputs.mean(axes.tolist(), keepdim=True) if self.attrs.normalize_variance: eps_mode = self.attrs.eps_mode eps = self.attrs.eps diff --git a/otx/mpa/modules/ov/ops/object_detections.py b/otx/core/ov/ops/object_detections.py similarity index 80% rename from otx/mpa/modules/ov/ops/object_detections.py rename to otx/core/ov/ops/object_detections.py index 70b4c89a9ca..5d2bfe495e3 100644 --- a/otx/mpa/modules/ov/ops/object_detections.py +++ b/otx/core/ov/ops/object_detections.py @@ -1,16 +1,21 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Object-detection-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass, field from typing import List, Optional -from .builder import OPS -from .op import Attribute, Operation +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation + +# pylint: disable=too-many-instance-attributes @dataclass class ProposalV4Attribute(Attribute): + """ProposalV4Attribute class.""" + base_size: int pre_nms_topn: int post_nms_topn: int @@ -27,6 +32,7 @@ class ProposalV4Attribute(Attribute): framework: str = field(default="") def __post_init__(self): + """ProposalV4Attribute's post-init function.""" super().__post_init__() valid_framework = ["", "tensorflow"] if self.framework not in valid_framework: @@ -35,32 +41,21 @@ def __post_init__(self): @OPS.register() class ProposalV4(Operation[ProposalV4Attribute]): + """ProposalV4 class.""" + TYPE = "Proposal" VERSION = 4 ATTRIBUTE_FACTORY = ProposalV4Attribute - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # from mmdet.core.anchor.anchor_generator import AnchorGenerator - # self._anchor_generator = AnchorGenerator( - # strides=[attrs["feat_stride"]], - # ratios=attrs["ratio"], - # scales=attrs["scale"], - # base_sizes=[attrs["base_size"]], - # ) - - # from torchvision.models.detection.anchor_utils import AnchorGenerator - # self._anchor_generator = AnchorGenerator( - # sizes=(self.attrs["base_size"],), - # aspect_ratios= - def forward(self, class_probs, bbox_deltas, image_shape): + """ProposalV4's forward function.""" raise NotImplementedError @dataclass class ROIPoolingV0Attribute(Attribute): + """ROIPoolingV0Attribute class.""" + pooled_h: int pooled_w: int spatial_scale: float @@ -68,6 +63,7 @@ class ROIPoolingV0Attribute(Attribute): output_size: List[int] = field(default_factory=lambda: []) def __post_init__(self): + """ROIPoolingV0Attribute's post-init function.""" super().__post_init__() valid_method = ["max", "bilinear"] if self.method not in valid_method: @@ -76,16 +72,21 @@ def __post_init__(self): @OPS.register() class ROIPoolingV0(Operation[ROIPoolingV0Attribute]): + """ROIPoolingV0 class.""" + TYPE = "ROIPooling" VERSION = 0 ATTRIBUTE_FACTORY = ROIPoolingV0Attribute - def forward(self, input, boxes): + def forward(self, inputs, boxes): + """ROIPoolingV0's forward function.""" raise NotImplementedError @dataclass class DetectionOutputV0Attribute(Attribute): + """DetectionOutputV0Attribute class.""" + keep_top_k: List[int] nms_threshold: float background_label_id: int = field(default=0) @@ -103,6 +104,7 @@ class DetectionOutputV0Attribute(Attribute): objectness_score: float = field(default=0) def __post_init__(self): + """DetectionOutputV0Attribute's post-init function.""" super().__post_init__() valid_code_type = [ "caffe.PriorBoxParameter.CORNER", @@ -114,16 +116,21 @@ def __post_init__(self): @OPS.register() class DetectionOutputV0(Operation[DetectionOutputV0Attribute]): + """DetectionOutputV0 class.""" + TYPE = "DetectionOutput" VERSION = 0 ATTRIBUTE_FACTORY = DetectionOutputV0Attribute def forward(self, loc_data, conf_data, prior_data, arm_conf_data=None, arm_loc_data=None): + """DetectionOutputV0's forward.""" raise NotImplementedError @dataclass class RegionYoloV0Attribute(Attribute): + """RegionYoloV0Attribute class.""" + axis: int coords: int classes: int @@ -136,16 +143,21 @@ class RegionYoloV0Attribute(Attribute): @OPS.register() class RegionYoloV0(Operation[RegionYoloV0Attribute]): + """RegionYoloV0 class.""" + TYPE = "RegionYolo" VERSION = 0 ATTRIBUTE_FACTORY = RegionYoloV0Attribute - def forward(self, input): + def forward(self, inputs): + """RegionYoloV0's forward function.""" raise NotImplementedError @dataclass class PriorBoxV0Attribute(Attribute): + """PriorBoxV0Attribute class.""" + offset: float min_size: List[float] = field(default_factory=lambda: []) max_size: List[float] = field(default_factory=lambda: []) @@ -162,16 +174,21 @@ class PriorBoxV0Attribute(Attribute): @OPS.register() class PriorBoxV0(Operation[PriorBoxV0Attribute]): + """PriorBoxV0 class.""" + TYPE = "PriorBox" VERSION = 0 ATTRIBUTE_FACTORY = PriorBoxV0Attribute def forward(self, output_size, image_size): + """PriorBoxV0's forward function.""" raise NotImplementedError @dataclass class PriorBoxClusteredV0Attribute(Attribute): + """PriorBoxClusteredV0Attribute class.""" + offset: float width: List[float] = field(default_factory=lambda: [1.0]) height: List[float] = field(default_factory=lambda: [1.0]) @@ -184,9 +201,12 @@ class PriorBoxClusteredV0Attribute(Attribute): @OPS.register() class PriorBoxClusteredV0(Operation[PriorBoxClusteredV0Attribute]): + """PriorBoxClusteredV0 class.""" + TYPE = "PriorBoxClustered" VERSION = 0 ATTRIBUTE_FACTORY = PriorBoxClusteredV0Attribute def forward(self, output_size, image_size): + """PriorBoxClusteredV0's forward function.""" raise NotImplementedError diff --git a/otx/mpa/modules/ov/ops/op.py b/otx/core/ov/ops/op.py similarity index 64% rename from otx/mpa/modules/ov/ops/op.py rename to otx/core/ov/ops/op.py index 3658338f831..a5e6950b039 100644 --- a/otx/mpa/modules/ov/ops/op.py +++ b/otx/core/ov/ops/op.py @@ -1,6 +1,7 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Operation-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT import re from dataclasses import dataclass, fields @@ -8,15 +9,19 @@ import torch -from ..utils import get_op_name +from otx.mpa.modules.ov.utils import get_op_name + from .utils import get_dynamic_shape @dataclass class Attribute: + """Attribute class.""" + shape: Optional[Union[Tuple[Tuple[int]], Tuple[int]]] def __post_init__(self): + """Attribute's post-init function.""" if self.shape is not None and not isinstance(self.shape, tuple): raise ValueError("shape must be a tuple of ints or a tuple of tuples of ints.") @@ -24,10 +29,12 @@ def __post_init__(self): _T = TypeVar("_T", bound=Attribute) -class Operation(torch.nn.Module, Generic[_T]): +class Operation(torch.nn.Module, Generic[_T]): # pylint: disable=abstract-method, invalid-overridden-method + """Operation class.""" + TYPE = "" VERSION = -1 - ATTRIBUTE_FACTORY: Type[_T] = Attribute + ATTRIBUTE_FACTORY: Type[Attribute] = Attribute def __init__(self, name: str, **kwargs): super().__init__() @@ -36,6 +43,7 @@ def __init__(self, name: str, **kwargs): @classmethod def from_ov(cls, ov_op): + """Operation's from_ov function.""" op_type = ov_op.get_type_name() op_version = ov_op.get_version() op_name = get_op_name(ov_op) @@ -54,38 +62,40 @@ def from_ov(cls, ov_op): return cls(name=op_name, **attrs) @property - def type(self) -> str: + def type(self) -> str: # pylint: disable=invalid-overridden-method + """Operation's type property.""" return self.TYPE @property def version(self) -> int: + """Operation's version property.""" return self.VERSION @property def name(self) -> str: + """Operation's name property.""" return self._name @property def attrs(self): + """Operation's attrs property.""" return self._attrs @property def shape(self) -> Optional[Union[Tuple[Tuple[int]], Tuple[int]]]: + """Operation's shape property.""" return self.attrs.shape - # shape = self.attrs.get("shape", None) - # if shape is not None and len(shape) == 1: - # shape = shape[0] - # return shape def __repr__(self): - repr = f"{self.__class__.__name__}(" - repr += f"name={self.name}, " + """Operation's __repr__ function.""" + repr_str = f"{self.__class__.__name__}(" + repr_str += f"name={self.name}, " for field in fields(self.attrs): key = field.name if key == "shape": continue value = getattr(self.attrs, key) - repr += f"{key}={value}, " - repr = re.sub(", $", "", repr) - repr += ")" - return repr + repr_str += f"{key}={value}, " + repr_str = re.sub(", $", "", repr_str) + repr_str += ")" + return repr_str diff --git a/otx/mpa/modules/ov/ops/poolings.py b/otx/core/ov/ops/poolings.py similarity index 79% rename from otx/mpa/modules/ov/ops/poolings.py rename to otx/core/ov/ops/poolings.py index bd9752fccf3..1431d194f12 100644 --- a/otx/mpa/modules/ov/ops/poolings.py +++ b/otx/core/ov/ops/poolings.py @@ -1,19 +1,24 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Pooling-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass, field from typing import Callable, List from torch.nn import functional as F -from .builder import OPS -from .op import Attribute, Operation -from .utils import get_torch_padding +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.movements import get_torch_padding +from otx.core.ov.ops.op import Attribute, Operation + +# pylint: disable=too-many-instance-attributes @dataclass class MaxPoolV0Attribute(Attribute): + """MaxPoolV0Attribute class.""" + strides: List[int] pads_begin: List[int] pads_end: List[int] @@ -25,6 +30,7 @@ class MaxPoolV0Attribute(Attribute): axis: int = field(default=0) def __post_init__(self): + """MaxPoolV0Attribute's post-init functions.""" super().__post_init__() valid_auto_pad = ["explicit", "same_upper", "same_Lower", "valid"] if self.auto_pad not in valid_auto_pad: @@ -50,17 +56,19 @@ def __post_init__(self): @OPS.register() class MaxPoolV0(Operation[MaxPoolV0Attribute]): + """MaxPoolV0 class.""" + TYPE = "MaxPool" VERSION = 0 ATTRIBUTE_FACTORY = MaxPoolV0Attribute - def forward(self, input): - - if input.dim() == 3: + def forward(self, inputs): + """MaxPoolV0's forward function.""" + if inputs.dim() == 3: func = F.max_pool1d - elif input.dim() == 4: + elif inputs.dim() == 4: func = F.max_pool2d - elif input.dim() == 5: + elif inputs.dim() == 5: func = F.max_pool3d else: raise NotImplementedError @@ -69,16 +77,16 @@ def forward(self, input): self.attrs.pads_begin, self.attrs.pads_end, self.attrs.auto_pad, - list(input.shape[2:]), + list(inputs.shape[2:]), self.attrs.kernel, self.attrs.strides, ) if isinstance(padding, Callable): - input = padding(input=input) + inputs = padding(input=inputs) padding = 0 return func( - input=input, + input=inputs, kernel_size=self.attrs.kernel, stride=self.attrs.strides, padding=padding, @@ -90,6 +98,8 @@ def forward(self, input): @dataclass class AvgPoolV1Attribute(Attribute): + """AvgPoolV1Attribute class.""" + exclude_pad: bool strides: List[int] pads_begin: List[int] @@ -99,6 +109,7 @@ class AvgPoolV1Attribute(Attribute): auto_pad: str = field(default="explicit") def __post_init__(self): + """AvgPoolV1Attribute's post-init function.""" super().__post_init__() valid_auto_pad = ["explicit", "same_upper", "same_Lower", "valid"] if self.auto_pad not in valid_auto_pad: @@ -112,6 +123,8 @@ def __post_init__(self): @OPS.register() class AvgPoolV1(Operation[AvgPoolV1Attribute]): + """AvgPoolV1 class.""" + TYPE = "AvgPool" VERSION = 1 ATTRIBUTE_FACTORY = AvgPoolV1Attribute @@ -121,12 +134,13 @@ def __init__(self, *args, **kwargs): kwargs["exclude_pad"] = kwargs.pop("exclude-pad") super().__init__(*args, **kwargs) - def forward(self, input): - if input.dim() == 3: + def forward(self, inputs): + """AvgPoolV1's forward function.""" + if inputs.dim() == 3: func = F.avg_pool1d - elif input.dim() == 4: + elif inputs.dim() == 4: func = F.avg_pool2d - elif input.dim() == 5: + elif inputs.dim() == 5: func = F.avg_pool3d else: raise NotImplementedError @@ -135,16 +149,16 @@ def forward(self, input): self.attrs.pads_begin, self.attrs.pads_end, self.attrs.auto_pad, - list(input.shape[2:]), + list(inputs.shape[2:]), self.attrs.kernel, self.attrs.strides, ) if isinstance(padding, Callable): - input = padding(input=input) + inputs = padding(input=inputs) padding = 0 return func( - input=input, + input=inputs, kernel_size=self.attrs.kernel, stride=self.attrs.strides, padding=padding, diff --git a/otx/mpa/modules/ov/ops/reductions.py b/otx/core/ov/ops/reductions.py similarity index 58% rename from otx/mpa/modules/ov/ops/reductions.py rename to otx/core/ov/ops/reductions.py index ffffe451370..130d767c0a7 100644 --- a/otx/mpa/modules/ov/ops/reductions.py +++ b/otx/core/ov/ops/reductions.py @@ -1,61 +1,72 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Redunction-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass, field import torch -from .builder import OPS -from .op import Attribute, Operation +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation @dataclass class ReduceMeanV1Attribute(Attribute): + """ReduceMeanV1Attribute class.""" + keep_dims: bool = field(default=False) @OPS.register() class ReduceMeanV1(Operation[ReduceMeanV1Attribute]): + """ReduceMeanV1 class.""" + TYPE = "ReduceMean" VERSION = 1 ATTRIBUTE_FACTORY = ReduceMeanV1Attribute - def forward(self, input, axes): + def forward(self, inputs, axes): + """ReduceMeanV1's forward function.""" if isinstance(axes, torch.Tensor): axes = axes.tolist() if not axes: - return input + return inputs if not isinstance(axes, (list, tuple)): axes = [axes] - return torch.mean(input=input, dim=axes, keepdim=self.attrs.keep_dims) + return torch.mean(input=inputs, dim=axes, keepdim=self.attrs.keep_dims) @dataclass class ReduceProdV1Attribute(Attribute): + """ReduceMeanV1Attribute class.""" + keep_dims: bool = field(default=False) @OPS.register() class ReduceProdV1(Operation[ReduceProdV1Attribute]): + """ReduceMeanV1Attribute class.""" + TYPE = "ReduceProd" VERSION = 1 ATTRIBUTE_FACTORY = ReduceProdV1Attribute - def forward(self, input, axes): + def forward(self, inputs, axes): + """ReduceMeanV1Attribute's forward function.""" if isinstance(axes, torch.Tensor): axes = axes.tolist() if not axes: - return input + return inputs if not isinstance(axes, (list, tuple)): axes = [axes] - output = input - for ax in axes: - output = torch.prod(input=output, dim=ax, keepdim=True) + output = inputs + for axe in axes: + output = torch.prod(input=output, dim=axe, keepdim=True) if not self.attrs.keep_dims: output = torch.squeeze(output) @@ -64,27 +75,32 @@ def forward(self, input, axes): @dataclass class ReduceMinV1Attribute(Attribute): + """ReduceMinV1Attribute class.""" + keep_dims: bool = field(default=False) @OPS.register() class ReduceMinV1(Operation[ReduceMinV1Attribute]): + """ReduceMinV1 class.""" + TYPE = "ReduceMin" VERSION = 1 ATTRIBUTE_FACTORY = ReduceMinV1Attribute - def forward(self, input, axes): + def forward(self, inputs, axes): + """ReduceMinV1's forward function.""" if isinstance(axes, torch.Tensor): axes = axes.tolist() if not axes: - return input + return inputs if not isinstance(axes, (list, tuple)): axes = [axes] - output = input - for ax in axes: - output = torch.min(input=output, dim=ax, keepdim=True)[0] + output = inputs + for axe in axes: + output = torch.min(input=output, dim=axe, keepdim=True)[0] if not self.attrs.keep_dims: output = torch.squeeze(output) @@ -93,19 +109,24 @@ def forward(self, input, axes): @dataclass class ReduceSumV1Attribute(Attribute): + """ReduceSumV1Attribute class.""" + keep_dims: bool = field(default=False) @OPS.register() class ReduceSumV1(Operation[ReduceSumV1Attribute]): + """ReduceSumV1 class.""" + TYPE = "ReduceSum" VERSION = 1 ATTRIBUTE_FACTORY = ReduceSumV1Attribute - def forward(self, input, axes): + def forward(self, inputs, axes): + """ReduceSumV1's forward function.""" if isinstance(axes, torch.Tensor): axes = axes.tolist() if not axes: - return input + return inputs - return torch.sum(input=input, dim=axes, keepdim=self.attrs.keep_dims) + return torch.sum(input=inputs, dim=axes, keepdim=self.attrs.keep_dims) diff --git a/otx/mpa/modules/ov/ops/shape_manipulations.py b/otx/core/ov/ops/shape_manipulations.py similarity index 64% rename from otx/mpa/modules/ov/ops/shape_manipulations.py rename to otx/core/ov/ops/shape_manipulations.py index 3e14b204a2a..7eb9149225a 100644 --- a/otx/mpa/modules/ov/ops/shape_manipulations.py +++ b/otx/core/ov/ops/shape_manipulations.py @@ -1,41 +1,47 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Shape-mainpulation-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass, field import torch -from .builder import OPS -from .op import Attribute, Operation -from .type_conversions import ConvertV0 +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation +from otx.core.ov.ops.type_conversions import ConvertV0 @dataclass class SqueezeV0Attribute(Attribute): - pass + """SqueezeV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class SqueezeV0(Operation[SqueezeV0Attribute]): + """SqueezeV0 class.""" + TYPE = "Squeeze" VERSION = 0 ATTRIBUTE_FACTORY = SqueezeV0Attribute - def forward(self, input, dims=None): + def forward(self, inputs, dims=None): + """SqueezeV0's forward function.""" if dims is None: - return torch.squeeze(input) + return torch.squeeze(inputs) if dims.dim() == 0: dims = torch.unsqueeze(dims, 0) - max_dim = input.dim() + max_dim = inputs.dim() dims = dims.detach().cpu().tolist() for i, dim in enumerate(dims): if dim < 0: dims[i] = max_dim + dim - output = input + output = inputs for dim in sorted(dims, reverse=True): output = torch.squeeze(output, dim) @@ -44,28 +50,32 @@ def forward(self, input, dims=None): @dataclass class UnsqueezeV0Attribute(Attribute): - pass + """UnsqueezeV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class UnsqueezeV0(Operation[UnsqueezeV0Attribute]): + """UnsqueezeV0 class.""" + TYPE = "Unsqueeze" VERSION = 0 ATTRIBUTE_FACTORY = UnsqueezeV0Attribute - def forward(self, input, dims): - + def forward(self, inputs, dims): + """UnsqueezeV0's forward function.""" if dims.dim() == 0: dims = torch.unsqueeze(dims, 0) - max_dim = input.dim() + max_dim = inputs.dim() dims = dims.detach().cpu().tolist() if len(dims) > 1: for i, dim in enumerate(dims): if dim < 0: dims[i] = max_dim + dim - output = input + output = inputs for dim in sorted(dims, reverse=True): output = torch.unsqueeze(output, dim) @@ -74,18 +84,23 @@ def forward(self, input, dims): @dataclass class ReshapeV1Attribute(Attribute): + """ReshapeV1Attribute class.""" + special_zero: bool @OPS.register() class ReshapeV1(Operation[ReshapeV1Attribute]): + """ReshapeV1 class.""" + TYPE = "Reshape" VERSION = 1 ATTRIBUTE_FACTORY = ReshapeV1Attribute - def forward(self, input, shape): + def forward(self, inputs, shape): + """ReshapeV1's forward function.""" target_shape = shape.detach().cpu().tolist() - origin_shape = list(input.shape) + origin_shape = list(inputs.shape) for i, (origin_dim, target_dim) in enumerate(zip(origin_shape, target_shape)): if target_dim == 0 and self.attrs.special_zero: target_shape[i] = origin_dim @@ -96,29 +111,37 @@ def forward(self, input, shape): target_shape[i] = origin_dim elif target_dim == -1: break - return torch.reshape(input, target_shape) + return torch.reshape(inputs, target_shape) @dataclass class ShapeOfV0Attribute(Attribute): - pass + """ShapeOfV0Attribute class.""" + + pass # pylint: disable=unnecessary-pass @OPS.register() class ShapeOfV0(Operation[ShapeOfV0Attribute]): + """ShapeOfV0 class.""" + TYPE = "ShapeOf" VERSION = 0 ATTRIBUTE_FACTORY = ShapeOfV0Attribute - def forward(self, input): - return torch.tensor(input.shape, device=input.device) + def forward(self, inputs): + """ShapeOfV0's forward function.""" + return torch.tensor(inputs.shape, device=inputs.device) @dataclass class ShapeOfV3Attribute(Attribute): + """ShapeOfV3Attribute class.""" + output_type: str = field(default="i64") def __post_init__(self): + """ShapeOfV3Attribute's post-init function.""" super().__post_init__() valid_output_type = ["i64", "i32"] if self.output_type not in valid_output_type: @@ -127,11 +150,14 @@ def __post_init__(self): @OPS.register() class ShapeOfV3(Operation[ShapeOfV3Attribute]): + """ShapeOfV3 class.""" + TYPE = "ShapeOf" VERSION = 3 ATTRIBUTE_FACTORY = ShapeOfV3Attribute - def forward(self, input): + def forward(self, inputs): + """ShapeOfV3's forward function.""" return ConvertV0("temp", shape=self.shape, destination_type=self.attrs.output_type)( - torch.tensor(input.shape, device=input.device) + torch.tensor(inputs.shape, device=inputs.device) ) diff --git a/otx/mpa/modules/ov/ops/sorting_maximization.py b/otx/core/ov/ops/sorting_maximization.py similarity index 76% rename from otx/mpa/modules/ov/ops/sorting_maximization.py rename to otx/core/ov/ops/sorting_maximization.py index 53d2f880d72..b51a463b4ac 100644 --- a/otx/mpa/modules/ov/ops/sorting_maximization.py +++ b/otx/core/ov/ops/sorting_maximization.py @@ -1,21 +1,25 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Sorting-maximization-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass, field -from .builder import OPS -from .op import Attribute, Operation +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation @dataclass class TopKV3Attribute(Attribute): + """TopKV3Attribute class.""" + axis: int mode: str sort: str index_element_type: str = field(default="i32") def __post_init__(self): + """TopKV3Attribute's post-init function.""" super().__post_init__() valid_mode = ["min", "max"] if self.mode not in valid_mode: @@ -35,30 +39,21 @@ def __post_init__(self): @OPS.register() class TopKV3(Operation[TopKV3Attribute]): + """TopKV3 class.""" + TYPE = "TopK" VERSION = 3 ATTRIBUTE_FACTORY = TopKV3Attribute - def forward(self, input, k): + def forward(self, inputs, k): + """TopKV3's forward function.""" raise NotImplementedError - # values, indices = torch.topk( - # input=input, - # k=k, - # dim=self.attrs.axis, - # largest=self.attrs.mode == "max", - # sorted=True, - # ) - # - # if self.attrs.sort == "index": - # sorted = torch.argsort(indices) - # indices = indices[sorted] - # values = values[sorted] - # - # return values, indices @dataclass class NonMaxSuppressionV5Attribute(Attribute): + """NonMaxSuppressionV5Attribute class.""" + box_encoding: str = field(default="corner") sort_result_descending: bool = field(default=True) output_type: str = field(default="i64") @@ -66,6 +61,8 @@ class NonMaxSuppressionV5Attribute(Attribute): @OPS.register() class NonMaxSuppressionV5(Operation[NonMaxSuppressionV5Attribute]): + """NonMaxSuppressionV5 class.""" + TYPE = "NonMaxSuppression" VERSION = 5 ATTRIBUTE_FACTORY = NonMaxSuppressionV5Attribute @@ -79,11 +76,14 @@ def forward( score_threshold=0, soft_nms_sigma=0, ): + """NonMaxSuppressionV5's forward function.""" raise NotImplementedError @dataclass class NonMaxSuppressionV9Attribute(Attribute): + """NonMaxSuppressionV9Attribute class.""" + box_encoding: str = field(default="corner") sort_result_descending: bool = field(default=True) output_type: str = field(default="i64") @@ -91,6 +91,8 @@ class NonMaxSuppressionV9Attribute(Attribute): @OPS.register() class NonMaxSuppressionV9(Operation[NonMaxSuppressionV9Attribute]): + """NonMaxSuppressionV9 class.""" + TYPE = "NonMaxSuppression" VERSION = 9 ATTRIBUTE_FACTORY = NonMaxSuppressionV9Attribute @@ -104,4 +106,5 @@ def forward( score_threshold=0, soft_nms_sigma=0, ): + """NonMaxSuppressionV9's forward function.""" raise NotImplementedError diff --git a/otx/mpa/modules/ov/ops/type_conversions.py b/otx/core/ov/ops/type_conversions.py similarity index 70% rename from otx/mpa/modules/ov/ops/type_conversions.py rename to otx/core/ov/ops/type_conversions.py index f120739076a..9b81d0d1ded 100644 --- a/otx/mpa/modules/ov/ops/type_conversions.py +++ b/otx/core/ov/ops/type_conversions.py @@ -1,13 +1,14 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Type-conversion-related modules for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from dataclasses import dataclass import torch -from .builder import OPS -from .op import Attribute, Operation +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.op import Attribute, Operation _torch_to_ov = { torch.uint8: ["u1", "u4", "u8"], @@ -39,26 +40,33 @@ @dataclass class ConvertV0Attribute(Attribute): + """ConvertV0Attribute class.""" + destination_type: str @OPS.register() class ConvertV0(Operation[ConvertV0Attribute]): + """ConvertV0 class.""" + TYPE = "Convert" VERSION = 0 ATTRIBUTE_FACTORY = ConvertV0Attribute @staticmethod def convert_ov_type(ov_type): + """ConvertV0's convert_ov_type function.""" if ov_type not in _ov_to_torch: raise NotImplementedError return _ov_to_torch[ov_type] @staticmethod def convert_torch_type(torch_type): + """ConvertV0's convert_torch_type function.""" if torch_type not in _torch_to_ov: raise NotImplementedError return _torch_to_ov[torch_type][-1] - def forward(self, input): - return input.type(self.convert_ov_type(self.attrs.destination_type)) + def forward(self, inputs): + """ConvertV0's forward function.""" + return inputs.type(self.convert_ov_type(self.attrs.destination_type)) diff --git a/otx/core/ov/ops/utils.py b/otx/core/ov/ops/utils.py new file mode 100644 index 00000000000..c1d4c51ce7d --- /dev/null +++ b/otx/core/ov/ops/utils.py @@ -0,0 +1,16 @@ +"""Utils function for otx.core.ov.ops.""" +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + + +def get_dynamic_shape(output): + """Getter function for dynamic shape.""" + shape = [str(i) for i in output.get_partial_shape()] + for i, shape_ in enumerate(shape): + try: + shape_ = int(shape_) + except ValueError: + shape_ = -1 + shape[i] = shape_ + return shape diff --git a/otx/mpa/modules/ov/registry.py b/otx/core/ov/registry.py similarity index 67% rename from otx/mpa/modules/ov/registry.py rename to otx/core/ov/registry.py index a7e40ad45c4..2d790debe5b 100644 --- a/otx/mpa/modules/ov/registry.py +++ b/otx/core/ov/registry.py @@ -1,11 +1,14 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Registry Class for otx.core.ov.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from typing import Any, Dict, Optional class Registry: + """Registry Class for OMZ model.""" + REGISTERED_NAME_ATTR = "_registered_name" def __init__(self, name, add_name_as_attr=False): @@ -15,14 +18,18 @@ def __init__(self, name, add_name_as_attr=False): @property def registry_dict(self) -> Dict[Any, Any]: + """Dictionary of registered module.""" return self._registry_dict def _register(self, obj: Any, name: Any): + """Register obj with name.""" if name in self._registry_dict: - raise KeyError("{} is already registered in {}".format(name, self._name)) + raise KeyError(f"{name} is already registered in {self._name}") self._registry_dict[name] = obj def register(self, name: Optional[Any] = None): + """Register from name.""" + def wrap(obj): cls_name = name if cls_name is None: @@ -35,12 +42,15 @@ def wrap(obj): return wrap def get(self, key: Any) -> Any: + """Get from module name (key).""" if key not in self._registry_dict: self._key_not_found(key) return self._registry_dict[key] def _key_not_found(self, key: Any): - raise KeyError("{} is not found in {}".format(key, self._name)) + """Raise KeyError when key not founded.""" + raise KeyError(f"{key} is not found in {self._name}") def __contains__(self, item): + """Check containing of item.""" return item in self._registry_dict.values() diff --git a/otx/mpa/modules/__init__.py b/otx/mpa/modules/__init__.py index fd1fbacf7b3..db39b1f3d6d 100644 --- a/otx/mpa/modules/__init__.py +++ b/otx/mpa/modules/__init__.py @@ -4,9 +4,9 @@ # flake8: noqa -try: - import openvino -except ImportError: - pass -else: - from . import ov +# try: +# import openvino +# except ImportError: +# pass +# else: +# from . import ov diff --git a/otx/mpa/modules/ov/__init__.py b/otx/mpa/modules/ov/__init__.py index 0eadb4fe25d..aaeb84acfb5 100644 --- a/otx/mpa/modules/ov/__init__.py +++ b/otx/mpa/modules/ov/__init__.py @@ -3,6 +3,7 @@ # # flake8: noqa -from .graph import * -from .models import * -from .ops import * +# from .graph import * +# from .models import * + +# from otx.core.ov.ops import * diff --git a/otx/mpa/modules/ov/graph/graph.py b/otx/mpa/modules/ov/graph/graph.py index 5afa8d1aaab..ee83586fc0e 100644 --- a/otx/mpa/modules/ov/graph/graph.py +++ b/otx/mpa/modules/ov/graph/graph.py @@ -12,11 +12,10 @@ import networkx as nx from openvino.pyopenvino import Model +from otx.core.ov.ops.op import Operation +from otx.mpa.modules.ov.utils import convert_op_to_torch, get_op_name from otx.mpa.utils.logger import get_logger -from ..ops.op import Operation -from ..utils import convert_op_to_torch, get_op_name - logger = get_logger() diff --git a/otx/mpa/modules/ov/graph/parsers/__init__.py b/otx/mpa/modules/ov/graph/parsers/__init__.py index 37e7f3d50d2..c8f589f8ca2 100644 --- a/otx/mpa/modules/ov/graph/parsers/__init__.py +++ b/otx/mpa/modules/ov/graph/parsers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # -from . import cls +# from . import cls # flake8: noqa from .builder import PARSERS diff --git a/otx/mpa/modules/ov/graph/parsers/builder.py b/otx/mpa/modules/ov/graph/parsers/builder.py index 92eaf9c069b..b9cd8d60317 100644 --- a/otx/mpa/modules/ov/graph/parsers/builder.py +++ b/otx/mpa/modules/ov/graph/parsers/builder.py @@ -2,6 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # -from ...registry import Registry +from otx.core.ov.registry import Registry PARSERS = Registry("ov graph parsers") diff --git a/otx/mpa/modules/ov/graph/parsers/cls/__init__.py b/otx/mpa/modules/ov/graph/parsers/cls/__init__.py deleted file mode 100644 index e02b04e6302..00000000000 --- a/otx/mpa/modules/ov/graph/parsers/cls/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -# flake8: noqa -from .cls_base_parser import * diff --git a/otx/mpa/modules/ov/graph/utils.py b/otx/mpa/modules/ov/graph/utils.py index 990fbd0d5aa..5b7818906ec 100644 --- a/otx/mpa/modules/ov/graph/utils.py +++ b/otx/mpa/modules/ov/graph/utils.py @@ -6,10 +6,11 @@ import torch +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.infrastructures import ConstantV0 +from otx.core.ov.ops.op import Operation from otx.mpa.utils.logger import get_logger -from ..ops import OPS, Operation -from ..ops.infrastructures import ConstantV0 from .graph import Graph logger = get_logger() diff --git a/otx/mpa/modules/ov/models/mmcls/backbones/mmov_backbone.py b/otx/mpa/modules/ov/models/mmcls/backbones/mmov_backbone.py index 5c901229f37..22fee24e2a0 100644 --- a/otx/mpa/modules/ov/models/mmcls/backbones/mmov_backbone.py +++ b/otx/mpa/modules/ov/models/mmcls/backbones/mmov_backbone.py @@ -6,7 +6,8 @@ from mmcls.models.builder import BACKBONES -from ....graph.parsers.cls.cls_base_parser import cls_base_parser +from otx.core.ov.graph.parsers.cls import cls_base_parser + from ...mmov_model import MMOVModel diff --git a/otx/mpa/modules/ov/models/mmcls/heads/mmov_cls_head.py b/otx/mpa/modules/ov/models/mmcls/heads/mmov_cls_head.py index 804ee70babd..b0178635172 100644 --- a/otx/mpa/modules/ov/models/mmcls/heads/mmov_cls_head.py +++ b/otx/mpa/modules/ov/models/mmcls/heads/mmov_cls_head.py @@ -9,7 +9,8 @@ from mmcls.models.builder import HEADS from mmcls.models.heads import ClsHead -from ....graph.parsers.cls import cls_base_parser +from otx.core.ov.graph.parsers.cls import cls_base_parser + from ...mmov_model import MMOVModel diff --git a/otx/mpa/modules/ov/models/mmcls/necks/mmov_neck.py b/otx/mpa/modules/ov/models/mmcls/necks/mmov_neck.py index f37a6c2e699..e655ed5030f 100644 --- a/otx/mpa/modules/ov/models/mmcls/necks/mmov_neck.py +++ b/otx/mpa/modules/ov/models/mmcls/necks/mmov_neck.py @@ -6,7 +6,8 @@ from mmcls.models.builder import NECKS -from ....graph.parsers.cls.cls_base_parser import cls_base_parser +from otx.core.ov.graph.parsers.cls import cls_base_parser + from ...mmov_model import MMOVModel diff --git a/otx/mpa/modules/ov/models/ov_model.py b/otx/mpa/modules/ov/models/ov_model.py index f6a681418d6..31919029754 100644 --- a/otx/mpa/modules/ov/models/ov_model.py +++ b/otx/mpa/modules/ov/models/ov_model.py @@ -13,6 +13,8 @@ import torch from torch.nn import init +from otx.core.ov.ops.builder import OPS +from otx.mpa.modules.ov.utils import load_ov_model, normalize_name from otx.mpa.utils.logger import get_logger from ..graph import Graph @@ -21,8 +23,6 @@ handle_paired_batchnorm, handle_reshape, ) -from ..ops import OPS -from ..utils import load_ov_model, normalize_name logger = get_logger() @@ -102,7 +102,7 @@ def __init__( # internal init weight def init_weight(m, graph): - from ..ops.op import Operation + from .....core.ov.ops.op import Operation if not isinstance(m, Operation): return diff --git a/otx/mpa/modules/ov/ops/__init__.py b/otx/mpa/modules/ov/ops/__init__.py deleted file mode 100644 index 693b7dffa84..00000000000 --- a/otx/mpa/modules/ov/ops/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -from .activations import * -from .arithmetics import * -from .builder import OPS -from .convolutions import * -from .generation import * -from .image_processings import * -from .infrastructures import * -from .matmuls import * -from .movements import * -from .normalizations import * -from .object_detections import * - -# flake8: noqa -from .op import * -from .poolings import * -from .reductions import * -from .shape_manipulations import * -from .sorting_maximization import * -from .type_conversions import * diff --git a/otx/mpa/modules/ov/ops/activations.py b/otx/mpa/modules/ov/ops/activations.py deleted file mode 100644 index 6cf73192367..00000000000 --- a/otx/mpa/modules/ov/ops/activations.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import math -from dataclasses import dataclass, field - -import torch -from torch.nn import functional as F - -from .builder import OPS -from .op import Attribute, Operation - - -@dataclass -class SoftMaxV0Attribute(Attribute): - axis: int = field(default=1) - - -@OPS.register() -class SoftMaxV0(Operation[SoftMaxV0Attribute]): - TYPE = "Softmax" - VERSION = 0 - ATTRIBUTE_FACTORY = SoftMaxV0Attribute - - def forward(self, input): - return F.softmax(input=input, dim=self.attrs.axis) - - -@dataclass -class SoftMaxV1Attribute(Attribute): - axis: int = field(default=1) - - -@OPS.register() -class SoftMaxV1(Operation[SoftMaxV1Attribute]): - TYPE = "Softmax" - VERSION = 1 - ATTRIBUTE_FACTORY = SoftMaxV1Attribute - - def forward(self, input): - return F.softmax(input=input, dim=self.attrs.axis) - - -@dataclass -class ReluV0Attribute(Attribute): - pass - - -@OPS.register() -class ReluV0(Operation[ReluV0Attribute]): - TYPE = "Relu" - VERSION = 0 - ATTRIBUTE_FACTORY = ReluV0Attribute - - def forward(self, input): - return F.relu(input) - - -@dataclass -class SwishV4Attribute(Attribute): - pass - - -@OPS.register() -class SwishV4(Operation[SwishV4Attribute]): - TYPE = "Swish" - VERSION = 4 - ATTRIBUTE_FACTORY = SwishV4Attribute - - def forward(self, input, beta=1.0): - return input * torch.sigmoid(input * beta) - - -@dataclass -class SigmoidV0Attribute(Attribute): - pass - - -@OPS.register() -class SigmoidV0(Operation[SigmoidV0Attribute]): - TYPE = "Sigmoid" - VERSION = 0 - ATTRIBUTE_FACTORY = SigmoidV0Attribute - - def forward(self, input): - return torch.sigmoid(input) - - -@dataclass -class ClampV0Attribute(Attribute): - min: float - max: float - - -@OPS.register() -class ClampV0(Operation[ClampV0Attribute]): - TYPE = "Clamp" - VERSION = 0 - ATTRIBUTE_FACTORY = ClampV0Attribute - - def forward(self, input): - return input.clamp(min=self.attrs.min, max=self.attrs.max) - - -@dataclass -class PReluV0Attribute(Attribute): - pass - - -@OPS.register() -class PReluV0(Operation[PReluV0Attribute]): - TYPE = "PRelu" - VERSION = 0 - ATTRIBUTE_FACTORY = PReluV0Attribute - - def forward(self, input, slope): - return F.prelu(input=input, weight=slope) - - -@dataclass -class TanhV0Attribute(Attribute): - pass - - -@OPS.register() -class TanhV0(Operation[TanhV0Attribute]): - TYPE = "Tanh" - VERSION = 0 - ATTRIBUTE_FACTORY = TanhV0Attribute - - def forward(self, input): - return F.tanh(input) - - -@dataclass -class EluV0Attribute(Attribute): - alpha: float - - -@OPS.register() -class EluV0(Operation[EluV0Attribute]): - TYPE = "Elu" - VERSION = 0 - ATTRIBUTE_FACTORY = EluV0Attribute - - def forward(self, input): - return F.elu(input=input, alpha=self.attrs.alpha) - - -@dataclass -class SeluV0Attribute(Attribute): - pass - - -@OPS.register() -class SeluV0(Operation[SeluV0Attribute]): - TYPE = "Selu" - VERSION = 0 - ATTRIBUTE_FACTORY = SeluV0Attribute - - def forward(self, input, alpha, lambda_): - return lambda_ * F.elu(input=input, alpha=alpha) - - -@dataclass -class MishV4Attribute(Attribute): - pass - - -@OPS.register() -class MishV4(Operation[MishV4Attribute]): - TYPE = "Mish" - VERSION = 4 - ATTRIBUTE_FACTORY = MishV4Attribute - - def forward(self, input): - # NOTE: pytorch 1.8.2 does not have mish function - # return F.mish(input=input) - return input * F.tanh(F.softplus(input)) - - -@dataclass -class HSwishV4Attribute(Attribute): - pass - - -@OPS.register() -class HSwishV4(Operation[HSwishV4Attribute]): - TYPE = "HSwish" - VERSION = 4 - ATTRIBUTE_FACTORY = HSwishV4Attribute - - def forward(self, input): - return F.hardswish(input=input) - - -@dataclass -class HSigmoidV5Attribute(Attribute): - pass - - -@OPS.register() -class HSigmoidV5(Operation[HSigmoidV5Attribute]): - TYPE = "HSigmoid" - VERSION = 5 - ATTRIBUTE_FACTORY = HSigmoidV5Attribute - - def forward(self, input): - return F.hardsigmoid(input=input) - - -@dataclass -class ExpV0Attribute(Attribute): - pass - - -@OPS.register() -class ExpV0(Operation[ExpV0Attribute]): - TYPE = "Exp" - VERSION = 0 - ATTRIBUTE_FACTORY = ExpV0Attribute - - def forward(self, input): - return torch.exp(input) - - -@dataclass -class HardSigmoidV0Attribute(Attribute): - pass - - -@OPS.register() -class HardSigmoidV0(Operation[HardSigmoidV0Attribute]): - TYPE = "HardSigmoid" - VERSION = 0 - ATTRIBUTE_FACTORY = HardSigmoidV0Attribute - - def forward(self, input, alpha, beta): - return torch.maximum( - torch.zeros_like(input), - torch.minimum(torch.ones_like(input), input * alpha + beta), - ) - - -@dataclass -class GeluV7Attribute(Attribute): - approximation_mode: str = field(default="ERF") - - def __post_init__(self): - super().__post_init__() - valid_approximation_mode = ["ERF", "tanh"] - if self.approximation_mode not in valid_approximation_mode: - raise ValueError( - f"Invalid approximation_mode {self.approximation_mode}. " - f"It must be one of {valid_approximation_mode}." - ) - - -@OPS.register() -class GeluV7(Operation[GeluV7Attribute]): - TYPE = "Gelu" - VERSION = 7 - ATTRIBUTE_FACTORY = GeluV7Attribute - - def forward(self, input): - mode = self.attrs.approximation_mode - if mode == "ERF": - return F.gelu(input=input) - elif mode == "tanh": - return input * 0.5 * (1 + F.tanh(torch.sqrt(2 / torch.tensor(math.pi)) * (input + 0.044715 * input**3))) diff --git a/otx/mpa/modules/ov/ops/modules/__init__.py b/otx/mpa/modules/ov/ops/modules/__init__.py deleted file mode 100644 index 2b6e6bb7ef0..00000000000 --- a/otx/mpa/modules/ov/ops/modules/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -# flake8: noqa -from .op_module import OperationModule diff --git a/otx/mpa/modules/ov/ops/utils.py b/otx/mpa/modules/ov/ops/utils.py deleted file mode 100644 index 0877ec6409f..00000000000 --- a/otx/mpa/modules/ov/ops/utils.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import math -from functools import partial - -from torch.nn import functional as F - - -def get_dynamic_shape(op): - shape = [str(i) for i in op.get_partial_shape()] - for i, shape_ in enumerate(shape): - try: - shape_ = int(shape_) - except ValueError: - shape_ = -1 - shape[i] = shape_ - return shape - - -def get_torch_padding(pads_begin, pads_end, auto_pad, input_size, weight_size, stride, dilation=None): - from .movements import PadV1 - - if dilation is None: - dilation = [1 for _ in input_size] - - if auto_pad == "valid": - return 0 - elif auto_pad == "same_upper" or auto_pad == "same_lower": - assert len(set(dilation)) == 1 and dilation[0] == 1 - pads_begin = [] - pads_end = [] - for input_size_, weight_size_, stride_, dilation_ in zip(input_size, weight_size, stride, dilation): - out_size = math.ceil(input_size_ / stride_) - padding_needed = max(0, (out_size - 1) * stride_ + weight_size_ - input_size_) - padding_lhs = int(padding_needed / 2) - padding_rhs = padding_needed - padding_lhs - - pads_begin.append(padding_lhs if auto_pad == "same_upper" else padding_rhs) - pads_end.append(padding_rhs if auto_pad == "same_upper" else padding_lhs) - pad = PadV1.get_torch_pad_dim(pads_begin, pads_end) - return partial(F.pad, pad=pad, mode="constant", value=0) - elif auto_pad == "explicit": - pad = PadV1.get_torch_pad_dim(pads_begin, pads_end) - return partial(F.pad, pad=pad, mode="constant", value=0) - else: - raise NotImplementedError diff --git a/otx/mpa/modules/ov/utils.py b/otx/mpa/modules/ov/utils.py index 7dc487565be..f3b8d12043c 100644 --- a/otx/mpa/modules/ov/utils.py +++ b/otx/mpa/modules/ov/utils.py @@ -9,10 +9,9 @@ from openvino.pyopenvino import Model, Node from openvino.runtime import Core +from otx.core.ov.omz_wrapper import AVAILABLE_OMZ_MODELS, get_omz_model from otx.mpa.utils.logger import get_logger -from .omz_wrapper import AVAILABLE_OMZ_MODELS, get_omz_model - logger = get_logger() @@ -125,7 +124,7 @@ def get_op_name(op: Node) -> str: def convert_op_to_torch(op: Node): - from .ops import OPS + from otx.core.ov.ops.builder import OPS op_type = op.get_type_name() op_version = op.get_version() @@ -143,7 +142,7 @@ def convert_op_to_torch(op: Node): def convert_op_to_torch_module(target_op: Node): - from .ops.modules import OperationModule + from otx.core.ov.ops.modules.op_module import OperationModule dependent_modules = [] for in_port in target_op.inputs(): diff --git a/otx/mpa/utils/file.py b/otx/mpa/utils/file.py deleted file mode 100644 index c3fd47846b0..00000000000 --- a/otx/mpa/utils/file.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import os - -MPA_CACHE = os.path.expanduser(os.getenv("MPA_CACHE", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "mpa"))) -os.makedirs(MPA_CACHE, exist_ok=True) diff --git a/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_cls_parser.py b/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_cls_parser.py index 71eee579996..e13e4bd8df6 100644 --- a/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_cls_parser.py +++ b/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_cls_parser.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 # +from otx.core.ov.graph.parsers.cls import cls_base_parser from otx.mpa.modules.ov.graph.graph import Graph -from otx.mpa.modules.ov.graph.parsers.cls.cls_base_parser import cls_base_parser from otx.mpa.modules.ov.utils import load_ov_model from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_activations.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_activations.py index 553f002eafe..a2c6e8f08f7 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_activations.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_activations.py @@ -8,7 +8,7 @@ import torch from torch.nn import functional as F -from otx.mpa.modules.ov.ops.activations import ( +from otx.core.ov.ops.activations import ( ClampV0, EluV0, ExpV0, diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_arithmetics.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_arithmetics.py index 046d8f95f0f..e477286114b 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_arithmetics.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_arithmetics.py @@ -5,13 +5,7 @@ import pytest import torch -from otx.mpa.modules.ov.ops.arithmetics import ( - AddV1, - DivideV1, - MultiplyV1, - SubtractV1, - TanV0, -) +from otx.core.ov.ops.arithmetics import AddV1, DivideV1, MultiplyV1, SubtractV1, TanV0 from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_builder.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_builder.py index 84a1f8bf6f2..f175747fc9b 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_builder.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_builder.py @@ -6,8 +6,8 @@ import pytest -from otx.mpa.modules.ov.ops.builder import OperationRegistry -from otx.mpa.modules.ov.ops.op import Attribute, Operation +from otx.core.ov.ops.builder import OperationRegistry +from otx.core.ov.ops.op import Attribute, Operation from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_convolutions.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_convolutions.py index 3cd37523080..ec0a6c41500 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_convolutions.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_convolutions.py @@ -6,7 +6,7 @@ import torch from torch.nn import functional as F -from otx.mpa.modules.ov.ops.convolutions import ConvolutionV1, GroupConvolutionV1 +from otx.core.ov.ops.convolutions import ConvolutionV1, GroupConvolutionV1 from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_generation.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_generation.py index dc0919ceb2c..c8efe5d31de 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_generation.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_generation.py @@ -5,7 +5,7 @@ import pytest import torch -from otx.mpa.modules.ov.ops.generation import RangeV4 +from otx.core.ov.ops.generation import RangeV4 from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_image_processings.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_image_processings.py index dc1dfa321e5..1c86c54f691 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_image_processings.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_image_processings.py @@ -6,7 +6,7 @@ import torch from torch.nn import functional as F -from otx.mpa.modules.ov.ops.image_processings import InterpolateV4 +from otx.core.ov.ops.image_processings import InterpolateV4 from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_infrastructures.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_infrastructures.py index f688229d693..91df181ac81 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_infrastructures.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_infrastructures.py @@ -7,7 +7,7 @@ import pytest import torch -from otx.mpa.modules.ov.ops.infrastructures import ConstantV0, ParameterV0, ResultV0 +from otx.core.ov.ops.infrastructures import ConstantV0, ParameterV0, ResultV0 from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_matmuls.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_matmuls.py index e766472efca..e795aea7536 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_matmuls.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_matmuls.py @@ -5,7 +5,7 @@ import pytest import torch -from otx.mpa.modules.ov.ops.matmuls import EinsumV7, MatMulV0 +from otx.core.ov.ops.matmuls import EinsumV7, MatMulV0 from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_module.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_module.py index 158bc7e4c9d..56e3942ada8 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_module.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_module.py @@ -5,8 +5,8 @@ import pytest import torch -from otx.mpa.modules.ov.ops import OPS -from otx.mpa.modules.ov.ops.modules.op_module import OperationModule +from otx.core.ov.ops.builder import OPS +from otx.core.ov.ops.modules.op_module import OperationModule from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_movements.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_movements.py index 52333569da2..ec5b0faeb04 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_movements.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_movements.py @@ -5,7 +5,7 @@ import pytest import torch -from otx.mpa.modules.ov.ops.movements import ( +from otx.core.ov.ops.movements import ( BroadcastV3, ConcatV0, GatherV0, diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_normalizations.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_normalizations.py index cba839b572c..268b12aa7c9 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_normalizations.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_normalizations.py @@ -6,7 +6,7 @@ import torch from torch.nn import functional as F -from otx.mpa.modules.ov.ops.normalizations import ( +from otx.core.ov.ops.normalizations import ( MVNV6, BatchNormalizationV0, LocalResponseNormalizationV0, diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_object_detections.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_object_detections.py index 91aa2e674d8..76315d67d61 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_object_detections.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_object_detections.py @@ -4,7 +4,7 @@ import pytest -from otx.mpa.modules.ov.ops.object_detections import ( +from otx.core.ov.ops.object_detections import ( DetectionOutputV0, PriorBoxClusteredV0, PriorBoxV0, diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_op.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_op.py index 275f06ab391..d4e9557503c 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_op.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_op.py @@ -4,7 +4,7 @@ import openvino.runtime as ov -from otx.mpa.modules.ov.ops.arithmetics import MultiplyV1 +from otx.core.ov.ops.arithmetics import MultiplyV1 from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_poolings.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_poolings.py index 315f7e5d65c..192b6218cc5 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_poolings.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_poolings.py @@ -6,7 +6,7 @@ import torch from torch.nn import functional as F -from otx.mpa.modules.ov.ops.poolings import AvgPoolV1, MaxPoolV0 +from otx.core.ov.ops.poolings import AvgPoolV1, MaxPoolV0 from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_reductions.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_reductions.py index d6749aa6434..34419a695e0 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_reductions.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_reductions.py @@ -4,7 +4,7 @@ import torch -from otx.mpa.modules.ov.ops.reductions import ( +from otx.core.ov.ops.reductions import ( ReduceMeanV1, ReduceMinV1, ReduceProdV1, diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_shape_manipulations.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_shape_manipulations.py index 8777d4ee964..ec449751074 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_shape_manipulations.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_shape_manipulations.py @@ -5,7 +5,7 @@ import pytest import torch -from otx.mpa.modules.ov.ops.shape_manipulations import ( +from otx.core.ov.ops.shape_manipulations import ( ReshapeV1, ShapeOfV0, ShapeOfV3, diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_sorting_maximization.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_sorting_maximization.py index 7ccdd358b1f..5f0d0a65b56 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_sorting_maximization.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_sorting_maximization.py @@ -4,7 +4,7 @@ import pytest -from otx.mpa.modules.ov.ops.sorting_maximization import ( +from otx.core.ov.ops.sorting_maximization import ( NonMaxSuppressionV5, NonMaxSuppressionV9, TopKV3, diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_type_conversions.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_type_conversions.py index 6db5128ab91..9df3ef3de88 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_type_conversions.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_type_conversions.py @@ -5,7 +5,7 @@ import pytest import torch -from otx.mpa.modules.ov.ops.type_conversions import ConvertV0 +from otx.core.ov.ops.type_conversions import ConvertV0 from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_utils.py b/tests/unit/mpa/modules/ov/ops/test_ov_ops_utils.py index a35cbe42b14..2c3c9f2d068 100644 --- a/tests/unit/mpa/modules/ov/ops/test_ov_ops_utils.py +++ b/tests/unit/mpa/modules/ov/ops/test_ov_ops_utils.py @@ -6,7 +6,8 @@ import pytest import torch -from otx.mpa.modules.ov.ops.utils import get_dynamic_shape, get_torch_padding +from otx.core.ov.ops.movements import get_torch_padding +from otx.core.ov.ops.utils import get_dynamic_shape from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/test_ov_omz_wrapper.py b/tests/unit/mpa/modules/ov/test_ov_omz_wrapper.py index 345e19ee592..a04c9edce74 100644 --- a/tests/unit/mpa/modules/ov/test_ov_omz_wrapper.py +++ b/tests/unit/mpa/modules/ov/test_ov_omz_wrapper.py @@ -7,7 +7,7 @@ from openvino.model_zoo._configuration import Model -from otx.mpa.modules.ov.omz_wrapper import ( +from otx.core.ov.omz_wrapper import ( download_model, get_model_configuration, get_omz_model, diff --git a/tests/unit/mpa/modules/ov/test_ov_registry.py b/tests/unit/mpa/modules/ov/test_ov_registry.py index cc9f163a4db..25c5e32be7d 100644 --- a/tests/unit/mpa/modules/ov/test_ov_registry.py +++ b/tests/unit/mpa/modules/ov/test_ov_registry.py @@ -4,7 +4,7 @@ import pytest -from otx.mpa.modules.ov.registry import Registry +from otx.core.ov.registry import Registry from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/test_ov_utils.py b/tests/unit/mpa/modules/ov/test_ov_utils.py index d5432e49ed2..170517379ab 100644 --- a/tests/unit/mpa/modules/ov/test_ov_utils.py +++ b/tests/unit/mpa/modules/ov/test_ov_utils.py @@ -8,8 +8,8 @@ import openvino.runtime as ov import pytest -from otx.mpa.modules.ov.omz_wrapper import get_omz_model -from otx.mpa.modules.ov.ops import ParameterV0 +from otx.core.ov.omz_wrapper import get_omz_model +from otx.core.ov.ops.infrastructures import ParameterV0 from otx.mpa.modules.ov.utils import ( convert_op_to_torch, convert_op_to_torch_module, From 0a489f36e1231b95d33bc449a7b2caf80d75be57 Mon Sep 17 00:00:00 2001 From: "Kang, Harim" Date: Thu, 23 Mar 2023 15:54:41 +0900 Subject: [PATCH 2/6] Refactor mpa.modules.ov -> otx.core.ov --- .../mmcls/models/backbones/mmov_backbone.py | 6 +- .../mmcls/models/heads/mmov_cls_head.py | 2 +- .../adapters/mmcls/models/necks/mmov_neck.py | 6 +- .../mmdet/models/backbones/mmov_backbone.py | 2 +- .../mmdet/models/dense_heads/mmov_rpn_head.py | 2 +- .../mmdet/models/dense_heads/mmov_ssd_head.py | 2 +- .../models/dense_heads/mmov_yolov3_head.py | 2 +- .../adapters/mmdet/models/necks/mmov_fpn.py | 2 +- .../mmdet/models/necks/mmov_ssd_neck.py | 2 +- .../mmdet/models/necks/mmov_yolov3_neck.py | 4 +- .../roi_heads/bbox_heads/mmov_bbox_head.py | 2 +- .../roi_heads/mask_heads/mmov_mask_head.py | 2 +- .../mmseg/models/backbones/mmov_backbone.py | 2 +- .../mmseg/models/heads/mmov_decode_head.py | 2 +- otx/core/ov/graph/__init__.py | 5 + otx/{mpa/modules => core}/ov/graph/graph.py | 148 +++++++++++++----- otx/core/ov/graph/parsers/__init__.py | 4 + otx/core/ov/graph/parsers/builder.py | 8 + .../ov/graph/parsers/cls/cls_base_parser.py | 18 +-- .../ov/graph/parsers/parser.py | 8 +- otx/{mpa/modules => core}/ov/graph/utils.py | 46 +++--- otx/core/ov/models/__init__.py | 10 ++ .../modules => core}/ov/models/mmov_model.py | 20 ++- .../modules => core}/ov/models/ov_model.py | 100 +++++++----- .../ov/models/parser_mixin.py | 14 +- otx/core/ov/ops/infrastructures.py | 11 +- otx/core/ov/ops/modules/op_module.py | 19 +++ otx/core/ov/ops/op.py | 3 +- otx/core/ov/ops/utils.py | 25 +++ otx/{mpa/modules => core}/ov/utils.py | 73 +++------ otx/mpa/modules/ov/graph/__init__.py | 6 - otx/mpa/modules/ov/graph/parsers/__init__.py | 8 - otx/mpa/modules/ov/graph/parsers/builder.py | 7 - otx/mpa/modules/ov/models/__init__.py | 3 - .../graph/parsers/test_ov_graph_cls_parser.py | 4 +- .../ov/graph/parsers/test_ov_graph_parser.py | 6 +- .../modules/ov/graph/test_ov_graph_grapy.py | 2 +- .../modules/ov/graph/test_ov_graph_utils.py | 6 +- .../ov/models/test_ov_models_ov_model.py | 2 +- tests/unit/mpa/modules/ov/test_ov_utils.py | 6 +- 40 files changed, 366 insertions(+), 234 deletions(-) rename otx/{mpa/modules => core}/ov/graph/graph.py (80%) create mode 100644 otx/core/ov/graph/parsers/builder.py rename otx/{mpa/modules => core}/ov/graph/parsers/parser.py (60%) rename otx/{mpa/modules => core}/ov/graph/utils.py (88%) create mode 100644 otx/core/ov/models/__init__.py rename otx/{mpa/modules => core}/ov/models/mmov_model.py (73%) rename otx/{mpa/modules => core}/ov/models/ov_model.py (84%) rename otx/{mpa/modules => core}/ov/models/parser_mixin.py (85%) rename otx/{mpa/modules => core}/ov/utils.py (67%) delete mode 100644 otx/mpa/modules/ov/graph/__init__.py delete mode 100644 otx/mpa/modules/ov/graph/parsers/__init__.py delete mode 100644 otx/mpa/modules/ov/graph/parsers/builder.py delete mode 100644 otx/mpa/modules/ov/models/__init__.py diff --git a/otx/algorithms/classification/adapters/mmcls/models/backbones/mmov_backbone.py b/otx/algorithms/classification/adapters/mmcls/models/backbones/mmov_backbone.py index 06ffafd487b..1480701c509 100644 --- a/otx/algorithms/classification/adapters/mmcls/models/backbones/mmov_backbone.py +++ b/otx/algorithms/classification/adapters/mmcls/models/backbones/mmov_backbone.py @@ -1,11 +1,11 @@ """Module for the MMOVBackbone class.""" -from typing import Dict, List +from typing import Dict, List, Union from mmcls.models.builder import BACKBONES from otx.core.ov.graph.parsers.cls import cls_base_parser -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel @BACKBONES.register_module() @@ -18,7 +18,7 @@ class MMOVBackbone(MMOVModel): """ @staticmethod - def parser(graph, **kwargs) -> Dict[str, List[str]]: + def parser(graph, **kwargs) -> Dict[str, Union[List[str], Dict[str, List[str]]]]: """Parses the input and output of the model. Args: diff --git a/otx/algorithms/classification/adapters/mmcls/models/heads/mmov_cls_head.py b/otx/algorithms/classification/adapters/mmcls/models/heads/mmov_cls_head.py index 4f9ceca5cb5..3f18ded84ab 100644 --- a/otx/algorithms/classification/adapters/mmcls/models/heads/mmov_cls_head.py +++ b/otx/algorithms/classification/adapters/mmcls/models/heads/mmov_cls_head.py @@ -11,7 +11,7 @@ from mmcls.models.heads import ClsHead from otx.core.ov.graph.parsers.cls import cls_base_parser -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel @HEADS.register_module() diff --git a/otx/algorithms/classification/adapters/mmcls/models/necks/mmov_neck.py b/otx/algorithms/classification/adapters/mmcls/models/necks/mmov_neck.py index d4569e17dcc..a4e96798b6a 100644 --- a/otx/algorithms/classification/adapters/mmcls/models/necks/mmov_neck.py +++ b/otx/algorithms/classification/adapters/mmcls/models/necks/mmov_neck.py @@ -3,12 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 # -from typing import Dict, List +from typing import Dict, List, Union from mmcls.models.builder import NECKS from otx.core.ov.graph.parsers.cls import cls_base_parser -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel @NECKS.register_module() @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def parser(graph, **kwargs) -> Dict[str, List[str]]: + def parser(graph, **kwargs) -> Dict[str, Union[List[str], Dict[str, List[str]]]]: """Parser function returns base_parser for given graph.""" output = cls_base_parser(graph, "neck") if output is None: diff --git a/otx/algorithms/detection/adapters/mmdet/models/backbones/mmov_backbone.py b/otx/algorithms/detection/adapters/mmdet/models/backbones/mmov_backbone.py index e976a44ebb9..baa12457bc9 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/backbones/mmov_backbone.py +++ b/otx/algorithms/detection/adapters/mmdet/models/backbones/mmov_backbone.py @@ -5,7 +5,7 @@ from mmdet.models.builder import BACKBONES -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel @BACKBONES.register_module() diff --git a/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_rpn_head.py b/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_rpn_head.py index a5be0a455fa..fbea73263f4 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_rpn_head.py +++ b/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_rpn_head.py @@ -11,7 +11,7 @@ from mmdet.models.builder import HEADS from mmdet.models.dense_heads.rpn_head import RPNHead -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel from otx.mpa.utils.logger import get_logger logger = get_logger() diff --git a/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_ssd_head.py b/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_ssd_head.py index 08f82c5d25b..90a2b573cdd 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_ssd_head.py +++ b/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_ssd_head.py @@ -12,7 +12,7 @@ from mmdet.models.builder import HEADS from mmdet.models.dense_heads.ssd_head import SSDHead -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel # TODO: Need to fix pylint issues # pylint: disable=redefined-argument-from-local, too-many-instance-attributes diff --git a/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_yolov3_head.py b/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_yolov3_head.py index 89a2b2af830..be9bfb1ecd2 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_yolov3_head.py +++ b/otx/algorithms/detection/adapters/mmdet/models/dense_heads/mmov_yolov3_head.py @@ -11,7 +11,7 @@ from mmdet.models.builder import HEADS from mmdet.models.dense_heads.yolo_head import YOLOV3Head -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel # TODO: Need to fix pylint issues # pylint: disable=too-many-instance-attributes, keyword-arg-before-vararg diff --git a/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_fpn.py b/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_fpn.py index b4fbe7ba01d..bbb9cea7849 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_fpn.py +++ b/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_fpn.py @@ -10,7 +10,7 @@ from mmdet.models.necks.fpn import FPN from torch import nn -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel # TODO: Need to fix pylint issues # pylint: disable=keyword-arg-before-vararg, too-many-locals diff --git a/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_ssd_neck.py b/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_ssd_neck.py index b65403142cf..27df619302c 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_ssd_neck.py +++ b/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_ssd_neck.py @@ -12,7 +12,7 @@ from mmdet.models.builder import NECKS from torch import nn -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel # pylint: disable=too-many-arguments, too-many-locals diff --git a/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_yolov3_neck.py b/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_yolov3_neck.py index d879d2f15d5..24473f4c34e 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_yolov3_neck.py +++ b/otx/algorithms/detection/adapters/mmdet/models/necks/mmov_yolov3_neck.py @@ -10,8 +10,8 @@ from mmdet.models.builder import NECKS from mmdet.models.necks.yolo_neck import YOLOV3Neck -from otx.mpa.modules.ov.models.mmov_model import MMOVModel -from otx.mpa.modules.ov.models.parser_mixin import ParserMixin +from otx.core.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.parser_mixin import ParserMixin # type: ignore[attr-defined] @NECKS.register_module() diff --git a/otx/algorithms/detection/adapters/mmdet/models/roi_heads/bbox_heads/mmov_bbox_head.py b/otx/algorithms/detection/adapters/mmdet/models/roi_heads/bbox_heads/mmov_bbox_head.py index f741ba3ffd7..c144fe9e7c4 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/roi_heads/bbox_heads/mmov_bbox_head.py +++ b/otx/algorithms/detection/adapters/mmdet/models/roi_heads/bbox_heads/mmov_bbox_head.py @@ -11,7 +11,7 @@ from mmdet.models.builder import HEADS from mmdet.models.roi_heads.bbox_heads.bbox_head import BBoxHead -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel # TODO: Need to fix pylint issues # pylint: disable=too-many-instance-attributes, too-many-arguments, keyword-arg-before-vararg, dangerous-default-value diff --git a/otx/algorithms/detection/adapters/mmdet/models/roi_heads/mask_heads/mmov_mask_head.py b/otx/algorithms/detection/adapters/mmdet/models/roi_heads/mask_heads/mmov_mask_head.py index 7343fa92043..29475519b44 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/roi_heads/mask_heads/mmov_mask_head.py +++ b/otx/algorithms/detection/adapters/mmdet/models/roi_heads/mask_heads/mmov_mask_head.py @@ -10,7 +10,7 @@ from mmdet.models.builder import HEADS from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel # TODO: Need to fix pylint issues # pylint: disable=too-many-instance-attributes, too-many-arguments, keyword-arg-before-vararg, dangerous-default-value diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/backbones/mmov_backbone.py b/otx/algorithms/segmentation/adapters/mmseg/models/backbones/mmov_backbone.py index 0b19fc7eb12..9ab524dcfb7 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/models/backbones/mmov_backbone.py +++ b/otx/algorithms/segmentation/adapters/mmseg/models/backbones/mmov_backbone.py @@ -6,7 +6,7 @@ from mmseg.models.builder import BACKBONES -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel # pylint: disable=unused-argument diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/heads/mmov_decode_head.py b/otx/algorithms/segmentation/adapters/mmseg/models/heads/mmov_decode_head.py index 75fc3083919..f34c72b9840 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/models/heads/mmov_decode_head.py +++ b/otx/algorithms/segmentation/adapters/mmseg/models/heads/mmov_decode_head.py @@ -10,7 +10,7 @@ import openvino.runtime as ov from mmseg.models.decode_heads.decode_head import BaseDecodeHead -from otx.mpa.modules.ov.models.mmov_model import MMOVModel +from otx.core.ov.models.mmov_model import MMOVModel # pylint: disable=too-many-instance-attributes, keyword-arg-before-vararg diff --git a/otx/core/ov/graph/__init__.py b/otx/core/ov/graph/__init__.py index 9696d660171..ad121dad22d 100644 --- a/otx/core/ov/graph/__init__.py +++ b/otx/core/ov/graph/__init__.py @@ -2,3 +2,8 @@ # Copyright (C) 2023 Intel Corporation # # SPDX-License-Identifier: MIT + +# TODO: Need to remove comment with ignore mypy and fix mypy issues +from .graph import Graph # type: ignore[attr-defined] + +__all__ = ["Graph"] diff --git a/otx/mpa/modules/ov/graph/graph.py b/otx/core/ov/graph/graph.py similarity index 80% rename from otx/mpa/modules/ov/graph/graph.py rename to otx/core/ov/graph/graph.py index ee83586fc0e..75aad71282b 100644 --- a/otx/mpa/modules/ov/graph/graph.py +++ b/otx/core/ov/graph/graph.py @@ -1,6 +1,9 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +# type: ignore +# TODO: Need to remove line 1 (ignore mypy) and fix mypy issues +"""Modules for otx.core.ov.graph.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT import inspect from collections import OrderedDict @@ -10,52 +13,73 @@ import _collections_abc import networkx as nx -from openvino.pyopenvino import Model +from openvino.pyopenvino import Model # pylint: disable=no-name-in-module -from otx.core.ov.ops.op import Operation -from otx.mpa.modules.ov.utils import convert_op_to_torch, get_op_name from otx.mpa.utils.logger import get_logger +from ..ops.op import Operation +from ..ops.utils import convert_op_to_torch +from ..utils import get_op_name + logger = get_logger() +# pylint: disable=too-many-locals, too-many-nested-blocks, arguments-renamed, too-many-branches, too-many-statements + class SortedDictKeysView(_collections_abc.KeysView): + """SortedDictKeysView class.""" + def __repr__(self): - return f"{self.__class__.__name__}({[i for i in self._mapping]})" + """Function repr of SortedDictKeysView.""" + return f"{self.__class__.__name__}({list(self._mapping)})" def __reversed__(self): + """Function reversed of SortedDictKeysView.""" yield from reversed(self._mapping) class SortedDictValuesView(_collections_abc.ValuesView): + """SortedDictValuesView class.""" + def __repr__(self): + """Sorteddictvaluesview's repr function.""" return f"{self.__class__.__name__}({[self._mapping[i] for i in self._mapping]})" def __reversed__(self): + """Sorteddictvaluesview's reversed function.""" for key in reversed(self._mapping): yield self._mapping[key] class SortedDictItemsView(_collections_abc.ItemsView): + """SortedDictItemsView class.""" + def __repr__(self): + """Sorteddictitemsview's repr function.""" return f"{self.__class__.__name__}({[(i, self._mapping[i]) for i in self._mapping]})" def __reversed__(self): + """Sorteddictitemsview's reversed function.""" for key in reversed(self._mapping): yield (key, self._mapping[key]) class NOOP: - pass + """NOOP class.""" + + pass # pylint: disable=unnecessary-pass class SortedDict(dict): + """SortedDict class.""" + def __init__(self, sort_key, *args, **kwargs): self._sort_key = sort_key self._sorted_keys = [] super().__init__(self, *args, **kwargs) def __setitem__(self, key, value): + """Sorteddict's setitem function.""" assert len(value) == 1 edge_key, edge_attr = next(iter(value.items())) sort_value = float("inf") if self._sort_key not in edge_attr else edge_attr[self._sort_key] @@ -68,43 +92,50 @@ def __setitem__(self, key, value): super().__setitem__(key, value) def __delitem__(self, key): + """Sorteddict's delitem function.""" super().__delitem__(key) for i, (_, key_in, _) in enumerate(self._sorted_keys): if key_in == key: break - self._sorted_keys.pop(i) + self._sorted_keys.pop(i) # pylint: disable=undefined-loop-variable def __iter__(self): + """Sorteddict's iter function.""" for _, key, _ in self._sorted_keys: yield key def __reversed__(self): + """Sorteddict's reversed function.""" for _, key, _ in self._sorted_keys[::-1]: yield key def __repr__(self): - if not len(self): + """Sorteddict's repr function.""" + if not len(self): # pylint: disable=use-implicit-booleaness-not-len return "{}" - repr = "{" + repr_ = "{" for _, key, _ in self._sorted_keys: - repr += f"{key}: {self[key]}, " - repr = repr[:-2] - repr += "}" - return repr + repr_ += f"{key}: {self[key]}, " + repr_ = repr_[:-2] + repr_ += "}" + return repr_ def __deepcopy__(self, memo): + """Sorteddict's deepcopy function.""" cls = self.__class__ result = cls(self._sort_key) memo[id(self)] = result - for k, v in self.items(): - result[k] = deepcopy(v, memo) + for key, value in self.items(): + result[key] = deepcopy(value, memo) return result def clear(self): + """Sorteddict's clear function.""" super().clear() self._sorted_keys = [] def pop(self, key, default=NOOP()): + """Sorteddict's pop function.""" if isinstance(default, NOOP): value = super().pop(key) else: @@ -113,39 +144,49 @@ def pop(self, key, default=NOOP()): for i, (_, key_in, _) in enumerate(self._sorted_keys): if key_in == key: break - self._sorted_keys.pop(i) + self._sorted_keys.pop(i) # pylint: disable=undefined-loop-variable return value def popitem(self): + """Sorteddict's popitem function.""" raise NotImplementedError @staticmethod def fromkeys(iterable, value=None): + """Sorteddict's fromkeys function.""" raise NotImplementedError def keys(self): + """Sorteddict's keys function.""" return SortedDictKeysView(self) def values(self): + """Sorteddict's values function.""" return SortedDictValuesView(self) def items(self): + """Sorteddict's items function.""" return SortedDictItemsView(self) class SortedDictHelper(dict): - def __init__(self, sort_key=None, *args, **kwargs): + """SortedDictHelper class.""" + + def __init__(self, sort_key=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg self._sort_key = sort_key super().__init__(*args, **kwargs) def __setitem__(self, key, value): + """Sorteddicthelper's setitem function.""" super().__setitem__(key, SortedDict(self._sort_key)) for v_key, v_value in value.items(): self[key][v_key] = v_value class Graph(nx.MultiDiGraph): + """Graph class.""" + adjlist_outer_dict_factory = SortedDictHelper def __init__(self, *args, **kwargs): @@ -159,12 +200,13 @@ def __init__(self, *args, **kwargs): @staticmethod def from_ov(ov_model: Model) -> "Graph": + """Graph's from_ov function.""" graph = Graph() ov_ops = ov_model.get_ordered_ops() ops_dict = OrderedDict() - parents_dict = {} - children_dict = {} + parents_dict: Dict[str, List[Optional[List]]] = {} + children_dict: Dict[str, List[Optional[List]]] = {} for ov_op in ov_ops: op_name = get_op_name(ov_op) @@ -212,21 +254,26 @@ def from_ov(ov_model: Model) -> "Graph": ) # freeze normalization nodes - graph._freeze_normalize_nodes() + graph._freeze_normalize_nodes() # pylint: disable=protected-access return graph def get_edge_data(self, node_from: Operation, node_to: Operation, default=None) -> Optional[List[Dict[Any, Any]]]: + """Graph's get_edge_data function.""" edge_data = super().get_edge_data(node_from, node_to, None, default) if edge_data is not None: return list(edge_data.values()) - else: - return None + return None def remove_node(self, node: Operation, keep_connect: bool = False): + """Graph's remove_node function.""" edges_to_keep = [] if keep_connect: - predecessors = [predecessor for predecessor in self.predecessors(node) if predecessor.type != "Constant"] + predecessors = [ + predecessor + for predecessor in self.predecessors(node) + if hasattr(predecessor, "type") and predecessor.type != "Constant" + ] if predecessors: assert len(predecessors) == 1 predecessor = predecessors[0] @@ -245,6 +292,7 @@ def remove_node(self, node: Operation, keep_connect: bool = False): self.add_edge(node_from, node_to, **attrs) def replace_node(self, old_node: Operation, new_node: Operation): + """Graph's replace_node function.""" edges = [] for successor in self.successors(old_node): for edge_attrs in self.get_edge_data(old_node, successor): @@ -267,6 +315,7 @@ def add_edge( in_port: Optional[int] = None, **kwargs, ): + """Graph's add_edge function.""" if node_from not in self: self.add_node(node_from) @@ -319,6 +368,7 @@ def predecessors( node: Operation, with_edge_data: bool = False, ) -> Generator[Union[Tuple[Operation, Optional[List]], Operation], None, None]: + """Graph's predecessors function.""" for predecessor in super().predecessors(node): if with_edge_data: yield (predecessor, self.get_edge_data(predecessor, node)) @@ -330,6 +380,7 @@ def successors( node: Operation, with_edge_data: bool = False, ) -> Generator[Union[Tuple[Operation, Optional[List]], Operation], None, None]: + """Graph's successors function.""" for successor in super().successors(node): if with_edge_data: yield (successor, self.get_edge_data(node, successor)) @@ -337,6 +388,7 @@ def successors( yield successor def get_nodes_by_types(self, types: List[str]) -> List[Operation]: + """Graph's get_nodes_by_types function.""" found = [] for node in self.topological_sort(): if node.type in types: @@ -346,9 +398,10 @@ def get_nodes_by_types(self, types: List[str]) -> List[Operation]: def bfs( self, node: Operation, reverse: bool = False, depth_limit: Optional[int] = None ) -> Generator[Union[Tuple[Operation, Operation], Tuple[Operation, Tuple[Operation]]], None, None]: + """Graph's bfs function.""" if reverse: - for s, t in nx.bfs_edges(self, node, reverse=True, depth_limit=depth_limit): - yield (t, s) + for s_value, t_value in nx.bfs_edges(self, node, reverse=True, depth_limit=depth_limit): + yield (t_value, s_value) else: parent = node children = [] @@ -369,6 +422,7 @@ def bfs( # return nx.dfs_predecessors(self, node, depth_limit) def get_nodes_by_type_pattern(self, pattern: List[str], start_node: Optional[Operation] = None, reverse=False): + """Graph's get_nodes_by_type_pattern function.""" if len(pattern) < 1: raise ValueError(f"pattern must be longer than 2 but {len(pattern)} is given") pattern_pairs = [pattern[i : i + 2] for i in range(len(pattern) - 1)] @@ -383,16 +437,16 @@ def get_nodes_by_type_pattern(self, pattern: List[str], start_node: Optional[Ope start_nodes = [start_node] for pattern_pair in pattern_pairs: found_ = {start_node: None for start_node in start_nodes} - for start_node in start_nodes: - for s, ts in self.bfs(start_node, reverse, 1): - if not isinstance(ts, tuple): - ts = (ts,) - for t in ts: - if [s.type, t.type] == pattern_pair: + for start_node_ in start_nodes: + for s_value, ts_ in self.bfs(start_node_, reverse, 1): + if not isinstance(ts_, tuple): + ts_ = (ts_,) + for t in ts_: + if [s_value.type, t.type] == pattern_pair: if reverse: - found_[t] = s + found_[t] = s_value else: - found_[s] = t + found_[s_value] = t if founds: pop_indices = [] for i, found in enumerate(founds): @@ -409,9 +463,11 @@ def get_nodes_by_type_pattern(self, pattern: List[str], start_node: Optional[Ope return founds def _freeze_normalize_nodes(self): # noqa: C901 + """Graph's _freeze_normalize_nodes function.""" invariant_types = ["Transpose", "Convert"] def test_constant(node): + """Graph's test_constant function.""" constant_nodes = [node_ for node_ in self.predecessors(node) if node_.type == "Constant"] if len(constant_nodes) != 1: return False @@ -422,11 +478,13 @@ def test_constant(node): def get_nodes_by_type_from_node( node, - type, - ignore_types=[], + types, + ignore_types=None, reverse=False, depth_limit=-1, ): + """Graph's get_nodes_by_type_from_node function.""" + ignore_types = ignore_types if ignore_types else [] func = self.successors if reverse: func = self.predecessors @@ -434,15 +492,16 @@ def get_nodes_by_type_from_node( candidates = [(i, 1) for i in func(node)] found = [] for candidate, cur_depth in candidates: - if depth_limit > -1 and cur_depth > depth_limit: + if cur_depth > depth_limit > -1: break - if candidate.type == type: + if candidate.type == types: found.append(candidate) elif candidate.type in ignore_types: candidates.extend([(i, cur_depth + 1) for i in func(candidate)]) return found def find_multiply_add(node): + """Graph's find_multiply_add function.""" scale_node = None mean_node = None @@ -462,6 +521,7 @@ def find_multiply_add(node): return (scale_node, mean_node) def find_subtract_divide(node): + """Graph's find_subtract_divide function.""" mean_node = None scale_node = None @@ -481,6 +541,7 @@ def find_subtract_divide(node): return (mean_node, scale_node) def find_subtract_multiply(node): + """Graph's find_subtract_multiply function.""" mean_node = None scale_node = None @@ -512,7 +573,7 @@ def find_subtract_multiply(node): if len([i for i in found if i is not None]) < len([i for i in found_ if i is not None]): found = found_ - if not all([i is not None for i in found]): + if not all(i is not None for i in found): continue self._normalize_nodes.append(found) @@ -532,6 +593,7 @@ def find_subtract_multiply(node): self.replace_node(constant_node, new_constant_node) def remove_normalize_nodes(self): + """Graph's remove_normalize_nodes function.""" for nodes in self._normalize_nodes: first_node, second_node = nodes @@ -544,21 +606,25 @@ def remove_normalize_nodes(self): try: self.remove_node(second_node, keep_connect=True) logger.info(f"Remove normalize node {second_node.name}") - except Exception: + except Exception: # pylint: disable=broad-exception-caught pass self._normalize_nodes = [] def topological_sort(self): + """Graph's topological_sort function.""" return nx.topological_sort(self) def has_path(self, node_from: Operation, node_to: Operation): + """Graph's has_path function.""" return nx.has_path(self, node_from, node_to) def clean_up( self, - nodes_to_keep: List[Operation] = [], + nodes_to_keep: List[Operation] = None, remove_sub_components: bool = True, ): + """Graph's clean_up function.""" + nodes_to_keep = nodes_to_keep if nodes_to_keep else [] if remove_sub_components: # clean up sub components components = list(nx.connected_components(self.to_undirected())) diff --git a/otx/core/ov/graph/parsers/__init__.py b/otx/core/ov/graph/parsers/__init__.py index 63ec930e665..7a158471ce2 100644 --- a/otx/core/ov/graph/parsers/__init__.py +++ b/otx/core/ov/graph/parsers/__init__.py @@ -2,3 +2,7 @@ # Copyright (C) 2023 Intel Corporation # # SPDX-License-Identifier: MIT + +from .builder import PARSERS + +__all__ = ["PARSERS"] diff --git a/otx/core/ov/graph/parsers/builder.py b/otx/core/ov/graph/parsers/builder.py new file mode 100644 index 00000000000..b802f44365b --- /dev/null +++ b/otx/core/ov/graph/parsers/builder.py @@ -0,0 +1,8 @@ +"""Builder module for otx.core.ov.graph.parsers.""" +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from otx.core.ov.registry import Registry + +PARSERS = Registry("ov graph parsers") diff --git a/otx/core/ov/graph/parsers/cls/cls_base_parser.py b/otx/core/ov/graph/parsers/cls/cls_base_parser.py index 4b88ab5e447..b5fd44d8beb 100644 --- a/otx/core/ov/graph/parsers/cls/cls_base_parser.py +++ b/otx/core/ov/graph/parsers/cls/cls_base_parser.py @@ -5,10 +5,11 @@ from typing import Dict, List, Optional -from otx.mpa.modules.ov.graph.parsers.builder import PARSERS -from otx.mpa.modules.ov.graph.parsers.parser import parameter_parser from otx.mpa.utils.logger import get_logger +from ..builder import PARSERS +from ..parser import parameter_parser + logger = get_logger() # pylint: disable=too-many-return-statements, too-many-branches @@ -83,26 +84,23 @@ def cls_base_parser(graph, component: str = "backbone") -> Optional[Dict[str, Li ) if component == "head": - inputs = list(graph.successors(neck_output)) - # if len(inputs) != 1: - # logger.debug(f"neck_output {neck_output.name} has more than one successors.") - # return None + head_inputs = list(graph.successors(neck_output)) outputs = graph.get_nodes_by_types(["Result"]) if len(outputs) != 1: - logger.debug("more than one network output are found.") + logger.debug("More than one network output is found.") return None for node_from, node_to in graph.bfs(outputs[0], True, 5): if node_to.type == "Softmax": outputs = [node_from] break - if not graph.has_path(inputs[0], outputs[0]): - logger.debug(f"input({inputs[0].name}) and output({outputs[0].name}) are reversed") + if not graph.has_path(head_inputs[0], outputs[0]): + logger.debug(f"input({head_inputs[0].name}) and output({outputs[0].name}) are reversed") return None return dict( - inputs=[input.name for input in inputs], + inputs=[input_.name for input_ in head_inputs], outputs=[output.name for output in outputs], ) return None diff --git a/otx/mpa/modules/ov/graph/parsers/parser.py b/otx/core/ov/graph/parsers/parser.py similarity index 60% rename from otx/mpa/modules/ov/graph/parsers/parser.py rename to otx/core/ov/graph/parsers/parser.py index 96046046107..9cbf0583651 100644 --- a/otx/mpa/modules/ov/graph/parsers/parser.py +++ b/otx/core/ov/graph/parsers/parser.py @@ -1,11 +1,13 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Parser modules for otx.core.ov.graph.parsers.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from typing import List def type_parser(graph, types) -> List[str]: + """Type Parser from graph, types.""" found = [] for node in graph: if node.type in types: @@ -14,8 +16,10 @@ def type_parser(graph, types) -> List[str]: def result_parser(graph) -> List[str]: + """Result Parser from graph.""" return type_parser(graph, ["Result"]) def parameter_parser(graph) -> List[str]: + """Parameter Parser from graph.""" return type_parser(graph, ["Parameter"]) diff --git a/otx/mpa/modules/ov/graph/utils.py b/otx/core/ov/graph/utils.py similarity index 88% rename from otx/mpa/modules/ov/graph/utils.py rename to otx/core/ov/graph/utils.py index 5b7818906ec..e7e55f77891 100644 --- a/otx/mpa/modules/ov/graph/utils.py +++ b/otx/core/ov/graph/utils.py @@ -1,31 +1,36 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""Utils for otx.core.ov.graph.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT -from typing import List +from typing import Any, List import torch +from otx.core.ov.graph import Graph from otx.core.ov.ops.builder import OPS from otx.core.ov.ops.infrastructures import ConstantV0 from otx.core.ov.ops.op import Operation from otx.mpa.utils.logger import get_logger -from .graph import Graph - logger = get_logger() +# pylint: disable=too-many-locals, protected-access, too-many-branches, too-many-statements, too-many-nested-blocks + def get_constant_input_nodes(graph: Graph, node: Operation) -> List[Operation]: + """Getter constant input nodes from graph, node.""" found = [] - for node in graph.predecessors(node): - if node.type == "Constant": - found.append(node) + for node_ in graph.predecessors(node): + if node_.type == "Constant": + found.append(node_) return found -def handle_merging_into_batchnorm(graph, type_patterns=[["Multiply", "Add"]], type_mappings=[{"gamma": 0, "beta": 1}]): - +def handle_merging_into_batchnorm(graph, type_patterns=None, type_mappings=None): # noqa: C901 + """Merge function graph into batchnorm.""" + type_patterns = type_patterns if type_patterns else [["Multiply", "Add"]] + type_mappings = type_mappings if type_mappings else [{"gamma": 0, "beta": 1}] assert len(type_patterns) == len(type_mappings) batchnorm_cls = OPS.get_by_type_version("BatchNormInference", 0) constant_cls = OPS.get_by_type_version("Constant", 0) @@ -153,7 +158,9 @@ def handle_merging_into_batchnorm(graph, type_patterns=[["Multiply", "Add"]], ty graph.add_edge(running_variance, batchnorm) -def handle_paired_batchnorm(graph, replace: bool = False, types: List[str] = ["Convolution", "GroupConvolution"]): +def handle_paired_batchnorm(graph, replace: bool = False, types: List[str] = None): + """Handle function paired batchnorm.""" + types = types if types else ["Convolution", "GroupConvolution"] batchnorm_cls = OPS.get_by_type_version("BatchNormInference", 0) constant_cls = OPS.get_by_type_version("Constant", 0) @@ -172,9 +179,9 @@ def handle_paired_batchnorm(graph, replace: bool = False, types: List[str] = ["C ) continue - bias_node = [n for n in graph.successors(node) if n.type == "Add"] - if len(bias_node) == 1: - bias_node = bias_node[0] + bias_node_list: List[Any] = [n for n in graph.successors(node) if n.type == "Add"] + if len(bias_node_list) == 1: + bias_node = bias_node_list[0] else: bias_node = None @@ -183,10 +190,9 @@ def handle_paired_batchnorm(graph, replace: bool = False, types: List[str] = ["C logger.info(f"Skip a paired batch normalization for {node.name} " "becuase it has no bias add node.") continue # if add node is not bias add node - elif not isinstance(list(graph.predecessors(bias_node))[1], ConstantV0): + if not isinstance(list(graph.predecessors(bias_node))[1], ConstantV0): logger.info( - f"Skip a pared batch normalization for {node.name } " - f"because {bias_node.name} is not a bias add node." + f"Skip a pared batch normalization for {node.name} " f"because {bias_node.name} is not a bias add node." ) continue @@ -270,7 +276,7 @@ def handle_paired_batchnorm(graph, replace: bool = False, types: List[str] = ["C def handle_reshape(graph): - + """Reshape function.""" for result in graph.get_nodes_by_types(["Result"]): for node in graph.predecessors(result): # some models, for example, dla-34, have reshape node as its predecessor @@ -281,5 +287,5 @@ def handle_reshape(graph): for shape_ in input_node.shape[0][::-1]: if shape_ != 1: break - logger.info(f"Change reshape to [-1, {shape_}]") - shape.data = torch.tensor([-1, shape_]) + logger.info(f"Change reshape to [-1, {shape_}]") # pylint: disable=undefined-loop-variable + shape.data = torch.tensor([-1, shape_]) # pylint: disable=undefined-loop-variable diff --git a/otx/core/ov/models/__init__.py b/otx/core/ov/models/__init__.py new file mode 100644 index 00000000000..4c62ff2a3ad --- /dev/null +++ b/otx/core/ov/models/__init__.py @@ -0,0 +1,10 @@ +"""Module for otx.core.ov.models.""" +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from .mmov_model import MMOVModel + +__all__ = [ + "MMOVModel", +] diff --git a/otx/mpa/modules/ov/models/mmov_model.py b/otx/core/ov/models/mmov_model.py similarity index 73% rename from otx/mpa/modules/ov/models/mmov_model.py rename to otx/core/ov/models/mmov_model.py index 389faa2e25a..3e3398125bc 100644 --- a/otx/mpa/modules/ov/models/mmov_model.py +++ b/otx/core/ov/models/mmov_model.py @@ -1,17 +1,24 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +"""MMOVModel for otx.core.ov.models.mmov_model.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from typing import Dict, List, Optional, Union import openvino.runtime as ov import torch -from .ov_model import OVModel -from .parser_mixin import ParserMixin +# TODO: Need to remove line 1 (ignore mypy) and fix mypy issues +from .ov_model import OVModel # type: ignore[attr-defined] +from .parser_mixin import ParserMixin # type: ignore[attr-defined] + +# TODO: Need to fix pylint issues +# pylint: disable=keyword-arg-before-vararg class MMOVModel(OVModel, ParserMixin): + """MMOVModel for OMZ model type.""" + def __init__( self, model_path_or_model: Union[str, ov.Model], @@ -42,12 +49,13 @@ def __init__( ) def forward(self, inputs, gt_label=None): + """Function forward.""" if isinstance(inputs, torch.Tensor): inputs = (inputs,) assert len(inputs) == len(self.inputs) feed_dict = dict() - for key, input in zip(self.inputs, inputs): - feed_dict[key] = input + for key, input_ in zip(self.inputs, inputs): + feed_dict[key] = input_ if gt_label is not None: assert "gt_label" not in self.features diff --git a/otx/mpa/modules/ov/models/ov_model.py b/otx/core/ov/models/ov_model.py similarity index 84% rename from otx/mpa/modules/ov/models/ov_model.py rename to otx/core/ov/models/ov_model.py index 31919029754..c7a72f2ea29 100644 --- a/otx/mpa/modules/ov/models/ov_model.py +++ b/otx/core/ov/models/ov_model.py @@ -1,6 +1,9 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +# type: ignore +# TODO: Need to remove line 1 (ignore mypy) and fix mypy issues +"""Modules for otx.core.ov.models.ov_model.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT import math import os @@ -13,8 +16,6 @@ import torch from torch.nn import init -from otx.core.ov.ops.builder import OPS -from otx.mpa.modules.ov.utils import load_ov_model, normalize_name from otx.mpa.utils.logger import get_logger from ..graph import Graph @@ -23,20 +24,26 @@ handle_paired_batchnorm, handle_reshape, ) +from ..ops.builder import OPS +from ..utils import load_ov_model, normalize_name logger = get_logger() CONNECTION_SEPARATOR = "||" +# pylint: disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements -class OVModel(torch.nn.Module): - def __init__( + +class OVModel(torch.nn.Module): # pylint: disable=too-many-instance-attributes + """OVModel class.""" + + def __init__( # noqa: C901 self, model_path_or_model: Union[str, ov.Model] = None, weight_path: Optional[str] = None, - inputs: Union[str, List[str]] = [], - outputs: Union[str, List[str]] = [], + inputs: Optional[Union[str, List[str]]] = None, + outputs: Optional[Union[str, List[str]]] = None, features_to_keep: Optional[List] = None, remove_normalize: bool = False, merge_bn: bool = True, @@ -54,8 +61,8 @@ def __init__( self._init_weight = init_weight self._verify_shape = verify_shape - self._inputs = [] - self._outputs = [] + self._inputs: List[str] = [] + self._outputs: List[str] = [] self._feature_dict = OrderedDict() # build graph @@ -101,40 +108,42 @@ def __init__( if not isinstance(init_weight, Callable): # internal init weight - def init_weight(m, graph): - from .....core.ov.ops.op import Operation + def init_weight(module, graph): # pylint: disable=function-redefined + from ..ops.op import Operation - if not isinstance(m, Operation): + if not isinstance(module, Operation): return - if m.TYPE == "BatchNormInference": - _, gamma, beta, mean, var = list(graph.predecessors(m)) + if module.TYPE == "BatchNormInference": + _, gamma, beta, mean, var = list(graph.predecessors(module)) init.ones_(gamma.data) init.zeros_(beta.data) mean.data.zero_() var.data.fill_(1) - logger.info(f"Initialize {m.TYPE} -> {m.name}") - elif m.TYPE in [ + logger.info(f"Initialize {module.TYPE} -> {module.name}") + elif module.TYPE in [ "Convolution", "GroupConvolution", "MatMul", ]: - for weight in graph.predecessors(m): + for weight in graph.predecessors(module): if weight.TYPE == "Constant" and isinstance(weight.data, torch.nn.parameter.Parameter): init.kaiming_uniform_(weight.data, a=math.sqrt(5)) - logger.info(f"Initialize {m.TYPE} -> {m.name}") - elif m.TYPE in [ + logger.info(f"Initialize {module.TYPE} -> {module.name}") + elif module.TYPE in [ "Multiply", "Divide", "Add", "Subtract", ]: - for weight in graph.predecessors(m): + for weight in graph.predecessors(module): if weight.TYPE == "Constant" and isinstance(weight.data, torch.nn.parameter.Parameter): - fan_in, _ = init._calculate_fan_in_and_fan_out(weight.data) + fan_in, _ = init._calculate_fan_in_and_fan_out( # pylint: disable=protected-access + weight.data + ) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 init.uniform_(weight.data, -bound, bound) - logger.info(f"Initialize {m.TYPE} -> {m.name}") + logger.info(f"Initialize {module.TYPE} -> {module.name}") self.model.apply(lambda m: init_weight(m, graph)) @@ -150,33 +159,39 @@ def init_weight(m, graph): output_shapes[node.name] = node.shape[0] self._input_shapes = OrderedDict() self._output_shapes = OrderedDict() - for input in self._inputs: - self._input_shapes[input] = input_shapes[input] + for input_ in self._inputs: + self._input_shapes[input_] = input_shapes[input_] for output in self._outputs: self._output_shapes[output] = output_shapes[output] @property def inputs(self): + """Property inputs.""" return self._inputs @property def outputs(self): + """Property outputs.""" return self._outputs @property def features(self): + """Property features.""" return self._feature_dict @property def input_shapes(self): + """Property input_shapes.""" return self._input_shapes @property def output_shapes(self): + """Property output_shapes.""" return self._output_shapes @staticmethod def build_graph(model_path_or_model, weight_path=None): + """Function build_graph.""" with tempfile.TemporaryDirectory() as tempdir: if isinstance(model_path_or_model, ov.Model): assert weight_path is None, "if openvino model is given 'weight_path' must be None" @@ -193,7 +208,8 @@ def build_graph(model_path_or_model, weight_path=None): return graph @staticmethod - def build_custom_outputs(graph, outputs): + def build_custom_outputs(graph, outputs): # noqa: C901 + """Function build_custom_outputs.""" cls_result = OPS.get_by_type_version("Result", 0) node_dict = OrderedDict((i.name, i) for i in graph.topological_sort()) @@ -262,7 +278,7 @@ def build_custom_outputs(graph, outputs): for edges in edges_to_add.values(): for edge in edges: edge["in_port"] = 0 - assert set([len(edges) for edges in edges_to_add.values()]) == {1} + assert {len(edges) for edges in edges_to_add.values()} == {1} edges_to_add = [edge for edges in edges_to_add.values() for edge in edges] else: edges_to_add = [] @@ -274,7 +290,8 @@ def build_custom_outputs(graph, outputs): return outputs @staticmethod - def build_custom_inputs(graph, inputs: Union[str, List[str]]): + def build_custom_inputs(graph, inputs: Union[str, List[str]]): # noqa: C901 + """Function build_custom_inputs.""" cls_param = OPS.get_by_type_version("Parameter", 0) node_dict = OrderedDict((i.name, i) for i in graph.topological_sort()) @@ -283,16 +300,16 @@ def build_custom_inputs(graph, inputs: Union[str, List[str]]): edges_to_add = {} nodes_to_remove = [] - for i, input in enumerate(inputs): - input = normalize_name(input) - input = input.split(CONNECTION_SEPARATOR) + for i, input_ in enumerate(inputs): + input_ = normalize_name(input_) + input_ = input_.split(CONNECTION_SEPARATOR) explicit_src = False - if len(input) == 1: + if len(input_) == 1: src = None - tgt = input[0] - elif len(input) == 2: - src, tgt = input + tgt = input_[0] + elif len(input_) == 2: + src, tgt = input_ explicit_src = True else: raise ValueError() @@ -353,7 +370,7 @@ def build_custom_inputs(graph, inputs: Union[str, List[str]]): for edges in edges_to_add.values(): for edge in edges: edge["out_port"] = 0 - assert set([len(edges) for edges in edges_to_add.values()]) == {1} + assert {len(edges) for edges in edges_to_add.values()} == {1} edges_to_add = [edge for edges in edges_to_add.values() for edge in edges] else: edges_to_add = [] @@ -365,14 +382,18 @@ def build_custom_inputs(graph, inputs: Union[str, List[str]]): return inputs @staticmethod - def clean_up(graph, inputs=[], outputs=[]): + def clean_up(graph, inputs=None, outputs=None): + """Function clean_up.""" + inputs = inputs if inputs else [] + outputs = outputs if outputs else [] nodes = list(graph.topological_sort()) nodes_to_keep = [] for node in nodes: if node.name in inputs or node.name in outputs: nodes_to_keep.append(node) - def get_nodes_without_successors(graph, ignores=[]): + def get_nodes_without_successors(graph, ignores=None): + ignores = ignores if ignores else [] outputs = [] for node in reversed(list(graph.topological_sort())): if not list(graph.successors(node)) and node not in ignores: @@ -388,10 +409,12 @@ def get_nodes_without_successors(graph, ignores=[]): @staticmethod def build_torch_module(graph): + """Function build_torch_module.""" node_dict = OrderedDict((i.name, i) for i in graph.topological_sort()) return torch.nn.ModuleDict(list(node_dict.items())) def _build_forward_inputs(self, *args, **kwargs): + """Function _build_forward_inputs.""" inputs = {} if args: for key, arg in zip(self._inputs, args): @@ -404,6 +427,7 @@ def _build_forward_inputs(self, *args, **kwargs): return inputs def forward(self, *args, **kwargs): + """Function forward.""" self._feature_dict.clear() inputs = self._build_forward_inputs(*args, **kwargs) diff --git a/otx/mpa/modules/ov/models/parser_mixin.py b/otx/core/ov/models/parser_mixin.py similarity index 85% rename from otx/mpa/modules/ov/models/parser_mixin.py rename to otx/core/ov/models/parser_mixin.py index 48120ba5d87..2943d58dd4d 100644 --- a/otx/mpa/modules/ov/models/parser_mixin.py +++ b/otx/core/ov/models/parser_mixin.py @@ -1,6 +1,9 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +# type: ignore +# TODO: Need to remove line 1 (ignore mypy) and fix mypy issues +"""Parser mixin modules for otx.core.ov.models.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT from typing import Callable, Dict, List, Optional, Tuple, Union @@ -15,6 +18,8 @@ class ParserMixin: + """ParserMixin class.""" + def parse( self, model_path_or_model: Union[str, ov.Model], @@ -24,7 +29,7 @@ def parse( parser: Optional[Union[str, Callable]] = None, **kwargs, ) -> Tuple[Union[str, List[str]], Union[str, List[str]]]: - + """Parse function of ParserMixin class.""" parser = self.parser if parser is None else parser if isinstance(parser, str): parser = PARSERS.get(parser) @@ -52,5 +57,6 @@ def parse( return inputs, outputs @staticmethod - def parser(graph, **kwargs) -> Dict[str, Union[List[str], Dict[str, List[str]]]]: + def parser(graph, **kwargs) -> Dict[str, Union[List[str], Dict[str, List[str]]]]: # pylint: disable=unused-argument + """Function parser.""" return dict(inputs=[], outputs=[]) diff --git a/otx/core/ov/ops/infrastructures.py b/otx/core/ov/ops/infrastructures.py index 0a24882d8aa..f80356db842 100644 --- a/otx/core/ov/ops/infrastructures.py +++ b/otx/core/ov/ops/infrastructures.py @@ -10,13 +10,14 @@ import numpy as np import torch -from otx.core.ov.ops.builder import OPS -from otx.core.ov.ops.op import Attribute, Operation -from otx.core.ov.ops.type_conversions import ConvertV0 -from otx.core.ov.ops.utils import get_dynamic_shape -from otx.mpa.modules.ov.utils import get_op_name from otx.mpa.utils.logger import get_logger +from ..utils import get_op_name # type: ignore[attr-defined] +from .builder import OPS +from .op import Attribute, Operation +from .type_conversions import ConvertV0 +from .utils import get_dynamic_shape + logger = get_logger() diff --git a/otx/core/ov/ops/modules/op_module.py b/otx/core/ov/ops/modules/op_module.py index 00c50c62438..2f59da59cd3 100644 --- a/otx/core/ov/ops/modules/op_module.py +++ b/otx/core/ov/ops/modules/op_module.py @@ -7,8 +7,10 @@ from typing import Dict, List, Optional, Union import torch +from openvino.pyopenvino import Node # pylint: disable=no-name-in-module from ..op import Operation +from ..utils import convert_op_to_torch class OperationModule(torch.nn.Module): @@ -83,3 +85,20 @@ def shape(self): def attrs(self): """Operationmodule's attrs property.""" return self.op_v.attrs + + +def convert_op_to_torch_module(target_op: Node): + """Convert op Node to torch module.""" + dependent_modules = [] + for in_port in target_op.inputs(): + out_port = in_port.get_source_output() + parent = out_port.get_node() + + parent_type = parent.get_type_name() + if parent_type == "Constant": + dependent_modules.append(convert_op_to_torch(parent)) + else: + dependent_modules.append(None) + module = convert_op_to_torch(target_op) + module = OperationModule(module, dependent_modules) + return module diff --git a/otx/core/ov/ops/op.py b/otx/core/ov/ops/op.py index a5e6950b039..ee43731a2e9 100644 --- a/otx/core/ov/ops/op.py +++ b/otx/core/ov/ops/op.py @@ -9,8 +9,7 @@ import torch -from otx.mpa.modules.ov.utils import get_op_name - +from ..utils import get_op_name # type: ignore[attr-defined] from .utils import get_dynamic_shape diff --git a/otx/core/ov/ops/utils.py b/otx/core/ov/ops/utils.py index c1d4c51ce7d..50e4a9f3b1d 100644 --- a/otx/core/ov/ops/utils.py +++ b/otx/core/ov/ops/utils.py @@ -3,6 +3,14 @@ # # SPDX-License-Identifier: MIT +from openvino.pyopenvino import Node # pylint: disable=no-name-in-module + +from otx.mpa.utils.logger import get_logger + +from .builder import OPS + +logger = get_logger() + def get_dynamic_shape(output): """Getter function for dynamic shape.""" @@ -14,3 +22,20 @@ def get_dynamic_shape(output): shape_ = -1 shape[i] = shape_ return shape + + +def convert_op_to_torch(op_node: Node): + """Convert op Node to torch.""" + op_type = op_node.get_type_name() + op_version = op_node.get_version() + + try: + torch_module = OPS.get_by_type_version(op_type, op_version).from_ov(op_node) + except Exception as e: + logger.error(e) + logger.error(op_type) + logger.error(op_version) + logger.error(op_node.get_attributes()) + raise e + + return torch_module diff --git a/otx/mpa/modules/ov/utils.py b/otx/core/ov/utils.py similarity index 67% rename from otx/mpa/modules/ov/utils.py rename to otx/core/ov/utils.py index f3b8d12043c..edaf14fcf9e 100644 --- a/otx/mpa/modules/ov/utils.py +++ b/otx/core/ov/utils.py @@ -1,21 +1,28 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +# type: ignore +# TODO: Need to remove line 1 (ignore mypy) and fix mypy issues +"""Utils for otx.core.ov.""" +# Copyright (C) 2023 Intel Corporation # +# SPDX-License-Identifier: MIT import errno import os -from typing import List, Optional +from typing import Optional -from openvino.pyopenvino import Model, Node +from openvino.pyopenvino import Model, Node # pylint: disable=no-name-in-module from openvino.runtime import Core -from otx.core.ov.omz_wrapper import AVAILABLE_OMZ_MODELS, get_omz_model from otx.mpa.utils.logger import get_logger +from .omz_wrapper import AVAILABLE_OMZ_MODELS, get_omz_model + logger = get_logger() +# pylint: disable=too-many-locals + def to_dynamic_model(ov_model: Model) -> Model: + """Convert ov_model to dynamic Model.""" assert isinstance(ov_model, Model) shapes = {} @@ -56,14 +63,13 @@ def reshape_model(ov_model, shapes): try: ov_model.reshape(shapes) return True - except Exception: + except Exception: # pylint: disable=broad-exception-caught return False pop_targets = [["height", "width"], ["batch"]] pop_targets = pop_targets[::-1] while not reshape_model(ov_model, shapes): - for key in shapes.keys(): - shape = shapes[key] + for key, shape in shapes.items(): target_layout = target_layouts[key] targets = pop_targets.pop() @@ -80,6 +86,7 @@ def reshape_model(ov_model, shapes): def load_ov_model(model_path: str, weight_path: Optional[str] = None, convert_dynamic: bool = False) -> Model: + """Load ov_model from model_path.""" model_path = str(model_path) if model_path.startswith("omz://"): model_path = model_path.replace("omz://", "") @@ -96,8 +103,8 @@ def load_ov_model(model_path: str, weight_path: Optional[str] = None, convert_dy if not os.path.exists(weight_path): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), weight_path) - ie = Core() - ov_model = ie.read_model(model=model_path, weights=weight_path) + ie_core = Core() + ov_model = ie_core.read_model(model=model_path, weights=weight_path) if convert_dynamic: ov_model = to_dynamic_model(ov_model) @@ -106,54 +113,20 @@ def load_ov_model(model_path: str, weight_path: Optional[str] = None, convert_dy def normalize_name(name: str) -> str: + """Normalize name string.""" # ModuleDict does not allow '.' in module name string name = name.replace(".", "#") - return name + return f"{name}" def unnormalize_name(name: str) -> str: + """Unnormalize name string.""" name = name.replace("#", ".") return name -def get_op_name(op: Node) -> str: - op_name = op.get_friendly_name() +def get_op_name(op_node: Node) -> str: + """Get op name string.""" + op_name = op_node.get_friendly_name() op_name = normalize_name(op_name) return op_name - - -def convert_op_to_torch(op: Node): - - from otx.core.ov.ops.builder import OPS - - op_type = op.get_type_name() - op_version = op.get_version() - - try: - torch_module = OPS.get_by_type_version(op_type, op_version).from_ov(op) - except Exception as e: - logger.error(e) - logger.error(op_type) - logger.error(op_version) - logger.error(op.get_attributes()) - raise e - - return torch_module - - -def convert_op_to_torch_module(target_op: Node): - from otx.core.ov.ops.modules.op_module import OperationModule - - dependent_modules = [] - for in_port in target_op.inputs(): - out_port = in_port.get_source_output() - parent = out_port.get_node() - - parent_type = parent.get_type_name() - if parent_type == "Constant": - dependent_modules.append(convert_op_to_torch(parent)) - else: - dependent_modules.append(None) - module = convert_op_to_torch(target_op) - module = OperationModule(module, dependent_modules) - return module diff --git a/otx/mpa/modules/ov/graph/__init__.py b/otx/mpa/modules/ov/graph/__init__.py deleted file mode 100644 index 8f15e00642e..00000000000 --- a/otx/mpa/modules/ov/graph/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -# flake8: noqa -from .graph import Graph diff --git a/otx/mpa/modules/ov/graph/parsers/__init__.py b/otx/mpa/modules/ov/graph/parsers/__init__.py deleted file mode 100644 index c8f589f8ca2..00000000000 --- a/otx/mpa/modules/ov/graph/parsers/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -# from . import cls - -# flake8: noqa -from .builder import PARSERS diff --git a/otx/mpa/modules/ov/graph/parsers/builder.py b/otx/mpa/modules/ov/graph/parsers/builder.py deleted file mode 100644 index b9cd8d60317..00000000000 --- a/otx/mpa/modules/ov/graph/parsers/builder.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -from otx.core.ov.registry import Registry - -PARSERS = Registry("ov graph parsers") diff --git a/otx/mpa/modules/ov/models/__init__.py b/otx/mpa/modules/ov/models/__init__.py deleted file mode 100644 index 1e19f1159d9..00000000000 --- a/otx/mpa/modules/ov/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# diff --git a/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_cls_parser.py b/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_cls_parser.py index e13e4bd8df6..37e0519a318 100644 --- a/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_cls_parser.py +++ b/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_cls_parser.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 # +from otx.core.ov.graph import Graph from otx.core.ov.graph.parsers.cls import cls_base_parser -from otx.mpa.modules.ov.graph.graph import Graph -from otx.mpa.modules.ov.utils import load_ov_model +from otx.core.ov.utils import load_ov_model from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_parser.py b/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_parser.py index 405f3f37ebc..8d26f8479cb 100644 --- a/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_parser.py +++ b/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_parser.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 # -from otx.mpa.modules.ov.graph.graph import Graph -from otx.mpa.modules.ov.graph.parsers.parser import parameter_parser, result_parser -from otx.mpa.modules.ov.utils import load_ov_model +from otx.core.ov.graph import Graph +from otx.core.ov.graph.parsers.parser import parameter_parser, result_parser +from otx.core.ov.utils import load_ov_model from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/graph/test_ov_graph_grapy.py b/tests/unit/mpa/modules/ov/graph/test_ov_graph_grapy.py index afa43e09e55..c00743df343 100644 --- a/tests/unit/mpa/modules/ov/graph/test_ov_graph_grapy.py +++ b/tests/unit/mpa/modules/ov/graph/test_ov_graph_grapy.py @@ -9,7 +9,7 @@ import openvino.runtime as ov import pytest -from otx.mpa.modules.ov.graph.graph import Graph, SortedDict +from otx.core.ov.graph.graph import Graph, SortedDict from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/graph/test_ov_graph_utils.py b/tests/unit/mpa/modules/ov/graph/test_ov_graph_utils.py index 5847bc76d98..7133f523da4 100644 --- a/tests/unit/mpa/modules/ov/graph/test_ov_graph_utils.py +++ b/tests/unit/mpa/modules/ov/graph/test_ov_graph_utils.py @@ -2,13 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 # -from otx.mpa.modules.ov.graph.graph import Graph -from otx.mpa.modules.ov.graph.utils import ( +from otx.core.ov.graph.graph import Graph +from otx.core.ov.graph.utils import ( get_constant_input_nodes, handle_paired_batchnorm, handle_reshape, ) -from otx.mpa.modules.ov.utils import load_ov_model +from otx.core.ov.utils import load_ov_model from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/models/test_ov_models_ov_model.py b/tests/unit/mpa/modules/ov/models/test_ov_models_ov_model.py index dda0f86c4b8..f3bf7fe78ca 100644 --- a/tests/unit/mpa/modules/ov/models/test_ov_models_ov_model.py +++ b/tests/unit/mpa/modules/ov/models/test_ov_models_ov_model.py @@ -6,7 +6,7 @@ import openvino.runtime as ov import torch -from otx.mpa.modules.ov.models.ov_model import OVModel +from otx.core.ov.models.ov_model import OVModel from tests.test_suite.e2e_test_system import e2e_pytest_unit diff --git a/tests/unit/mpa/modules/ov/test_ov_utils.py b/tests/unit/mpa/modules/ov/test_ov_utils.py index 170517379ab..3057de461db 100644 --- a/tests/unit/mpa/modules/ov/test_ov_utils.py +++ b/tests/unit/mpa/modules/ov/test_ov_utils.py @@ -10,9 +10,9 @@ from otx.core.ov.omz_wrapper import get_omz_model from otx.core.ov.ops.infrastructures import ParameterV0 -from otx.mpa.modules.ov.utils import ( - convert_op_to_torch, - convert_op_to_torch_module, +from otx.core.ov.ops.modules.op_module import convert_op_to_torch_module +from otx.core.ov.ops.utils import convert_op_to_torch +from otx.core.ov.utils import ( get_op_name, load_ov_model, normalize_name, From 7dcec130902f6e7fb9502bf6b6d8d18603ab60ca Mon Sep 17 00:00:00 2001 From: "Kang, Harim" Date: Thu, 23 Mar 2023 15:55:24 +0900 Subject: [PATCH 3/6] Remove otx.mpa.modules.ov --- otx/mpa/modules/ov/__init__.py | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 otx/mpa/modules/ov/__init__.py diff --git a/otx/mpa/modules/ov/__init__.py b/otx/mpa/modules/ov/__init__.py deleted file mode 100644 index aaeb84acfb5..00000000000 --- a/otx/mpa/modules/ov/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -# flake8: noqa -# from .graph import * -# from .models import * - -# from otx.core.ov.ops import * From 267ffff80685f46c2511a3a56cc4ec4ab3df3444 Mon Sep 17 00:00:00 2001 From: "Kang, Harim" Date: Thu, 23 Mar 2023 16:26:53 +0900 Subject: [PATCH 4/6] Fix docs & unit test paths --- docs/source/guide/index.rst | 2 +- docs/source/guide/reference/core/index.rst | 8 ++ .../{mpa/modules => core}/ov/graph.rst | 8 +- .../{mpa/modules => core}/ov/index.rst | 8 +- .../{mpa/modules => core}/ov/models.rst | 8 +- docs/source/guide/reference/core/ov/ops.rst | 82 +++++++++++++++++++ .../guide/reference/mpa/modules/index.rst | 1 - .../guide/reference/mpa/modules/ov/ops.rst | 82 ------------------- docs/source/guide/reference/mpa/utils.rst | 3 - .../graph/parsers/test_ov_graph_cls_parser.py | 0 .../ov/graph/parsers/test_ov_graph_parser.py | 0 .../ov/graph/test_ov_graph_grapy.py | 0 .../ov/graph/test_ov_graph_utils.py | 0 .../modules => core}/ov/models/__init__.py | 0 .../ov/models/mmcls/__init__.py | 0 .../backbones/test_ov_mmcls_mmov_backbone.py | 2 +- .../mmcls/heads/test_ov_mmcls_cls_head.py | 0 .../mmcls/heads/test_ov_mmcls_conv_head.py | 0 .../heads/test_ov_mmcls_mmcv_cls_head.py | 2 +- .../mmcls/necks/test_ov_mmcls_mmov_neck.py | 2 +- .../ov/models/mmcls/test_helpers.py | 0 .../ov/models/test_ov_models_ov_model.py | 0 .../ov/ops/test_ov_ops_activations.py | 0 .../ov/ops/test_ov_ops_arithmetics.py | 0 .../ov/ops/test_ov_ops_builder.py | 0 .../ov/ops/test_ov_ops_convolutions.py | 0 .../ov/ops/test_ov_ops_generation.py | 0 .../ov/ops/test_ov_ops_image_processings.py | 0 .../ov/ops/test_ov_ops_infrastructures.py | 0 .../ov/ops/test_ov_ops_matmuls.py | 0 .../ov/ops/test_ov_ops_module.py | 0 .../ov/ops/test_ov_ops_movements.py | 0 .../ov/ops/test_ov_ops_normalizations.py | 0 .../ov/ops/test_ov_ops_object_detections.py | 0 .../modules => core}/ov/ops/test_ov_ops_op.py | 0 .../ov/ops/test_ov_ops_poolings.py | 0 .../ov/ops/test_ov_ops_reductions.py | 0 .../ov/ops/test_ov_ops_shape_manipulations.py | 0 .../ops/test_ov_ops_sorting_maximization.py | 0 .../ov/ops/test_ov_ops_type_conversions.py | 0 .../ov/ops/test_ov_ops_utils.py | 0 .../ov/test_ov_omz_wrapper.py | 0 .../modules => core}/ov/test_ov_registry.py | 0 .../{mpa/modules => core}/ov/test_ov_utils.py | 0 44 files changed, 106 insertions(+), 102 deletions(-) create mode 100644 docs/source/guide/reference/core/index.rst rename docs/source/guide/reference/{mpa/modules => core}/ov/graph.rst (51%) rename docs/source/guide/reference/{mpa/modules => core}/ov/index.rst (56%) rename docs/source/guide/reference/{mpa/modules => core}/ov/models.rst (50%) create mode 100644 docs/source/guide/reference/core/ov/ops.rst delete mode 100644 docs/source/guide/reference/mpa/modules/ov/ops.rst rename tests/unit/{mpa/modules => core}/ov/graph/parsers/test_ov_graph_cls_parser.py (100%) rename tests/unit/{mpa/modules => core}/ov/graph/parsers/test_ov_graph_parser.py (100%) rename tests/unit/{mpa/modules => core}/ov/graph/test_ov_graph_grapy.py (100%) rename tests/unit/{mpa/modules => core}/ov/graph/test_ov_graph_utils.py (100%) rename tests/unit/{mpa/modules => core}/ov/models/__init__.py (100%) rename tests/unit/{mpa/modules => core}/ov/models/mmcls/__init__.py (100%) rename tests/unit/{mpa/modules => core}/ov/models/mmcls/backbones/test_ov_mmcls_mmov_backbone.py (94%) rename tests/unit/{mpa/modules => core}/ov/models/mmcls/heads/test_ov_mmcls_cls_head.py (100%) rename tests/unit/{mpa/modules => core}/ov/models/mmcls/heads/test_ov_mmcls_conv_head.py (100%) rename tests/unit/{mpa/modules => core}/ov/models/mmcls/heads/test_ov_mmcls_mmcv_cls_head.py (93%) rename tests/unit/{mpa/modules => core}/ov/models/mmcls/necks/test_ov_mmcls_mmov_neck.py (86%) rename tests/unit/{mpa/modules => core}/ov/models/mmcls/test_helpers.py (100%) rename tests/unit/{mpa/modules => core}/ov/models/test_ov_models_ov_model.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_activations.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_arithmetics.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_builder.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_convolutions.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_generation.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_image_processings.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_infrastructures.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_matmuls.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_module.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_movements.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_normalizations.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_object_detections.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_op.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_poolings.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_reductions.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_shape_manipulations.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_sorting_maximization.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_type_conversions.py (100%) rename tests/unit/{mpa/modules => core}/ov/ops/test_ov_ops_utils.py (100%) rename tests/unit/{mpa/modules => core}/ov/test_ov_omz_wrapper.py (100%) rename tests/unit/{mpa/modules => core}/ov/test_ov_registry.py (100%) rename tests/unit/{mpa/modules => core}/ov/test_ov_utils.py (100%) diff --git a/docs/source/guide/index.rst b/docs/source/guide/index.rst index f798f932241..2715ced60c0 100644 --- a/docs/source/guide/index.rst +++ b/docs/source/guide/index.rst @@ -30,7 +30,7 @@ Guide reference/api/index reference/algorithm/index - reference/core/data + reference/core/index reference/hpo/hpo reference/mpa/index diff --git a/docs/source/guide/reference/core/index.rst b/docs/source/guide/reference/core/index.rst new file mode 100644 index 00000000000..b55754cce07 --- /dev/null +++ b/docs/source/guide/reference/core/index.rst @@ -0,0 +1,8 @@ +Core +==== + +.. toctree:: + :maxdepth: 1 + + data + ov/index diff --git a/docs/source/guide/reference/mpa/modules/ov/graph.rst b/docs/source/guide/reference/core/ov/graph.rst similarity index 51% rename from docs/source/guide/reference/mpa/modules/ov/graph.rst rename to docs/source/guide/reference/core/ov/graph.rst index 01bab6ca9b4..6f47f9e154b 100644 --- a/docs/source/guide/reference/mpa/modules/ov/graph.rst +++ b/docs/source/guide/reference/core/ov/graph.rst @@ -5,18 +5,18 @@ Graph :maxdepth: 3 :caption: Contents: -.. automodule:: otx.mpa.modules.ov.graph +.. automodule:: otx.core.ov.graph :members: :undoc-members: -.. automodule:: otx.mpa.modules.ov.graph.graph +.. automodule:: otx.core.ov.graph.graph :members: :undoc-members: -.. automodule:: otx.mpa.modules.ov.graph.utils +.. automodule:: otx.core.ov.graph.utils :members: :undoc-members: -.. automodule:: otx.mpa.modules.ov.graph.parsers +.. automodule:: otx.core.ov.graph.parsers :members: :undoc-members: \ No newline at end of file diff --git a/docs/source/guide/reference/mpa/modules/ov/index.rst b/docs/source/guide/reference/core/ov/index.rst similarity index 56% rename from docs/source/guide/reference/mpa/modules/ov/index.rst rename to docs/source/guide/reference/core/ov/index.rst index ec585d95c7e..07ce1abd4cf 100644 --- a/docs/source/guide/reference/mpa/modules/ov/index.rst +++ b/docs/source/guide/reference/core/ov/index.rst @@ -8,18 +8,18 @@ OpenVINO models ops -.. automodule:: otx.mpa.modules.ov +.. automodule:: otx.core.ov :members: :undoc-members: -.. automodule:: otx.mpa.modules.ov.omz_wrapper +.. automodule:: otx.core.ov.omz_wrapper :members: :undoc-members: -.. automodule:: otx.mpa.modules.ov.registry +.. automodule:: otx.core.ov.registry :members: :undoc-members: -.. automodule:: otx.mpa.modules.ov.utils +.. automodule:: otx.core.ov.utils :members: :undoc-members: \ No newline at end of file diff --git a/docs/source/guide/reference/mpa/modules/ov/models.rst b/docs/source/guide/reference/core/ov/models.rst similarity index 50% rename from docs/source/guide/reference/mpa/modules/ov/models.rst rename to docs/source/guide/reference/core/ov/models.rst index 5531868fdea..c3f535e82ab 100644 --- a/docs/source/guide/reference/mpa/modules/ov/models.rst +++ b/docs/source/guide/reference/core/ov/models.rst @@ -5,18 +5,18 @@ Models :maxdepth: 3 :caption: Contents: -.. automodule:: otx.mpa.modules.ov.models +.. automodule:: otx.core.ov.models :members: :undoc-members: -.. automodule:: otx.mpa.modules.ov.models.mmov_model +.. automodule:: otx.core.ov.models.mmov_model :members: :undoc-members: -.. automodule:: otx.mpa.modules.ov.models.ov_model +.. automodule:: otx.core.ov.models.ov_model :members: :undoc-members: -.. automodule:: otx.mpa.modules.ov.models.parser_mixin +.. automodule:: otx.core.ov.models.parser_mixin :members: :undoc-members: diff --git a/docs/source/guide/reference/core/ov/ops.rst b/docs/source/guide/reference/core/ov/ops.rst new file mode 100644 index 00000000000..7b249e02702 --- /dev/null +++ b/docs/source/guide/reference/core/ov/ops.rst @@ -0,0 +1,82 @@ +OPS +^^^ + +.. toctree:: + :maxdepth: 3 + :caption: Contents: + +.. automodule:: otx.core.ov.ops + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.activations + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.arithmetics + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.builder + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.convolutions + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.generation + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.image_processings + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.infrastructures + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.matmuls + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.movements + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.normalizations + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.object_detections + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.op + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.poolings + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.reductions + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.shape_manipulations + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.sorting_maximization + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.type_conversions + :members: + :undoc-members: + +.. automodule:: otx.core.ov.ops.utils + :members: + :undoc-members: \ No newline at end of file diff --git a/docs/source/guide/reference/mpa/modules/index.rst b/docs/source/guide/reference/mpa/modules/index.rst index 731dfbb571f..111f92ffdba 100644 --- a/docs/source/guide/reference/mpa/modules/index.rst +++ b/docs/source/guide/reference/mpa/modules/index.rst @@ -7,5 +7,4 @@ Modules models/index datasets hooks - ov/index utils diff --git a/docs/source/guide/reference/mpa/modules/ov/ops.rst b/docs/source/guide/reference/mpa/modules/ov/ops.rst deleted file mode 100644 index 72b59b1770f..00000000000 --- a/docs/source/guide/reference/mpa/modules/ov/ops.rst +++ /dev/null @@ -1,82 +0,0 @@ -OPS -^^^ - -.. toctree:: - :maxdepth: 3 - :caption: Contents: - -.. automodule:: otx.mpa.modules.ov.ops - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.activations - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.arithmetics - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.builder - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.convolutions - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.generation - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.image_processings - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.infrastructures - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.matmuls - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.movements - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.normalizations - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.object_detections - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.op - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.poolings - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.reductions - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.shape_manipulations - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.sorting_maximization - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.type_conversions - :members: - :undoc-members: - -.. automodule:: otx.mpa.modules.ov.ops.utils - :members: - :undoc-members: \ No newline at end of file diff --git a/docs/source/guide/reference/mpa/utils.rst b/docs/source/guide/reference/mpa/utils.rst index 9419a4a872c..67b73170d23 100644 --- a/docs/source/guide/reference/mpa/utils.rst +++ b/docs/source/guide/reference/mpa/utils.rst @@ -21,9 +21,6 @@ Utils :members: :undoc-members: -.. automodule:: otx.mpa.utils.file - :members: - :undoc-members: .. automodule:: otx.mpa.utils.mo_wrapper :members: diff --git a/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_cls_parser.py b/tests/unit/core/ov/graph/parsers/test_ov_graph_cls_parser.py similarity index 100% rename from tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_cls_parser.py rename to tests/unit/core/ov/graph/parsers/test_ov_graph_cls_parser.py diff --git a/tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_parser.py b/tests/unit/core/ov/graph/parsers/test_ov_graph_parser.py similarity index 100% rename from tests/unit/mpa/modules/ov/graph/parsers/test_ov_graph_parser.py rename to tests/unit/core/ov/graph/parsers/test_ov_graph_parser.py diff --git a/tests/unit/mpa/modules/ov/graph/test_ov_graph_grapy.py b/tests/unit/core/ov/graph/test_ov_graph_grapy.py similarity index 100% rename from tests/unit/mpa/modules/ov/graph/test_ov_graph_grapy.py rename to tests/unit/core/ov/graph/test_ov_graph_grapy.py diff --git a/tests/unit/mpa/modules/ov/graph/test_ov_graph_utils.py b/tests/unit/core/ov/graph/test_ov_graph_utils.py similarity index 100% rename from tests/unit/mpa/modules/ov/graph/test_ov_graph_utils.py rename to tests/unit/core/ov/graph/test_ov_graph_utils.py diff --git a/tests/unit/mpa/modules/ov/models/__init__.py b/tests/unit/core/ov/models/__init__.py similarity index 100% rename from tests/unit/mpa/modules/ov/models/__init__.py rename to tests/unit/core/ov/models/__init__.py diff --git a/tests/unit/mpa/modules/ov/models/mmcls/__init__.py b/tests/unit/core/ov/models/mmcls/__init__.py similarity index 100% rename from tests/unit/mpa/modules/ov/models/mmcls/__init__.py rename to tests/unit/core/ov/models/mmcls/__init__.py diff --git a/tests/unit/mpa/modules/ov/models/mmcls/backbones/test_ov_mmcls_mmov_backbone.py b/tests/unit/core/ov/models/mmcls/backbones/test_ov_mmcls_mmov_backbone.py similarity index 94% rename from tests/unit/mpa/modules/ov/models/mmcls/backbones/test_ov_mmcls_mmov_backbone.py rename to tests/unit/core/ov/models/mmcls/backbones/test_ov_mmcls_mmov_backbone.py index c3e48d1428b..87ed5aae9de 100644 --- a/tests/unit/mpa/modules/ov/models/mmcls/backbones/test_ov_mmcls_mmov_backbone.py +++ b/tests/unit/core/ov/models/mmcls/backbones/test_ov_mmcls_mmov_backbone.py @@ -9,7 +9,7 @@ MMOVBackbone, ) from tests.test_suite.e2e_test_system import e2e_pytest_unit -from tests.unit.mpa.modules.ov.models.mmcls.test_helpers import create_ov_model +from tests.unit.core.ov.models.mmcls.test_helpers import create_ov_model class TestMMOVBackbone: diff --git a/tests/unit/mpa/modules/ov/models/mmcls/heads/test_ov_mmcls_cls_head.py b/tests/unit/core/ov/models/mmcls/heads/test_ov_mmcls_cls_head.py similarity index 100% rename from tests/unit/mpa/modules/ov/models/mmcls/heads/test_ov_mmcls_cls_head.py rename to tests/unit/core/ov/models/mmcls/heads/test_ov_mmcls_cls_head.py diff --git a/tests/unit/mpa/modules/ov/models/mmcls/heads/test_ov_mmcls_conv_head.py b/tests/unit/core/ov/models/mmcls/heads/test_ov_mmcls_conv_head.py similarity index 100% rename from tests/unit/mpa/modules/ov/models/mmcls/heads/test_ov_mmcls_conv_head.py rename to tests/unit/core/ov/models/mmcls/heads/test_ov_mmcls_conv_head.py diff --git a/tests/unit/mpa/modules/ov/models/mmcls/heads/test_ov_mmcls_mmcv_cls_head.py b/tests/unit/core/ov/models/mmcls/heads/test_ov_mmcls_mmcv_cls_head.py similarity index 93% rename from tests/unit/mpa/modules/ov/models/mmcls/heads/test_ov_mmcls_mmcv_cls_head.py rename to tests/unit/core/ov/models/mmcls/heads/test_ov_mmcls_mmcv_cls_head.py index 71823180291..c12e05642b6 100644 --- a/tests/unit/mpa/modules/ov/models/mmcls/heads/test_ov_mmcls_mmcv_cls_head.py +++ b/tests/unit/core/ov/models/mmcls/heads/test_ov_mmcls_mmcv_cls_head.py @@ -9,7 +9,7 @@ MMOVClsHead, ) from tests.test_suite.e2e_test_system import e2e_pytest_unit -from tests.unit.mpa.modules.ov.models.mmcls.test_helpers import create_ov_model +from tests.unit.core.ov.models.mmcls.test_helpers import create_ov_model class TestMMOVClsHead: diff --git a/tests/unit/mpa/modules/ov/models/mmcls/necks/test_ov_mmcls_mmov_neck.py b/tests/unit/core/ov/models/mmcls/necks/test_ov_mmcls_mmov_neck.py similarity index 86% rename from tests/unit/mpa/modules/ov/models/mmcls/necks/test_ov_mmcls_mmov_neck.py rename to tests/unit/core/ov/models/mmcls/necks/test_ov_mmcls_mmov_neck.py index 8dff281798a..d6255891516 100644 --- a/tests/unit/mpa/modules/ov/models/mmcls/necks/test_ov_mmcls_mmov_neck.py +++ b/tests/unit/core/ov/models/mmcls/necks/test_ov_mmcls_mmov_neck.py @@ -6,7 +6,7 @@ from otx.algorithms.classification.adapters.mmcls.models.necks import MMOVNeck from tests.test_suite.e2e_test_system import e2e_pytest_unit -from tests.unit.mpa.modules.ov.models.mmcls.test_helpers import create_ov_model +from tests.unit.core.ov.models.mmcls.test_helpers import create_ov_model class TestMMOVNeck: diff --git a/tests/unit/mpa/modules/ov/models/mmcls/test_helpers.py b/tests/unit/core/ov/models/mmcls/test_helpers.py similarity index 100% rename from tests/unit/mpa/modules/ov/models/mmcls/test_helpers.py rename to tests/unit/core/ov/models/mmcls/test_helpers.py diff --git a/tests/unit/mpa/modules/ov/models/test_ov_models_ov_model.py b/tests/unit/core/ov/models/test_ov_models_ov_model.py similarity index 100% rename from tests/unit/mpa/modules/ov/models/test_ov_models_ov_model.py rename to tests/unit/core/ov/models/test_ov_models_ov_model.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_activations.py b/tests/unit/core/ov/ops/test_ov_ops_activations.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_activations.py rename to tests/unit/core/ov/ops/test_ov_ops_activations.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_arithmetics.py b/tests/unit/core/ov/ops/test_ov_ops_arithmetics.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_arithmetics.py rename to tests/unit/core/ov/ops/test_ov_ops_arithmetics.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_builder.py b/tests/unit/core/ov/ops/test_ov_ops_builder.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_builder.py rename to tests/unit/core/ov/ops/test_ov_ops_builder.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_convolutions.py b/tests/unit/core/ov/ops/test_ov_ops_convolutions.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_convolutions.py rename to tests/unit/core/ov/ops/test_ov_ops_convolutions.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_generation.py b/tests/unit/core/ov/ops/test_ov_ops_generation.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_generation.py rename to tests/unit/core/ov/ops/test_ov_ops_generation.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_image_processings.py b/tests/unit/core/ov/ops/test_ov_ops_image_processings.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_image_processings.py rename to tests/unit/core/ov/ops/test_ov_ops_image_processings.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_infrastructures.py b/tests/unit/core/ov/ops/test_ov_ops_infrastructures.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_infrastructures.py rename to tests/unit/core/ov/ops/test_ov_ops_infrastructures.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_matmuls.py b/tests/unit/core/ov/ops/test_ov_ops_matmuls.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_matmuls.py rename to tests/unit/core/ov/ops/test_ov_ops_matmuls.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_module.py b/tests/unit/core/ov/ops/test_ov_ops_module.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_module.py rename to tests/unit/core/ov/ops/test_ov_ops_module.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_movements.py b/tests/unit/core/ov/ops/test_ov_ops_movements.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_movements.py rename to tests/unit/core/ov/ops/test_ov_ops_movements.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_normalizations.py b/tests/unit/core/ov/ops/test_ov_ops_normalizations.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_normalizations.py rename to tests/unit/core/ov/ops/test_ov_ops_normalizations.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_object_detections.py b/tests/unit/core/ov/ops/test_ov_ops_object_detections.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_object_detections.py rename to tests/unit/core/ov/ops/test_ov_ops_object_detections.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_op.py b/tests/unit/core/ov/ops/test_ov_ops_op.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_op.py rename to tests/unit/core/ov/ops/test_ov_ops_op.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_poolings.py b/tests/unit/core/ov/ops/test_ov_ops_poolings.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_poolings.py rename to tests/unit/core/ov/ops/test_ov_ops_poolings.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_reductions.py b/tests/unit/core/ov/ops/test_ov_ops_reductions.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_reductions.py rename to tests/unit/core/ov/ops/test_ov_ops_reductions.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_shape_manipulations.py b/tests/unit/core/ov/ops/test_ov_ops_shape_manipulations.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_shape_manipulations.py rename to tests/unit/core/ov/ops/test_ov_ops_shape_manipulations.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_sorting_maximization.py b/tests/unit/core/ov/ops/test_ov_ops_sorting_maximization.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_sorting_maximization.py rename to tests/unit/core/ov/ops/test_ov_ops_sorting_maximization.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_type_conversions.py b/tests/unit/core/ov/ops/test_ov_ops_type_conversions.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_type_conversions.py rename to tests/unit/core/ov/ops/test_ov_ops_type_conversions.py diff --git a/tests/unit/mpa/modules/ov/ops/test_ov_ops_utils.py b/tests/unit/core/ov/ops/test_ov_ops_utils.py similarity index 100% rename from tests/unit/mpa/modules/ov/ops/test_ov_ops_utils.py rename to tests/unit/core/ov/ops/test_ov_ops_utils.py diff --git a/tests/unit/mpa/modules/ov/test_ov_omz_wrapper.py b/tests/unit/core/ov/test_ov_omz_wrapper.py similarity index 100% rename from tests/unit/mpa/modules/ov/test_ov_omz_wrapper.py rename to tests/unit/core/ov/test_ov_omz_wrapper.py diff --git a/tests/unit/mpa/modules/ov/test_ov_registry.py b/tests/unit/core/ov/test_ov_registry.py similarity index 100% rename from tests/unit/mpa/modules/ov/test_ov_registry.py rename to tests/unit/core/ov/test_ov_registry.py diff --git a/tests/unit/mpa/modules/ov/test_ov_utils.py b/tests/unit/core/ov/test_ov_utils.py similarity index 100% rename from tests/unit/mpa/modules/ov/test_ov_utils.py rename to tests/unit/core/ov/test_ov_utils.py From 6406c91d6b3224b75deb6b61b957cb4b4ccc072b Mon Sep 17 00:00:00 2001 From: "Kang, Harim" Date: Thu, 23 Mar 2023 20:04:19 +0900 Subject: [PATCH 5/6] Fix some import --- otx/core/ov/__init__.py | 5 + otx/core/ov/graph/graph.py | 8 +- .../ov/graph/parsers/cls/cls_base_parser.py | 18 +-- otx/core/ov/graph/utils.py | 49 ++++--- otx/core/ov/models/__init__.py | 4 + otx/core/ov/models/ov_model.py | 11 +- otx/core/ov/models/parser_mixin.py | 8 +- otx/core/ov/ops/__init__.py | 136 ++++++++++++++++++ otx/core/ov/ops/infrastructures.py | 9 +- otx/core/ov/ops/utils.py | 12 +- otx/core/ov/utils.py | 4 - 11 files changed, 189 insertions(+), 75 deletions(-) diff --git a/otx/core/ov/__init__.py b/otx/core/ov/__init__.py index 9ae460d03b2..1ac90928687 100644 --- a/otx/core/ov/__init__.py +++ b/otx/core/ov/__init__.py @@ -2,3 +2,8 @@ # Copyright (C) 2023 Intel Corporation # # SPDX-License-Identifier: MIT + +# flake8: noqa +from .graph import * +from .models import * +from .ops import * diff --git a/otx/core/ov/graph/graph.py b/otx/core/ov/graph/graph.py index 75aad71282b..7079473b01f 100644 --- a/otx/core/ov/graph/graph.py +++ b/otx/core/ov/graph/graph.py @@ -15,14 +15,10 @@ import networkx as nx from openvino.pyopenvino import Model # pylint: disable=no-name-in-module -from otx.mpa.utils.logger import get_logger - from ..ops.op import Operation from ..ops.utils import convert_op_to_torch from ..utils import get_op_name -logger = get_logger() - # pylint: disable=too-many-locals, too-many-nested-blocks, arguments-renamed, too-many-branches, too-many-statements @@ -602,10 +598,10 @@ def remove_normalize_nodes(self): elif second_node is None: second_node = first_node self.remove_node(first_node, keep_connect=True) - logger.info(f"Remove normalize node {first_node.name}") + # logger.info(f"Remove normalize node {first_node.name}") try: self.remove_node(second_node, keep_connect=True) - logger.info(f"Remove normalize node {second_node.name}") + # logger.info(f"Remove normalize node {second_node.name}") except Exception: # pylint: disable=broad-exception-caught pass self._normalize_nodes = [] diff --git a/otx/core/ov/graph/parsers/cls/cls_base_parser.py b/otx/core/ov/graph/parsers/cls/cls_base_parser.py index b5fd44d8beb..fee4c3c078e 100644 --- a/otx/core/ov/graph/parsers/cls/cls_base_parser.py +++ b/otx/core/ov/graph/parsers/cls/cls_base_parser.py @@ -5,13 +5,9 @@ from typing import Dict, List, Optional -from otx.mpa.utils.logger import get_logger - from ..builder import PARSERS from ..parser import parameter_parser -logger = get_logger() - # pylint: disable=too-many-return-statements, too-many-branches @@ -35,19 +31,19 @@ def cls_base_parser(graph, component: str = "backbone") -> Optional[Dict[str, Li result_nodes = graph.get_nodes_by_types(["Result"]) if len(result_nodes) != 1: - logger.debug("More than one reulst nodes are found.") + # logger.debug("More than one reulst nodes are found.") return None result_node = result_nodes[0] neck_input = None for _, node_to in graph.bfs(result_node, True, 20): if node_to.type in NECK_INPUT_TYPES: - logger.debug(f"Found neck_input: {node_to.name}") + # logger.debug(f"Found neck_input: {node_to.name}") neck_input = node_to break if neck_input is None: - logger.debug("Can not determine the output of backbone.") + # logger.debug("Can not determine the output of backbone.") return None neck_output = neck_input @@ -64,12 +60,12 @@ def cls_base_parser(graph, component: str = "backbone") -> Optional[Dict[str, Li if component == "backbone": outputs = [node.name for node in graph.predecessors(neck_input) if node.type != "Constant"] if len(outputs) != 1: - logger.debug(f"neck_input {neck_input.name} has more than one predecessors.") + # logger.debug(f"neck_input {neck_input.name} has more than one predecessors.") return None inputs = parameter_parser(graph) if len(inputs) != 1: - logger.debug("More than on parameter nodes are found.") + # logger.debug("More than on parameter nodes are found.") return None return dict( @@ -88,7 +84,7 @@ def cls_base_parser(graph, component: str = "backbone") -> Optional[Dict[str, Li outputs = graph.get_nodes_by_types(["Result"]) if len(outputs) != 1: - logger.debug("More than one network output is found.") + # logger.debug("More than one network output is found.") return None for node_from, node_to in graph.bfs(outputs[0], True, 5): if node_to.type == "Softmax": @@ -96,7 +92,7 @@ def cls_base_parser(graph, component: str = "backbone") -> Optional[Dict[str, Li break if not graph.has_path(head_inputs[0], outputs[0]): - logger.debug(f"input({head_inputs[0].name}) and output({outputs[0].name}) are reversed") + # logger.debug(f"input({head_inputs[0].name}) and output({outputs[0].name}) are reversed") return None return dict( diff --git a/otx/core/ov/graph/utils.py b/otx/core/ov/graph/utils.py index e7e55f77891..172b9050c90 100644 --- a/otx/core/ov/graph/utils.py +++ b/otx/core/ov/graph/utils.py @@ -11,9 +11,6 @@ from otx.core.ov.ops.builder import OPS from otx.core.ov.ops.infrastructures import ConstantV0 from otx.core.ov.ops.op import Operation -from otx.mpa.utils.logger import get_logger - -logger = get_logger() # pylint: disable=too-many-locals, protected-access, too-many-branches, too-many-statements, too-many-nested-blocks @@ -55,10 +52,10 @@ def handle_merging_into_batchnorm(graph, type_patterns=None, type_mappings=None) is_normalize = True break if is_normalize: - logger.info( - f"Skip merging {[i.name for i in nodes]} " - f"becuase they are part of normalization (preprocessing of IR)" - ) + # logger.info( + # f"Skip merging {[i.name for i in nodes]} " + # f"becuase they are part of normalization (preprocessing of IR)" + # ) continue shapes = [] @@ -73,19 +70,20 @@ def handle_merging_into_batchnorm(graph, type_patterns=None, type_mappings=None) shapes.append(constant.shape) constants.append(constant) if not is_valid: - logger.info( - f"Skip merging {[i.name for i in nodes]} " f"becuase it has more than one weights for node {node.name}." - ) + # logger.info( + # f"Skip merging {[i.name for i in nodes]} " + # f"becuase it has more than one weights for node {node.name}." + # ) continue if len(set(shapes)) != 1: - logger.info( - f"Skip merging {[i.name for i in nodes]} " f"becuase shape of weights are not the same. ({shapes})" - ) + # logger.info( + # f"Skip merging {[i.name for i in nodes]} " f"becuase shape of weights are not the same. ({shapes})" + # ) continue if len(set(shapes[0][2:])) != 1 or shapes[0][2] != 1: - logger.info(f"Skip merging {[i.name for i in nodes]} " f"becuase shape of weights are not 1. ({shapes})") + # logger.info(f"Skip merging {[i.name for i in nodes]} " f"becuase shape of weights are not 1. ({shapes})") continue channel_dim = shapes[0][1] @@ -135,7 +133,7 @@ def handle_merging_into_batchnorm(graph, type_patterns=None, type_mappings=None) is_parameter=False, ) - logger.info(f"Merge {[i.name for i in nodes]} into batch normalization.") + # logger.info(f"Merge {[i.name for i in nodes]} into batch normalization.") edges = [] for predecessor in graph.predecessors(nodes[0]): if predecessor.type != "Constant": @@ -174,9 +172,9 @@ def handle_paired_batchnorm(graph, replace: bool = False, types: List[str] = Non edge = edge[0] input_shape = input_node.shape[edge["out_port"]][2:] if len(set(input_shape)) == 1 and input_shape[0] == 1: - logger.info( - f"Skip a paired batch normalization for {node.name} " f"becuase input shape to it is {input_shape}." - ) + # logger.info( + # f"Skip a paired batch normalization for {node.name} " f"becuase input shape to it is {input_shape}." + # ) continue bias_node_list: List[Any] = [n for n in graph.successors(node) if n.type == "Add"] @@ -187,13 +185,14 @@ def handle_paired_batchnorm(graph, replace: bool = False, types: List[str] = Non # if bias node is not found we do not need to add batchnorm if bias_node is None: - logger.info(f"Skip a paired batch normalization for {node.name} " "becuase it has no bias add node.") + # logger.info(f"Skip a paired batch normalization for {node.name} " "becuase it has no bias add node.") continue # if add node is not bias add node if not isinstance(list(graph.predecessors(bias_node))[1], ConstantV0): - logger.info( - f"Skip a pared batch normalization for {node.name} " f"because {bias_node.name} is not a bias add node." - ) + # logger.info( + # f"Skip a pared batch normalization for {node.name} " + # f"because {bias_node.name} is not a bias add node." + # ) continue node_name = node.name @@ -234,7 +233,7 @@ def handle_paired_batchnorm(graph, replace: bool = False, types: List[str] = Non ) if replace and bias_node is not None: - logger.info(f"Replace {bias_node.name} with a paired batch normalization.") + # logger.info(f"Replace {bias_node.name} with a paired batch normalization.") edges = [] for successor in graph.successors(bias_node): edges_attrs = graph.get_edge_data(bias_node, successor) @@ -257,7 +256,7 @@ def handle_paired_batchnorm(graph, replace: bool = False, types: List[str] = Non for edge in edges: graph.add_edge(**edge) else: - logger.info(f"Append a paired batch normalization after {node.name}") + # logger.info(f"Append a paired batch normalization after {node.name}") edges = [] for successor in graph.successors(node): edges_attrs = graph.get_edge_data(node, successor) @@ -287,5 +286,5 @@ def handle_reshape(graph): for shape_ in input_node.shape[0][::-1]: if shape_ != 1: break - logger.info(f"Change reshape to [-1, {shape_}]") # pylint: disable=undefined-loop-variable + # logger.info(f"Change reshape to [-1, {shape_}]") # pylint: disable=undefined-loop-variable shape.data = torch.tensor([-1, shape_]) # pylint: disable=undefined-loop-variable diff --git a/otx/core/ov/models/__init__.py b/otx/core/ov/models/__init__.py index 4c62ff2a3ad..99fff8b1223 100644 --- a/otx/core/ov/models/__init__.py +++ b/otx/core/ov/models/__init__.py @@ -4,7 +4,11 @@ # SPDX-License-Identifier: MIT from .mmov_model import MMOVModel +from .ov_model import OVModel # type: ignore[attr-defined] +from .parser_mixin import ParserMixin # type: ignore[attr-defined] __all__ = [ "MMOVModel", + "OVModel", + "ParserMixin", ] diff --git a/otx/core/ov/models/ov_model.py b/otx/core/ov/models/ov_model.py index c7a72f2ea29..cbccdee2919 100644 --- a/otx/core/ov/models/ov_model.py +++ b/otx/core/ov/models/ov_model.py @@ -16,8 +16,6 @@ import torch from torch.nn import init -from otx.mpa.utils.logger import get_logger - from ..graph import Graph from ..graph.utils import ( handle_merging_into_batchnorm, @@ -27,9 +25,6 @@ from ..ops.builder import OPS from ..utils import load_ov_model, normalize_name -logger = get_logger() - - CONNECTION_SEPARATOR = "||" # pylint: disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements @@ -120,7 +115,7 @@ def init_weight(module, graph): # pylint: disable=function-redefined init.zeros_(beta.data) mean.data.zero_() var.data.fill_(1) - logger.info(f"Initialize {module.TYPE} -> {module.name}") + # logger.info(f"Initialize {module.TYPE} -> {module.name}") elif module.TYPE in [ "Convolution", "GroupConvolution", @@ -129,7 +124,7 @@ def init_weight(module, graph): # pylint: disable=function-redefined for weight in graph.predecessors(module): if weight.TYPE == "Constant" and isinstance(weight.data, torch.nn.parameter.Parameter): init.kaiming_uniform_(weight.data, a=math.sqrt(5)) - logger.info(f"Initialize {module.TYPE} -> {module.name}") + # logger.info(f"Initialize {module.TYPE} -> {module.name}") elif module.TYPE in [ "Multiply", "Divide", @@ -143,7 +138,7 @@ def init_weight(module, graph): # pylint: disable=function-redefined ) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 init.uniform_(weight.data, -bound, bound) - logger.info(f"Initialize {module.TYPE} -> {module.name}") + # logger.info(f"Initialize {module.TYPE} -> {module.name}") self.model.apply(lambda m: init_weight(m, graph)) diff --git a/otx/core/ov/models/parser_mixin.py b/otx/core/ov/models/parser_mixin.py index 2943d58dd4d..46791c919fa 100644 --- a/otx/core/ov/models/parser_mixin.py +++ b/otx/core/ov/models/parser_mixin.py @@ -9,13 +9,9 @@ import openvino.runtime as ov -from otx.mpa.utils.logger import get_logger - from ..graph.parsers.builder import PARSERS from .ov_model import OVModel -logger = get_logger() - class ParserMixin: """ParserMixin class.""" @@ -51,8 +47,8 @@ def parse( inputs = parsed["inputs"] if not inputs else inputs outputs = parsed["outputs"] if not outputs else outputs - logger.info(f"inputs: {inputs}") - logger.info(f"outputs: {outputs}") + # logger.info(f"inputs: {inputs}") + # logger.info(f"outputs: {outputs}") return inputs, outputs diff --git a/otx/core/ov/ops/__init__.py b/otx/core/ov/ops/__init__.py index df6df449c4a..938a11995a6 100644 --- a/otx/core/ov/ops/__init__.py +++ b/otx/core/ov/ops/__init__.py @@ -2,3 +2,139 @@ # Copyright (C) 2023 Intel Corporation # # SPDX-License-Identifier: MIT + +from .activations import ( + ClampV0, + EluV0, + ExpV0, + GeluV7, + HardSigmoidV0, + HSigmoidV5, + HSwishV4, + MishV4, + PReluV0, + ReluV0, + SeluV0, + SigmoidV0, + SoftMaxV0, + SoftMaxV1, + SwishV4, + TanhV0, +) +from .arithmetics import AddV1, DivideV1, MultiplyV1, SubtractV1, TanV0 +from .builder import OPS, OperationRegistry +from .convolutions import ConvolutionV1, GroupConvolutionV1 +from .generation import RangeV4 +from .image_processings import InterpolateV4 +from .infrastructures import ConstantV0, ParameterV0, ResultV0 +from .matmuls import EinsumV7, MatMulV0 +from .movements import ( + BroadcastV3, + ConcatV0, + GatherV0, + GatherV1, + PadV1, + ScatterNDUpdateV3, + ScatterUpdateV3, + ShuffleChannelsV0, + SplitV1, + StridedSliceV1, + TileV0, + TransposeV1, + VariadicSplitV1, +) +from .normalizations import ( + MVNV6, + BatchNormalizationV0, + LocalResponseNormalizationV0, + NormalizeL2V0, +) +from .object_detections import ( + DetectionOutputV0, + PriorBoxClusteredV0, + PriorBoxV0, + ProposalV4, + RegionYoloV0, + ROIPoolingV0, +) +from .op import Attribute, Operation +from .poolings import AvgPoolV1, MaxPoolV0 +from .reductions import ReduceMeanV1, ReduceMinV1, ReduceProdV1, ReduceSumV1 +from .shape_manipulations import ReshapeV1, ShapeOfV0, ShapeOfV3, SqueezeV0, UnsqueezeV0 +from .sorting_maximization import NonMaxSuppressionV5, NonMaxSuppressionV9, TopKV3 +from .type_conversions import ConvertV0 + +__all__ = [ + "SoftMaxV0", + "SoftMaxV1", + "ReluV0", + "SwishV4", + "SigmoidV0", + "ClampV0", + "PReluV0", + "TanhV0", + "EluV0", + "SeluV0", + "MishV4", + "HSwishV4", + "HSigmoidV5", + "ExpV0", + "HardSigmoidV0", + "GeluV7", + "MultiplyV1", + "DivideV1", + "AddV1", + "SubtractV1", + "TanV0", + "OPS", + "OperationRegistry", + "ConvolutionV1", + "GroupConvolutionV1", + "RangeV4", + "InterpolateV4", + "ParameterV0", + "ResultV0", + "ConstantV0", + "MatMulV0", + "EinsumV7", + "PadV1", + "ConcatV0", + "TransposeV1", + "GatherV0", + "GatherV1", + "StridedSliceV1", + "SplitV1", + "VariadicSplitV1", + "ShuffleChannelsV0", + "BroadcastV3", + "ScatterNDUpdateV3", + "ScatterUpdateV3", + "TileV0", + "BatchNormalizationV0", + "LocalResponseNormalizationV0", + "NormalizeL2V0", + "MVNV6", + "ProposalV4", + "ROIPoolingV0", + "DetectionOutputV0", + "RegionYoloV0", + "PriorBoxV0", + "PriorBoxClusteredV0", + "Operation", + "Attribute", + "MaxPoolV0", + "AvgPoolV1", + "ReduceMeanV1", + "ReduceProdV1", + "ReduceMinV1", + "ReduceSumV1", + "SqueezeV0", + "UnsqueezeV0", + "ReshapeV1", + "ShapeOfV0", + "ShapeOfV3", + "TopKV3", + "NonMaxSuppressionV5", + "NonMaxSuppressionV9", + "ConvertV0", +] diff --git a/otx/core/ov/ops/infrastructures.py b/otx/core/ov/ops/infrastructures.py index f80356db842..4a4ec3026ed 100644 --- a/otx/core/ov/ops/infrastructures.py +++ b/otx/core/ov/ops/infrastructures.py @@ -10,17 +10,12 @@ import numpy as np import torch -from otx.mpa.utils.logger import get_logger - from ..utils import get_op_name # type: ignore[attr-defined] from .builder import OPS from .op import Attribute, Operation from .type_conversions import ConvertV0 from .utils import get_dynamic_shape -logger = get_logger() - - NODE_TYPES_WITH_WEIGHT = set( [ "Convolution", @@ -231,8 +226,8 @@ def from_ov(cls, ov_op): data = ov_op.get_data() if data.dtype == np.uint64: data_ = data.astype(np.int64) - if not np.array_equal(data, data_): - logger.warning(f"Overflow detected in {op_name}") + # if not np.array_equal(data, data_): + # logger.warning(f"Overflow detected in {op_name}") data = torch.from_numpy(data_) else: data = torch.from_numpy(data) diff --git a/otx/core/ov/ops/utils.py b/otx/core/ov/ops/utils.py index 50e4a9f3b1d..d1397d460e1 100644 --- a/otx/core/ov/ops/utils.py +++ b/otx/core/ov/ops/utils.py @@ -5,12 +5,8 @@ from openvino.pyopenvino import Node # pylint: disable=no-name-in-module -from otx.mpa.utils.logger import get_logger - from .builder import OPS -logger = get_logger() - def get_dynamic_shape(output): """Getter function for dynamic shape.""" @@ -32,10 +28,10 @@ def convert_op_to_torch(op_node: Node): try: torch_module = OPS.get_by_type_version(op_type, op_version).from_ov(op_node) except Exception as e: - logger.error(e) - logger.error(op_type) - logger.error(op_version) - logger.error(op_node.get_attributes()) + # logger.error(e) + # logger.error(op_type) + # logger.error(op_version) + # logger.error(op_node.get_attributes()) raise e return torch_module diff --git a/otx/core/ov/utils.py b/otx/core/ov/utils.py index edaf14fcf9e..2e95e500249 100644 --- a/otx/core/ov/utils.py +++ b/otx/core/ov/utils.py @@ -12,12 +12,8 @@ from openvino.pyopenvino import Model, Node # pylint: disable=no-name-in-module from openvino.runtime import Core -from otx.mpa.utils.logger import get_logger - from .omz_wrapper import AVAILABLE_OMZ_MODELS, get_omz_model -logger = get_logger() - # pylint: disable=too-many-locals From c66b28b5ff677ed38be0cf36d1cb37a997f57ada Mon Sep 17 00:00:00 2001 From: "Kang, Harim" Date: Thu, 23 Mar 2023 20:32:38 +0900 Subject: [PATCH 6/6] Add TODO comments --- otx/core/ov/graph/graph.py | 1 + otx/core/ov/graph/parsers/cls/cls_base_parser.py | 1 + otx/core/ov/graph/utils.py | 1 + otx/core/ov/models/ov_model.py | 1 + otx/core/ov/models/parser_mixin.py | 2 ++ otx/core/ov/ops/utils.py | 2 ++ 6 files changed, 8 insertions(+) diff --git a/otx/core/ov/graph/graph.py b/otx/core/ov/graph/graph.py index 7079473b01f..9626c4f267c 100644 --- a/otx/core/ov/graph/graph.py +++ b/otx/core/ov/graph/graph.py @@ -19,6 +19,7 @@ from ..ops.utils import convert_op_to_torch from ..utils import get_op_name +# TODO: We moved the location of otx.mpa.utils.logger, we need to revert the logger in that code again. # pylint: disable=too-many-locals, too-many-nested-blocks, arguments-renamed, too-many-branches, too-many-statements diff --git a/otx/core/ov/graph/parsers/cls/cls_base_parser.py b/otx/core/ov/graph/parsers/cls/cls_base_parser.py index fee4c3c078e..ac869278281 100644 --- a/otx/core/ov/graph/parsers/cls/cls_base_parser.py +++ b/otx/core/ov/graph/parsers/cls/cls_base_parser.py @@ -8,6 +8,7 @@ from ..builder import PARSERS from ..parser import parameter_parser +# TODO: We moved the location of otx.mpa.utils.logger, we need to revert the logger in that code again. # pylint: disable=too-many-return-statements, too-many-branches diff --git a/otx/core/ov/graph/utils.py b/otx/core/ov/graph/utils.py index 172b9050c90..f8c95d62b8c 100644 --- a/otx/core/ov/graph/utils.py +++ b/otx/core/ov/graph/utils.py @@ -12,6 +12,7 @@ from otx.core.ov.ops.infrastructures import ConstantV0 from otx.core.ov.ops.op import Operation +# TODO: We moved the location of otx.mpa.utils.logger, we need to revert the logger in that code again. # pylint: disable=too-many-locals, protected-access, too-many-branches, too-many-statements, too-many-nested-blocks diff --git a/otx/core/ov/models/ov_model.py b/otx/core/ov/models/ov_model.py index cbccdee2919..bcb4928722f 100644 --- a/otx/core/ov/models/ov_model.py +++ b/otx/core/ov/models/ov_model.py @@ -27,6 +27,7 @@ CONNECTION_SEPARATOR = "||" +# TODO: We moved the location of otx.mpa.utils.logger, we need to revert the logger in that code again. # pylint: disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements diff --git a/otx/core/ov/models/parser_mixin.py b/otx/core/ov/models/parser_mixin.py index 46791c919fa..2f5f87c5662 100644 --- a/otx/core/ov/models/parser_mixin.py +++ b/otx/core/ov/models/parser_mixin.py @@ -12,6 +12,8 @@ from ..graph.parsers.builder import PARSERS from .ov_model import OVModel +# TODO: We moved the location of otx.mpa.utils.logger, we need to revert the logger in that code again. + class ParserMixin: """ParserMixin class.""" diff --git a/otx/core/ov/ops/utils.py b/otx/core/ov/ops/utils.py index d1397d460e1..a5cf201a581 100644 --- a/otx/core/ov/ops/utils.py +++ b/otx/core/ov/ops/utils.py @@ -7,6 +7,8 @@ from .builder import OPS +# TODO: We moved the location of otx.mpa.utils.logger, we need to revert the logger in that code again. + def get_dynamic_shape(output): """Getter function for dynamic shape."""