From 10fce7955dca39a0a170f75a159cf15a62214304 Mon Sep 17 00:00:00 2001 From: Jaeguk Hyun Date: Wed, 12 Apr 2023 16:58:32 +0900 Subject: [PATCH] Split export function and _export_model function for detection task (#1997) * Split export function and _export_model function for detection task * Add basic unit test for detection task * Update change log * Reflect changes from https://github.com/openvinotoolkit/training_extensions/pull/1976 * Fix unit test failure --- CHANGELOG.md | 1 + .../detection/adapters/mmdet/task.py | 61 +-- otx/algorithms/detection/task.py | 60 ++- .../detection/adapters/mmdet/test_task.py | 362 ++++++++++++++++++ .../unit/algorithms/detection/test_helpers.py | 8 +- tests/unit/algorithms/detection/test_task.py | 150 -------- 6 files changed, 431 insertions(+), 211 deletions(-) create mode 100644 tests/unit/algorithms/detection/adapters/mmdet/test_task.py delete mode 100644 tests/unit/algorithms/detection/test_task.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 255db9ce4d2..962dcca651e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ All notable changes to this project will be documented in this file. - Clean up and refactor the output of the OTX CLI () - Enhance DetCon logic and SupCon for semantic segmentation() +- Detection task refactoring () - Classification task refactoring () - Extend OTX explain CLI () - Segmentation task refactoring () diff --git a/otx/algorithms/detection/adapters/mmdet/task.py b/otx/algorithms/detection/adapters/mmdet/task.py index ee99c02758e..37758f298da 100644 --- a/otx/algorithms/detection/adapters/mmdet/task.py +++ b/otx/algorithms/detection/adapters/mmdet/task.py @@ -22,7 +22,6 @@ from copy import deepcopy from typing import Any, Dict, Optional, Union -import numpy as np import torch from mmcv.runner import wrap_fp16_model from mmcv.utils import Config, ConfigDict, get_git_hash @@ -52,7 +51,6 @@ from otx.algorithms.common.utils import set_random_seed from otx.algorithms.common.utils.callback import InferenceProgressCallback from otx.algorithms.common.utils.data import get_dataset -from otx.algorithms.common.utils.ir import embed_ir_model_data from otx.algorithms.common.utils.logger import get_logger from otx.algorithms.detection.adapters.mmdet.configurer import ( DetectionConfigurer, @@ -69,24 +67,20 @@ ) from otx.algorithms.detection.adapters.mmdet.utils.exporter import DetectionExporter from otx.algorithms.detection.task import OTXDetectionTask -from otx.algorithms.detection.utils import get_det_model_api_configuration from otx.algorithms.detection.utils.data import adaptive_tile_params from otx.api.configuration import cfg_helper -from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings +from otx.api.configuration.helper.utils import ids_to_strings from otx.api.entities.datasets import DatasetEntity from otx.api.entities.explain_parameters import ExplainParameters from otx.api.entities.inference_parameters import InferenceParameters from otx.api.entities.model import ( ModelEntity, - ModelFormat, - ModelOptimizationType, ModelPrecision, ) from otx.api.entities.subset import Subset from otx.api.entities.task_environment import TaskEnvironment from otx.api.entities.train_parameters import default_progress_callback from otx.api.serialization.label_mapper import label_schema_to_bytes -from otx.api.usecases.tasks.interfaces.export_interface import ExportType from otx.core.data import caching logger = get_logger() @@ -461,21 +455,12 @@ def hook(module, inp, outp): return prediction_results, metric # pylint: disable=too-many-statements - def export( + def _export_model( self, - export_type: ExportType, - output_model: ModelEntity, - precision: ModelPrecision = ModelPrecision.FP32, - dump_features: bool = True, + precision: ModelPrecision, + dump_features: bool, ): - """Export function of OTX Detection Task.""" - # copied from OTX inference_task.py - logger.info("Exporting the model") - if export_type != ExportType.OPENVINO: - raise RuntimeError(f"not supported export type {export_type}") - output_model.model_format = ModelFormat.OPENVINO - output_model.optimization_type = ModelOptimizationType.MO - + """Main export function of OTX MMDetection Task.""" self._init_task(export=True) cfg = self.configure(False, "test", None) @@ -506,41 +491,7 @@ def export( **export_options, ) - outputs = results.get("outputs") - logger.debug(f"results of run_task = {outputs}") - if outputs is None: - raise RuntimeError(results.get("msg")) - - bin_file = outputs.get("bin") - xml_file = outputs.get("xml") - onnx_file = outputs.get("onnx") - - ir_extra_data = get_det_model_api_configuration( - self._task_environment.label_schema, self._task_type, self.confidence_threshold - ) - embed_ir_model_data(xml_file, ir_extra_data) - - if xml_file is None or bin_file is None or onnx_file is None: - raise RuntimeError("invalid status of exporting. bin and xml or onnx should not be None") - with open(bin_file, "rb") as f: - output_model.set_data("openvino.bin", f.read()) - with open(xml_file, "rb") as f: - output_model.set_data("openvino.xml", f.read()) - with open(onnx_file, "rb") as f: - output_model.set_data("model.onnx", f.read()) - output_model.set_data( - "confidence_threshold", - np.array([self.confidence_threshold], dtype=np.float32).tobytes(), - ) - output_model.set_data("config.json", config_to_bytes(self._hyperparams)) - output_model.precision = self._precision - output_model.optimization_methods = self._optimization_methods - output_model.has_xai = dump_features - output_model.set_data( - "label_schema.json", - label_schema_to_bytes(self._task_environment.label_schema), - ) - logger.info("Exporting completed") + return results def explain( self, diff --git a/otx/algorithms/detection/task.py b/otx/algorithms/detection/task.py index b2f80de1662..98db8f056bb 100644 --- a/otx/algorithms/detection/task.py +++ b/otx/algorithms/detection/task.py @@ -30,10 +30,12 @@ InferenceProgressCallback, TrainingProgressCallback, ) +from otx.algorithms.common.utils.ir import embed_ir_model_data from otx.algorithms.common.utils.logger import get_logger from otx.algorithms.detection.configs.base import DetectionConfig +from otx.algorithms.detection.utils import get_det_model_api_configuration from otx.api.configuration import cfg_helper -from otx.api.configuration.helper.utils import ids_to_strings +from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings from otx.api.entities.annotation import Annotation from otx.api.entities.datasets import DatasetEntity from otx.api.entities.explain_parameters import ExplainParameters @@ -50,7 +52,12 @@ ScoreMetric, VisualizationType, ) -from otx.api.entities.model import ModelEntity, ModelPrecision +from otx.api.entities.model import ( + ModelEntity, + ModelFormat, + ModelOptimizationType, + ModelPrecision, +) from otx.api.entities.model_template import TaskType from otx.api.entities.resultset import ResultSetEntity from otx.api.entities.scored_label import ScoredLabel @@ -246,7 +253,6 @@ def _infer_model( """Get inference results from dataset.""" raise NotImplementedError - @abstractmethod def export( self, export_type: ExportType, @@ -254,7 +260,53 @@ def export( precision: ModelPrecision = ModelPrecision.FP32, dump_features: bool = True, ): - """Export function of OTX Task.""" + """Export function of OTX Detection Task.""" + logger.info("Exporting the model") + if export_type != ExportType.OPENVINO: + raise RuntimeError(f"not supported export type {export_type}") + output_model.model_format = ModelFormat.OPENVINO + output_model.optimization_type = ModelOptimizationType.MO + + results = self._export_model(precision, dump_features) + outputs = results.get("outputs") + logger.debug(f"results of run_task = {outputs}") + if outputs is None: + raise RuntimeError(results.get("msg")) + + bin_file = outputs.get("bin") + xml_file = outputs.get("xml") + onnx_file = outputs.get("onnx") + + ir_extra_data = get_det_model_api_configuration( + self._task_environment.label_schema, self._task_type, self.confidence_threshold + ) + embed_ir_model_data(xml_file, ir_extra_data) + + if xml_file is None or bin_file is None or onnx_file is None: + raise RuntimeError("invalid status of exporting. bin and xml or onnx should not be None") + with open(bin_file, "rb") as f: + output_model.set_data("openvino.bin", f.read()) + with open(xml_file, "rb") as f: + output_model.set_data("openvino.xml", f.read()) + with open(onnx_file, "rb") as f: + output_model.set_data("model.onnx", f.read()) + output_model.set_data( + "confidence_threshold", + np.array([self.confidence_threshold], dtype=np.float32).tobytes(), + ) + output_model.set_data("config.json", config_to_bytes(self._hyperparams)) + output_model.precision = self._precision + output_model.optimization_methods = self._optimization_methods + output_model.has_xai = dump_features + output_model.set_data( + "label_schema.json", + label_schema_to_bytes(self._task_environment.label_schema), + ) + logger.info("Exporting completed") + + @abstractmethod + def _export_model(self, precision: ModelPrecision, dump_features: bool): + """Main export function using training backend.""" raise NotImplementedError @abstractmethod diff --git a/tests/unit/algorithms/detection/adapters/mmdet/test_task.py b/tests/unit/algorithms/detection/adapters/mmdet/test_task.py new file mode 100644 index 00000000000..db35bda88aa --- /dev/null +++ b/tests/unit/algorithms/detection/adapters/mmdet/test_task.py @@ -0,0 +1,362 @@ +"""Unit Test for otx.algorithms.detection.adapters.mmdet.task.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import os +import json +from contextlib import nullcontext +from copy import deepcopy +from typing import Any, Dict + +import numpy as np +import pytest +import torch +from mmcv.utils import Config +from torch import nn + +from otx.algorithms.common.adapters.mmcv.utils.config_utils import MPAConfig +from otx.algorithms.detection.adapters.mmdet.task import MMDetectionTask +from otx.algorithms.detection.adapters.mmdet.models.detectors.custom_atss_detector import CustomATSS +from otx.algorithms.detection.configs.base import DetectionConfig +from otx.api.configuration import ConfigurableParameters +from otx.api.configuration.helper import create +from otx.api.entities.dataset_item import DatasetItemEntity +from otx.api.entities.datasets import DatasetEntity +from otx.api.entities.explain_parameters import ExplainParameters +from otx.api.entities.inference_parameters import InferenceParameters +from otx.api.entities.label import Domain +from otx.api.entities.label_schema import LabelGroup, LabelGroupType, LabelSchemaEntity +from otx.api.entities.model import ( + ModelConfiguration, + ModelEntity, + ModelFormat, + ModelOptimizationType, + ModelPrecision, +) +from otx.api.entities.model_template import InstantiationType, parse_model_template, TaskFamily, TaskType +from otx.api.entities.resultset import ResultSetEntity +from otx.api.usecases.tasks.interfaces.export_interface import ExportType +from tests.test_suite.e2e_test_system import e2e_pytest_unit +from tests.unit.algorithms.detection.test_helpers import ( + DEFAULT_DET_TEMPLATE_DIR, + DEFAULT_ISEG_TEMPLATE_DIR, + init_environment, + generate_det_dataset, +) + + +class MockModule(nn.Module): + """Mock class for nn.Module.""" + + def forward(self, inputs: Any): + return inputs + + +class MockModel(nn.Module): + """Mock class for pytorch model.""" + + def __init__(self, task_type): + super().__init__() + self.module = MockModule() + self.module.backbone = MockModule() + self.backbone = MockModule() + self.task_type = task_type + + def forward(self, *args, **kwargs): + forward_hooks = list(self.module.backbone._forward_hooks.values()) + for hook in forward_hooks: + hook(1, 2, 3) + return [[np.array([[0, 0, 1, 1, 0.1]]), np.array([[0, 0, 1, 1, 0.2]]), np.array([[0, 0, 1, 1, 0.7]])]] + + @staticmethod + def named_parameters(): + return {"name": torch.Tensor([0.5])}.items() + + +class MockDataset(DatasetEntity): + """Mock class for mm_dataset.""" + + def __init__(self, dataset: DatasetEntity, task_type: str): + self.dataset = dataset + self.task_type = task_type + self.CLASSES = ["1", "2", "3"] + + def __len__(self): + return len(self.dataset) + + def evaluate(self, prediction, *args, **kwargs): + if self.task_type == "det": + return {"mAP": 1.0} + else: + return {"mAP": 1.0} + + +class MockDataLoader: + """Mock class for data loader.""" + + def __init__(self, dataset: DatasetEntity): + self.dataset = dataset + self.iter = iter(self.dataset) + + def __len__(self) -> int: + return len(self.dataset) + + def __next__(self) -> Dict[str, DatasetItemEntity]: + return {"imgs": next(self.iter)} + + def __iter__(self): + return self + + +class MockExporter: + """Mock class for Exporter.""" + + def __init__(self, task): + self._output_path = task._output_path + + def run(self, *args, **kwargs): + with open(os.path.join(self._output_path, "openvino.bin"), "wb") as f: + f.write(np.ndarray([0])) + with open(os.path.join(self._output_path, "openvino.xml"), "wb") as f: + f.write(np.ndarray([0])) + with open(os.path.join(self._output_path, "model.onnx"), "wb") as f: + f.write(np.ndarray([0])) + + return { + "outputs": { + "bin": os.path.join(self._output_path, "openvino.bin"), + "xml": os.path.join(self._output_path, "openvino.xml"), + "onnx": os.path.join(self._output_path, "model.onnx"), + } + } + + +class TestMMActionTask: + """Test class for MMActionTask. + + Details are explained in each test function. + """ + + @pytest.fixture(autouse=True) + def setup(self) -> None: + model_template = parse_model_template(os.path.join(DEFAULT_DET_TEMPLATE_DIR, "template.yaml")) + hyper_parameters = create(model_template.hyper_parameters.data) + task_env = init_environment(hyper_parameters, model_template, task_type=TaskType.DETECTION) + + self.det_task = MMDetectionTask(task_env) + + self.det_dataset, self.det_labels = generate_det_dataset(TaskType.DETECTION, 100) + self.det_label_schema = LabelSchemaEntity() + det_label_group = LabelGroup( + name="labels", + labels=self.det_labels, + group_type=LabelGroupType.EXCLUSIVE, + ) + self.det_label_schema.add_group(det_label_group) + + model_template = parse_model_template(os.path.join(DEFAULT_ISEG_TEMPLATE_DIR, "template.yaml")) + hyper_parameters = create(model_template.hyper_parameters.data) + task_env = init_environment(hyper_parameters, model_template, task_type=TaskType.INSTANCE_SEGMENTATION) + + self.iseg_task = MMDetectionTask(task_env) + + self.iseg_dataset, self.iseg_labels = generate_det_dataset(TaskType.INSTANCE_SEGMENTATION, 100) + self.iseg_label_schema = LabelSchemaEntity() + iseg_label_group = LabelGroup( + name="labels", + labels=self.iseg_labels, + group_type=LabelGroupType.EXCLUSIVE, + ) + self.iseg_label_schema.add_group(iseg_label_group) + + @e2e_pytest_unit + def test_build_model(self, mocker) -> None: + """Test build_model function.""" + _mock_recipe_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_DET_TEMPLATE_DIR, "model.py")) + model = self.det_task.build_model(_mock_recipe_cfg, True) + assert isinstance(model, CustomATSS) + + @e2e_pytest_unit + def test_train(self, mocker) -> None: + """Test train function.""" + + def _mock_train_detector_det(*args, **kwargs): + with open(os.path.join(self.det_task._output_path, "latest.pth"), "wb") as f: + torch.save({"dummy": torch.randn(1, 3, 3, 3)}, f) + + def _mock_train_detector_iseg(*args, **kwargs): + with open(os.path.join(self.iseg_task._output_path, "latest.pth"), "wb") as f: + torch.save({"dummy": torch.randn(1, 3, 3, 3)}, f) + + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.build_dataset", + return_value=MockDataset(self.det_dataset, "det"), + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.build_dataloader", + return_value=MockDataLoader(self.det_dataset), + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.patch_data_pipeline", + return_value=True, + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.train_detector", + side_effect=_mock_train_detector_det, + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.single_gpu_test", + return_value=[ + np.array([np.array([[0, 0, 1, 1, 0.1]]), np.array([[0, 0, 1, 1, 0.2]]), np.array([[0, 0, 1, 1, 0.7]])]) + ] + * 100, + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.FeatureVectorHook", + return_value=nullcontext(), + ) + + _config = ModelConfiguration(DetectionConfig(), self.det_label_schema) + output_model = ModelEntity(self.det_dataset, _config) + self.det_task.train(self.det_dataset, output_model) + output_model.performance == 1.0 + + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.train_detector", + side_effect=_mock_train_detector_iseg, + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.single_gpu_test", + return_value=[(np.array([[[0, 0, 1, 1, 1]]]), np.array([[[0, 0, 1, 1, 1, 1, 1]]]))] * 100, + ) + _config = ModelConfiguration(DetectionConfig(), self.iseg_label_schema) + output_model = ModelEntity(self.iseg_dataset, _config) + self.iseg_task.train(self.iseg_dataset, output_model) + output_model.performance == 1.0 + + @e2e_pytest_unit + def test_infer(self, mocker) -> None: + """Test infer function.""" + + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.build_dataset", + return_value=MockDataset(self.det_dataset, "det"), + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.build_dataloader", + return_value=MockDataLoader(self.det_dataset), + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.patch_data_pipeline", + return_value=True, + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.single_gpu_test", + return_value=[ + np.array([np.array([[0, 0, 1, 1, 0.1]]), np.array([[0, 0, 1, 1, 0.2]]), np.array([[0, 0, 1, 1, 0.7]])]) + ] + * 100, + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.FeatureVectorHook", + return_value=nullcontext(), + ) + + inference_parameters = InferenceParameters(is_evaluation=True) + outputs = self.det_task.infer(self.det_dataset, inference_parameters) + for output in outputs: + assert output.get_annotations()[-1].get_labels()[0].probability == 0.7 + + @e2e_pytest_unit + def test_det_evaluate(self) -> None: + """Test evaluate function for detection.""" + + _config = ModelConfiguration(DetectionConfig(), self.det_label_schema) + _model = ModelEntity(self.det_dataset, _config) + resultset = ResultSetEntity(_model, self.det_dataset, self.det_dataset) + self.det_task.evaluate(resultset) + assert resultset.performance.score.value == 1.0 + + @e2e_pytest_unit + def test_det_evaluate_with_empty_annotations(self) -> None: + """Test evaluate function for detection with empty predictions.""" + + _config = ModelConfiguration(DetectionConfig(), self.det_label_schema) + _model = ModelEntity(self.det_dataset, _config) + resultset = ResultSetEntity(_model, self.det_dataset, self.det_dataset.with_empty_annotations()) + self.det_task.evaluate(resultset) + assert resultset.performance.score.value == 0.0 + + @e2e_pytest_unit + def test_iseg_evaluate(self) -> None: + """Test evaluate function for instance segmentation.""" + + _config = ModelConfiguration(DetectionConfig(), self.iseg_label_schema) + _model = ModelEntity(self.iseg_dataset, _config) + resultset = ResultSetEntity(_model, self.iseg_dataset, self.iseg_dataset) + self.iseg_task.evaluate(resultset) + assert resultset.performance.score.value == 1.0 + + @pytest.mark.parametrize("precision", [ModelPrecision.FP16, ModelPrecision.FP32]) + @e2e_pytest_unit + def test_export(self, mocker, precision: ModelPrecision) -> None: + """Test export function. + + + 1. Create model entity + 2. Run export function + 3. Check output model attributes + """ + _config = ModelConfiguration(DetectionConfig(), self.det_label_schema) + _model = ModelEntity(self.det_dataset, _config) + + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.DetectionExporter", + return_value=MockExporter(self.det_task), + ) + mocker.patch( + "otx.algorithms.detection.task.embed_ir_model_data", + return_value=True, + ) + + self.det_task.export(ExportType.OPENVINO, _model, precision, False) + + assert _model.model_format == ModelFormat.OPENVINO + assert _model.optimization_type == ModelOptimizationType.MO + assert _model.precision[0] == precision + assert _model.get_data("openvino.bin") is not None + assert _model.get_data("openvino.xml") is not None + assert _model.get_data("confidence_threshold") is not None + assert _model.precision == self.det_task._precision + assert _model.optimization_methods == self.det_task._optimization_methods + assert _model.get_data("label_schema.json") is not None + + @e2e_pytest_unit + def test_explain(self, mocker): + """Test explain function.""" + + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.build_dataset", + return_value=MockDataset(self.det_dataset, "det"), + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.build_dataloader", + return_value=MockDataLoader(self.det_dataset), + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.patch_data_pipeline", + return_value=True, + ) + mocker.patch( + "otx.algorithms.detection.adapters.mmdet.task.build_data_parallel", + return_value=MockModel(TaskType.DETECTION), + ) + + explain_parameters = ExplainParameters( + explainer="ClassWiseSaliencyMap", + process_saliency_maps=False, + explain_predicted_classes=True, + ) + outputs = self.det_task.explain(self.det_dataset, explain_parameters) diff --git a/tests/unit/algorithms/detection/test_helpers.py b/tests/unit/algorithms/detection/test_helpers.py index 0e940b9662d..4b3bc73434f 100644 --- a/tests/unit/algorithms/detection/test_helpers.py +++ b/tests/unit/algorithms/detection/test_helpers.py @@ -73,7 +73,11 @@ def generate_det_dataset(task_type, number_of_images=1): label_schema = generate_label_schema(classes, task_type_to_label_domain(task_type)) items = [] - for _ in range(number_of_images): + for idx in range(number_of_images): + if idx < 30: + subset = Subset.VALIDATION + else: + subset = Subset.TRAINING image_numpy, annos = generate_random_annotated_image( image_width=640, image_height=480, @@ -87,7 +91,7 @@ def generate_det_dataset(task_type, number_of_images=1): anno.shape = ShapeFactory.shape_as_polygon(anno.shape) image = Image(data=image_numpy) annotation_scene = AnnotationSceneEntity(kind=AnnotationSceneKind.ANNOTATION, annotations=annos) - items.append(DatasetItemEntity(media=image, annotation_scene=annotation_scene, subset=Subset.VALIDATION)) + items.append(DatasetItemEntity(media=image, annotation_scene=annotation_scene, subset=subset)) dataset = DatasetEntity(items) return dataset, dataset.get_labels() diff --git a/tests/unit/algorithms/detection/test_task.py b/tests/unit/algorithms/detection/test_task.py deleted file mode 100644 index 6b2dadd9079..00000000000 --- a/tests/unit/algorithms/detection/test_task.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Test otx detection task.""" -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import os - -import numpy as np -import pytest -from mmcv import Config - -from otx.algorithms.detection.task import OTXDetectionTask -from otx.algorithms.detection.utils import generate_label_schema -from otx.api.configuration.helper import create -from otx.api.entities.model_template import ( - TaskType, - parse_model_template, - task_type_to_label_domain, -) -from tests.unit.algorithms.detection.test_helpers import ( - DEFAULT_DET_TEMPLATE_DIR, - generate_det_dataset, - init_environment, -) - - -class MockOTXDetectionTask(OTXDetectionTask): - def _infer_model(*args, **kwargs): - return zip([[np.array([[0, 0, 1, 1, 1]])]] * 50, np.ndarray([50, 472, 1, 1]), [None] * 50), 1.0 - - def _train_model(*args, **kwargs): - return {"final_ckpt": "dummy.pth"} - - def explain(*args, **kwargs): - pass - - def export(*args, **kwargs): - pass - - -class MockOTXIsegTask(OTXDetectionTask): - def _infer_model(*args, **kwargs): - predictions = zip([[np.array([[0, 0, 1, 1, 1]])]] * 50, [np.array([[0, 0, 1, 1, 1, 1, 1]])] * 50) - return ( - zip( - predictions, - np.ndarray([50, 472, 1, 1]), - [None] * 50, - ), - 1.0, - ) - - def _train_model(*args, **kwargs): - return {"final_ckpt": "dummy.pth"} - - def explain(*args, **kwargs): - pass - - def export(*args, **kwargs): - pass - - -class MockModel: - class _Configuration: - def __init__(self, label_schema): - self.label_schema = label_schema - - def get_label_schema(self): - return self.label_schema - - def __init__(self): - self.model_adapters = ["weights.pth"] - self.data = np.ndarray(1) - - classes = ("rectangle", "ellipse", "triangle") - label_schema = generate_label_schema(classes, task_type_to_label_domain(TaskType.DETECTION)) - - self.configuration = self._Configuration(label_schema) - - def get_data(self, name): - return self.data - - def set_data(self, *args, **kwargs): - return - - -class TestOTXDetectionTask: - @pytest.fixture(autouse=True) - def setup(self, mocker): - model_template = parse_model_template(os.path.join(DEFAULT_DET_TEMPLATE_DIR, "template.yaml")) - hyper_parameters = create(model_template.hyper_parameters.data) - task_env = init_environment(hyper_parameters, model_template, task_type=TaskType.DETECTION) - - self.det_task = MockOTXDetectionTask(task_env) - - def test_load_model_ckpt(self, mocker): - mocker.patch( - "torch.load", - return_value={ - "anchors": [1], - "confidence_threshold": 0.1, - "config": { - "tiling_parameters": { - "enable_tiling": {"value": True}, - "tile_size": {"value": 256}, - "tile_overlap": {"value": 0}, - "tile_max_number": {"value": 500}, - } - }, - }, - ) - - self.det_task._load_model_ckpt(MockModel()) - - assert self.det_task._anchors == [1] - assert self.det_task._hyperparams.tiling_parameters.enable_tiling is True - assert self.det_task._hyperparams.tiling_parameters.tile_size == 256 - assert self.det_task._hyperparams.tiling_parameters.tile_overlap == 0 - assert self.det_task._hyperparams.tiling_parameters.tile_max_number == 500 - - def test_train(self, mocker): - dataset = generate_det_dataset(TaskType.DETECTION, 50)[0] - mocker.patch("torch.load", return_value=np.ndarray([1])) - self.det_task.train(dataset, MockModel()) - assert self.det_task._model_ckpt == "dummy.pth" - - def test_infer(self): - dataset = generate_det_dataset(TaskType.DETECTION, 50)[0] - predicted_dataset = self.det_task.infer(dataset.with_empty_annotations()) - assert predicted_dataset[0].annotation_scene.annotations[0].shape.x1 == 0.0 - assert predicted_dataset[0].annotation_scene.annotations[0].shape.y1 == 0.0 - - def test_evaluate(self, mocker): - class _MockMetric: - f_measure = Config({"value": 1.0}) - - def get_performance(self): - return 1.0 - - class _MockResultEntity: - performance = 0.0 - - mocker.patch( - "otx.algorithms.detection.task.MetricsHelper.compute_f_measure", - return_value=_MockMetric(), - ) - - _result_entity = _MockResultEntity() - self.det_task.evaluate(_result_entity) - assert _result_entity.performance == 1.0