From ddc2194350f126e6a8e9bc8216f508bf65681366 Mon Sep 17 00:00:00 2001 From: Harim Kang Date: Thu, 15 Feb 2024 14:08:47 +0900 Subject: [PATCH 1/6] Separate integration tests by task (#2914) * Split intg-test per task * Revert Some command in tox * Change pytest_generate_tests to pytest.paramiterize & Unify some task group * Modify tox.ini --- .github/workflows/pre_merge.yaml | 16 ++-- .../api/test_auto_configuration.py | 27 ++++--- tests/integration/api/test_engine_api.py | 4 +- .../cli/test_auto_configuration.py | 15 ++-- tests/integration/cli/test_cli.py | 35 +++++---- .../integration/cli/test_export_inference.py | 15 ++-- tests/integration/conftest.py | 75 +++++++++++++++++++ tox.ini | 14 +++- 8 files changed, 149 insertions(+), 52 deletions(-) diff --git a/.github/workflows/pre_merge.yaml b/.github/workflows/pre_merge.yaml index a667a51fa2f..ee5bd10bbc4 100644 --- a/.github/workflows/pre_merge.yaml +++ b/.github/workflows/pre_merge.yaml @@ -68,12 +68,16 @@ jobs: fail-fast: false matrix: include: - - python-version: "3.10" - tox-env: "py310" - name: Integration-Test-Py${{ matrix.python-version }} + - task: "action" + - task: "classification" + - task: "detection" + - task: "instance_segmentation" + - task: "semantic_segmentation" + - task: "visual_prompting" + name: Integration-Test-${{ matrix.task }}-py310 # This is what will cancel the job concurrency concurrency: - group: ${{ github.workflow }}-Integration-${{ github.event.pull_request.number || github.ref }} + group: ${{ github.workflow }}-Integration-${{ github.event.pull_request.number || github.ref }}-${{ matrix.task }} cancel-in-progress: true steps: - name: Checkout repository @@ -81,8 +85,8 @@ jobs: - name: Install Python uses: actions/setup-python@v4 with: - python-version: ${{ matrix.python-version }} + python-version: "3.10" - name: Install tox run: python -m pip install tox - name: Run Integration Test - run: tox -vv -e integration-test + run: tox -vv -e integration-test-${{ matrix.task }} diff --git a/tests/integration/api/test_auto_configuration.py b/tests/integration/api/test_auto_configuration.py index 69a5e5bafe4..825042a1209 100644 --- a/tests/integration/api/test_auto_configuration.py +++ b/tests/integration/api/test_auto_configuration.py @@ -11,22 +11,28 @@ from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK -@pytest.mark.parametrize("task", [task.value.lower() for task in DEFAULT_CONFIG_PER_TASK]) -def test_auto_configuration(task: str, tmp_path: Path, fxt_accelerator: str, fxt_target_dataset_per_task: dict) -> None: +@pytest.mark.parametrize("task", pytest.TASK_LIST) +def test_auto_configuration( + task: OTXTaskType, + tmp_path: Path, + fxt_accelerator: str, + fxt_target_dataset_per_task: dict, +) -> None: """Test the auto configuration functionality. Args: - task (str): The task for which auto configuration is being tested. + task (OTXTaskType): The task for which auto configuration is being tested. tmp_path (Path): The temporary path for storing training data. fxt_accelerator (str): The accelerator used for training. fxt_target_dataset_per_task (dict): A dictionary mapping tasks to target datasets. """ + if task not in DEFAULT_CONFIG_PER_TASK: + pytest.skip(f"Task {task} is not supported in the auto-configuration.") tmp_path_train = tmp_path / f"auto_train_{task}" - data_root = fxt_target_dataset_per_task[task] - task_type = OTXTaskType(task.upper()) + data_root = fxt_target_dataset_per_task[task.lower()] engine = Engine( data_root=data_root, - task=task_type, + task=task, work_dir=tmp_path_train, device=fxt_accelerator, ) @@ -36,12 +42,12 @@ def test_auto_configuration(task: str, tmp_path: Path, fxt_accelerator: str, fxt assert isinstance(engine.datamodule, OTXDataModule) # Check Auto-Configurator task - assert engine._auto_configurator.task == task_type + assert engine._auto_configurator.task == task # Check Default Configuration from otx.cli.utils.jsonargparse import get_configuration - default_config = get_configuration(DEFAULT_CONFIG_PER_TASK[task_type]) + default_config = get_configuration(DEFAULT_CONFIG_PER_TASK[task]) default_config["data"]["config"]["data_root"] = data_root num_classes = engine.datamodule.meta_info.num_classes @@ -49,8 +55,9 @@ def test_auto_configuration(task: str, tmp_path: Path, fxt_accelerator: str, fxt assert engine._auto_configurator.config == default_config - train_metric = engine.train(max_epochs=default_config.get("max_epochs", 2)) - if task != "zero_shot_visual_prompting": + max_epochs = 2 if task.lower() != "zero_shot_visual_prompting" else 1 + train_metric = engine.train(max_epochs=max_epochs) + if task.lower() != "zero_shot_visual_prompting": assert len(train_metric) > 0 test_metric = engine.test() diff --git a/tests/integration/api/test_engine_api.py b/tests/integration/api/test_engine_api.py index 51570d69c29..a0baf5c1fd5 100644 --- a/tests/integration/api/test_engine_api.py +++ b/tests/integration/api/test_engine_api.py @@ -11,7 +11,7 @@ from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, OVMODEL_PER_TASK -@pytest.mark.parametrize("task", list(DEFAULT_CONFIG_PER_TASK)) +@pytest.mark.parametrize("task", pytest.TASK_LIST) def test_engine_from_config( task: OTXTaskType, tmp_path: Path, @@ -26,6 +26,8 @@ def test_engine_from_config( fxt_accelerator (str): The accelerator used for training. fxt_target_dataset_per_task (dict): A dictionary mapping tasks to target datasets. """ + if task not in DEFAULT_CONFIG_PER_TASK: + pytest.skip("Only the Task has Default config is tested to reduce unnecessary resources.") if task.lower() in ("action_classification"): pytest.xfail(reason="xFail until this root cause is resolved on the Datumaro side.") diff --git a/tests/integration/cli/test_auto_configuration.py b/tests/integration/cli/test_auto_configuration.py index 49e7ac3dc4f..62b62dcecc9 100644 --- a/tests/integration/cli/test_auto_configuration.py +++ b/tests/integration/cli/test_auto_configuration.py @@ -5,14 +5,15 @@ from pathlib import Path import pytest +from otx.core.types.task import OTXTaskType from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK from tests.integration.cli.utils import run_main -@pytest.mark.parametrize("task", [task.value.lower() for task in DEFAULT_CONFIG_PER_TASK]) +@pytest.mark.parametrize("task", pytest.TASK_LIST) def test_otx_cli_auto_configuration( - task: str, + task: OTXTaskType, tmp_path: Path, fxt_accelerator: str, fxt_target_dataset_per_task: dict, @@ -22,7 +23,7 @@ def test_otx_cli_auto_configuration( """Test the OTX auto configuration with CLI. Args: - task (str): The task to be performed. + task (OTXTaskType): The task to be performed. tmp_path (Path): The temporary path for storing outputs. fxt_accelerator (str): The accelerator to be used. fxt_target_dataset_per_task (dict): The target dataset per task. @@ -30,14 +31,16 @@ def test_otx_cli_auto_configuration( Returns: None """ - if task in ("action_classification"): + if task not in DEFAULT_CONFIG_PER_TASK: + pytest.skip(f"Task {task} is not supported in the auto-configuration.") + if task.lower() in ("action_classification"): pytest.xfail(reason="xFail until this root cause is resolved on the Datumaro side.") tmp_path_train = tmp_path / f"otx_auto_train_{task}" command_cfg = [ "otx", "train", "--data_root", - fxt_target_dataset_per_task[task], + fxt_target_dataset_per_task[task.lower()], "--task", task.upper(), "--engine.work_dir", @@ -46,7 +49,7 @@ def test_otx_cli_auto_configuration( fxt_accelerator, "--max_epochs", "2", - *fxt_cli_override_command_per_task[task], + *fxt_cli_override_command_per_task[task.lower()], ] run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index 4b0d53957bc..55b29e92370 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -2,8 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import importlib -import inspect from pathlib import Path import pytest @@ -11,15 +9,12 @@ from tests.integration.cli.utils import run_main -# This assumes have OTX installed in environment. -otx_module = importlib.import_module("otx") -RECIPE_PATH = Path(inspect.getfile(otx_module)).parent / "recipe" -RECIPE_LIST = [str(p) for p in RECIPE_PATH.glob("**/*.yaml") if "_base_" not in p.parts] -RECIPE_OV_LIST = [str(p) for p in RECIPE_PATH.glob("**/openvino_model.yaml") if "_base_" not in p.parts] -RECIPE_LIST = set(RECIPE_LIST) - set(RECIPE_OV_LIST) - -@pytest.mark.parametrize("recipe", RECIPE_LIST) +@pytest.mark.parametrize( + "recipe", + pytest.RECIPE_LIST, + ids=lambda x: "/".join(Path(x).parts[-2:]), +) def test_otx_e2e( recipe: str, tmp_path: Path, @@ -177,7 +172,11 @@ def test_otx_e2e( assert (tmp_path_test / "outputs").exists() -@pytest.mark.parametrize("recipe", RECIPE_LIST) +@pytest.mark.parametrize( + "recipe", + pytest.RECIPE_LIST, + ids=lambda x: "/".join(Path(x).parts[-2:]), +) def test_otx_explain_e2e( recipe: str, tmp_path: Path, @@ -249,9 +248,13 @@ def test_otx_explain_e2e( assert np.max(np.abs(actual_sal_vals - ref_sal_vals) <= 3) -@pytest.mark.parametrize("recipe", RECIPE_OV_LIST) +# @pytest.mark.skipif(len(pytest.RECIPE_OV_LIST) < 1, reason="No OV recipe found.") +@pytest.mark.parametrize( + "ov_recipe", + pytest.RECIPE_OV_LIST, +) def test_otx_ov_test( - recipe: str, + ov_recipe: str, tmp_path: Path, fxt_target_dataset_per_task: dict, fxt_open_subprocess: bool, @@ -268,8 +271,8 @@ def test_otx_ov_test( Returns: None """ - task = recipe.split("/")[-2] - model_name = recipe.split("/")[-1].split(".")[0] + task = ov_recipe.split("/")[-2] + model_name = ov_recipe.split("/")[-1].split(".")[0] if task in ["multi_label_cls", "instance_segmentation", "h_label_cls"]: # OMZ doesn't have proper model for Pytorch MaskRCNN interface @@ -282,7 +285,7 @@ def test_otx_ov_test( "otx", "test", "--config", - recipe, + ov_recipe, "--data_root", fxt_target_dataset_per_task[task], "--engine.work_dir", diff --git a/tests/integration/cli/test_export_inference.py b/tests/integration/cli/test_export_inference.py index 552e96aad7c..9d28753acd6 100644 --- a/tests/integration/cli/test_export_inference.py +++ b/tests/integration/cli/test_export_inference.py @@ -2,8 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import importlib -import inspect import logging from pathlib import Path @@ -14,13 +12,6 @@ log = logging.getLogger(__name__) -# This assumes have OTX installed in environment. -otx_module = importlib.import_module("otx") -RECIPE_PATH = Path(inspect.getfile(otx_module)).parent / "recipe" -RECIPE_LIST = [str(p) for p in RECIPE_PATH.glob("**/*.yaml") if "_base_" not in p.parts] -RECIPE_OV_LIST = [str(p) for p in RECIPE_PATH.glob("**/openvino_model.yaml") if "_base_" not in p.parts] -RECIPE_LIST = set(RECIPE_LIST) - set(RECIPE_OV_LIST) - def _check_relative_metric_diff(ref: float, value: float, eps: float) -> None: assert ref >= 0 @@ -55,7 +46,11 @@ def fxt_local_seed() -> int: } -@pytest.mark.parametrize("recipe", RECIPE_LIST, ids=lambda x: "/".join(Path(x).parts[-2:])) +@pytest.mark.parametrize( + "recipe", + pytest.RECIPE_LIST, + ids=lambda x: "/".join(Path(x).parts[-2:]), +) def test_otx_export_infer( recipe: str, tmp_path: Path, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 9d36b114c25..e51f793e0ea 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,11 +1,15 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # +from __future__ import annotations +import importlib +import inspect from pathlib import Path import pytest from mmengine.config import Config as MMConfig +from otx.core.types.task import OTXTaskType def pytest_addoption(parser: pytest.Parser) -> None: @@ -16,6 +20,13 @@ def pytest_addoption(parser: pytest.Parser) -> None: "This option can be used for easy memory management " "while running consecutive multiple tests (default: false).", ) + parser.addoption( + "--task", + action="store", + default="all", + type=str, + help="Task type of OTX to use integration test.", + ) @pytest.fixture(scope="module", autouse=True) @@ -28,6 +39,70 @@ def fxt_open_subprocess(request: pytest.FixtureRequest) -> bool: return request.config.getoption("--open-subprocess") +def find_recipe_folder(base_path: Path, folder_name: str) -> Path: + """ + Find the folder with the given name within the specified base path. + + Args: + base_path (Path): The base path to search within. + folder_name (str): The name of the folder to find. + + Returns: + Path: The path to the folder. + """ + for folder_path in base_path.rglob(folder_name): + if folder_path.is_dir(): + return folder_path + msg = f"Folder {folder_name} not found in {base_path}." + raise FileNotFoundError(msg) + + +def get_task_list(task: str) -> list[OTXTaskType]: + if task == "all": + return [task_type for task_type in OTXTaskType if task_type != OTXTaskType.DETECTION_SEMI_SL] + if task == "classification": + return [OTXTaskType.MULTI_CLASS_CLS, OTXTaskType.MULTI_LABEL_CLS, OTXTaskType.H_LABEL_CLS] + if task == "action": + return [OTXTaskType.ACTION_CLASSIFICATION, OTXTaskType.ACTION_DETECTION] + if task == "visual_prompting": + return [OTXTaskType.VISUAL_PROMPTING, OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING] + return [OTXTaskType(task.upper())] + + +def pytest_configure(config): + """Configure pytest options and set task, recipe, and recipe_ov lists. + + Args: + config (pytest.Config): The pytest configuration object. + + Returns: + None + """ + task = config.getoption("--task") + + # This assumes have OTX installed in environment. + otx_module = importlib.import_module("otx") + # Modify RECIPE_PATH based on the task + recipe_path = Path(inspect.getfile(otx_module)).parent / "recipe" + task_list = get_task_list(task.lower()) + recipe_dir = [find_recipe_folder(recipe_path, task_type.value.lower()) for task_type in task_list] + + # Update RECIPE_LIST + target_recipe_list = [] + target_ov_recipe_list = [] + for task_recipe_dir in recipe_dir: + recipe_list = [str(p) for p in task_recipe_dir.glob("**/*.yaml") if "_base_" not in p.parts] + recipe_ov_list = [str(p) for p in task_recipe_dir.glob("**/openvino_model.yaml") if "_base_" not in p.parts] + recipe_list = set(recipe_list) - set(recipe_ov_list) + + target_recipe_list.extend(recipe_list) + target_ov_recipe_list.extend(recipe_ov_list) + + pytest.TASK_LIST = task_list + pytest.RECIPE_LIST = target_recipe_list + pytest.RECIPE_OV_LIST = target_ov_recipe_list + + @pytest.fixture(scope="session") def fxt_asset_dir() -> Path: return Path(__file__).parent.parent / "assets" diff --git a/tox.ini b/tox.ini index ff405f6b7dc..5011930e8ee 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,15 @@ addopts = --csv=.tox/tests-{env:TOXENV_TASK}-{env:TOXENV_PYVER}.csv [testenv] setenv = TOX_WORK_DIR={toxworkdir} +task = + all: "all" + action: "action" + classification: "classification" + detection: "detection" + rotated_detection: "rotated_detection" + instance_segmentation: "instance_segmentation" + semantic_segmentation: "semantic_segmentation" + visual_prompting: "visual_prompting" passenv = ftp_proxy HTTP_PROXY @@ -39,7 +48,7 @@ commands = {posargs} -[testenv:integration-test] +[testenv:integration-test-{all, action, classification, detection, rotated_detection, instance_segmentation, semantic_segmentation, visual_prompting}] setenv = CUBLAS_WORKSPACE_CONFIG=:4096:8 deps = @@ -48,8 +57,7 @@ commands_pre = ; [TODO]: Needs to be fixed so that this is not duplicated for each test run otx install -v commands = - python -m pytest -ra --showlocals --csv={toxworkdir}/{envname}.csv --open-subprocess {posargs:tests/integration} - + python -m pytest tests/integration -ra --showlocals --csv={toxworkdir}/{envname}.csv --task {[testenv]task} --open-subprocess {posargs} [testenv:performance-test] deps = From 9edb4b76942d4f96a12e7a7ca1abbfa933f9ddf0 Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Thu, 15 Feb 2024 05:12:54 +0000 Subject: [PATCH 2/6] Refine Tile Processing and Expand Tile Configuration (#2903) * tile init commit * remove task from DataModuleConfig * minor polish * minor polish * Add OTX Tile Dataset and Merger * pre-commit changes * style changes * add TileDetDataEntity and TileBatchDetDataEntity * pre-commit fixes * tidying up * Add tiling support for InstanceSeg task * solve merging conflicts * fix mypy issues * polish up * add missing docstrings * add back blankline * partly fix unit tests * create unbinding methods in tile entities * add OTX tile generic type to OTX based models * update tile detection recipe * fix tests * fix tests * batch inferece for tiling * remove attributes from ImageInfo * update tile entity and fix test * update tile merge module * fix bug * update tile merger * add comment to base.py * improve docstrings * solve merge conflicts * add todo to model config * update data entity * fix test * fix test * fix recipes * reformat recipes * remove tiling recipes for the time being * add yolox tile recipes * update yolox tile recipes * remove datumaro TileMerge and use native torch ops * update yolo tile recipes * fix memcache * fix mem cache force FP32 * add inst-seg tile recipes * update InstanceSegTileMerge * update tile recipes * add gradient clipping to maskrcnn r50 * impl custom OTXInstSegMeanAveragePrecision * reivse customise OTXInstSegMeanAveragePrecision * return rle from polygon_to_bitmap and cache rle in gt_caches * revert cli change * polish up * update inst seg lit module * update inst seg lit module * fix unit test * fix api test * remove gt cache * fix integration test * address PR comments * fix integration test * add unittests * reformat * update tile recipes * Implement OTXTileTransform for taking tile_size & Implement NMS in TileMerge * enable configurable tile merger * introduce a way to pass tile_config to OTXModel * add tile adapter and change name * save adapted tile_config back to config * update tile recipes * coalesce sparse tensor * udpate tile adaptor * add tile sampling * update tile sampling * fix typo * update tiling test cli * add tiling integration test * update tile integration test * skip explain test on tiling * revert cli changes * revert cli changes * Update src/otx/algo/callbacks/tile_sampling_hook.py Co-authored-by: Sungman Cho * add docstring header * improve docstring comment * update tile recipes * remove sampling hook --------- Co-authored-by: Sungman Cho --- src/otx/cli/cli.py | 11 ++ src/otx/core/config/data.py | 13 +- src/otx/core/data/dataset/tile.py | 96 ++++++++- src/otx/core/data/module.py | 36 +++- src/otx/core/data/tile_adaptor.py | 183 ++++++++++++++++++ src/otx/core/model/entity/detection.py | 11 +- .../model/entity/instance_segmentation.py | 13 +- src/otx/core/utils/tile_merge.py | 74 +++---- src/otx/engine/engine.py | 4 +- src/otx/engine/utils/auto_configurator.py | 6 +- src/otx/recipe/detection/yolox_l_tile.yaml | 1 + src/otx/recipe/detection/yolox_s_tile.yaml | 1 + src/otx/recipe/detection/yolox_tiny_tile.yaml | 1 + src/otx/recipe/detection/yolox_x_tile.yaml | 1 + .../maskrcnn_efficientnetb2b_tile.yaml | 1 + .../maskrcnn_r50_tile.yaml | 1 + tests/integration/cli/test_cli.py | 5 +- tests/integration/conftest.py | 4 - tests/integration/detection/conftest.py | 4 +- tests/integration/test_tiling.py | 149 ++++++++++++++ tests/unit/core/data/test_factory.py | 4 +- tests/unit/core/data/test_module.py | 4 +- 22 files changed, 543 insertions(+), 80 deletions(-) create mode 100644 src/otx/core/data/tile_adaptor.py create mode 100644 tests/integration/test_tiling.py diff --git a/src/otx/cli/cli.py b/src/otx/cli/cli.py index 764e84794a8..1a8918fe73d 100644 --- a/src/otx/cli/cli.py +++ b/src/otx/cli/cli.py @@ -6,6 +6,7 @@ from __future__ import annotations +import dataclasses import sys from pathlib import Path from typing import TYPE_CHECKING, Any, Optional @@ -339,6 +340,16 @@ def instantiate_model(self, model_config: Namespace) -> tuple: model_parser.add_subclass_arguments(OTXModel, "model", required=False, fail_untyped=False) model = model_parser.instantiate_classes(Namespace(model=model_config)).get("model") + # Update tile config due to adaptive tiling + if self.datamodule.config.tile_config.enable_tiler: + if not hasattr(model, "tile_config"): + msg = "The model does not have a tile_config attribute. Please check if the model supports tiling." + raise AttributeError(msg) + model.tile_config = self.datamodule.config.tile_config + self.config[self.subcommand].data.config.tile_config.update( + Namespace(dataclasses.asdict(model.tile_config)), + ) + # Update self.config with model self.config[self.subcommand].update(Namespace(model=model_config)) diff --git a/src/otx/core/config/data.py b/src/otx/core/config/data.py index ddd0de26943..a8142580bdc 100644 --- a/src/otx/core/config/data.py +++ b/src/otx/core/config/data.py @@ -61,12 +61,17 @@ class SubsetConfig: @dataclass -class TilerConfig: +class TileConfig: """DTO for tiler configuration.""" enable_tiler: bool = False - grid_size: tuple[int, int] = (2, 2) - overlap: float = 0.0 + enable_adaptive_tiling: bool = True + tile_size: tuple[int, int] = (400, 400) + overlap: float = 0.2 + iou_threshold: float = 0.45 + max_num_instances: int = 1500 + object_tile_ratio: float = 0.03 + sampling_ratio: float = 1.0 @dataclass @@ -88,7 +93,7 @@ class DataModuleConfig: val_subset: SubsetConfig test_subset: SubsetConfig - tile_config: TilerConfig = field(default_factory=lambda: TilerConfig()) + tile_config: TileConfig = field(default_factory=lambda: TileConfig()) vpm_config: VisualPromptingConfig = field(default_factory=lambda: VisualPromptingConfig()) mem_cache_size: str = "1GB" diff --git a/src/otx/core/data/dataset/tile.py b/src/otx/core/data/dataset/tile.py index 9017d953c31..82a8a67e139 100644 --- a/src/otx/core/data/dataset/tile.py +++ b/src/otx/core/data/dataset/tile.py @@ -5,6 +5,7 @@ from __future__ import annotations +import logging as log from typing import TYPE_CHECKING, Callable import numpy as np @@ -12,6 +13,12 @@ from datumaro import Bbox, DatasetItem, DatasetSubset, Image, Polygon from datumaro import Dataset as DmDataset from datumaro.plugins.tiling import Tile +from datumaro.plugins.tiling.util import ( + clip_x1y1x2y2, + cxcywh_to_x1y1x2y2, + x1y1x2y2_to_cxcywh, + x1y1x2y2_to_xywh, +) from torchvision import tv_tensors from otx.core.data.entity.base import ImageInfo @@ -29,7 +36,9 @@ from .base import OTXDataset if TYPE_CHECKING: - from otx.core.config.data import TilerConfig + from datumaro.components.media import BboxIntCoords + + from otx.core.config.data import TileConfig from otx.core.data.dataset.detection import OTXDetectionDataset from otx.core.data.dataset.instance_segmentation import OTXInstanceSegDataset from otx.core.data.entity.base import OTXDataEntity @@ -39,6 +48,69 @@ # This is a workaround so we could apply the same transforms to tiles as the original dataset. +class OTXTileTransform(Tile): + """OTX tile transform. + + Different from the original Datumaro Tile transform, + OTXTileTransform takes tile_size and overlap as input instead of grid size + + Args: + extractor (DatasetSubset): Dataset subset to extract tiles from. + tile_size (tuple[int, int]): Tile size. + overlap (tuple[float, float]): Overlap ratio. + threshold_drop_ann (float): Threshold to drop annotations. + """ + + def __init__( + self, + extractor: DatasetSubset, + tile_size: tuple[int, int], + overlap: tuple[float, float], + threshold_drop_ann: float, + ) -> None: + super().__init__( + extractor, + (0, 0), + overlap=overlap, + threshold_drop_ann=threshold_drop_ann, + ) + self._tile_size = tile_size + + def _extract_rois(self, image: Image) -> list[BboxIntCoords]: + """Extracts Tile ROIs from the given image. + + Args: + image (Image): Full image. + + Returns: + list[BboxIntCoords]: list of ROIs. + """ + if image.size is None: + msg = "Image size is None" + raise ValueError(msg) + + img_h, img_w = image.size + tile_h, tile_w = self._tile_size + h_ovl, w_ovl = self._overlap + stride_h, stride_w = int(tile_h * (1 - h_ovl)), int(tile_w * (1 - w_ovl)) + n_row, n_col = (img_h + stride_h - 1) // stride_h, (img_w + stride_w - 1) // stride_w + + rois: list[BboxIntCoords] = [] + + for r in range(n_row): + for c in range(n_col): + y1, x1 = stride_h * r, stride_w * c + y2, x2 = y1 + stride_h, x1 + stride_w + + c_x, c_y, w, h = x1y1x2y2_to_cxcywh(x1, y1, x2, y2) + x1, y1, x2, y2 = cxcywh_to_x1y1x2y2(c_x, c_y, w, h) + x1, y1, x2, y2 = clip_x1y1x2y2(x1, y1, x2, y2, img_w, img_h) + rois += [x1y1x2y2_to_xywh(x1, y1, x2, y2)] + log.info(f"image: {img_h}x{img_w} ~ tile_size: {self._tile_size}") + log.info(f"{n_row}x{n_col} tiles -> {len(rois)} tiles") + return rois + + class OTXTileDatasetFactory: """OTX tile dataset factory.""" @@ -47,7 +119,7 @@ def create( cls, task: OTXTaskType, dataset: OTXDataset, - tile_config: TilerConfig, + tile_config: TileConfig, ) -> OTXTileDataset: """Create a tile dataset based on the task type and subset type. @@ -82,7 +154,7 @@ class OTXTileDataset(OTXDataset): tile_config (TilerConfig): Tile configuration. """ - def __init__(self, dataset: OTXDataset, tile_config: TilerConfig) -> None: + def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None: super().__init__( dataset.dm_subset, dataset.transforms, @@ -124,12 +196,16 @@ def get_tiles(self, image: np.ndarray, item: DatasetItem) -> tuple[list[OTXDataE """ tile_ds = DmDataset.from_iterable([item]) tile_ds = tile_ds.transform( - Tile, - grid_size=self.tile_config.grid_size, + OTXTileTransform, + tile_size=self.tile_config.tile_size, overlap=(self.tile_config.overlap, self.tile_config.overlap), threshold_drop_ann=0.5, ) + if self.dm_subset.name == "val": + # NOTE: filter validation tiles with annotations only to avoid evaluation on empty tiles. + tile_ds = tile_ds.filter("/item/annotation", filter_annotations=True, remove_empty=True) + tile_entities: list[OTXDataEntity] = [] tile_attrs: list[dict] = [] for tile in tile_ds: @@ -152,11 +228,11 @@ class OTXTileTrainDataset(OTXTileDataset): tile_config (TilerConfig): Tile configuration. """ - def __init__(self, dataset: OTXDataset, tile_config: TilerConfig) -> None: + def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None: dm_dataset = dataset.dm_subset.as_dataset() dm_dataset = dm_dataset.transform( - Tile, - grid_size=tile_config.grid_size, + OTXTileTransform, + tile_size=tile_config.tile_size, overlap=(tile_config.overlap, tile_config.overlap), threshold_drop_ann=0.5, ) @@ -177,7 +253,7 @@ class OTXTileDetTestDataset(OTXTileDataset): tile_config (TilerConfig): Tile configuration. """ - def __init__(self, dataset: OTXDetectionDataset, tile_config: TilerConfig) -> None: + def __init__(self, dataset: OTXDetectionDataset, tile_config: TileConfig) -> None: super().__init__(dataset, tile_config) @property @@ -268,7 +344,7 @@ class OTXTileInstSegTestDataset(OTXTileDataset): tile_config (TilerConfig): Tile configuration. """ - def __init__(self, dataset: OTXInstanceSegDataset, tile_config: TilerConfig) -> None: + def __init__(self, dataset: OTXInstanceSegDataset, tile_config: TileConfig) -> None: super().__init__(dataset, tile_config) @property diff --git a/src/otx/core/data/module.py b/src/otx/core/data/module.py index bd54ac3820e..c4027558692 100644 --- a/src/otx/core/data/module.py +++ b/src/otx/core/data/module.py @@ -11,7 +11,7 @@ from datumaro import Dataset as DmDataset from lightning import LightningDataModule from omegaconf import DictConfig, OmegaConf -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, RandomSampler from otx.core.data.dataset.base import LabelInfo from otx.core.data.dataset.tile import OTXTileDatasetFactory @@ -21,6 +21,7 @@ parse_mem_cache_size_to_int, ) from otx.core.data.pre_filtering import pre_filtering +from otx.core.data.tile_adaptor import adapt_tile_config from otx.core.types.task import OTXTaskType if TYPE_CHECKING: @@ -55,6 +56,8 @@ def __init__( dataset = DmDataset.import_from(self.config.data_root, format=self.config.data_format) if self.task != "H_LABEL_CLS": dataset = pre_filtering(dataset, self.config.data_format) + if config.tile_config.enable_tiler and config.tile_config.enable_adaptive_tiling: + adapt_tile_config(config.tile_config, dataset=dataset) config_mapping = { self.config.train_subset.subset_name: self.config.train_subset, @@ -121,15 +124,28 @@ def train_dataloader(self) -> DataLoader: config = self.config.train_subset dataset = self._get_dataset(config.subset_name) - return DataLoader( - dataset=dataset, - batch_size=config.batch_size, - shuffle=True, - num_workers=config.num_workers, - pin_memory=True, - collate_fn=dataset.collate_fn, - persistent_workers=config.num_workers > 0, - ) + common_args = { + "dataset": dataset, + "batch_size": config.batch_size, + "num_workers": config.num_workers, + "pin_memory": True, + "collate_fn": dataset.collate_fn, + "persistent_workers": config.num_workers > 0, + } + + tile_config = self.config.tile_config + if tile_config.enable_tiler and tile_config.sampling_ratio < 1: + num_samples = max(1, int(len(dataset) * tile_config.sampling_ratio)) + log.info(f"Using tiled sampling with {num_samples} samples") + common_args.update( + { + "shuffle": False, + "sampler": RandomSampler(dataset, num_samples=num_samples), + }, + ) + else: + common_args["shuffle"] = True + return DataLoader(**common_args) def val_dataloader(self) -> DataLoader: """Get val dataloader.""" diff --git a/src/otx/core/data/tile_adaptor.py b/src/otx/core/data/tile_adaptor.py new file mode 100644 index 00000000000..dedbe890423 --- /dev/null +++ b/src/otx/core/data/tile_adaptor.py @@ -0,0 +1,183 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Tile Adaptor for OTX.""" +from __future__ import annotations + +import logging as log +from typing import Any + +import numpy as np +from datumaro import Bbox, Dataset, DatasetSubset, Polygon + +from otx.core.config.data import TileConfig + + +def compute_robust_statistics(values: np.array) -> dict[str, float]: + """Computes robust statistics of given samples. + + Args: + values (np.array): Array of samples + + Returns: + dict[str, float]: Robust avg, min, max values + """ + stat: dict = {} + if values.size == 0: + return stat + + avg_value = np.mean(values) + std_value = np.std(values) + avg_3std_min_value = avg_value - 3 * std_value + avg_3std_max_value = avg_value + 3 * std_value + min_value = np.min(values) + max_value = np.max(values) + + # Refine min/max to reduce outlier effect + robust_min_value = max(min_value, avg_3std_min_value) + robust_max_value = min(max_value, avg_3std_max_value) + + stat["avg"] = float(avg_value) + stat["std"] = float(std_value) + stat["min"] = float(min_value) + stat["max"] = float(max_value) + stat["robust_min"] = float(robust_min_value) + stat["robust_max"] = float(robust_max_value) + return stat + + +def compute_robust_scale_statistics(values: np.array) -> dict[str, float]: + """Computes robust statistics of scale values. + + Average of 0.5x scale and 2x scale should be 1x + + Args: + values (np.array): Array of positive scale values + + Returns: + dict[str, float]: Robust avg, min, max values + """ + # Compute stat in log scale & convert back to original scale + if values.size == 0: + return {} + + stat = compute_robust_statistics(np.log(values)) + stat = {k: float(np.exp(v)) for k, v in stat.items()} + # Normal scale std is easier to understand + stat["std"] = float(np.std(values)) + return stat + + +def compute_robust_dataset_statistics( + dataset: DatasetSubset, + ann_stat: bool = False, + max_samples: int = 1000, +) -> dict[str, Any]: + """Computes robust statistics of image & annotation sizes. + + Args: + dataset (DatasetSubset): Input dataset. + ann_stat (bool, optional): Whether to compute annotation size statistics. Defaults to False. + max_samples (int, optional): Maximum number of dataset subsamples to analyze. Defaults to 1000. + + Returns: + Dict[str, Any]: Robust avg, min, max values for images, and annotations optionally. + ex) stat = { + "image": {"avg": ...}, + "annotation": { + "num_per_image": {"avg": ...}, + "size_of_shape": {"avg": ...}, + } + } + """ + stat: dict = {} + if len(dataset) == 0 or max_samples <= 0: + return stat + + data_ids = [item.id for item in dataset] + max_image_samples = min(max_samples, len(dataset)) + # NOTE: current OTX does not set seed globally + rng = np.random.default_rng(42) + data_ids = rng.choice(data_ids, max_image_samples, replace=False)[:max_image_samples] + + image_sizes = [] + for idx in data_ids: + data = dataset.get(id=idx, subset=dataset.name) + height, width = data.media.size + image_sizes.append(np.sqrt(width * height)) + stat["image"] = compute_robust_scale_statistics(np.array(image_sizes)) + + if ann_stat: + stat["annotation"] = {} + num_per_images: list[int] = [] + size_of_box_shapes: list[float] = [] + size_of_polygon_shapes: list[float] = [] + for idx in data_ids: + data = dataset.get(id=idx, subset=dataset.name) + annotations: dict[str, list] = {"boxes": [], "polygons": []} + for ann in data.annotations: + if isinstance(ann, Bbox): + annotations["boxes"].append(ann) + elif isinstance(ann, Polygon): + annotations["polygons"].append(ann) + + num_per_images.append(max(len(annotations["boxes"]), len(annotations["polygons"]))) + + if len(size_of_box_shapes) >= max_samples or len(size_of_polygon_shapes) >= max_samples: + continue + + size_of_box_shapes.extend( + filter(lambda x: x >= 1, [np.sqrt(anno.get_area()) for anno in annotations["boxes"]]), + ) + size_of_polygon_shapes.extend( + filter(lambda x: x >= 1, [np.sqrt(anno.get_area()) for anno in annotations["polygons"]]), + ) + + stat["annotation"]["num_per_image"] = compute_robust_statistics(np.array(num_per_images)) + stat["annotation"]["size_of_shape"] = compute_robust_scale_statistics( + np.array(size_of_polygon_shapes) if len(size_of_polygon_shapes) else np.array(size_of_box_shapes), + ) + + return stat + + +def adapt_tile_config(tile_config: TileConfig, dataset: Dataset) -> None: + """Config tile parameters. + + Adapt based on annotation statistics. + i.e. tile size, tile overlap, ratio and max objects per sample + + Args: + tile_config (TileConfig): tiling parameters of the model + dataset (Dataset): Datumaro dataset including all subsets + """ + if (train_dataset := dataset.subsets().get("train")) is not None: + stat = compute_robust_dataset_statistics(train_dataset, ann_stat=True) + max_num_objects = round(stat["annotation"]["num_per_image"]["max"]) + avg_size = stat["annotation"]["size_of_shape"]["avg"] + min_size = stat["annotation"]["size_of_shape"]["robust_min"] + max_size = stat["annotation"]["size_of_shape"]["robust_max"] + log.info(f"----> [stat] scale avg: {avg_size}") + log.info(f"----> [stat] scale min: {min_size}") + log.info(f"----> [stat] scale max: {max_size}") + + log.info("[Adaptive tiling pararms]") + object_tile_ratio = tile_config.object_tile_ratio + tile_size = int(avg_size / object_tile_ratio) + tile_overlap = max_size / tile_size + log.info(f"----> avg_object_size: {avg_size}") + log.info(f"----> max_object_size: {max_size}") + log.info(f"----> object_tile_ratio: {object_tile_ratio}") + log.info(f"----> tile_size: {avg_size} / {object_tile_ratio} = {tile_size}") + log.info(f"----> tile_overlap: {max_size} / {tile_size} = {tile_overlap}") + + if tile_overlap >= 0.9: + # Use the average object area if the tile overlap is too large to prevent 0 stride. + tile_overlap = avg_size / tile_size + log.info(f"----> (too big) tile_overlap: {avg_size} / {tile_size} = {tile_overlap}") + + # TODO(Eugene): how to validate lower/upper_bound? dataclass? pydantic? + # https://github.com/openvinotoolkit/training_extensions/pull/2903 + tile_config.tile_size = (tile_size, tile_size) + tile_config.max_num_instances = max_num_objects + tile_config.overlap = tile_overlap diff --git a/src/otx/core/model/entity/detection.py b/src/otx/core/model/entity/detection.py index 1dbfd24e2c9..0925f18cc8b 100644 --- a/src/otx/core/model/entity/detection.py +++ b/src/otx/core/model/entity/detection.py @@ -11,6 +11,7 @@ import torch from torchvision import tv_tensors +from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity from otx.core.data.entity.tile import TileBatchDetDataEntity @@ -31,6 +32,10 @@ class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity, TileBatchDetDataEntity]): """Base class for the detection models used in OTX.""" + def __init__(self, *arg, **kwargs) -> None: + super().__init__(*arg, **kwargs) + self.tile_config = TileConfig() + def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: """Unpack detection tiles. @@ -42,7 +47,11 @@ def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: """ tile_preds: list[DetBatchPredEntity] = [] tile_attrs: list[list[dict[str, int | str]]] = [] - merger = DetectionTileMerge(inputs.imgs_info) + merger = DetectionTileMerge( + inputs.imgs_info, + self.tile_config.iou_threshold, + self.tile_config.max_num_instances, + ) for batch_tile_attrs, batch_tile_input in inputs.unbind(): output = self.forward(batch_tile_input) if isinstance(output, OTXBatchLossEntity): diff --git a/src/otx/core/model/entity/instance_segmentation.py b/src/otx/core/model/entity/instance_segmentation.py index 124046ed82b..c8dd1818b36 100644 --- a/src/otx/core/model/entity/instance_segmentation.py +++ b/src/otx/core/model/entity/instance_segmentation.py @@ -12,6 +12,7 @@ import torch from torchvision import tv_tensors +from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.instance_segmentation import ( InstanceSegBatchDataEntity, @@ -36,6 +37,10 @@ class OTXInstanceSegModel( ): """Base class for the Instance Segmentation models used in OTX.""" + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.tile_config = TileConfig() + def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchPredEntity: """Unpack instance segmentation tiles. @@ -47,7 +52,11 @@ def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchP """ tile_preds: list[InstanceSegBatchPredEntity] = [] tile_attrs: list[list[dict[str, int | str]]] = [] - merger = InstanceSegTileMerge(inputs.imgs_info) + merger = InstanceSegTileMerge( + inputs.imgs_info, + self.tile_config.iou_threshold, + self.tile_config.max_num_instances, + ) for batch_tile_attrs, batch_tile_input in inputs.unbind(): output = self.forward(batch_tile_input) if isinstance(output, OTXBatchLossEntity): @@ -242,7 +251,7 @@ def _customize_outputs( tv_tensors.BoundingBoxes( output.pred_instances.bboxes, format="XYXY", - canvas_size=output.img_shape, + canvas_size=output.ori_shape, ), ) output_masks = tv_tensors.Mask( diff --git a/src/otx/core/utils/tile_merge.py b/src/otx/core/utils/tile_merge.py index 4b73aa09c92..5f147c77a04 100644 --- a/src/otx/core/utils/tile_merge.py +++ b/src/otx/core/utils/tile_merge.py @@ -11,6 +11,7 @@ import torch from torchvision import tv_tensors +from torchvision.ops import batched_nms from otx.core.data.entity.base import ImageInfo, T_OTXBatchPredEntity, T_OTXDataEntity from otx.core.data.entity.detection import DetBatchPredEntity, DetPredEntity @@ -25,21 +26,19 @@ class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]): Args: img_infos (list[ImageInfo]): Original image information before tiling. - score_thres (float, optional): Score threshold to filter out low score predictions. Defaults to 0.1. - max_num_instances (int, optional): Maximum number of instances to keep. Defaults to 100. + iou_threshold (float, optional): IoU threshold for non-maximum suppression. Defaults to 0.45. + max_num_instances (int, optional): Maximum number of instances to keep. Defaults to 500. - TODO (Eugene): Find a way to configure tile merge parameters(score_thres, max_num, etc) from tile config. - # https://github.com/openvinotoolkit/datumaro/pull/1194 """ def __init__( self, img_infos: list[ImageInfo], - score_thres: float = 0.25, - max_num_instances: int = 100, + iou_threshold: float = 0.45, + max_num_instances: int = 500, ) -> None: self.img_infos = img_infos - self.score_thres = score_thres + self.iou_threshold = iou_threshold self.max_num_instances = max_num_instances @abstractmethod @@ -69,6 +68,25 @@ def merge( """ raise NotImplementedError + def nms_postprocess( + self, + bboxes: torch.Tensor, + scores: torch.Tensor, + labels: torch.Tensor, + masks: None | list[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None | torch.Tensor]: + """Non-maximum suppression and post-process.""" + keep = batched_nms(bboxes, scores, labels, self.iou_threshold) + if len(keep) > self.max_num_instances: + keep = keep[: self.max_num_instances] + bboxes = bboxes[keep] + labels = labels[keep] + scores = scores[keep] + if masks is not None and len(masks) > 0: + # coalesce sparse tensors to prevent them from growing too large. + masks = torch.stack([masks[idx] for idx in keep]).coalesce().to_dense() + return bboxes, labels, scores, masks + class DetectionTileMerge(TileMerge): """Detection tile merge.""" @@ -96,15 +114,9 @@ def merge( tile_preds.labels, tile_preds.scores, ): - keep_indices = tile_scores > self.score_thres - keep_indices = keep_indices.nonzero(as_tuple=True)[0] - _bboxes = tile_bboxes[keep_indices] - _labels = tile_labels[keep_indices] - _scores = tile_scores[keep_indices] - offset_x, offset_y, _, _ = tile_attr["roi"] - _bboxes[:, 0::2] += offset_x - _bboxes[:, 1::2] += offset_y + tile_bboxes[:, 0::2] += offset_x + tile_bboxes[:, 1::2] += offset_y tile_id = tile_attr["tile_id"] if tile_id not in img_ids: @@ -115,9 +127,9 @@ def merge( DetPredEntity( image=torch.empty(tile_img_info.ori_shape), img_info=tile_img_info, - bboxes=_bboxes, - labels=_labels, - score=_scores, + bboxes=tile_bboxes, + labels=tile_labels, + score=tile_scores, ), ) return [ @@ -150,12 +162,11 @@ def _merge_entities(self, img_info: ImageInfo, entities: list[DetPredEntity]) -> labels = torch.stack(labels) if len(labels) > 0 else torch.empty((0,), device=img_info.device) scores = torch.stack(scores) if len(scores) > 0 else torch.empty((0,), device=img_info.device) - sort_inds = torch.argsort(scores, descending=True) - if len(sort_inds) > self.max_num_instances: - sort_inds = sort_inds[: self.max_num_instances] - bboxes = bboxes[sort_inds] - labels = labels[sort_inds] - scores = scores[sort_inds] + bboxes, labels, scores, _ = self.nms_postprocess( + bboxes, + scores, + labels, + ) return DetPredEntity( image=torch.empty(img_size), @@ -197,7 +208,7 @@ def merge( tile_preds.scores, tile_preds.masks, ): - keep_indices = (tile_scores > self.score_thres) & (tile_masks.sum((1, 2)) > 0) + keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0 keep_indices = keep_indices.nonzero(as_tuple=True)[0] _bboxes = tile_bboxes[keep_indices] _labels = tile_labels[keep_indices] @@ -264,18 +275,9 @@ def _merge_entities(self, img_info: ImageInfo, entities: list[InstanceSegPredEnt bboxes = torch.stack(bboxes) if len(bboxes) > 0 else torch.empty((0, 4), device=img_info.device) labels = torch.stack(labels) if len(labels) > 0 else torch.empty((0,), device=img_info.device) scores = torch.stack(scores) if len(scores) > 0 else torch.empty((0,), device=img_info.device) + masks = masks if len(masks) > 0 else torch.empty((0, *img_size)) - sort_inds = torch.argsort(scores, descending=True) - if len(sort_inds) > self.max_num_instances: - sort_inds = sort_inds[: self.max_num_instances] - - bboxes = bboxes[sort_inds] - labels = labels[sort_inds] - scores = scores[sort_inds] - masks = ( - torch.stack([masks[idx] for idx in sort_inds]).to_dense() if len(masks) > 0 else torch.empty((0, *img_size)) - ) - + bboxes, labels, scores, masks = self.nms_postprocess(bboxes, scores, labels, masks) return InstanceSegPredEntity( image=torch.empty(img_size), img_info=img_info, diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 694a29827ac..484fcd452ff 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -11,7 +11,7 @@ import torch from lightning import Trainer, seed_everything -from otx.core.config.data import DataModuleConfig, SubsetConfig, TilerConfig +from otx.core.config.data import DataModuleConfig, SubsetConfig, TileConfig from otx.core.config.device import DeviceConfig from otx.core.config.explain import ExplainConfig from otx.core.data.module import OTXDataModule @@ -573,7 +573,7 @@ def from_config(cls, config_path: PathLike, data_root: PathLike | None = None, * train_subset=SubsetConfig(**data_config["config"].pop("train_subset")), val_subset=SubsetConfig(**data_config["config"].pop("val_subset")), test_subset=SubsetConfig(**data_config["config"].pop("test_subset")), - tile_config=TilerConfig(**data_config["config"].pop("tile_config", {})), + tile_config=TileConfig(**data_config["config"].pop("tile_config", {})), **data_config["config"], ), ) diff --git a/src/otx/engine/utils/auto_configurator.py b/src/otx/engine/utils/auto_configurator.py index edeceb9ce50..2c57a0ad713 100644 --- a/src/otx/engine/utils/auto_configurator.py +++ b/src/otx/engine/utils/auto_configurator.py @@ -13,7 +13,7 @@ import datumaro from lightning.pytorch.cli import instantiate_class -from otx.core.config.data import DataModuleConfig, SubsetConfig, TilerConfig +from otx.core.config.data import DataModuleConfig, SubsetConfig, TileConfig from otx.core.data.dataset.base import LabelInfo from otx.core.data.module import OTXDataModule from otx.core.model.entity.base import OVModel @@ -211,7 +211,7 @@ def get_datamodule(self) -> OTXDataModule | None: train_subset=SubsetConfig(**data_config.pop("train_subset")), val_subset=SubsetConfig(**data_config.pop("val_subset")), test_subset=SubsetConfig(**data_config.pop("test_subset")), - tile_config=TilerConfig(**data_config.pop("tile_config", {})), + tile_config=TileConfig(**data_config.pop("tile_config", {})), **data_config, ), ) @@ -309,7 +309,7 @@ def get_ov_datamodule(self) -> OTXDataModule: train_subset=SubsetConfig(**data_config.pop("train_subset")), val_subset=SubsetConfig(**data_config.pop("val_subset")), test_subset=SubsetConfig(**data_config.pop("test_subset")), - tile_config=TilerConfig(**data_config.pop("tile_config", {})), + tile_config=TileConfig(**data_config.pop("tile_config", {})), **data_config, ), ) diff --git a/src/otx/recipe/detection/yolox_l_tile.yaml b/src/otx/recipe/detection/yolox_l_tile.yaml index af45f977f3b..3bc2727e805 100644 --- a/src/otx/recipe/detection/yolox_l_tile.yaml +++ b/src/otx/recipe/detection/yolox_l_tile.yaml @@ -35,6 +35,7 @@ overrides: config: tile_config: enable_tiler: true + enable_adaptive_tiling: true image_color_channel: BGR train_subset: num_workers: 4 diff --git a/src/otx/recipe/detection/yolox_s_tile.yaml b/src/otx/recipe/detection/yolox_s_tile.yaml index 5f44cbe8964..f9806972601 100644 --- a/src/otx/recipe/detection/yolox_s_tile.yaml +++ b/src/otx/recipe/detection/yolox_s_tile.yaml @@ -35,6 +35,7 @@ overrides: config: tile_config: enable_tiler: true + enable_adaptive_tiling: true image_color_channel: BGR train_subset: num_workers: 4 diff --git a/src/otx/recipe/detection/yolox_tiny_tile.yaml b/src/otx/recipe/detection/yolox_tiny_tile.yaml index 54d37914817..2e305cbb9c3 100644 --- a/src/otx/recipe/detection/yolox_tiny_tile.yaml +++ b/src/otx/recipe/detection/yolox_tiny_tile.yaml @@ -35,6 +35,7 @@ overrides: config: tile_config: enable_tiler: true + enable_adaptive_tiling: true train_subset: num_workers: 4 batch_size: 8 diff --git a/src/otx/recipe/detection/yolox_x_tile.yaml b/src/otx/recipe/detection/yolox_x_tile.yaml index f5c74413cf8..5b23cc530fe 100644 --- a/src/otx/recipe/detection/yolox_x_tile.yaml +++ b/src/otx/recipe/detection/yolox_x_tile.yaml @@ -35,6 +35,7 @@ overrides: config: tile_config: enable_tiler: true + enable_adaptive_tiling: true image_color_channel: BGR train_subset: num_workers: 4 diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml index c795668081c..b495600617c 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml @@ -37,6 +37,7 @@ overrides: config: tile_config: enable_tiler: true + enable_adaptive_tiling: true include_polygons: true train_subset: num_workers: 4 diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml index 398943e1744..3b9c6e5e9b0 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml @@ -37,6 +37,7 @@ overrides: config: tile_config: enable_tiler: true + enable_adaptive_tiling: true include_polygons: true train_subset: num_workers: 4 diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index 55b29e92370..a64c673e117 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -39,7 +39,6 @@ def test_otx_e2e( None """ task = recipe.split("/")[-2] - tile_param = fxt_cli_override_command_per_task["tile"] if "tile" in recipe else [] model_name = recipe.split("/")[-1].split(".")[0] if task in ("action_classification"): pytest.xfail(reason="xFail until this root cause is resolved on the Datumaro side.") @@ -60,7 +59,6 @@ def test_otx_e2e( "--max_epochs", "2", *fxt_cli_override_command_per_task[task], - *tile_param, ] run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) @@ -195,6 +193,9 @@ def test_otx_explain_e2e( Returns: None """ + if "tile" in recipe: + pytest.skip("Explain is not supported for tiling yet.") + import cv2 import numpy as np diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index e51f793e0ea..03337f1a120 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -159,8 +159,4 @@ def fxt_cli_override_command_per_task() -> dict: "1", "--disable-infer-num-classes", ], - "tile": [ - "--data.config.tile_config.grid_size", - "[1,1]", - ], } diff --git a/tests/integration/detection/conftest.py b/tests/integration/detection/conftest.py index a03599b040f..1464fc7d5ac 100644 --- a/tests/integration/detection/conftest.py +++ b/tests/integration/detection/conftest.py @@ -10,7 +10,7 @@ from otx.core.config.data import ( DataModuleConfig, SubsetConfig, - TilerConfig, + TileConfig, VisualPromptingConfig, ) from otx.core.data.module import OTXDataModule @@ -56,7 +56,7 @@ def fxt_datamodule(fxt_asset_dir, fxt_mmcv_det_transform_config) -> OTXDataModul transform_lib_type="MMDET", transforms=fxt_mmcv_det_transform_config, ), - tile_config=TilerConfig(), + tile_config=TileConfig(), vpm_config=VisualPromptingConfig(), ) datamodule = OTXDataModule( diff --git a/tests/integration/test_tiling.py b/tests/integration/test_tiling.py new file mode 100644 index 00000000000..01549184291 --- /dev/null +++ b/tests/integration/test_tiling.py @@ -0,0 +1,149 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from __future__ import annotations + +import numpy as np +import pytest +from datumaro import Dataset as DmDataset +from omegaconf import DictConfig, OmegaConf +from otx.core.config.data import ( + DataModuleConfig, + SubsetConfig, + TileConfig, + VisualPromptingConfig, +) +from otx.core.data.dataset.tile import OTXTileTransform +from otx.core.data.entity.detection import DetBatchDataEntity +from otx.core.data.entity.tile import TileBatchDetDataEntity +from otx.core.data.module import OTXDataModule +from otx.core.types.task import OTXTaskType + + +class TestOTXTiling: + @pytest.fixture() + def fxt_mmcv_det_transform_config(self) -> list[DictConfig]: + mmdet_base = OmegaConf.load("src/otx/recipe/_base_/data/mmdet_base.yaml") + return mmdet_base.config.train_subset.transforms + + @pytest.fixture() + def fxt_det_data_config(self, fxt_asset_dir, fxt_mmcv_det_transform_config) -> OTXDataModule: + data_root = fxt_asset_dir / "car_tree_bug" + + batch_size = 8 + num_workers = 0 + return DataModuleConfig( + data_format="coco_instances", + data_root=data_root, + train_subset=SubsetConfig( + subset_name="train", + batch_size=batch_size, + num_workers=num_workers, + transform_lib_type="MMDET", + transforms=fxt_mmcv_det_transform_config, + ), + val_subset=SubsetConfig( + subset_name="val", + batch_size=batch_size, + num_workers=num_workers, + transform_lib_type="MMDET", + transforms=fxt_mmcv_det_transform_config, + ), + test_subset=SubsetConfig( + subset_name="test", + batch_size=batch_size, + num_workers=num_workers, + transform_lib_type="MMDET", + transforms=fxt_mmcv_det_transform_config, + ), + tile_config=TileConfig(), + vpm_config=VisualPromptingConfig(), + ) + + def test_tile_transform(self): + dataset = DmDataset.import_from("tests/assets/car_tree_bug", format="coco_instances") + first_item = next(iter(dataset), None) + height, width = first_item.media.data.shape[:2] + + rng = np.random.default_rng() + tile_size = rng.integers(low=100, high=500, size=(2,)) + overlap = rng.random(2) + threshold_drop_ann = rng.random() + tiled_dataset = DmDataset.import_from("tests/assets/car_tree_bug", format="coco_instances") + tiled_dataset.transform( + OTXTileTransform, + tile_size=tile_size, + overlap=overlap, + threshold_drop_ann=threshold_drop_ann, + ) + + h_stride = int((1 - overlap[0]) * tile_size[0]) + w_stride = int((1 - overlap[1]) * tile_size[1]) + num_tile_rows = (height + h_stride - 1) // h_stride + num_tile_cols = (width + w_stride - 1) // w_stride + assert len(tiled_dataset) == (num_tile_rows * num_tile_cols * len(dataset)), "Incorrect number of tiles" + + def test_adaptive_tiling(self, fxt_det_data_config): + # Enable tile adapter + fxt_det_data_config.tile_config.enable_tiler = True + fxt_det_data_config.tile_config.enable_adaptive_tiling = True + tile_datamodule = OTXDataModule( + task=OTXTaskType.DETECTION, + config=fxt_det_data_config, + ) + tile_datamodule.prepare_data() + + assert tile_datamodule.config.tile_config.tile_size == (6750, 6750), "Tile size should be [6750, 6750]" + assert ( + pytest.approx(tile_datamodule.config.tile_config.overlap, rel=1e-3) == 0.03608 + ), "Overlap should be 0.03608" + assert tile_datamodule.config.tile_config.max_num_instances == 3, "Max num instances should be 3" + + def test_tile_sampler(self, fxt_det_data_config): + rng = np.random.default_rng() + + fxt_det_data_config.tile_config.enable_tiler = True + fxt_det_data_config.tile_config.enable_adaptive_tiling = False + fxt_det_data_config.tile_config.sampling_ratio = rng.random() + tile_datamodule = OTXDataModule( + task=OTXTaskType.DETECTION, + config=fxt_det_data_config, + ) + tile_datamodule.prepare_data() + sampled_count = max( + 1, + int(len(tile_datamodule._get_dataset("train")) * fxt_det_data_config.tile_config.sampling_ratio), + ) + + count = 0 + for batch in tile_datamodule.train_dataloader(): + count += batch.batch_size + assert isinstance(batch, DetBatchDataEntity) + + assert sampled_count == count, "Sampled count should be equal to the count of the dataloader batch size" + + def test_train_dataloader(self, fxt_det_data_config) -> None: + # Enable tile adapter + fxt_det_data_config.tile_config.enable_tiler = True + tile_datamodule = OTXDataModule( + task=OTXTaskType.DETECTION, + config=fxt_det_data_config, + ) + tile_datamodule.prepare_data() + for batch in tile_datamodule.train_dataloader(): + assert isinstance(batch, DetBatchDataEntity) + + def test_val_dataloader(self, fxt_det_data_config) -> None: + # Enable tile adapter + fxt_det_data_config.tile_config.enable_tiler = True + tile_datamodule = OTXDataModule( + task=OTXTaskType.DETECTION, + config=fxt_det_data_config, + ) + tile_datamodule.prepare_data() + for batch in tile_datamodule.val_dataloader(): + assert isinstance(batch, TileBatchDetDataEntity) + + def test_tile_merge(self): + pytest.skip("Not implemented yet") diff --git a/tests/unit/core/data/test_factory.py b/tests/unit/core/data/test_factory.py index e9073dd9fc5..328ad03d173 100644 --- a/tests/unit/core/data/test_factory.py +++ b/tests/unit/core/data/test_factory.py @@ -4,7 +4,7 @@ """Test Factory classes for dataset and transforms.""" import pytest -from otx.core.config.data import DataModuleConfig, SubsetConfig, TilerConfig, VisualPromptingConfig +from otx.core.config.data import DataModuleConfig, SubsetConfig, TileConfig, VisualPromptingConfig from otx.core.data.dataset.classification import OTXMulticlassClsDataset from otx.core.data.dataset.detection import OTXDetectionDataset from otx.core.data.dataset.segmentation import OTXSegmentationDataset @@ -53,7 +53,7 @@ def test_create(self, fxt_mock_dm_subset, fxt_mem_cache_handler, task_type, data mocker.patch.object(TransformLibFactory, "generate", return_value=None) cfg_subset = mocker.MagicMock(spec=SubsetConfig) cfg_data_module = mocker.MagicMock(spec=DataModuleConfig) - cfg_data_module.tile_config = mocker.MagicMock(spec=TilerConfig) + cfg_data_module.tile_config = mocker.MagicMock(spec=TileConfig) cfg_data_module.tile_config.enable_tiler = False cfg_data_module.vpm_config = mocker.MagicMock(spec=VisualPromptingConfig) cfg_data_module.vpm_config.use_bbox = False diff --git a/tests/unit/core/data/test_module.py b/tests/unit/core/data/test_module.py index 7971b227c65..0ab4c0db8bd 100644 --- a/tests/unit/core/data/test_module.py +++ b/tests/unit/core/data/test_module.py @@ -11,7 +11,7 @@ from otx.core.config.data import ( DataModuleConfig, SubsetConfig, - TilerConfig, + TileConfig, ) from otx.core.data.module import ( OTXDataModule, @@ -37,7 +37,7 @@ def fxt_config(self) -> DataModuleConfig: mock.val_subset.num_workers = 0 mock.test_subset = MagicMock(spec=SubsetConfig) mock.test_subset.num_workers = 0 - mock.tile_config = MagicMock(spec=TilerConfig) + mock.tile_config = MagicMock(spec=TileConfig) mock.tile_config.enable_tiler = False return mock From 8ebc0930ee3fdbde6d2b3586e9e03139db21f488 Mon Sep 17 00:00:00 2001 From: Jaeguk Hyun Date: Thu, 15 Feb 2024 15:33:40 +0900 Subject: [PATCH 3/6] Migrate SSD anchor generator (#2915) * Add AnchorGenerator callback * Add anchor load and save methods * Support anchor loading from v1 weight * Add unit tests * Reflect comments * Move auto anchor generating to OTXSSD --- .../heads/custom_anchor_generator.py | 6 +- src/otx/algo/detection/ssd.py | 124 +++++++++++++++++- src/otx/algo/utils/support_otx_v1.py | 13 ++ src/otx/core/model/entity/base.py | 8 ++ src/otx/core/model/module/base.py | 5 +- tests/unit/algo/detection/test_ssd.py | 26 ++++ tests/unit/algo/utils/test_support_otx_v1.py | 97 ++++++++++---- 7 files changed, 250 insertions(+), 29 deletions(-) create mode 100644 tests/unit/algo/detection/test_ssd.py diff --git a/src/otx/algo/detection/heads/custom_anchor_generator.py b/src/otx/algo/detection/heads/custom_anchor_generator.py index d5bf061f279..0637991847f 100644 --- a/src/otx/algo/detection/heads/custom_anchor_generator.py +++ b/src/otx/algo/detection/heads/custom_anchor_generator.py @@ -34,10 +34,10 @@ def __init__( self.centers = [(stride / 2.0, stride / 2.0) for stride in strides] self.center_offset = 0 - self.base_anchors = self.gen_base_anchors() + self.gen_base_anchors() self.use_box_type = False - def gen_base_anchors(self) -> list[torch.Tensor]: + def gen_base_anchors(self) -> None: """Generate base anchor for SSD.""" multi_level_base_anchors = [] for widths, heights, centers in zip(self.widths, self.heights, self.centers): @@ -47,7 +47,7 @@ def gen_base_anchors(self) -> list[torch.Tensor]: center=torch.Tensor(centers), ) multi_level_base_anchors.append(base_anchors) - return multi_level_base_anchors + self.base_anchors = multi_level_base_anchors def gen_single_level_base_anchors( self, diff --git a/src/otx/algo/detection/ssd.py b/src/otx/algo/detection/ssd.py index 9152f524299..69cfb89f79d 100644 --- a/src/otx/algo/detection/ssd.py +++ b/src/otx/algo/detection/ssd.py @@ -5,9 +5,13 @@ from __future__ import annotations +import logging from copy import deepcopy from typing import TYPE_CHECKING, Any, Literal +import numpy as np +from datumaro.components.annotation import Bbox + from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.model.entity.detection import MMDetCompatibleModel @@ -15,10 +19,17 @@ if TYPE_CHECKING: import torch + from lightning import Trainer + from mmdet.models.task_modules.prior_generators.anchor_generator import AnchorGenerator from mmengine.registry import Registry from omegaconf import DictConfig from torch import device, nn + from otx.core.data.dataset.base import OTXDataset + + +logger = logging.getLogger() + class SSD(MMDetCompatibleModel): """Detecion model class for SSD.""" @@ -28,6 +39,7 @@ def __init__(self, num_classes: int, variant: Literal["mobilenetv2"]) -> None: config = read_mmconfig(model_name=model_name) super().__init__(num_classes=num_classes, config=config) self.image_size = (1, 3, 864, 864) + self._register_load_state_dict_pre_hook(self._set_anchors_hook) def _create_model(self) -> nn.Module: from mmdet.models.data_preprocessors import ( @@ -52,6 +64,94 @@ def device(self) -> device: self.classification_layers = self.get_classification_layers(self.config, MODELS, "model.") return build_mm_model(self.config, MODELS, self.load_from) + def setup_callback(self, trainer: Trainer) -> None: + """Callback for setup OTX Model. + + OTXSSD requires auto anchor generating w.r.t. training dataset for better accuracy. + This callback will provide training dataset to model's anchor generator. + + Args: + trainer(Trainer): Lightning trainer contains OTXLitModule and OTXDatamodule. + """ + if trainer.training: + anchor_generator = self.model.bbox_head.anchor_generator + dataset = trainer.datamodule.train_dataloader().dataset + new_anchors = self._get_new_anchors(dataset, anchor_generator) + if new_anchors is not None: + logger.warning("Anchor will be updated by Dataset's statistics") + logger.warning(f"{anchor_generator.widths} -> {new_anchors[0]}") + logger.warning(f"{anchor_generator.heights} -> {new_anchors[1]}") + anchor_generator.widths = new_anchors[0] + anchor_generator.heights = new_anchors[1] + anchor_generator.gen_base_anchors() + + def _get_new_anchors(self, dataset: OTXDataset, anchor_generator: AnchorGenerator) -> tuple | None: + """Get new anchors for SSD from OTXDataset.""" + from mmdet.datasets.transforms import Resize + + target_wh = None + if isinstance(dataset.transforms, list): + for transform in dataset.transforms: + if isinstance(transform, Resize): + target_wh = transform.scale + if target_wh is None: + target_wh = (864, 864) + msg = f"Cannot get target_wh from the dataset. Assign it with the default value: {target_wh}" + logger.warning(msg) + group_as = [len(width) for width in anchor_generator.widths] + wh_stats = self._get_sizes_from_dataset_entity(dataset, list(target_wh)) + + if len(wh_stats) < sum(group_as): + logger.warning( + f"There are not enough objects to cluster: {len(wh_stats)} were detected, while it should be " + f"at least {sum(group_as)}. Anchor box clustering was skipped.", + ) + return None + + return self._get_anchor_boxes(wh_stats, group_as) + + @staticmethod + def _get_sizes_from_dataset_entity(dataset: OTXDataset, target_wh: list[int]) -> list[tuple[int, int]]: + """Function to get width and height size of items in OTXDataset. + + Args: + dataset(OTXDataset): OTXDataset in which to get statistics + target_wh(list[int]): target width and height of the dataset + Return + list[tuple[int, int]]: tuples with width and height of each instance + """ + wh_stats: list[tuple[int, int]] = [] + for item in dataset.dm_subset: + for ann in item.annotations: + if isinstance(ann, Bbox): + x1, y1, x2, y2 = ann.points + x1 = x1 / item.media.size[1] * target_wh[0] + y1 = y1 / item.media.size[0] * target_wh[1] + x2 = x2 / item.media.size[1] * target_wh[0] + y2 = y2 / item.media.size[0] * target_wh[1] + wh_stats.append((x2 - x1, y2 - y1)) + return wh_stats + + @staticmethod + def _get_anchor_boxes(wh_stats: list[tuple[int, int]], group_as: list[int]) -> tuple: + """Get new anchor box widths & heights using KMeans.""" + from sklearn.cluster import KMeans + + kmeans = KMeans(init="k-means++", n_clusters=sum(group_as), random_state=0).fit(wh_stats) + centers = kmeans.cluster_centers_ + + areas = np.sqrt(np.prod(centers, axis=1)) + idx = np.argsort(areas) + + widths = centers[idx, 0] + heights = centers[idx, 1] + + group_as = np.cumsum(group_as[:-1]) + widths, heights = np.split(widths, group_as), np.split(heights, group_as) + widths = [width.tolist() for width in widths] + heights = [height.tolist() for height in heights] + return widths, heights + @staticmethod def get_classification_layers( config: DictConfig, @@ -95,6 +195,19 @@ def get_classification_layers( classification_layers[prefix + key] = {"use_bg": use_bg, "num_anchors": num_anchors} return classification_layers + def state_dict(self, *args, **kwargs) -> dict[str, Any]: + """Return state dictionary of model entity with anchor information. + + Returns: + A dictionary containing SSD state. + + """ + state_dict = super().state_dict(*args, **kwargs) + anchor_generator = self.model.bbox_head.anchor_generator + anchors = {"heights": anchor_generator.heights, "widths": anchor_generator.widths} + state_dict["model.model.anchors"] = anchors + return state_dict + def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs) -> None: """Modify input state_dict according to class name matching before weight loading.""" model2ckpt = self.map_class_names(self.model_classes, self.ckpt_classes) @@ -138,6 +251,15 @@ def _export_parameters(self) -> dict[str, Any]: return export_params + def _set_anchors_hook(self, state_dict: dict[str, Any], *args, **kwargs) -> None: + """Pre hook for pop anchor statistics from checkpoint state_dict.""" + anchors = state_dict.pop("model.model.anchors", None) + if anchors is not None: + anchor_generator = self.model.bbox_head.anchor_generator + anchor_generator.widths = anchors["widths"] + anchor_generator.heights = anchors["heights"] + anchor_generator.gen_base_anchors() + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_det_ckpt(state_dict, add_prefix) + return OTXv1Helper.load_ssd_ckpt(state_dict, add_prefix) diff --git a/src/otx/algo/utils/support_otx_v1.py b/src/otx/algo/utils/support_otx_v1.py index f3f105f7382..60f8475054b 100644 --- a/src/otx/algo/utils/support_otx_v1.py +++ b/src/otx/algo/utils/support_otx_v1.py @@ -11,6 +11,7 @@ class OTXv1Helper: @staticmethod def load_common_ckpt(state_dict: dict, add_prefix: str = "") -> dict: """Load the OTX1.x model checkpoints that don't need special handling.""" + state_dict = state_dict["model"]["state_dict"] for key in list(state_dict.keys()): val = state_dict.pop(key) state_dict[add_prefix + key] = val @@ -19,6 +20,7 @@ def load_common_ckpt(state_dict: dict, add_prefix: str = "") -> dict: @staticmethod def load_cls_effnet_b0_ckpt(state_dict: dict, label_type: str, add_prefix: str = "") -> dict: """Load the OTX1.x efficientnet b0 classification checkpoints.""" + state_dict = state_dict["model"]["state_dict"] for key in list(state_dict.keys()): val = state_dict.pop(key) if key.startswith("features."): @@ -34,6 +36,7 @@ def load_cls_effnet_b0_ckpt(state_dict: dict, label_type: str, add_prefix: str = @staticmethod def load_cls_effnet_v2_ckpt(state_dict: dict, label_type: str, add_prefix: str = "") -> dict: """Load the OTX1.x efficientnet v2 classification checkpoints.""" + state_dict = state_dict["model"]["state_dict"] for key in list(state_dict.keys()): val = state_dict.pop(key) if key.startswith("model.classifier."): @@ -48,6 +51,7 @@ def load_cls_effnet_v2_ckpt(state_dict: dict, label_type: str, add_prefix: str = @staticmethod def load_cls_mobilenet_v3_ckpt(state_dict: dict, label_type: str, add_prefix: str = "") -> dict: """Load the OTX1.x mobilenet v3 classification checkpoints.""" + state_dict = state_dict["model"]["state_dict"] for key in list(state_dict.keys()): val = state_dict.pop(key) if key.startswith("classifier."): @@ -72,12 +76,19 @@ def load_cls_deit_ckpt(state_dict: dict, add_prefix: str = "") -> dict: @staticmethod def load_det_ckpt(state_dict: dict, add_prefix: str = "") -> dict: """Load the OTX1.x detection model checkpoints.""" + state_dict = state_dict["model"]["state_dict"] for key in list(state_dict.keys()): val = state_dict.pop(key) if not key.startswith("ema_"): state_dict[add_prefix + key] = val return state_dict + @staticmethod + def load_ssd_ckpt(state_dict: dict, add_prefix: str = "") -> dict: + """Load OTX1.x SSD model checkpoints.""" + state_dict["model"]["state_dict"]["anchors"] = state_dict.pop("anchors", None) + return OTXv1Helper.load_det_ckpt(state_dict, add_prefix) + @staticmethod def load_iseg_ckpt(state_dict: dict, add_prefix: str = "") -> dict: """Load the instance segmentation model checkpoints.""" @@ -86,6 +97,7 @@ def load_iseg_ckpt(state_dict: dict, add_prefix: str = "") -> dict: @staticmethod def load_seg_segnext_ckpt(state_dict: dict, add_prefix: str = "") -> dict: """Load the OTX1.x segnext segmentation checkpoints.""" + state_dict = state_dict["model"]["state_dict"] for key in list(state_dict.keys()): val = state_dict.pop(key) if "ham.bases" not in key: @@ -95,6 +107,7 @@ def load_seg_segnext_ckpt(state_dict: dict, add_prefix: str = "") -> dict: @staticmethod def load_seg_lite_hrnet_ckpt(state_dict: dict, add_prefix: str = "") -> dict: """Load the OTX1.x lite hrnet segmentation checkpoints.""" + state_dict = state_dict["model"]["state_dict"] for key in list(state_dict.keys()): val = state_dict.pop(key) state_dict[add_prefix + key] = val diff --git a/src/otx/core/model/entity/base.py b/src/otx/core/model/entity/base.py index e76f19abc3b..9b9acc9e513 100644 --- a/src/otx/core/model/entity/base.py +++ b/src/otx/core/model/entity/base.py @@ -32,6 +32,7 @@ from pathlib import Path import torch + from lightning import Trainer from otx.core.data.module import OTXDataModule @@ -52,6 +53,13 @@ def __init__(self, num_classes: int) -> None: self.classification_layers: dict[str, dict[str, Any]] = {} self.model = self._create_model() + def setup_callback(self, trainer: Trainer) -> None: + """Callback for setup OTX Model. + + Args: + trainer(Trainer): Lightning trainer contains OTXLitModule and OTXDatamodule. + """ + @property def label_info(self) -> LabelInfo: """Get this model label information.""" diff --git a/src/otx/core/model/module/base.py b/src/otx/core/model/module/base.py index 8dd4f0e5618..20edfc02795 100644 --- a/src/otx/core/model/module/base.py +++ b/src/otx/core/model/module/base.py @@ -111,6 +111,8 @@ def setup(self, stage: str) -> None: if self.torch_compile and stage == "fit": self.model = torch.compile(self.model) + self.model.setup_callback(self.trainer) + def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]]: """Choose what optimizers and learning-rate schedulers to use in your optimization. @@ -171,10 +173,9 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: load_state_pre_hook for smart weight loading will be registered. """ if is_ckpt_from_otx_v1(ckpt): - model_state_dict = ckpt["model"]["state_dict"] msg = "The checkpoint comes from OTXv1, checkpoint keys will be updated automatically." warnings.warn(msg, stacklevel=2) - state_dict = self.model.load_from_otx_v1_ckpt(model_state_dict) + state_dict = self.model.load_from_otx_v1_ckpt(ckpt) elif is_ckpt_for_finetuning(ckpt): state_dict = ckpt["state_dict"] else: diff --git a/tests/unit/algo/detection/test_ssd.py b/tests/unit/algo/detection/test_ssd.py new file mode 100644 index 00000000000..9a21a1a570d --- /dev/null +++ b/tests/unit/algo/detection/test_ssd.py @@ -0,0 +1,26 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Test of OTX SSD architecture.""" + +import pytest +from otx.algo.detection.ssd import SSD + + +class TestSSD: + @pytest.fixture() + def fxt_model(self) -> SSD: + return SSD(num_classes=3, variant="mobilenetv2") + + def test_save_and_load_anchors(self, fxt_model) -> None: + anchor_widths = fxt_model.model.bbox_head.anchor_generator.widths + anchor_heights = fxt_model.model.bbox_head.anchor_generator.heights + state_dict = fxt_model.state_dict() + assert anchor_widths == state_dict["model.model.anchors"]["widths"] + assert anchor_heights == state_dict["model.model.anchors"]["heights"] + + state_dict["model.model.anchors"]["widths"][0][0] = 40 + state_dict["model.model.anchors"]["heights"][0][0] = 50 + + fxt_model.load_state_dict(state_dict) + assert fxt_model.model.bbox_head.anchor_generator.widths[0][0] == 40 + assert fxt_model.model.bbox_head.anchor_generator.heights[0][0] == 50 diff --git a/tests/unit/algo/utils/test_support_otx_v1.py b/tests/unit/algo/utils/test_support_otx_v1.py index 40e0a313fbf..57eeca84ea2 100644 --- a/tests/unit/algo/utils/test_support_otx_v1.py +++ b/tests/unit/algo/utils/test_support_otx_v1.py @@ -22,9 +22,13 @@ def _check_ckpt_pairs(self, src_state_dict: dict, dst_state_dict: dict) -> None: @pytest.mark.parametrize("label_type", ["multiclass", "multilabel", "hlabel"]) def test_load_cls_effnet_b0_ckpt(self, label_type: str, fxt_random_tensor: torch.Tensor) -> None: src_state_dict = { - "features.weights": fxt_random_tensor, - "features.activ.weights": fxt_random_tensor, - "output.asl.weights": fxt_random_tensor, + "model": { + "state_dict": { + "features.weights": fxt_random_tensor, + "features.activ.weights": fxt_random_tensor, + "output.asl.weights": fxt_random_tensor, + }, + }, } if label_type != "hlabel": @@ -50,8 +54,12 @@ def test_load_cls_effnet_b0_ckpt(self, label_type: str, fxt_random_tensor: torch @pytest.mark.parametrize("label_type", ["multiclass", "multilabel", "hlabel"]) def test_load_cls_effnet_v2_ckpt(self, label_type: str, fxt_random_tensor: torch.Tensor) -> None: src_state_dict = { - "model.weights": fxt_random_tensor, - "model.classifier.weights": fxt_random_tensor, + "model": { + "state_dict": { + "model.weights": fxt_random_tensor, + "model.classifier.weights": fxt_random_tensor, + }, + }, } if label_type != "hlabel": @@ -75,10 +83,14 @@ def test_load_cls_effnet_v2_ckpt(self, label_type: str, fxt_random_tensor: torch @pytest.mark.parametrize("label_type", ["multiclass", "multilabel", "hlabel"]) def test_load_cls_mobilenet_v3_ckpt(self, label_type: str, fxt_random_tensor: torch.Tensor) -> None: src_state_dict = { - "model.weights": fxt_random_tensor, - "classifier.2.weights": fxt_random_tensor, - "classifier.4.weights": fxt_random_tensor, - "act.weights": fxt_random_tensor, + "model": { + "state_dict": { + "model.weights": fxt_random_tensor, + "classifier.2.weights": fxt_random_tensor, + "classifier.4.weights": fxt_random_tensor, + "act.weights": fxt_random_tensor, + }, + }, } if label_type == "multilabel": @@ -105,9 +117,13 @@ def test_load_cls_mobilenet_v3_ckpt(self, label_type: str, fxt_random_tensor: to def test_load_det_ckpt(self, fxt_random_tensor: torch.Tensor) -> None: src_state_dict = { - "model.weights": fxt_random_tensor, - "head.weights": fxt_random_tensor, - "ema_model.weights": fxt_random_tensor, + "model": { + "state_dict": { + "model.weights": fxt_random_tensor, + "head.weights": fxt_random_tensor, + "ema_model.weights": fxt_random_tensor, + }, + }, } dst_state_dict = { @@ -118,11 +134,34 @@ def test_load_det_ckpt(self, fxt_random_tensor: torch.Tensor) -> None: converted_state_dict = OTXv1Helper.load_det_ckpt(src_state_dict, add_prefix="model.model.") self._check_ckpt_pairs(converted_state_dict, dst_state_dict) + def test_load_ssd_ckpt(self, fxt_random_tensor: torch.Tensor) -> None: + src_state_dict = { + "model": { + "state_dict": { + "model.weights": fxt_random_tensor, + "head.weights": fxt_random_tensor, + "ema_model.weights": fxt_random_tensor, + }, + }, + "anchors": fxt_random_tensor, + } + dst_state_dict = { + "model.model.model.weights": fxt_random_tensor, + "model.model.head.weights": fxt_random_tensor, + "model.model.anchors": fxt_random_tensor, + } + converted_state_dict = OTXv1Helper.load_det_ckpt(src_state_dict, add_prefix="model.model.") + self._check_ckpt_pairs(converted_state_dict, dst_state_dict) + def test_load_iseg_ckpt(self, fxt_random_tensor: torch.Tensor) -> None: src_state_dict = { - "model.weights": fxt_random_tensor, - "head.weights": fxt_random_tensor, - "ema_model.weights": fxt_random_tensor, + "model": { + "state_dict": { + "model.weights": fxt_random_tensor, + "head.weights": fxt_random_tensor, + "ema_model.weights": fxt_random_tensor, + }, + }, } dst_state_dict = { @@ -135,9 +174,13 @@ def test_load_iseg_ckpt(self, fxt_random_tensor: torch.Tensor) -> None: def test_load_seg_segnext_ckpt(self, fxt_random_tensor: torch.Tensor) -> None: src_state_dict = { - "model.weights": fxt_random_tensor, - "head.weights": fxt_random_tensor, - "ham.bases.weights": fxt_random_tensor, + "model": { + "state_dict": { + "model.weights": fxt_random_tensor, + "head.weights": fxt_random_tensor, + "ham.bases.weights": fxt_random_tensor, + }, + }, } dst_state_dict = { @@ -150,9 +193,13 @@ def test_load_seg_segnext_ckpt(self, fxt_random_tensor: torch.Tensor) -> None: def test_load_seg_lite_hrnet_ckpt(self, fxt_random_tensor: torch.Tensor) -> None: src_state_dict = { - "model.weights": fxt_random_tensor, - "head.weights": fxt_random_tensor, - "decode_head.aggregator.projects.weights": fxt_random_tensor, + "model": { + "state_dict": { + "model.weights": fxt_random_tensor, + "head.weights": fxt_random_tensor, + "decode_head.aggregator.projects.weights": fxt_random_tensor, + }, + }, } dst_state_dict = { @@ -165,8 +212,12 @@ def test_load_seg_lite_hrnet_ckpt(self, fxt_random_tensor: torch.Tensor) -> None def test_load_action_ckpt(self, fxt_random_tensor: torch.Tensor) -> None: src_state_dict = { - "model.weights": fxt_random_tensor, - "head.weights": fxt_random_tensor, + "model": { + "state_dict": { + "model.weights": fxt_random_tensor, + "head.weights": fxt_random_tensor, + }, + }, } dst_state_dict = { From 079d3b58e1f64414fabd7bbc9ea3fffa4ac78327 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 15 Feb 2024 16:37:21 +0100 Subject: [PATCH 4/6] Add entities for storing explain results (#2913) * Add entities for storing explain results * Add feature vectors to explain results * Fix type checking in cls lit module * Add explain data structures for detection and iseg * Add XAI deta for sseg * Reduce duplication in cli test * Fix open subprocess fixture * Update types for seg predictions * Update types for det preditcions * Update types for iseg * Don't use hook as a storage for predicitons in otx explain * Add predictions with XAI for sseg and det torch models * Fix import * Fix unit tests for py 3.9 --- src/otx/algo/hooks/recording_forward_hook.py | 14 +- src/otx/algo/utils/xai_utils.py | 13 +- src/otx/core/data/entity/base.py | 22 +++ src/otx/core/data/entity/classification.py | 32 +++++ src/otx/core/data/entity/detection.py | 12 ++ .../core/data/entity/instance_segmentation.py | 12 +- src/otx/core/data/entity/segmentation.py | 19 ++- .../model/entity/action_classification.py | 6 +- src/otx/core/model/entity/action_detection.py | 6 +- src/otx/core/model/entity/base.py | 21 ++- src/otx/core/model/entity/classification.py | 135 ++++++++++++++++-- src/otx/core/model/entity/detection.py | 50 +++++-- .../model/entity/instance_segmentation.py | 33 ++++- src/otx/core/model/entity/segmentation.py | 42 +++++- src/otx/core/model/entity/visual_prompting.py | 15 +- src/otx/core/model/module/classification.py | 21 +-- src/otx/core/model/module/detection.py | 7 +- .../model/module/instance_segmentation.py | 7 +- src/otx/core/model/module/segmentation.py | 7 +- src/otx/core/utils/tile_merge.py | 7 +- .../integration/cli/test_export_inference.py | 78 ++++------ tests/integration/conftest.py | 2 +- 22 files changed, 429 insertions(+), 132 deletions(-) diff --git a/src/otx/algo/hooks/recording_forward_hook.py b/src/otx/algo/hooks/recording_forward_hook.py index 1df6ede0c13..0e3ee7e71a3 100644 --- a/src/otx/algo/hooks/recording_forward_hook.py +++ b/src/otx/algo/hooks/recording_forward_hook.py @@ -10,7 +10,7 @@ import numpy as np import torch -from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity +from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity, InstanceSegBatchPredEntityWithXAI if TYPE_CHECKING: from torch.utils.hooks import RemovableHandle @@ -409,7 +409,11 @@ def create_and_register_hook(cls, num_classes: int) -> BaseRecordingForwardHook: """Create this object and register it to the module forward hook.""" return cls(num_classes) - def func(self, preds: list[InstanceSegBatchPredEntity], _: int = -1) -> list[np.array]: + def func( + self, + preds: list[InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI], + _: int = -1, + ) -> list[np.array]: """Generate saliency maps from predicted masks by averaging and normalizing them per-class. Args: @@ -428,7 +432,11 @@ def func(self, preds: list[InstanceSegBatchPredEntity], _: int = -1) -> list[np. return batch_saliency_maps @classmethod - def average_and_normalize(cls, pred: InstanceSegBatchPredEntity, num_classes: int) -> np.array: + def average_and_normalize( + cls, + pred: InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI, + num_classes: int, + ) -> np.array: """Average and normalize masks in prediction per-class. Args: diff --git a/src/otx/algo/utils/xai_utils.py b/src/otx/algo/utils/xai_utils.py index 35aa2583351..d6d9da4b92f 100644 --- a/src/otx/algo/utils/xai_utils.py +++ b/src/otx/algo/utils/xai_utils.py @@ -4,12 +4,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import cv2 from otx.algo.hooks.recording_forward_hook import BaseRecordingForwardHook, MaskRCNNRecordingForwardHook from otx.core.config.explain import ExplainConfig +from otx.core.data.entity.base import OTXBatchPredEntityWithXAI +from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntityWithXAI if TYPE_CHECKING: from pathlib import Path @@ -18,19 +20,22 @@ def get_processed_saliency_maps( explain_hook: BaseRecordingForwardHook, explain_config: ExplainConfig, - predictions: list | None, + predictions: list[Any] | list[OTXBatchPredEntityWithXAI | InstanceSegBatchPredEntityWithXAI] | None, work_dir: Path | None, ) -> list: """Implement saliency map filtering and post-processing.""" # Optimize for memory <- TODO(negvet) - raw_saliency_maps = explain_hook.records + raw_saliency_maps: list = [] + + if predictions is not None: + raw_saliency_maps = predictions[0].saliency_maps if predictions is not None and isinstance(explain_hook, MaskRCNNRecordingForwardHook): # TODO: It is a temporary workaround. This function will be removed after we # noqa: TD003, TD002 # refactor XAI logics into `OTXModel.forward_explain()`. # Mask-RCNN case, receive saliency maps from predictions - raw_saliency_maps = explain_hook.func(predictions) + raw_saliency_maps = explain_hook.func(predictions) # type: ignore[arg-type] if work_dir: # Temporary saving saliency map for image 0, class 0 (for tests) diff --git a/src/otx/core/data/entity/base.py b/src/otx/core/data/entity/base.py index 7f97e023397..ca9e24b7fc3 100644 --- a/src/otx/core/data/entity/base.py +++ b/src/otx/core/data/entity/base.py @@ -506,6 +506,14 @@ class OTXPredEntity(OTXDataEntity): score: np.ndarray | Tensor +@dataclass +class OTXPredEntityWithXAI(OTXPredEntity): + """Data entity to represent model output prediction with explanations.""" + + saliency_map: np.ndarray | Tensor + feature_vector: np.ndarray | list + + T_OTXBatchDataEntity = TypeVar( "T_OTXBatchDataEntity", bound="OTXBatchDataEntity", @@ -590,6 +598,12 @@ def pin_memory(self: T_OTXBatchDataEntity) -> T_OTXBatchDataEntity: ) +T_OTXBatchPredEntityWithXAI = TypeVar( + "T_OTXBatchPredEntityWithXAI", + bound="OTXBatchPredEntityWithXAI", +) + + @dataclass class OTXBatchPredEntity(OTXBatchDataEntity): """Data entity to represent model output predictions.""" @@ -597,6 +611,14 @@ class OTXBatchPredEntity(OTXBatchDataEntity): scores: list[np.ndarray] | list[Tensor] +@dataclass +class OTXBatchPredEntityWithXAI(OTXBatchPredEntity): + """Data entity to represent model output predictions with explanations.""" + + saliency_maps: list[np.ndarray] | list[Tensor] + feature_vectors: list[np.ndarray] | list[Tensor] + + T_OTXBatchLossEntity = TypeVar( "T_OTXBatchLossEntity", bound="OTXBatchLossEntity", diff --git a/src/otx/core/data/entity/classification.py b/src/otx/core/data/entity/classification.py index 6cbd7525a0f..2367cc127ca 100644 --- a/src/otx/core/data/entity/classification.py +++ b/src/otx/core/data/entity/classification.py @@ -15,8 +15,10 @@ from otx.core.data.entity.base import ( OTXBatchDataEntity, OTXBatchPredEntity, + OTXBatchPredEntityWithXAI, OTXDataEntity, OTXPredEntity, + OTXPredEntityWithXAI, ) from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType @@ -47,6 +49,11 @@ class MulticlassClsPredEntity(MulticlassClsDataEntity, OTXPredEntity): """Data entity to represent the multi-class classification model output prediction.""" +@dataclass +class MulticlassClsPredEntityWithXAI(MulticlassClsDataEntity, OTXPredEntityWithXAI): + """Data entity to represent the multi-class classification model output prediction with explanations.""" + + @dataclass class MulticlassClsBatchDataEntity(OTXBatchDataEntity[MulticlassClsDataEntity]): """Data entity for multi-class classification task. @@ -91,6 +98,11 @@ class MulticlassClsBatchPredEntity(MulticlassClsBatchDataEntity, OTXBatchPredEnt """Data entity to represent model output predictions for multi-class classification task.""" +@dataclass +class MulticlassClsBatchPredEntityWithXAI(MulticlassClsBatchDataEntity, OTXBatchPredEntityWithXAI): + """Data entity to represent model output predictions for multi-class classification task with explanations.""" + + @register_pytree_node @dataclass class MultilabelClsDataEntity(OTXDataEntity): @@ -112,6 +124,11 @@ class MultilabelClsPredEntity(MultilabelClsDataEntity, OTXPredEntity): """Data entity to represent the multi-label classification model output prediction.""" +@dataclass +class MultilabelClsPredEntityWithXAI(MultilabelClsDataEntity, OTXPredEntityWithXAI): + """Data entity to represent the multi-label classification model output prediction with explanations.""" + + @dataclass class MultilabelClsBatchDataEntity(OTXBatchDataEntity[MultilabelClsDataEntity]): """Data entity for multi-label classification task. @@ -156,6 +173,11 @@ class MultilabelClsBatchPredEntity(MultilabelClsBatchDataEntity, OTXBatchPredEnt """Data entity to represent model output predictions for multi-label classification task.""" +@dataclass +class MultilabelClsBatchPredEntityWithXAI(MultilabelClsBatchDataEntity, OTXBatchPredEntityWithXAI): + """Data entity to represent model output predictions for multi-label classification task with explanations.""" + + @dataclass class HLabelInfo: """The label information represents the hierarchy. @@ -337,6 +359,11 @@ class HlabelClsPredEntity(HlabelClsDataEntity, OTXPredEntity): """Data entity to represent the H-label classification model output prediction.""" +@dataclass +class HlabelClsPredEntityWithXAI(HlabelClsDataEntity, OTXPredEntityWithXAI): + """Data entity to represent the H-label classification model output prediction with explanation.""" + + @dataclass class HlabelClsBatchDataEntity(OTXBatchDataEntity[HlabelClsDataEntity]): """Data entity for H-label classification task. @@ -374,3 +401,8 @@ def collate_fn( @dataclass class HlabelClsBatchPredEntity(HlabelClsBatchDataEntity, OTXBatchPredEntity): """Data entity to represent model output predictions for H-label classification task.""" + + +@dataclass +class HlabelClsBatchPredEntityWithXAI(HlabelClsBatchDataEntity, OTXBatchPredEntityWithXAI): + """Data entity to represent model output predictions for H-label classification task with explanations.""" diff --git a/src/otx/core/data/entity/detection.py b/src/otx/core/data/entity/detection.py index 093d42dcf11..3346fcf822c 100644 --- a/src/otx/core/data/entity/detection.py +++ b/src/otx/core/data/entity/detection.py @@ -13,8 +13,10 @@ from otx.core.data.entity.base import ( OTXBatchDataEntity, OTXBatchPredEntity, + OTXBatchPredEntityWithXAI, OTXDataEntity, OTXPredEntity, + OTXPredEntityWithXAI, ) from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType @@ -47,6 +49,11 @@ class DetPredEntity(DetDataEntity, OTXPredEntity): """Data entity to represent the detection model output prediction.""" +@dataclass +class DetPredEntityWithXAI(DetDataEntity, OTXPredEntityWithXAI): + """Data entity to represent the detection model output prediction with explanations.""" + + @dataclass class DetBatchDataEntity(OTXBatchDataEntity[DetDataEntity]): """Data entity for detection task. @@ -87,3 +94,8 @@ def pin_memory(self) -> DetBatchDataEntity: @dataclass class DetBatchPredEntity(DetBatchDataEntity, OTXBatchPredEntity): """Data entity to represent model output predictions for detection task.""" + + +@dataclass +class DetBatchPredEntityWithXAI(DetBatchDataEntity, OTXBatchPredEntityWithXAI): + """Data entity to represent model output predictions for detection task with explanations.""" diff --git a/src/otx/core/data/entity/instance_segmentation.py b/src/otx/core/data/entity/instance_segmentation.py index 8b843ae3950..316a263c0f4 100644 --- a/src/otx/core/data/entity/instance_segmentation.py +++ b/src/otx/core/data/entity/instance_segmentation.py @@ -12,7 +12,7 @@ from otx.core.types.task import OTXTaskType -from .base import OTXBatchDataEntity, OTXBatchPredEntity, OTXDataEntity, OTXPredEntity +from .base import OTXBatchDataEntity, OTXBatchPredEntity, OTXBatchPredEntityWithXAI, OTXDataEntity, OTXPredEntity if TYPE_CHECKING: from datumaro import Polygon @@ -46,6 +46,11 @@ class InstanceSegPredEntity(InstanceSegDataEntity, OTXPredEntity): """Data entity to represent the detection model output prediction.""" +@dataclass +class InstanceSegPredEntityWithXAI(InstanceSegDataEntity, OTXBatchPredEntityWithXAI): + """Data entity to represent the detection model output prediction with explanation.""" + + @dataclass class InstanceSegBatchDataEntity(OTXBatchDataEntity[InstanceSegDataEntity]): """Batch entity for InstanceSegDataEntity. @@ -100,3 +105,8 @@ def pin_memory(self) -> InstanceSegBatchDataEntity: @dataclass class InstanceSegBatchPredEntity(InstanceSegBatchDataEntity, OTXBatchPredEntity): """Data entity to represent model output predictions for instance segmentation task.""" + + +@dataclass +class InstanceSegBatchPredEntityWithXAI(InstanceSegBatchDataEntity, OTXBatchPredEntityWithXAI): + """Data entity to represent model output predictions for instance segmentation task with explanations.""" diff --git a/src/otx/core/data/entity/segmentation.py b/src/otx/core/data/entity/segmentation.py index 49843c03e91..91a96e88ac9 100644 --- a/src/otx/core/data/entity/segmentation.py +++ b/src/otx/core/data/entity/segmentation.py @@ -9,7 +9,14 @@ from torchvision import tv_tensors -from otx.core.data.entity.base import OTXBatchDataEntity, OTXBatchPredEntity, OTXDataEntity, OTXPredEntity +from otx.core.data.entity.base import ( + OTXBatchDataEntity, + OTXBatchPredEntity, + OTXBatchPredEntityWithXAI, + OTXDataEntity, + OTXPredEntity, + OTXPredEntityWithXAI, +) from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType @@ -35,6 +42,11 @@ class SegPredEntity(SegDataEntity, OTXPredEntity): """Data entity to represent the segmentation model output prediction.""" +@dataclass +class SegPredEntityWithXAI(SegDataEntity, OTXPredEntityWithXAI): + """Data entity to represent the segmentation model output prediction with explanation.""" + + @dataclass class SegBatchDataEntity(OTXBatchDataEntity[SegDataEntity]): """Data entity for segmentation task. @@ -70,3 +82,8 @@ def pin_memory(self) -> SegBatchDataEntity: @dataclass class SegBatchPredEntity(SegBatchDataEntity, OTXBatchPredEntity): """Data entity to represent model output predictions for segmentation task.""" + + +@dataclass +class SegBatchPredEntityWithXAI(SegBatchDataEntity, OTXBatchPredEntityWithXAI): + """Data entity to represent model output predictions for segmentation task with explanations.""" diff --git a/src/otx/core/model/entity/action_classification.py b/src/otx/core/model/entity/action_classification.py index aab5665be5a..035727b7085 100644 --- a/src/otx/core/model/entity/action_classification.py +++ b/src/otx/core/model/entity/action_classification.py @@ -11,7 +11,7 @@ ActionClsBatchDataEntity, ActionClsBatchPredEntity, ) -from otx.core.data.entity.base import OTXBatchLossEntity +from otx.core.data.entity.base import OTXBatchLossEntity, T_OTXBatchPredEntityWithXAI from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.model.entity.base import OTXModel from otx.core.utils.config import inplace_num_classes @@ -21,7 +21,9 @@ from torch import nn -class OTXActionClsModel(OTXModel[ActionClsBatchDataEntity, ActionClsBatchPredEntity, T_OTXTileBatchDataEntity]): +class OTXActionClsModel( + OTXModel[ActionClsBatchDataEntity, ActionClsBatchPredEntity, T_OTXBatchPredEntityWithXAI, T_OTXTileBatchDataEntity], +): """Base class for the action classification models used in OTX.""" diff --git a/src/otx/core/model/entity/action_detection.py b/src/otx/core/model/entity/action_detection.py index 4c2132a408f..ee244b0f778 100644 --- a/src/otx/core/model/entity/action_detection.py +++ b/src/otx/core/model/entity/action_detection.py @@ -10,7 +10,7 @@ from torchvision import tv_tensors from otx.core.data.entity.action_detection import ActionDetBatchDataEntity, ActionDetBatchPredEntity -from otx.core.data.entity.base import OTXBatchLossEntity +from otx.core.data.entity.base import OTXBatchLossEntity, T_OTXBatchPredEntityWithXAI from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.model.entity.base import OTXModel from otx.core.utils.config import inplace_num_classes @@ -20,7 +20,9 @@ from torch import nn -class OTXActionDetModel(OTXModel[ActionDetBatchDataEntity, ActionDetBatchPredEntity, T_OTXTileBatchDataEntity]): +class OTXActionDetModel( + OTXModel[ActionDetBatchDataEntity, ActionDetBatchPredEntity, T_OTXBatchPredEntityWithXAI, T_OTXTileBatchDataEntity], +): """Base class for the action detection models used in OTX.""" diff --git a/src/otx/core/model/entity/base.py b/src/otx/core/model/entity/base.py index 9b9acc9e513..084e44e021d 100644 --- a/src/otx/core/model/entity/base.py +++ b/src/otx/core/model/entity/base.py @@ -21,6 +21,7 @@ OTXBatchLossEntity, T_OTXBatchDataEntity, T_OTXBatchPredEntity, + T_OTXBatchPredEntityWithXAI, ) from otx.core.data.entity.tile import OTXTileBatchDataEntity, T_OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter @@ -37,7 +38,10 @@ from otx.core.data.module import OTXDataModule -class OTXModel(nn.Module, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXTileBatchDataEntity]): +class OTXModel( + nn.Module, + Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXBatchPredEntityWithXAI, T_OTXTileBatchDataEntity], +): """Base class for the models used in OTX. Args: @@ -103,14 +107,14 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: T_OTXBatchDataEntity, - ) -> T_OTXBatchPredEntity | OTXBatchLossEntity: + ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: """Customize OTX output batch data entity if needed for model.""" raise NotImplementedError def forward( self, inputs: T_OTXBatchDataEntity, - ) -> T_OTXBatchPredEntity | OTXBatchLossEntity: + ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: """Model forward function.""" # If customize_inputs is overridden if isinstance(inputs, OTXTileBatchDataEntity): @@ -128,7 +132,10 @@ def forward( else outputs ) - def forward_tiles(self, inputs: T_OTXTileBatchDataEntity) -> T_OTXBatchPredEntity | OTXBatchLossEntity: + def forward_tiles( + self, + inputs: T_OTXTileBatchDataEntity, + ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: """Model forward function for tile task.""" raise NotImplementedError @@ -272,7 +279,7 @@ def _optimization_config(self) -> dict[str, str]: return {} -class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]): +class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXBatchPredEntityWithXAI]): """Base class for the OpenVINO model. This is a base class representing interface for interacting with OpenVINO @@ -328,14 +335,14 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: T_OTXBatchDataEntity, - ) -> T_OTXBatchPredEntity | OTXBatchLossEntity: + ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: """Customize OTX output batch data entity if needed for model.""" raise NotImplementedError def forward( self, inputs: T_OTXBatchDataEntity, - ) -> T_OTXBatchPredEntity | OTXBatchLossEntity: + ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: """Model forward function.""" def _callback(result: NamedTuple, idx: int) -> None: diff --git a/src/otx/core/model/entity/classification.py b/src/otx/core/model/entity/classification.py index f3d5f80c98b..4526632162e 100644 --- a/src/otx/core/model/entity/classification.py +++ b/src/otx/core/model/entity/classification.py @@ -5,6 +5,7 @@ from __future__ import annotations +import copy import json from typing import TYPE_CHECKING, Any @@ -12,14 +13,22 @@ import torch from otx.core.data.dataset.classification import HLabelMetaInfo -from otx.core.data.entity.base import OTXBatchLossEntity, T_OTXBatchDataEntity, T_OTXBatchPredEntity +from otx.core.data.entity.base import ( + OTXBatchLossEntity, + T_OTXBatchDataEntity, + T_OTXBatchPredEntity, + T_OTXBatchPredEntityWithXAI, +) from otx.core.data.entity.classification import ( HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, + HlabelClsBatchPredEntityWithXAI, MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity, + MulticlassClsBatchPredEntityWithXAI, MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, + MultilabelClsBatchPredEntityWithXAI, ) from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter @@ -37,7 +46,9 @@ from otx.core.data.entity.classification import HLabelInfo -class ExplainableOTXClsModel(OTXModel[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXTileBatchDataEntity]): +class ExplainableOTXClsModel( + OTXModel[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXBatchPredEntityWithXAI, T_OTXTileBatchDataEntity], +): """OTX classification model which can attach a XAI hook.""" @property @@ -85,7 +96,12 @@ def reset_explain_hook(self) -> None: class OTXMulticlassClsModel( - ExplainableOTXClsModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity, T_OTXTileBatchDataEntity], + ExplainableOTXClsModel[ + MulticlassClsBatchDataEntity, + MulticlassClsBatchPredEntity, + MulticlassClsBatchPredEntityWithXAI, + T_OTXTileBatchDataEntity, + ], ): """Base class for the classification models used in OTX.""" @@ -158,7 +174,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: + ) -> MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI | OTXBatchLossEntity: from mmpretrain.structures import DataSample if self.training: @@ -180,6 +196,20 @@ def _customize_outputs( scores.append(output.pred_score) labels.append(output.pred_label) + if hasattr(self, "explain_hook"): + hook_records = self.explain_hook.records + explain_results = copy.deepcopy(hook_records[-len(outputs) :]) + + return MulticlassClsBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=labels, + saliency_maps=explain_results, + feature_vectors=[], + ) + return MulticlassClsBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -213,7 +243,12 @@ def _exporter(self) -> OTXModelExporter: class OTXMultilabelClsModel( - ExplainableOTXClsModel[MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, T_OTXTileBatchDataEntity], + ExplainableOTXClsModel[ + MultilabelClsBatchDataEntity, + MultilabelClsBatchPredEntity, + MultilabelClsBatchPredEntityWithXAI, + T_OTXTileBatchDataEntity, + ], ): """Multi-label classification models used in OTX.""" @@ -289,7 +324,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: MultilabelClsBatchDataEntity, - ) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity: + ) -> MultilabelClsBatchPredEntity | MultilabelClsBatchPredEntityWithXAI | OTXBatchLossEntity: from mmpretrain.structures import DataSample if self.training: @@ -311,6 +346,20 @@ def _customize_outputs( scores.append(output.pred_score) labels.append(output.pred_label) + if hasattr(self, "explain_hook"): + hook_records = self.explain_hook.records + explain_results = copy.deepcopy(hook_records[-len(outputs) :]) + + return MultilabelClsBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=labels, + saliency_maps=explain_results, + feature_vectors=[], + ) + return MultilabelClsBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -340,7 +389,12 @@ def _exporter(self) -> OTXModelExporter: class OTXHlabelClsModel( - ExplainableOTXClsModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, T_OTXTileBatchDataEntity], + ExplainableOTXClsModel[ + HlabelClsBatchDataEntity, + HlabelClsBatchPredEntity, + HlabelClsBatchPredEntityWithXAI, + T_OTXTileBatchDataEntity, + ], ): """H-label classification models used in OTX.""" @@ -440,7 +494,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: HlabelClsBatchDataEntity, - ) -> HlabelClsBatchPredEntity | OTXBatchLossEntity: + ) -> HlabelClsBatchPredEntity | HlabelClsBatchPredEntityWithXAI | OTXBatchLossEntity: from mmpretrain.structures import DataSample if self.training: @@ -462,6 +516,20 @@ def _customize_outputs( scores.append(output.pred_score) labels.append(output.pred_label) + if hasattr(self, "explain_hook"): + hook_records = self.explain_hook.records + explain_results = copy.deepcopy(hook_records[-len(outputs) :]) + + return HlabelClsBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=labels, + saliency_maps=explain_results, + feature_vectors=[], + ) + return HlabelClsBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -491,7 +559,7 @@ def _exporter(self) -> OTXModelExporter: class OVMulticlassClassificationModel( - OVModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity], + OVModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity, MulticlassClsBatchPredEntityWithXAI], ): """Classification model compatible for OpenVINO IR inference. @@ -523,10 +591,23 @@ def _customize_outputs( self, outputs: list[ClassificationResult], inputs: MulticlassClsBatchDataEntity, - ) -> MulticlassClsBatchPredEntity: + ) -> MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI: pred_labels = [torch.tensor(out.top_labels[0][0], dtype=torch.long) for out in outputs] pred_scores = [torch.tensor(out.top_labels[0][2]) for out in outputs] + if outputs and outputs[0].saliency_map.size != 0: + predicted_s_maps = [out.saliency_map for out in outputs] + predicted_f_vectors = [out.feature_vector for out in outputs] + return MulticlassClsBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=pred_scores, + labels=pred_labels, + saliency_maps=predicted_s_maps, + feature_vectors=predicted_f_vectors, + ) + return MulticlassClsBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -537,7 +618,7 @@ def _customize_outputs( class OVHlabelClassificationModel( - OVModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity], + OVModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, HlabelClsBatchPredEntityWithXAI], ): """Hierarchical classification model compatible for OpenVINO IR inference. @@ -585,7 +666,7 @@ def _customize_outputs( self, outputs: list[ClassificationResult], inputs: HlabelClsBatchDataEntity, - ) -> HlabelClsBatchPredEntity: + ) -> HlabelClsBatchPredEntity | HlabelClsBatchPredEntityWithXAI: all_pred_labels = [] all_pred_scores = [] for output in outputs: @@ -614,6 +695,19 @@ def _customize_outputs( all_pred_labels.append(torch.tensor(predicted_labels, dtype=torch.long)) all_pred_scores.append(torch.tensor(predicted_scores)) + if outputs and outputs[0].saliency_map.size != 1: + predicted_s_maps = [out.saliency_map for out in outputs] + predicted_f_vectors = [out.feature_vector for out in outputs] + return HlabelClsBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=all_pred_scores, + labels=all_pred_labels, + saliency_maps=predicted_s_maps, + feature_vectors=predicted_f_vectors, + ) + return HlabelClsBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -624,7 +718,7 @@ def _customize_outputs( class OVMultilabelClassificationModel( - OVModel[MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity], + OVModel[MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, MultilabelClsBatchPredEntityWithXAI], ): """Multilabel classification model compatible for OpenVINO IR inference. @@ -658,9 +752,22 @@ def _customize_outputs( self, outputs: list[ClassificationResult], inputs: MultilabelClsBatchDataEntity, - ) -> MultilabelClsBatchPredEntity: + ) -> MultilabelClsBatchPredEntity | MultilabelClsBatchPredEntityWithXAI: pred_scores = [torch.tensor([top_label[2] for top_label in out.top_labels]) for out in outputs] + if outputs and outputs[0].saliency_map.size != 1: + predicted_s_maps = [out.saliency_map for out in outputs] + predicted_f_vectors = [out.feature_vector for out in outputs] + return MultilabelClsBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=pred_scores, + labels=[], + saliency_maps=predicted_s_maps, + feature_vectors=predicted_f_vectors, + ) + return MultilabelClsBatchPredEntity( batch_size=len(outputs), images=inputs.images, diff --git a/src/otx/core/model/entity/detection.py b/src/otx/core/model/entity/detection.py index 0925f18cc8b..d33fb74324d 100644 --- a/src/otx/core/model/entity/detection.py +++ b/src/otx/core/model/entity/detection.py @@ -5,7 +5,7 @@ from __future__ import annotations -from copy import copy +import copy from typing import TYPE_CHECKING, Any import torch @@ -13,7 +13,7 @@ from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity +from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity, DetBatchPredEntityWithXAI from otx.core.data.entity.tile import TileBatchDetDataEntity from otx.core.model.entity.base import OTXModel, OVModel from otx.core.utils.config import inplace_num_classes @@ -29,14 +29,16 @@ from otx.core.exporter.base import OTXModelExporter -class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity, TileBatchDetDataEntity]): +class OTXDetectionModel( + OTXModel[DetBatchDataEntity, DetBatchPredEntity, DetBatchPredEntityWithXAI, TileBatchDetDataEntity], +): """Base class for the detection models used in OTX.""" def __init__(self, *arg, **kwargs) -> None: super().__init__(*arg, **kwargs) self.tile_config = TileConfig() - def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: + def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity | DetBatchPredEntityWithXAI: """Unpack detection tiles. Args: @@ -45,7 +47,7 @@ def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: Returns: DetBatchPredEntity: Merged detection prediction. """ - tile_preds: list[DetBatchPredEntity] = [] + tile_preds: list[DetBatchPredEntity | DetBatchPredEntityWithXAI] = [] tile_attrs: list[list[dict[str, int | str]]] = [] merger = DetectionTileMerge( inputs.imgs_info, @@ -171,7 +173,7 @@ def _export_parameters(self) -> dict[str, Any]: export_params = super()._export_parameters export_params.update(get_mean_std_from_data_processing(self.config)) export_params["model_builder"] = self._create_model - export_params["model_cfg"] = copy(self.config) + export_params["model_cfg"] = copy.copy(self.config) export_params["test_pipeline"] = self._make_fake_test_pipeline() return export_params @@ -233,7 +235,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: DetBatchDataEntity, - ) -> DetBatchPredEntity | OTXBatchLossEntity: + ) -> DetBatchPredEntity | DetBatchPredEntityWithXAI | OTXBatchLossEntity: from mmdet.structures import DetDataSample if self.training: @@ -268,6 +270,21 @@ def _customize_outputs( ) labels.append(output.pred_instances.labels) + if hasattr(self, "explain_hook"): + hook_records = self.explain_hook.records + explain_results = copy.deepcopy(hook_records[-len(outputs) :]) + + return DetBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + bboxes=bboxes, + labels=labels, + saliency_maps=explain_results, + feature_vectors=[], + ) + return DetBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -285,7 +302,7 @@ def _exporter(self) -> OTXModelExporter: return MMdeployExporter(**self._export_parameters) -class OVDetectionModel(OVModel[DetBatchDataEntity, DetBatchPredEntity]): +class OVDetectionModel(OVModel[DetBatchDataEntity, DetBatchPredEntity, DetBatchPredEntityWithXAI]): """Object detection model compatible for OpenVINO IR inference. It can consume OpenVINO IR model path or model name from Intel OMZ repository @@ -316,7 +333,7 @@ def _customize_outputs( self, outputs: list[DetectionResult], inputs: DetBatchDataEntity, - ) -> DetBatchPredEntity | OTXBatchLossEntity: + ) -> DetBatchPredEntity | DetBatchPredEntityWithXAI | OTXBatchLossEntity: # add label index bboxes = [] scores = [] @@ -342,6 +359,21 @@ def _customize_outputs( else: labels.append(torch.tensor([output.id for output in output_objects])) + if outputs and outputs[0].saliency_map.size != 1: + predicted_s_maps = [out.saliency_map for out in outputs] + predicted_f_vectors = [out.feature_vector for out in outputs] + + return DetBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + bboxes=bboxes, + labels=labels, + saliency_maps=predicted_s_maps, + feature_vectors=predicted_f_vectors, + ) + return DetBatchPredEntity( batch_size=len(outputs), images=inputs.images, diff --git a/src/otx/core/model/entity/instance_segmentation.py b/src/otx/core/model/entity/instance_segmentation.py index c8dd1818b36..138a46a306a 100644 --- a/src/otx/core/model/entity/instance_segmentation.py +++ b/src/otx/core/model/entity/instance_segmentation.py @@ -17,6 +17,7 @@ from otx.core.data.entity.instance_segmentation import ( InstanceSegBatchDataEntity, InstanceSegBatchPredEntity, + InstanceSegBatchPredEntityWithXAI, ) from otx.core.data.entity.tile import TileBatchInstSegDataEntity from otx.core.exporter.base import OTXModelExporter @@ -33,7 +34,12 @@ class OTXInstanceSegModel( - OTXModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity, TileBatchInstSegDataEntity], + OTXModel[ + InstanceSegBatchDataEntity, + InstanceSegBatchPredEntity, + InstanceSegBatchPredEntityWithXAI, + TileBatchInstSegDataEntity, + ], ): """Base class for the Instance Segmentation models used in OTX.""" @@ -50,7 +56,7 @@ def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchP Returns: InstanceSegBatchPredEntity: Merged instance segmentation prediction. """ - tile_preds: list[InstanceSegBatchPredEntity] = [] + tile_preds: list[InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI] = [] tile_attrs: list[list[dict[str, int | str]]] = [] merger = InstanceSegTileMerge( inputs.imgs_info, @@ -222,7 +228,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: InstanceSegBatchDataEntity, - ) -> InstanceSegBatchPredEntity | OTXBatchLossEntity: + ) -> InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI | OTXBatchLossEntity: from mmdet.structures import DetDataSample if self.training: @@ -281,7 +287,7 @@ def _exporter(self) -> OTXModelExporter: class OVInstanceSegmentationModel( - OVModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity], + OVModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity, InstanceSegBatchPredEntityWithXAI], ): """Instance segmentation model compatible for OpenVINO IR inference. @@ -313,7 +319,7 @@ def _customize_outputs( self, outputs: list[InstanceSegmentationResult], inputs: InstanceSegBatchDataEntity, - ) -> InstanceSegBatchPredEntity | OTXBatchLossEntity: + ) -> InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI | OTXBatchLossEntity: # add label index bboxes = [] scores = [] @@ -336,6 +342,23 @@ def _customize_outputs( masks.append(torch.tensor([output.mask for output in output_objects])) labels.append(torch.tensor([output.id - 1 for output in output_objects])) + if outputs and outputs[0].saliency_map: + predicted_s_maps = [out.saliency_map for out in outputs] + predicted_f_vectors = [out.feature_vector for out in outputs] + + return InstanceSegBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + bboxes=bboxes, + masks=masks, + polygons=[], + labels=labels, + saliency_maps=predicted_s_maps, + feature_vectors=predicted_f_vectors, + ) + return InstanceSegBatchPredEntity( batch_size=len(outputs), images=inputs.images, diff --git a/src/otx/core/model/entity/segmentation.py b/src/otx/core/model/entity/segmentation.py index 79c2d5d63e1..c313c8a42fb 100644 --- a/src/otx/core/model/entity/segmentation.py +++ b/src/otx/core/model/entity/segmentation.py @@ -5,12 +5,13 @@ from __future__ import annotations +import copy from typing import TYPE_CHECKING, Any from torchvision import tv_tensors from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity +from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity, SegBatchPredEntityWithXAI from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter @@ -25,7 +26,9 @@ from torch import nn -class OTXSegmentationModel(OTXModel[SegBatchDataEntity, SegBatchPredEntity, T_OTXTileBatchDataEntity]): +class OTXSegmentationModel( + OTXModel[SegBatchDataEntity, SegBatchPredEntity, SegBatchPredEntityWithXAI, T_OTXTileBatchDataEntity], +): """Base class for the detection models used in OTX.""" @property @@ -104,7 +107,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: SegBatchDataEntity, - ) -> SegBatchPredEntity | OTXBatchLossEntity: + ) -> SegBatchPredEntity | SegBatchPredEntityWithXAI | OTXBatchLossEntity: from mmseg.structures import SegDataSample if self.training: @@ -124,6 +127,20 @@ def _customize_outputs( raise TypeError(output) masks.append(output.pred_sem_seg.data) + if hasattr(self, "explain_hook"): + hook_records = self.explain_hook.records + explain_results = copy.deepcopy(hook_records[-len(outputs) :]) + + return SegBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=[], + masks=masks, + saliency_maps=explain_results, + feature_vectors=[], + ) + return SegBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -152,7 +169,7 @@ def _exporter(self) -> OTXModelExporter: return OTXNativeModelExporter(**self._export_parameters) -class OVSegmentationModel(OVModel[SegBatchDataEntity, SegBatchPredEntity]): +class OVSegmentationModel(OVModel[SegBatchDataEntity, SegBatchPredEntity, SegBatchPredEntityWithXAI]): """Semantic segmentation model compatible for OpenVINO IR inference. It can consume OpenVINO IR model path or model name from Intel OMZ repository @@ -183,11 +200,22 @@ def _customize_outputs( self, outputs: list[ImageResultWithSoftPrediction], inputs: SegBatchDataEntity, - ) -> SegBatchPredEntity | OTXBatchLossEntity: - # add label index + ) -> SegBatchPredEntity | SegBatchPredEntityWithXAI | OTXBatchLossEntity: + if outputs and outputs[0].saliency_map.size != 1: + predicted_s_maps = [out.saliency_map for out in outputs] + predicted_f_vectors = [out.feature_vector for out in outputs] + return SegBatchPredEntityWithXAI( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=[], + masks=[tv_tensors.Mask(mask.resultImage) for mask in outputs], + saliency_maps=predicted_s_maps, + feature_vectors=predicted_f_vectors, + ) return SegBatchPredEntity( - batch_size=1, + batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, scores=[], diff --git a/src/otx/core/model/entity/visual_prompting.py b/src/otx/core/model/entity/visual_prompting.py index 1d56df65f95..8fccbe2a3ff 100644 --- a/src/otx/core/model/entity/visual_prompting.py +++ b/src/otx/core/model/entity/visual_prompting.py @@ -7,6 +7,7 @@ from typing import Any +from otx.core.data.entity.base import T_OTXBatchPredEntityWithXAI from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.data.entity.visual_prompting import ( VisualPromptingBatchDataEntity, @@ -18,7 +19,12 @@ class OTXVisualPromptingModel( - OTXModel[VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity, T_OTXTileBatchDataEntity], + OTXModel[ + VisualPromptingBatchDataEntity, + VisualPromptingBatchPredEntity, + T_OTXBatchPredEntityWithXAI, + T_OTXTileBatchDataEntity, + ], ): """Base class for the visual prompting models used in OTX.""" @@ -27,7 +33,12 @@ def __init__(self, num_classes: int = 0) -> None: class OTXZeroShotVisualPromptingModel( - OTXModel[ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity, T_OTXTileBatchDataEntity], + OTXModel[ + ZeroShotVisualPromptingBatchDataEntity, + ZeroShotVisualPromptingBatchPredEntity, + T_OTXBatchPredEntityWithXAI, + T_OTXTileBatchDataEntity, + ], ): """Base class for the zero-shot visual prompting models used in OTX.""" diff --git a/src/otx/core/model/module/classification.py b/src/otx/core/model/module/classification.py index fd1f9cf5431..8722befec7c 100644 --- a/src/otx/core/model/module/classification.py +++ b/src/otx/core/model/module/classification.py @@ -16,10 +16,13 @@ from otx.core.data.entity.classification import ( HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, + HlabelClsBatchPredEntityWithXAI, MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity, + MulticlassClsBatchPredEntityWithXAI, MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, + MultilabelClsBatchPredEntityWithXAI, ) from otx.core.model.entity.classification import OTXHlabelClsModel, OTXMulticlassClsModel, OTXMultilabelClsModel from otx.core.model.module.base import OTXLitModule @@ -83,7 +86,7 @@ def validation_step(self, inputs: MulticlassClsBatchDataEntity, batch_idx: int) """ preds = self.model(inputs) - if not isinstance(preds, MulticlassClsBatchPredEntity): + if not isinstance(preds, (MulticlassClsBatchPredEntity, MulticlassClsBatchPredEntityWithXAI)): raise TypeError(preds) self.val_metric.update( **self._convert_pred_entity_to_compute_metric(preds, inputs), @@ -91,7 +94,7 @@ def validation_step(self, inputs: MulticlassClsBatchDataEntity, batch_idx: int) def _convert_pred_entity_to_compute_metric( self, - preds: MulticlassClsBatchPredEntity, + preds: MulticlassClsBatchPredEntity | MulticlassClsBatchPredEntityWithXAI, inputs: MulticlassClsBatchDataEntity, ) -> dict[str, list[dict[str, Tensor]]]: pred = torch.tensor(preds.labels) @@ -110,7 +113,7 @@ def test_step(self, inputs: MulticlassClsBatchDataEntity, batch_idx: int) -> Non """ preds = self.model(inputs) - if not isinstance(preds, MulticlassClsBatchPredEntity): + if not isinstance(preds, (MulticlassClsBatchPredEntity, MulticlassClsBatchPredEntityWithXAI)): raise TypeError(preds) self.test_metric.update( @@ -172,7 +175,7 @@ def validation_step(self, inputs: MultilabelClsBatchDataEntity, batch_idx: int) """ preds = self.model(inputs) - if not isinstance(preds, MultilabelClsBatchPredEntity): + if not isinstance(preds, (MultilabelClsBatchPredEntity, MultilabelClsBatchPredEntityWithXAI)): raise TypeError(preds) self.val_metric.update( @@ -181,7 +184,7 @@ def validation_step(self, inputs: MultilabelClsBatchDataEntity, batch_idx: int) def _convert_pred_entity_to_compute_metric( self, - preds: MultilabelClsBatchPredEntity, + preds: MultilabelClsBatchPredEntity | MultilabelClsBatchPredEntityWithXAI, inputs: MultilabelClsBatchDataEntity, ) -> dict[str, list[dict[str, Tensor]]]: return { @@ -198,7 +201,7 @@ def test_step(self, inputs: MultilabelClsBatchDataEntity, batch_idx: int) -> Non """ preds = self.model(inputs) - if not isinstance(preds, MultilabelClsBatchPredEntity): + if not isinstance(preds, (MultilabelClsBatchPredEntity, MultilabelClsBatchPredEntityWithXAI)): raise TypeError(preds) self.test_metric.update( @@ -284,7 +287,7 @@ def validation_step(self, inputs: HlabelClsBatchDataEntity, batch_idx: int) -> N """ preds = self.model(inputs) - if not isinstance(preds, HlabelClsBatchPredEntity): + if not isinstance(preds, (HlabelClsBatchPredEntity, HlabelClsBatchPredEntityWithXAI)): raise TypeError(preds) self.val_metric.update( @@ -293,7 +296,7 @@ def validation_step(self, inputs: HlabelClsBatchDataEntity, batch_idx: int) -> N def _convert_pred_entity_to_compute_metric( self, - preds: HlabelClsBatchPredEntity, + preds: HlabelClsBatchPredEntity | HlabelClsBatchPredEntityWithXAI, inputs: HlabelClsBatchDataEntity, ) -> dict[str, list[dict[str, Tensor]]]: if self.num_multilabel_classes > 0: @@ -316,7 +319,7 @@ def test_step(self, inputs: HlabelClsBatchDataEntity, batch_idx: int) -> None: """ preds = self.model(inputs) - if not isinstance(preds, HlabelClsBatchPredEntity): + if not isinstance(preds, (HlabelClsBatchPredEntity, HlabelClsBatchPredEntityWithXAI)): raise TypeError(preds) self.test_metric.update( diff --git a/src/otx/core/model/module/detection.py b/src/otx/core/model/module/detection.py index f2d9938874a..f8d95bf64b8 100644 --- a/src/otx/core/model/module/detection.py +++ b/src/otx/core/model/module/detection.py @@ -14,6 +14,7 @@ from otx.core.data.entity.detection import ( DetBatchDataEntity, DetBatchPredEntity, + DetBatchPredEntityWithXAI, ) from otx.core.model.entity.detection import ExplainableOTXDetModel from otx.core.model.module.base import OTXLitModule @@ -88,7 +89,7 @@ def validation_step(self, inputs: DetBatchDataEntity, batch_idx: int) -> None: """ preds = self.model(inputs) - if not isinstance(preds, DetBatchPredEntity): + if not isinstance(preds, (DetBatchPredEntity, DetBatchPredEntityWithXAI)): raise TypeError(preds) self.val_metric.update( @@ -97,7 +98,7 @@ def validation_step(self, inputs: DetBatchDataEntity, batch_idx: int) -> None: def _convert_pred_entity_to_compute_metric( self, - preds: DetBatchPredEntity, + preds: DetBatchPredEntity | DetBatchPredEntityWithXAI, inputs: DetBatchDataEntity, ) -> dict[str, list[dict[str, Tensor]]]: return { @@ -131,7 +132,7 @@ def test_step(self, inputs: DetBatchDataEntity, batch_idx: int) -> None: """ preds = self.model(inputs) - if not isinstance(preds, DetBatchPredEntity): + if not isinstance(preds, (DetBatchPredEntity, DetBatchPredEntityWithXAI)): raise TypeError(preds) self.test_metric.update( diff --git a/src/otx/core/model/module/instance_segmentation.py b/src/otx/core/model/module/instance_segmentation.py index 40bf4fb4fb3..205ff1d27bf 100644 --- a/src/otx/core/model/module/instance_segmentation.py +++ b/src/otx/core/model/module/instance_segmentation.py @@ -16,6 +16,7 @@ from otx.core.data.entity.instance_segmentation import ( InstanceSegBatchDataEntity, InstanceSegBatchPredEntity, + InstanceSegBatchPredEntityWithXAI, ) from otx.core.model.entity.instance_segmentation import ExplainableOTXInstanceSegModel from otx.core.model.module.base import OTXLitModule @@ -99,7 +100,7 @@ def validation_step(self, inputs: InstanceSegBatchDataEntity, batch_idx: int) -> """ preds = self.model(inputs) - if not isinstance(preds, InstanceSegBatchPredEntity): + if not isinstance(preds, (InstanceSegBatchPredEntity, InstanceSegBatchPredEntityWithXAI)): raise TypeError(preds) self.val_metric.update( @@ -108,7 +109,7 @@ def validation_step(self, inputs: InstanceSegBatchDataEntity, batch_idx: int) -> def _convert_pred_entity_to_compute_metric( self, - preds: InstanceSegBatchPredEntity, + preds: InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI, inputs: InstanceSegBatchDataEntity, ) -> dict[str, list[dict[str, Tensor]]]: """Convert the prediction entity to the format that the metric can compute and cache the ground truth. @@ -173,7 +174,7 @@ def test_step(self, inputs: InstanceSegBatchDataEntity, batch_idx: int) -> None: """ preds = self.model(inputs) - if not isinstance(preds, InstanceSegBatchPredEntity): + if not isinstance(preds, (InstanceSegBatchPredEntity, InstanceSegBatchPredEntityWithXAI)): raise TypeError(preds) self.test_metric.update( diff --git a/src/otx/core/model/module/segmentation.py b/src/otx/core/model/module/segmentation.py index 000b2cdea3d..c5450b46af4 100644 --- a/src/otx/core/model/module/segmentation.py +++ b/src/otx/core/model/module/segmentation.py @@ -14,6 +14,7 @@ from otx.core.data.entity.segmentation import ( SegBatchDataEntity, SegBatchPredEntity, + SegBatchPredEntityWithXAI, ) from otx.core.model.entity.segmentation import OTXSegmentationModel from otx.core.model.module.base import OTXLitModule @@ -97,7 +98,7 @@ def validation_step(self, inputs: SegBatchDataEntity, batch_idx: int) -> None: """ preds = self.model(inputs) - if not isinstance(preds, SegBatchPredEntity): + if not isinstance(preds, (SegBatchPredEntity, SegBatchPredEntityWithXAI)): raise TypeError(preds) predictions = self._convert_pred_entity_to_compute_metric(preds, inputs) @@ -106,7 +107,7 @@ def validation_step(self, inputs: SegBatchDataEntity, batch_idx: int) -> None: def _convert_pred_entity_to_compute_metric( self, - preds: SegBatchPredEntity, + preds: SegBatchPredEntity | SegBatchPredEntityWithXAI, inputs: SegBatchDataEntity, ) -> list[dict[str, Tensor]]: return [ @@ -125,7 +126,7 @@ def test_step(self, inputs: SegBatchDataEntity, batch_idx: int) -> None: :param batch_idx: The index of the current batch. """ preds = self.model(inputs) - if not isinstance(preds, SegBatchPredEntity): + if not isinstance(preds, (SegBatchPredEntity, SegBatchPredEntityWithXAI)): raise TypeError(preds) predictions = self._convert_pred_entity_to_compute_metric(preds, inputs) for prediction in predictions: diff --git a/src/otx/core/utils/tile_merge.py b/src/otx/core/utils/tile_merge.py index 5f147c77a04..97f19660981 100644 --- a/src/otx/core/utils/tile_merge.py +++ b/src/otx/core/utils/tile_merge.py @@ -14,9 +14,10 @@ from torchvision.ops import batched_nms from otx.core.data.entity.base import ImageInfo, T_OTXBatchPredEntity, T_OTXDataEntity -from otx.core.data.entity.detection import DetBatchPredEntity, DetPredEntity +from otx.core.data.entity.detection import DetBatchPredEntity, DetBatchPredEntityWithXAI, DetPredEntity from otx.core.data.entity.instance_segmentation import ( InstanceSegBatchPredEntity, + InstanceSegBatchPredEntityWithXAI, InstanceSegPredEntity, ) @@ -93,7 +94,7 @@ class DetectionTileMerge(TileMerge): def merge( self, - batch_tile_preds: list[DetBatchPredEntity], + batch_tile_preds: list[DetBatchPredEntity | DetBatchPredEntityWithXAI], batch_tile_attrs: list[list[dict]], ) -> list[DetPredEntity]: """Merge batch tile predictions to a list of full-size prediction data entities. @@ -186,7 +187,7 @@ class InstanceSegTileMerge(TileMerge): def merge( self, - batch_tile_preds: list[InstanceSegBatchPredEntity], + batch_tile_preds: list[InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI], batch_tile_attrs: list[list[dict]], ) -> list[InstanceSegPredEntity]: """Merge inst-seg tile predictions to one single prediction. diff --git a/tests/integration/cli/test_export_inference.py b/tests/integration/cli/test_export_inference.py index 9d28753acd6..a0df2d68665 100644 --- a/tests/integration/cli/test_export_inference.py +++ b/tests/integration/cli/test_export_inference.py @@ -117,24 +117,30 @@ def test_otx_export_infer( assert len(ckpt_files) > 0 # 2) otx test - tmp_path_test = tmp_path / f"otx_test_{model_name}" - command_cfg = [ - "otx", - "test", - "--config", - recipe, - "--data_root", - fxt_target_dataset_per_task[task], - "--engine.work_dir", - str(tmp_path_test / "outputs" / "torch"), - "--engine.device", - fxt_accelerator, - *fxt_cli_override_command_per_task[task], - "--checkpoint", - str(ckpt_files[-1]), - ] + def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device: str = fxt_accelerator) -> Path: + tmp_path_test = tmp_path / f"otx_test_{model_name}" + command_cfg = [ + "otx", + "test", + "--config", + test_recipe, + "--data_root", + fxt_target_dataset_per_task[task], + "--engine.work_dir", + str(tmp_path_test / work_dir), + "--engine.device", + device, + *fxt_cli_override_command_per_task[task], + "--checkpoint", + checkpoint_path, + ] + run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) - run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) + return tmp_path_test + + tmp_path_test = run_cli_test(recipe, str(ckpt_files[-1]), Path("outputs") / "torch") + + assert (tmp_path_test / "outputs").exists() # 3) otx export format_to_ext = {"OPENVINO": "xml"} # [TODO](@Vlad): extend to "ONNX": "onnx" @@ -171,24 +177,7 @@ def test_otx_export_infer( export_test_recipe = f"src/otx/recipe/{task}/openvino_model.yaml" exported_model_path = str(tmp_path_test / "outputs" / "exported_model.xml") - command_cfg = [ - "otx", - "test", - "--config", - export_test_recipe, - "--data_root", - fxt_target_dataset_per_task[task], - "--engine.work_dir", - str(tmp_path_test / "outputs" / "openvino"), - "--engine.device", - "cpu", - *fxt_cli_override_command_per_task[task], - "--checkpoint", - exported_model_path, - ] - - run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) - + tmp_path_test = run_cli_test(export_test_recipe, exported_model_path, Path("outputs") / "openvino", "cpu") assert (tmp_path_test / "outputs").exists() # 5) test optimize @@ -214,24 +203,7 @@ def test_otx_export_infer( exported_model_path = str(tmp_path_test / "outputs" / "optimized_model.xml") # 6) test optimized model - command_cfg = [ - "otx", - "test", - "--config", - export_test_recipe, - "--data_root", - fxt_target_dataset_per_task[task], - "--engine.work_dir", - str(tmp_path_test / "outputs" / "nncf_ptq"), - "--engine.device", - "cpu", - *fxt_cli_override_command_per_task[task], - "--checkpoint", - exported_model_path, - ] - - run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) - + tmp_path_test = run_cli_test(export_test_recipe, exported_model_path, Path("outputs") / "nncf_ptq", "cpu") assert (tmp_path_test / "outputs").exists() df_torch = pd.read_csv(next((tmp_path_test / "outputs" / "torch").glob("**/metrics.csv"))) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 03337f1a120..94c40c05411 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -36,7 +36,7 @@ def fxt_open_subprocess(request: pytest.FixtureRequest) -> bool: This option can be used for easy memory management while running consecutive multiple tests (default: false). """ - return request.config.getoption("--open-subprocess") + return request.config.getoption("--open-subprocess", False) def find_recipe_folder(base_path: Path, folder_name: str) -> Path: From 0c2de2dd0b0f3e8142eef5749b6a5b8c71ae337b Mon Sep 17 00:00:00 2001 From: Eunwoo Shin Date: Fri, 16 Feb 2024 11:45:48 +0900 Subject: [PATCH 5/6] Enable HPO in OTX 2.0 (#2912) * make mem_cache obj picklizeable * revert mem_cache implementation * just implement for run hpo * just implement for run hpo w/ all tasks * implement draft * pass trial id to report function * change order of reverting os.enviorn in hpo runner * add code to use hpo algo * refine append_signal_handler code flow * refactor code * add docstring * align with pre-commit * clean minor thing * add missing type hint * fix a bug that best model weight isn't kept * align batch size to train set size * align with pre-commit * replace function * update unit test * add unit test for utils * align with pre-commit * skip hpo if task is zero shot vp * move property below * remove _prepare_hpo_args function * rename hpo.py to hpo_api.py * move signal realated things to signal.py * update unit test * change setter argument name * revert hpo_api * remove unused import * change hpo_cfg_file to HpoConfig * deal with the case where optimizer or scheduler is list type * add integration test for hpo * move otx/engine/utils/hpo directory * refine HpoConfig --- src/otx/core/config/hpo.py | 32 ++++ src/otx/core/data/mem_cache.py | 10 +- src/otx/engine/engine.py | 47 ++++- src/otx/engine/hpo/__init__.py | 9 + src/otx/engine/hpo/hpo_api.py | 259 ++++++++++++++++++++++++++++ src/otx/engine/hpo/hpo_trial.py | 140 +++++++++++++++ src/otx/engine/hpo/utils.py | 80 +++++++++ src/otx/hpo/hpo_base.py | 5 +- src/otx/hpo/hpo_runner.py | 47 +++-- src/otx/hpo/hyperband.py | 8 +- src/otx/hpo/search_space.py | 3 +- src/otx/utils/__init__.py | 4 +- src/otx/utils/signal.py | 70 ++++++++ src/otx/utils/utils.py | 120 +++++++++++-- tests/integration/cli/test_cli.py | 61 +++++++ tests/unit/hpo/test_hyperband.py | 7 +- tests/unit/hpo/test_search_space.py | 10 +- tests/unit/utils/test_signal.py | 61 +++++++ tests/unit/utils/test_utils.py | 113 ++++++++++-- 19 files changed, 1004 insertions(+), 82 deletions(-) create mode 100644 src/otx/core/config/hpo.py create mode 100644 src/otx/engine/hpo/__init__.py create mode 100644 src/otx/engine/hpo/hpo_api.py create mode 100644 src/otx/engine/hpo/hpo_trial.py create mode 100644 src/otx/engine/hpo/utils.py create mode 100644 src/otx/utils/signal.py create mode 100644 tests/unit/utils/test_signal.py diff --git a/src/otx/core/config/hpo.py b/src/otx/core/config/hpo.py new file mode 100644 index 00000000000..87efdcda849 --- /dev/null +++ b/src/otx/core/config/hpo.py @@ -0,0 +1,32 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Config objects for HPO.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +import torch + + +@dataclass +class HpoConfig: + """DTO for HPO configuration.""" + + search_space: dict[str, dict[str, Any]] | None = None + save_path: str | None = None + mode: Literal["max", "min"] = "max" + num_trials: int | None = None + num_workers: int = 1 + expected_time_ratio: int | float | None = 4 + maximum_resource: int | float | None = None + subset_ratio: float | int | None = None + min_subset_size: int = 500 + prior_hyper_parameters: dict | list[dict] | None = None + acceptable_additional_time_ratio: float | int = 1.0 + minimum_resource: int | float | None = None + reduction_factor: int = 3 + asynchronous_bracket: bool = True + asynchronous_sha: bool = torch.cuda.device_count() != 1 diff --git a/src/otx/core/data/mem_cache.py b/src/otx/core/data/mem_cache.py index 4440d00155d..92dafe44e22 100644 --- a/src/otx/core/data/mem_cache.py +++ b/src/otx/core/data/mem_cache.py @@ -15,6 +15,8 @@ import numpy as np import psutil +from otx.utils import append_signal_handler + if TYPE_CHECKING: from multiprocessing.managers import DictProxy from multiprocessing.synchronize import Lock @@ -300,13 +302,11 @@ def create(cls, mode: str, mem_size: int) -> MemCacheHandlerBase: raise MemCacheHandlerError(msg) # Should delete if receive sigint to gracefully terminate - original_handler = signal.getsignal(signal.SIGINT) - - def _new_handler(signum, frame) -> None: # noqa: ANN001 - original_handler(signum, frame) # type: ignore[operator, misc] + def _new_handler(signum_, frame_) -> None: # noqa: ARG001, ANN001 instance.shutdown() - signal.signal(signal.SIGINT, _new_handler) + append_signal_handler(signal.SIGINT, _new_handler) + append_signal_handler(signal.SIGTERM, _new_handler) cls.instances.append(instance) diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 484fcd452ff..a6fdf28211c 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -14,6 +14,7 @@ from otx.core.config.data import DataModuleConfig, SubsetConfig, TileConfig from otx.core.config.device import DeviceConfig from otx.core.config.explain import ExplainConfig +from otx.core.config.hpo import HpoConfig from otx.core.data.module import OTXDataModule from otx.core.model.entity.base import OTXModel, OVModel from otx.core.model.module.base import OTXLitModule @@ -23,6 +24,7 @@ from otx.core.types.task import OTXTaskType from otx.core.utils.cache import TrainerArgumentsCache +from .hpo import execute_hpo, update_hyper_parameter from .utils.auto_configurator import AutoConfigurator, PathLike if TYPE_CHECKING: @@ -105,15 +107,10 @@ def __init__( device (DeviceType, optional): The device type to use. Defaults to DeviceType.auto. **kwargs: Additional keyword arguments for pl.Trainer. """ - self.work_dir = work_dir + self._cache = TrainerArgumentsCache(**kwargs) self.checkpoint = checkpoint - self.device = DeviceConfig(accelerator=device) - self._cache = TrainerArgumentsCache( - default_root_dir=self.work_dir, - accelerator=self.device.accelerator, - devices=self.device.devices, - **kwargs, - ) + self.work_dir = work_dir + self.device = device # type: ignore[assignment] self._auto_configurator = AutoConfigurator( data_root=data_root, task=datamodule.task if datamodule is not None else task, @@ -156,6 +153,8 @@ def train( callbacks: list[Callback] | Callback | None = None, logger: Logger | Iterable[Logger] | bool | None = None, resume: bool = False, + run_hpo: bool = False, + hpo_config: HpoConfig | None = None, **kwargs, ) -> dict[str, Any]: """Trains the model using the provided LightningModule and OTXDataModule. @@ -171,6 +170,8 @@ def train( callbacks (list[Callback] | Callback | None, optional): The callbacks to be used during training. logger (Logger | Iterable[Logger] | bool | None, optional): The logger(s) to be used. Defaults to None. resume (bool, optional): If True, tries to resume training from existing checkpoint. + run_hpo (bool, optional): If True, optimizer hyper parameters before training a model. + hpo_config (HpoConfig | None, optional): Configuration for HPO. **kwargs: Additional keyword arguments for pl.Trainer configuration. Returns: @@ -206,6 +207,16 @@ def train( otx train --data_root --config ``` """ + if run_hpo: + if hpo_config is None: + hpo_config = HpoConfig() + best_config, best_trial_weight = execute_hpo(engine=self, **locals()) + if best_config is not None: + update_hyper_parameter(self, best_config) + if best_trial_weight is not None: + self.checkpoint = best_trial_weight + resume = True + lit_module = self._build_lightning_module( model=self.model, optimizer=self.optimizer, @@ -599,6 +610,26 @@ def from_config(cls, config_path: PathLike, data_root: PathLike | None = None, * # Property and setter functions provided by Engine. # ------------------------------------------------------------------------ # + @property + def work_dir(self) -> PathLike: + """Work directory.""" + return self._work_dir + + @work_dir.setter + def work_dir(self, work_dir: PathLike) -> None: + self._work_dir = work_dir + self._cache.update(default_root_dir=work_dir) + + @property + def device(self) -> DeviceConfig: + """Device engine uses.""" + return self._device + + @device.setter + def device(self, device: DeviceType) -> None: + self._device = DeviceConfig(accelerator=device) + self._cache.update(accelerator=self._device.accelerator, devices=self._device.devices) + @property def trainer(self) -> Trainer: """Returns the trainer object associated with the engine. diff --git a/src/otx/engine/hpo/__init__.py b/src/otx/engine/hpo/__init__.py new file mode 100644 index 00000000000..d82e8c52e64 --- /dev/null +++ b/src/otx/engine/hpo/__init__.py @@ -0,0 +1,9 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Functions and Classes to run HPO in the engine.""" + +from .hpo_api import execute_hpo +from .hpo_trial import update_hyper_parameter + +__all__ = ["execute_hpo", "update_hyper_parameter"] diff --git a/src/otx/engine/hpo/hpo_api.py b/src/otx/engine/hpo/hpo_api.py new file mode 100644 index 00000000000..bcafb6039ae --- /dev/null +++ b/src/otx/engine/hpo/hpo_api.py @@ -0,0 +1,259 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Components to run HPO.""" + +from __future__ import annotations + +import dataclasses +import logging +import time +from functools import partial +from pathlib import Path +from threading import Thread +from typing import TYPE_CHECKING, Any, Callable + +import torch +from lightning.pytorch.cli import OptimizerCallable + +from otx.core.config.hpo import HpoConfig +from otx.core.types.task import OTXTaskType +from otx.hpo import HyperBand, run_hpo_loop +from otx.utils.utils import get_decimal_point, get_using_dot_delimited_key, remove_matched_files + +from .hpo_trial import run_hpo_trial +from .utils import find_trial_file, get_best_hpo_weight, get_hpo_weight_dir + +if TYPE_CHECKING: + from otx.engine.engine import Engine + from otx.hpo.hpo_base import HpoBase + +logger = logging.getLogger(__name__) + +AVAILABLE_HP_NAME_MAP = { + "data.config.train_subset.batch_size": "datamodule.config.train_subset.batch_size", + "optimizer": "optimizer.keywords", + "scheduler": "scheduler.keywords", +} + + +def execute_hpo( + engine: Engine, + max_epochs: int, + hpo_config: HpoConfig | None = None, + progress_update_callback: Callable[[int | float], None] | None = None, + **train_args, +) -> tuple[dict[str, Any] | None, Path | None]: + """Execute HPO. + + Args: + engine (Engine): engine instnace. + max_epochs (int): max epochs to train. + hpo_config (HpoConfig | None, optional): Configuration for HPO. + progress_update_callback (Callable[[int | float], None] | None, optional): + callback to update progress. If it's given, it's called with progress every second. Defaults to None. + + Returns: + tuple[dict[str, Any] | None, Path | None]: + best hyper parameters and model weight trained with best hyper parameters. If it doesn't exist, + return None. + """ + if engine.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: # type: ignore[has-type] + logger.warning("Zero shot visual prompting task doesn't support HPO.") + return None, None + + hpo_workdir = Path(engine.work_dir) / "hpo" + hpo_workdir.mkdir(exist_ok=True) + hpo_configurator = HPOConfigurator( + engine, + max_epochs, + hpo_workdir, + hpo_config, + ) + if (hpo_algo := hpo_configurator.get_hpo_algo()) is None: + logger.warning("HPO is skipped.") + return None, None + + if progress_update_callback is not None: + Thread(target=_update_hpo_progress, args=[progress_update_callback, hpo_algo], daemon=True).start() + + run_hpo_loop( + hpo_algo, + partial( + run_hpo_trial, + hpo_workdir=hpo_workdir, + engine=engine, + max_epochs=max_epochs, + **_adjust_train_args(train_args), + ), + "gpu" if torch.cuda.is_available() else "cpu", + ) + + best_trial = hpo_algo.get_best_config() + if best_trial is None: + best_config = None + best_hpo_weight = None + else: + best_config = best_trial["configuration"] + if (trial_file := find_trial_file(hpo_workdir, best_trial["id"])) is not None: + best_hpo_weight = get_best_hpo_weight(get_hpo_weight_dir(hpo_workdir, best_trial["id"]), trial_file) + + hpo_algo.print_result() + _remove_unused_model_weights(hpo_workdir, best_hpo_weight) + + return best_config, best_hpo_weight + + +class HPOConfigurator: + """HPO configurator. Prepare a configuration and provide an HPO algorithm based on the configuration. + + Args: + engine (Engine): engine instance. + max_epoch (int): max epochs to train. + hpo_workdir (Path | None, optional): HPO work directory. Defaults to None. + hpo_config (HpoConfig | None, optional): Configuration for HPO. + """ + + def __init__( + self, + engine: Engine, + max_epoch: int, + hpo_workdir: Path | None = None, + hpo_config: HpoConfig | None = None, + ) -> None: + self._engine = engine + self._max_epoch = max_epoch + self._hpo_workdir = hpo_workdir if hpo_workdir is not None else Path(engine.work_dir) / "hpo" + self.hpo_config: dict[str, Any] = hpo_config # type: ignore[assignment] + + @property + def hpo_config(self) -> dict[str, Any]: + """Configuration for HPO algorithm.""" + return self._hpo_config + + @hpo_config.setter + def hpo_config(self, hpo_config: HpoConfig | None) -> None: + train_dataset_size = len(self._engine.datamodule.subsets["train"]) + val_dataset_size = len(self._engine.datamodule.subsets["val"]) + + self._hpo_config: dict[str, Any] = { # default setting + "save_path": str(self._hpo_workdir), + "num_full_iterations": self._max_epoch, + "full_dataset_size": train_dataset_size, + "non_pure_train_ratio": val_dataset_size / (train_dataset_size + val_dataset_size), + } + + if hpo_config is not None: + self._hpo_config.update( + {key: val for key, val in dataclasses.asdict(hpo_config).items() if val is not None}, + ) + + if "search_space" not in self._hpo_config: + self._hpo_config["search_space"] = self._get_default_search_space() + else: + self._align_hp_name(self._hpo_config["search_space"]) + + if ( # align batch size to train set size + "datamodule.config.train_subset.batch_size" in self._hpo_config["search_space"] + and self._hpo_config["search_space"]["datamodule.config.train_subset.batch_size"]["max"] + > train_dataset_size + ): + logger.info( + "Max value of batch size in HPO search space is lower than train dataset size. " + "Decrease it to train dataset size.", + ) + self._hpo_config["search_space"]["datamodule.config.train_subset.batch_size"]["max"] = train_dataset_size + + self._remove_wrong_search_space(self._hpo_config["search_space"]) + + if "prior_hyper_parameters" not in self._hpo_config: # default hyper parameters are tried first + self._hpo_config["prior_hyper_parameters"] = { + hp: get_using_dot_delimited_key(hp, self._engine) + for hp in self._hpo_config["search_space"].keys() # noqa: SIM118 + } + + def _get_default_search_space(self) -> dict[str, Any]: + """Set learning rate and batch size as search space.""" + search_space = {} + + if isinstance(self._engine.optimizer, list): + for i, optimizer in enumerate(self._engine.optimizer): + search_space[f"optimizer.{i}.keywords.lr"] = self._make_lr_search_space(optimizer) + elif isinstance(self._engine.optimizer, OptimizerCallable): + search_space["optimizer.keywords.lr"] = self._make_lr_search_space(self._engine.optimizer) + + cur_bs = self._engine.datamodule.config.train_subset.batch_size + search_space["datamodule.config.train_subset.batch_size"] = { + "type": "qloguniform", + "min": cur_bs // 2, + "max": cur_bs * 2, + "step": 2, + } + + return search_space + + @staticmethod + def _make_lr_search_space(optimizer: OptimizerCallable) -> dict[str, Any]: + cur_lr = optimizer.keywords["lr"] # type: ignore[union-attr] + min_lr = cur_lr / 10 + return { + "type": "qloguniform", + "min": min_lr, + "max": min(cur_lr * 10, 0.1), + "step": 10 ** -get_decimal_point(min_lr), + } + + @staticmethod + def _align_hp_name(search_space: dict[str, Any]) -> None: + for hp_name in list(search_space.keys()): + for valid_hp in AVAILABLE_HP_NAME_MAP: + if valid_hp in hp_name: + new_hp_name = hp_name.replace(valid_hp, AVAILABLE_HP_NAME_MAP[valid_hp]) + search_space[new_hp_name] = search_space.pop(hp_name) + break + else: + error_msg = ( + "Given hyper parameter can't be optimized by HPO. " + f"Please choose one from {','.join(AVAILABLE_HP_NAME_MAP)}." + ) + raise ValueError(error_msg) + + @staticmethod + def _remove_wrong_search_space(search_space: dict[str, dict[str, Any]]) -> None: + for hp_name, config in list(search_space.items()): + if config["type"] == "choice": + if not config["choice_list"]: + search_space.pop(hp_name) + logger.warning(f"choice_list is empty. {hp_name} is excluded from HPO serach space.") + elif config["max"] < config["min"] + config.get("step", 0): + search_space.pop(hp_name) + if "step" in config: + reason_to_exclude = "max is smaller than sum of min and step" + else: + reason_to_exclude = "max is smaller than min" + logger.warning(f"{reason_to_exclude}. {hp_name} is excluded from HPO serach space.") + + def get_hpo_algo(self) -> HpoBase | None: + """Get HPO algorithm based on prepared configuration.""" + if not self.hpo_config["search_space"]: + logger.warning("There is no hyper parameter to optimize.") + return None + return HyperBand(**self.hpo_config) + + +def _update_hpo_progress(progress_update_callback: Callable[[int | float], None], hpo_algo: HpoBase) -> None: + while not hpo_algo.is_done(): + progress_update_callback(hpo_algo.get_progress() * 100) + time.sleep(1) + + +def _adjust_train_args(train_args: dict[str, Any]) -> dict[str, Any]: + train_args.update(train_args.pop("kwargs", {})) + train_args.pop("self", None) + train_args.pop("run_hpo", None) + + return train_args + + +def _remove_unused_model_weights(hpo_workdir: Path, best_hpo_weight: Path | None = None) -> None: + remove_matched_files(hpo_workdir, "*.ckpt", best_hpo_weight) diff --git a/src/otx/engine/hpo/hpo_trial.py b/src/otx/engine/hpo/hpo_trial.py new file mode 100644 index 00000000000..6b519af504c --- /dev/null +++ b/src/otx/engine/hpo/hpo_trial.py @@ -0,0 +1,140 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Components to run HPO trial.""" + +from __future__ import annotations + +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Callable + +from lightning import Callback +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint + +from otx.algo.callbacks.adaptive_train_scheduling import AdaptiveTrainScheduling +from otx.hpo import TrialStatus +from otx.utils.utils import find_file_recursively, remove_matched_files, set_using_dot_delimited_key + +from .utils import find_trial_file, get_best_hpo_weight, get_hpo_weight_dir + +if TYPE_CHECKING: + from lightning import LightningModule, Trainer + + from otx.engine.engine import Engine + + +def update_hyper_parameter(engine: Engine, hyper_parameter: dict[str, Any]) -> None: + """Update hyper parameter in the engine.""" + for key, val in hyper_parameter.items(): + set_using_dot_delimited_key(key, val, engine) + + +class HPOCallback(Callback): + """HPO callback class which reports a score to HPO algorithm every epoch.""" + + def __init__(self, report_func: Callable[[float | int, float | int], TrialStatus], metric: str) -> None: + super().__init__() + self._report_func = report_func + self.metric = metric + + def on_train_epoch_end(self, trainer: Trainer, pl_module_: LightningModule) -> None: + """Report scores if score exists at the end of each epoch.""" + score = trainer.callback_metrics.get(self.metric) + if score is not None and self._report_func(score.item(), trainer.current_epoch + 1) == TrialStatus.STOP: + trainer.should_stop = True + + +def run_hpo_trial( + hp_config: dict[str, Any], + report_func: Callable[[int | float, int | float, bool], None], + hpo_workdir: Path, + engine: Engine, + callbacks: list[Callback] | Callback | None = None, + **train_args, +) -> None: + """Run HPO trial. After it's done, best weight and last weight are saved for later use. + + Args: + hp_config (dict[str, Any]): trial's hyper parameter. + report_func (Callable): function to report score. + hpo_workdir (Path): HPO work directory. + engine (Engine): engine instance. + callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None. + train_args: Arugments for 'engine.train'. + """ + trial_id = hp_config["id"] + hpo_weight_dir = get_hpo_weight_dir(hpo_workdir, trial_id) + + _set_trial_hyper_parameter(hp_config["configuration"], engine, train_args) + + if (checkpoint := _find_last_weight(hpo_weight_dir)) is not None: + engine.checkpoint = checkpoint + train_args["resume"] = True + + callbacks = _register_hpo_callback(report_func, callbacks) + _set_to_validate_every_epoch(callbacks, train_args) + + with TemporaryDirectory(prefix="OTX-HPO-") as temp_dir: + _change_work_dir(temp_dir, callbacks, engine) + engine.train(callbacks=callbacks, **train_args) + + _keep_best_and_last_weight(Path(temp_dir), hpo_workdir, trial_id) + + report_func(0, 0, done=True) # type: ignore[call-arg] + + +def _set_trial_hyper_parameter(hyper_parameter: dict[str, Any], engine: Engine, train_args: dict[str, Any]) -> None: + train_args["max_epochs"] = round(hyper_parameter.pop("iterations")) + update_hyper_parameter(engine, hyper_parameter) + + +def _find_last_weight(weight_dir: Path) -> Path | None: + return find_file_recursively(weight_dir, "last.ckpt") + + +def _register_hpo_callback(report_func: Callable, callbacks: list[Callback] | Callback | None) -> list[Callback]: + if isinstance(callbacks, Callback): + callbacks = [callbacks] + elif callbacks is None: + callbacks = [] + callbacks.append(HPOCallback(report_func, _get_metric(callbacks))) + return callbacks + + +def _get_metric(callbacks: list[Callback]) -> str: + for callback in callbacks: + if isinstance(callback, ModelCheckpoint): + return callback.monitor + error_msg = "Failed to find a metric. There is no ModelCheckpoint in callback list." + raise RuntimeError(error_msg) + + +def _set_to_validate_every_epoch(callbacks: list[Callback], train_args: dict[str, Any]) -> None: + for callback in callbacks: + if isinstance(callback, AdaptiveTrainScheduling): + callback.max_interval = 1 + break + else: + train_args["check_val_every_n_epoch"] = 1 + + +def _change_work_dir(work_dir: str, callbacks: list[Callback], engine: Engine) -> None: + for callback in callbacks: + if isinstance(callback, ModelCheckpoint): + callback.dirpath = work_dir + break + engine.work_dir = work_dir + + +def _keep_best_and_last_weight(trial_work_dir: Path, hpo_workdir: Path, trial_id: str) -> None: + weight_dir = get_hpo_weight_dir(hpo_workdir, trial_id) + _move_all_ckpt(trial_work_dir, weight_dir) + if (trial_file := find_trial_file(hpo_workdir, trial_id)) is not None: + best_weight = get_best_hpo_weight(weight_dir, trial_file) + remove_matched_files(weight_dir, "epoch_*.ckpt", best_weight) + + +def _move_all_ckpt(src: Path, dest: Path) -> None: + for ckpt_file in src.rglob("*.ckpt"): + ckpt_file.replace(dest / ckpt_file.name) diff --git a/src/otx/engine/hpo/utils.py b/src/otx/engine/hpo/utils.py new file mode 100644 index 00000000000..b2c43846f8d --- /dev/null +++ b/src/otx/engine/hpo/utils.py @@ -0,0 +1,80 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Util functions to run HPO.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +from otx.utils.utils import find_file_recursively + +if TYPE_CHECKING: + from pathlib import Path + + +def find_trial_file(hpo_workdir: Path, trial_id: str) -> Path | None: + """Find a trial file which store trial record. + + Args: + hpo_workdir (Path): HPO work directory. + trial_id (str): trial id. + + Returns: + Path | None: trial file. If it doesn't exist, return None. + """ + return find_file_recursively(hpo_workdir, f"{trial_id}.json") + + +def get_best_hpo_weight(weight_dir: Path, trial_file: Path) -> Path | None: + """Get best model weight path of the HPO trial. + + Args: + weight_dir (Path): directory where model weights are saved. + trial_file (Path): json format trial file which stores trial record. + + Returns: + Path | None: best HPO model weight. If it doesn't exist, return None. + """ + if not trial_file.exists(): + return None + + with trial_file.open("r") as f: + trial_output = json.load(f) + + best_epochs = [] + best_score = None + for epoch, score in trial_output["score"].items(): + eph = str(int(epoch) - 1) # lightning uses index starting from 0 + if best_score is None: + best_score = score + best_epochs.append(eph) + elif best_score < score: + best_score = score + best_epochs = [eph] + elif best_score == score: + best_epochs.append(eph) + + best_epochs.sort(key=int, reverse=True) + for best_epoch in best_epochs: + if (best_weight_path := find_file_recursively(weight_dir, f"epoch_*{best_epoch}.ckpt")) is not None: + return best_weight_path + + return None + + +def get_hpo_weight_dir(hpo_workdir: Path, trial_id: str) -> Path: + """Get HPO weight directory. If it doesn't exist, directory is made. + + Args: + hpo_workdir (Path): HPO work directory. + trial_id (str): trial id. + + Returns: + Path: HPO weight directory path. + """ + hpo_weight_dir: Path = hpo_workdir / "weight" / trial_id + if not hpo_weight_dir.exists(): + hpo_weight_dir.mkdir(parents=True) + return hpo_weight_dir diff --git a/src/otx/hpo/hpo_base.py b/src/otx/hpo/hpo_base.py index 1f663c06a87..f0e03452a57 100644 --- a/src/otx/hpo/hpo_base.py +++ b/src/otx/hpo/hpo_base.py @@ -32,13 +32,12 @@ class HpoBase(ABC): search_space (dict[str, dict[str, Any]]): hyper parameter search space to find. save_path (str | None, optional): path where result of HPO is saved. mode ("max" | "min", optional): One of {min, max}. Determines whether objective is - minimizing or maximizing the metric attribute. + minimizing or maximizing the score. num_trials (int | None, optional): How many training to conduct for HPO. num_workers (int, optional): How many trains are executed in parallel. num_full_iterations (int, optional): epoch for traninig after HPO. non_pure_train_ratio (float, optional): ratio of validation time to (train time + validation time) full_dataset_size (int, optional): train dataset size - metric (str, optional): Which score metric to use. expected_time_ratio (int | float | None, optional): Time to use for HPO. If HPO is configured automatically, HPO use time about exepected_time_ratio * @@ -67,7 +66,6 @@ def __init__( num_full_iterations: int | float = 1, non_pure_train_ratio: float = 0.2, full_dataset_size: int = 0, - metric: str = "mAP", expected_time_ratio: int | float | None = None, maximum_resource: int | float | None = None, subset_ratio: float | int | None = None, @@ -113,7 +111,6 @@ def __init__( self.min_subset_size = min_subset_size self.resume = resume self.hpo_status: dict = {} - self.metric = metric self.acceptable_additional_time_ratio = acceptable_additional_time_ratio if prior_hyper_parameters is None: prior_hyper_parameters = [] diff --git a/src/otx/hpo/hpo_runner.py b/src/otx/hpo/hpo_runner.py index 34f37fd39b3..2a936ff1a49 100644 --- a/src/otx/hpo/hpo_runner.py +++ b/src/otx/hpo/hpo_runner.py @@ -18,7 +18,7 @@ from otx.hpo.hpo_base import HpoBase, Trial, TrialStatus from otx.hpo.resource_manager import get_resource_manager -from otx.utils import append_signal_handler +from otx.utils import append_main_proc_signal_handler if TYPE_CHECKING: from collections.abc import Hashable @@ -76,8 +76,8 @@ def __init__( ) self._main_pid = os.getpid() - append_signal_handler(signal.SIGINT, self._terminate_signal_handler) - append_signal_handler(signal.SIGTERM, self._terminate_signal_handler) + append_main_proc_signal_handler(signal.SIGINT, self._terminate_signal_handler) + append_main_proc_signal_handler(signal.SIGTERM, self._terminate_signal_handler) def run(self) -> None: """Run a HPO loop.""" @@ -123,14 +123,20 @@ def _start_trial_process(self, trial: Trial) -> None: args=( self._train_func, trial.get_train_configuration(), - partial(_report_score, recv_queue=trial_queue, send_queue=self._report_queue, uid=uid), + partial( + _report_score, + recv_queue=trial_queue, + send_queue=self._report_queue, + uid=uid, + trial_id=trial.id, + ), ), ) + self._running_trials[uid] = RunningTrial(process, trial, trial_queue) # type: ignore[arg-type] + process.start() os.environ.clear() for key, val in origin_env.items(): os.environ[key] = val - self._running_trials[uid] = RunningTrial(process, trial, trial_queue) # type: ignore[arg-type] - process.start() def _remove_finished_process(self) -> None: trial_to_remove = [] @@ -150,14 +156,14 @@ def _remove_finished_process(self) -> None: def _get_reports(self) -> None: while not self._report_queue.empty(): report = self._report_queue.get_nowait() - trial = self._running_trials[report["uid"]] trial_status = self._hpo_algo.report_score( report["score"], report["progress"], - trial.trial.id, + report["trial_id"], report["done"], ) - trial.queue.put_nowait(trial_status) + if report["uid"] in self._running_trials: + self._running_trials[report["uid"]].queue.put_nowait(trial_status) self._hpo_algo.save_results() @@ -181,13 +187,9 @@ def _terminate_all_running_processes(self) -> None: process = trial.process if process.is_alive(): logger.info(f"Kill child process {process.pid}") - process.kill() + process.terminate() def _terminate_signal_handler(self, signum: Signals, frame_) -> None: # noqa: ANN001 - # This code prevents child processses from being killed unintentionally by proccesses forked from main process - if self._main_pid != os.getpid(): - return - self._terminate_all_running_processes() singal_name = {2: "SIGINT", 15: "SIGTERM"} @@ -206,11 +208,24 @@ def _report_score( recv_queue: multiprocessing.Queue, send_queue: multiprocessing.Queue, uid: Hashable, + trial_id: Hashable, done: bool = False, ) -> TrialStatus: - logger.debug(f"score : {score}, progress : {progress}, uid : {uid}, pid : {os.getpid()}, done : {done}") + logger.debug( + f"score : {score}, progress : {progress}, uid : {uid}, trial_id : {trial_id}, " + f"pid : {os.getpid()}, done : {done}", + ) try: - send_queue.put_nowait({"score": score, "progress": progress, "uid": uid, "pid": os.getpid(), "done": done}) + send_queue.put_nowait( + { + "score": score, + "progress": progress, + "uid": uid, + "trial_id": trial_id, + "pid": os.getpid(), + "done": done, + }, + ) except ValueError: return TrialStatus.STOP diff --git a/src/otx/hpo/hyperband.py b/src/otx/hpo/hyperband.py index 5c8b56190c4..82eb46d0036 100644 --- a/src/otx/hpo/hyperband.py +++ b/src/otx/hpo/hyperband.py @@ -8,6 +8,7 @@ import json import logging import math +from copy import copy from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -723,7 +724,7 @@ def save_results(self) -> None: """Save a ASHA result.""" for idx, bracket in self._brackets.items(): save_path = Path(self.save_path) / str(idx) - save_path.mkdir(parents=True) + save_path.mkdir(parents=True, exist_ok=True) bracket.save_results(str(save_path)) def auto_config(self) -> list[dict[str, Any]]: @@ -967,7 +968,10 @@ def get_best_config(self) -> dict[str, Any] | None: if best_trial is None: return None - return {"id": best_trial.id, "config": best_trial.configuration} + config = copy(best_trial.configuration) + if "iterations" in config: + config.pop("iterations") + return {"id": best_trial.id, "configuration": config} def print_result(self) -> None: """Print a ASHA result.""" diff --git a/src/otx/hpo/search_space.py b/src/otx/hpo/search_space.py index b0e85c0fed2..912e7efccf6 100644 --- a/src/otx/hpo/search_space.py +++ b/src/otx/hpo/search_space.py @@ -276,7 +276,7 @@ class SearchSpace: arguemnt format is as bellow. { "some_hyper_parameter_name" : { - "param_type": type of search space of hyper parameter. + "type": type of search space of hyper parameter. supported types: uniform, loguniform, quniform, qloguniform or choice # At this point, there are two available formats. @@ -310,7 +310,6 @@ def __init__( self.search_space: dict[str, SingleSearchSpace] = {} for key, val in search_space.items(): # pylint: disable=too-many-nested-blocks - val["type"] = val.pop("param_type") self.search_space[key] = SingleSearchSpace(**val) def __getitem__(self, key: str) -> SingleSearchSpace: diff --git a/src/otx/utils/__init__.py b/src/otx/utils/__init__.py index 01c097c2980..2d433f25bee 100644 --- a/src/otx/utils/__init__.py +++ b/src/otx/utils/__init__.py @@ -3,6 +3,6 @@ # """Utility files.""" -from .utils import append_signal_handler +from .signal import append_main_proc_signal_handler, append_signal_handler -__all__ = ["append_signal_handler"] +__all__ = ["append_signal_handler", "append_main_proc_signal_handler"] diff --git a/src/otx/utils/signal.py b/src/otx/utils/signal.py new file mode 100644 index 00000000000..1ca3311be89 --- /dev/null +++ b/src/otx/utils/signal.py @@ -0,0 +1,70 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Functions to append a signal handler.""" + +from __future__ import annotations + +import os +import signal +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from types import FrameType + + +@dataclass +class SigHandler: + """Signal handler dataclass having handler function and pid which registers the handler.""" + + handler: Callable + pid: int + + +_SIGNAL_HANDLERS: dict[int, list] = {} + + +def append_signal_handler(sig_num: int, sig_handler: Callable) -> None: + """Append the handler for a signal. The function appended at last is called first. + + Args: + sig_num (signal.Signals): Signal number to add a handler to. + sig_handler (Callable): Callable function to be executed when the signal is sent. + """ + _register_signal_handler(sig_num, sig_handler, -1) + + +def append_main_proc_signal_handler(sig_num: int, sig_handler: Callable) -> None: + """Append the handler for a signal triggered only by main process. The function appended at last is called first. + + It's almost same as append_signal_handler except that handler will be executed only by signal to + process which registers handler. + + Args: + sig_num (signal.Signals): Signal number to add a handler to. + sig_handler (Callable): Callable function to be executed when the signal is sent. + """ + _register_signal_handler(sig_num, sig_handler, os.getpid()) + + +def _register_signal_handler(sig_num: int, sig_handler: Callable, pid: int) -> None: + if sig_num not in _SIGNAL_HANDLERS: + old_sig_handler = signal.getsignal(sig_num) + _SIGNAL_HANDLERS[sig_num] = [old_sig_handler] + signal.signal(sig_num, _run_signal_handlers) + + _SIGNAL_HANDLERS[sig_num].insert(0, SigHandler(sig_handler, pid)) + + +def _run_signal_handlers(sig_num: int, frame: FrameType | None) -> None: + pid = os.getpid() + for handler in _SIGNAL_HANDLERS[sig_num]: + if handler == signal.SIG_DFL: + signal.signal(sig_num, signal.SIG_DFL) + signal.raise_signal(sig_num) + elif isinstance(handler, SigHandler): + if handler.pid < 0 or handler.pid == pid: + handler.handler(sig_num, frame) + else: + handler(sig_num, frame) diff --git a/src/otx/utils/utils.py b/src/otx/utils/utils.py index eee9c18e4db..89cf03a2c79 100644 --- a/src/otx/utils/utils.py +++ b/src/otx/utils/utils.py @@ -3,26 +3,114 @@ """OTX utility functions.""" -import signal -from functools import partial -from typing import Callable +from __future__ import annotations +from decimal import Decimal +from typing import TYPE_CHECKING, Any -def append_signal_handler(sig_num: signal.Signals, sig_handler: Callable) -> None: - """Append the handler for a signal. The function appended at last is called first. +if TYPE_CHECKING: + from pathlib import Path + + +def get_using_dot_delimited_key(key: str, target: Any) -> Any: # noqa: ANN401 + """Get values of attribute in target object using dot delimited key. + + For example, if key is "a.b.c", then get a value of 'target.a.b.c'. + Target should be object having attributes, dictionary or list. + To get an element in a list, an integer that is the index of corresponding value can be set as a key. Args: - sig_num (signal.Signals): Signal number to add a handler to. - sig_handler (Callable): Callable function to be executed when the signal is sent. + key (str): dot delimited key. + val (Any): value to set. + target (Any): target to set value to. """ - old_sig_handler = signal.getsignal(sig_num) + splited_key = key.split(".") + for each_key in splited_key: + if isinstance(target, dict): + target = target[each_key] + elif isinstance(target, list): + if not each_key.isdigit(): + error_msg = f"Key should be integer but '{each_key}'." + raise ValueError(error_msg) + target = target[int(each_key)] + else: + target = getattr(target, each_key) + return target + - def helper(*args, old_func: Callable, **kwargs) -> None: - sig_handler(*args, **kwargs) - if old_func == signal.SIG_DFL: - signal.signal(sig_num, signal.SIG_DFL) - signal.raise_signal(sig_num) - elif callable(old_func): - old_func(*args, **kwargs) +def set_using_dot_delimited_key(key: str, val: Any, target: Any) -> None: # noqa: ANN401 + """Set values to attribute in target object using dot delimited key. + + For example, if key is "a.b.c", then value is set at 'target.a.b.c'. + Target should be object having attributes, dictionary or list. + To get an element in a list, an integer that is the index of corresponding value can be set as a key. + + Args: + key (str): dot delimited key. + val (Any): value to set. + target (Any): target to set value to. + """ + splited_key = key.split(".") + for each_key in splited_key[:-1]: + if isinstance(target, dict): + target = target[each_key] + elif isinstance(target, list): + if not each_key.isdigit(): + error_msg = f"Key should be integer but '{each_key}'." + raise ValueError(error_msg) + target = target[int(each_key)] + else: + target = getattr(target, each_key) + + if isinstance(target, dict): + target[splited_key[-1]] = val + elif isinstance(target, list): + if not splited_key[-1].isdigit(): + error_msg = f"Key should be integer but '{splited_key[-1]}'." + raise ValueError(error_msg) + target[int(splited_key[-1])] = val + else: + setattr(target, splited_key[-1], val) + + +def get_decimal_point(num: int | float) -> int: + """Find a decimal point from the given float. + + Args: + num (int | float): float to find a decimal point from. - signal.signal(sig_num, partial(helper, old_func=old_sig_handler)) + Returns: + int: decimal point. + """ + if isinstance((exponent := Decimal(str(num)).as_tuple().exponent), int): + return abs(exponent) + error_msg = f"Can't get an exponent from {num}." + raise ValueError(error_msg) + + +def find_file_recursively(directory: Path, file_name: str) -> Path | None: + """Find the file from the direcotry recursively. If multiple files have a same name, return one of them. + + Args: + directory (Path): directory where to find. + file_name (str): file name to find. + + Returns: + Path | None: Found file. If it's failed to find a file, return None. + """ + if found_file := list(directory.rglob(file_name)): + return found_file[0] + return None + + +def remove_matched_files(directory: Path, pattern: str, file_to_leave: Path | None = None) -> None: + """Remove all files matched to pattern except file_to_leave. + + Args: + directory (Path): direcetory to find files to remove. + pattern (str): pattern to match a file name. + file_not_to_remove (Path | None, optional): files to leave. Defaults to None. + """ + for weight in directory.rglob(pattern): + if weight != file_to_leave: + weight.unlink() diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index a64c673e117..f6d933a4014 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -6,6 +6,7 @@ import pytest import yaml +from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK from tests.integration.cli.utils import run_main @@ -302,3 +303,63 @@ def test_otx_ov_test( assert (tmp_path_test / "outputs" / "csv").exists() metric_result = list((tmp_path_test / "outputs" / "csv").glob(pattern="**/metrics.csv")) assert len(metric_result) > 0 + + +@pytest.mark.parametrize("task", pytest.TASK_LIST) +def test_otx_hpo_e2e( + task: str, + tmp_path: Path, + fxt_accelerator: str, + fxt_target_dataset_per_task: dict, + fxt_cli_override_command_per_task: dict, + fxt_open_subprocess: bool, +) -> None: + """ + Test HPO e2e commands with default template of each task. + + Args: + task (OTXTaskType): The task to run HPO with. + tmp_path (Path): The temporary path for storing the training outputs. + + Returns: + None + """ + if task in ("action_classification"): + pytest.xfail(reason="xFail until this root cause is resolved on the Datumaro side.") + if task not in DEFAULT_CONFIG_PER_TASK: + pytest.skip(f"Task {task} is not supported in the auto-configuration.") + + task = task.lower() + tmp_path_hpo = tmp_path / f"otx_hpo_{task}" + tmp_path_hpo.mkdir(parents=True) + + command_cfg = [ + "otx", + "train", + "--task", + task.upper(), + "--data_root", + fxt_target_dataset_per_task[task], + "--engine.work_dir", + str(tmp_path_hpo), + "--engine.device", + fxt_accelerator, + "--max_epochs", + "2", + "--run_hpo", + "true", + "--hpo_config.expected_time_ratio", + "2", + *fxt_cli_override_command_per_task[task], + ] + + run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) + + # zero_shot_visual_prompting doesn't support HPO. Check just there is no error. + if task in ("zero_shot_visual_prompting"): + return + + hpo_work_dor = tmp_path_hpo / "hpo" + assert hpo_work_dor.exists() + assert len([val for val in hpo_work_dor.rglob("*.json") if str(val.stem).isdigit()]) == 2 + assert len(list(hpo_work_dor.rglob("*.ckpt"))) == 1 diff --git a/tests/unit/hpo/test_hyperband.py b/tests/unit/hpo/test_hyperband.py index fb9151f2bda..00dad36d58c 100644 --- a/tests/unit/hpo/test_hyperband.py +++ b/tests/unit/hpo/test_hyperband.py @@ -67,8 +67,8 @@ def good_hyperband_args(): with TemporaryDirectory() as tmp_dir: yield { "search_space": { - "hp1": {"param_type": "uniform", "max": 100, "min": 10}, - "hp2": {"param_type": "qloguniform", "max": 1000, "min": 100, "step": 2, "log_base": 10}, + "hp1": {"type": "uniform", "max": 100, "min": 10}, + "hp2": {"type": "qloguniform", "max": 1000, "min": 100, "step": 2, "log_base": 10}, }, "save_path": tmp_dir, "mode": "max", @@ -76,7 +76,6 @@ def good_hyperband_args(): "num_full_iterations": 64, "non_pure_train_ratio": 0.2, "full_dataset_size": 100, - "metric": "mAP", "maximum_resource": 64, "minimum_resource": 1, "reduction_factor": 4, @@ -659,7 +658,7 @@ def test_report_score_trial_done(self, hyper_band): def test_get_best_config(self, hyper_band): max_score = 9999999 trial = hyper_band.get_next_sample() - expected_configuration = {"id": trial.id, "config": trial.configuration} + expected_configuration = {"id": trial.id, "configuration": trial.configuration} hyper_band.report_score(score=max_score, resource=trial.iteration, trial_id=trial.id, done=False) hyper_band.report_score(score=max_score, resource=trial.iteration, trial_id=trial.id, done=True) while True: diff --git a/tests/unit/hpo/test_search_space.py b/tests/unit/hpo/test_search_space.py index ab1829ec440..d72b76c1eba 100644 --- a/tests/unit/hpo/test_search_space.py +++ b/tests/unit/hpo/test_search_space.py @@ -398,28 +398,28 @@ def get_search_space_depending_on_type(types) -> dict: @staticmethod def add_uniform_search_space(search_space) -> None: - search_space["uniform_search_space"] = {"param_type": "uniform"} + search_space["uniform_search_space"] = {"type": "uniform"} search_space["uniform_search_space"].update({"min": 1, "max": 10}) @staticmethod def add_quniform_search_space(search_space) -> None: - search_space["quniform_search_space"] = {"param_type": "quniform"} + search_space["quniform_search_space"] = {"type": "quniform"} search_space["quniform_search_space"].update({"min": 1, "max": 10, "step": 3}) @staticmethod def add_loguniform_search_space(search_space) -> None: - search_space["loguniform_search_space"] = {"param_type": "loguniform"} + search_space["loguniform_search_space"] = {"type": "loguniform"} search_space["loguniform_search_space"].update({"min": 1, "max": 10, "log_base": 2}) @staticmethod def add_qloguniform_search_space(search_space) -> None: - search_space["qloguniform_search_space"] = {"param_type": "qloguniform"} + search_space["qloguniform_search_space"] = {"type": "qloguniform"} search_space["qloguniform_search_space"].update({"min": 1, "max": 10, "step": 3, "log_base": 2}) @staticmethod def add_choice_search_space(search_space) -> None: search_space["choice_search_space"] = { - "param_type": "choice", + "type": "choice", "choice_list": ["somevalue1", "somevalue2", "somevalue3"], } diff --git a/tests/unit/utils/test_signal.py b/tests/unit/utils/test_signal.py new file mode 100644 index 00000000000..918d66d4497 --- /dev/null +++ b/tests/unit/utils/test_signal.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import signal +from contextlib import contextmanager +from copy import copy + +from otx.utils import signal as target_file +from otx.utils.signal import append_main_proc_signal_handler, append_signal_handler + + +@contextmanager +def register_signal_temporally(sig_num: signal.Signals): + old_sig_handler = signal.getsignal(sig_num) + ori_handler_arr = copy(target_file._SIGNAL_HANDLERS) + yield + signal.signal(sig_num, old_sig_handler) + target_file._SIGNAL_HANDLERS = ori_handler_arr + + +def test_append_signal_handler(mocker): + with register_signal_temporally(signal.SIGTERM): + # prepare + mocker.patch("signal.raise_signal") + spy_signal = mocker.spy(target_file.signal, "signal") + sig_hand_1 = mocker.MagicMock() + sig_hand_2 = mocker.MagicMock() + + # run + append_signal_handler(signal.SIGTERM, sig_hand_1) + append_signal_handler(signal.SIGTERM, sig_hand_2) + + old_sig_handler = signal.getsignal(signal.SIGTERM) + old_sig_handler(signal.SIGTERM, mocker.MagicMock()) + + # check + sig_hand_1.assert_called_once() + sig_hand_2.assert_called_once() + assert spy_signal.call_args == ((signal.SIGTERM, signal.SIG_DFL),) + + +def test_append_main_proc_signal_handler(mocker): + with register_signal_temporally(signal.SIGTERM): + # prepare + mocker.patch("os.getpid", return_value=1) + mocker.patch("signal.raise_signal") + spy_signal = mocker.spy(target_file.signal, "signal") + sig_hand_1 = mocker.MagicMock() + sig_hand_2 = mocker.MagicMock() + + # run + append_main_proc_signal_handler(signal.SIGTERM, sig_hand_1) + append_main_proc_signal_handler(signal.SIGTERM, sig_hand_2) + + mocker.patch("os.getpid", return_value=2) + old_sig_handler = signal.getsignal(signal.SIGTERM) + old_sig_handler(signal.SIGTERM, mocker.MagicMock()) + + # check + sig_hand_1.assert_not_called() + sig_hand_2.assert_not_called() + assert spy_signal.call_args == ((signal.SIGTERM, signal.SIG_DFL),) diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 552a4487971..10a6a939257 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -1,24 +1,101 @@ -import signal +from __future__ import annotations -from otx.utils import utils as target_file -from otx.utils.utils import append_signal_handler +from pathlib import Path +import pytest +from otx.utils.utils import ( + find_file_recursively, + get_decimal_point, + get_using_dot_delimited_key, + remove_matched_files, + set_using_dot_delimited_key, +) -def test_append_signal_handler(mocker): - # prepare - mocker.patch("signal.raise_signal") - spy_signal = mocker.spy(target_file.signal, "signal") - sig_hand_1 = mocker.MagicMock() - sig_hand_2 = mocker.MagicMock() - # run - append_signal_handler(signal.SIGTERM, sig_hand_1) - append_signal_handler(signal.SIGTERM, sig_hand_2) +@pytest.fixture() +def fake_obj(mocker): + target = mocker.MagicMock() + target.a.b.c = {"d": mocker.MagicMock()} + target.a.b.c["d"].e = [0, 1, 2] + return target - old_sig_handler = signal.getsignal(signal.SIGTERM) - old_sig_handler() - # check - sig_hand_1.assert_called_once() - sig_hand_2.assert_called_once() - assert spy_signal.call_args == ((signal.SIGTERM, signal.SIG_DFL),) +def test_get_using_dot_delimited_key(fake_obj): + assert get_using_dot_delimited_key("a.b.c.d.e.2", fake_obj) == 2 + + +def test_set_using_dot_delimited_key(fake_obj): + expected_val = 2 + set_using_dot_delimited_key("a.b.c.d.e.0", expected_val, fake_obj) + assert fake_obj.a.b.c["d"].e[0] == expected_val + + +@pytest.mark.parametrize(("val", "decimal_point"), [(0.001, 3), (-0.0001, 4), (1, 0), (100, 0), (-2, 0)]) +def test_get_decimal_point(val, decimal_point): + assert get_decimal_point(val) == decimal_point + + +def test_find_file_recursively(tmp_path): + file_name = "some_file.txt" + target = tmp_path / "foo" / "bar" / file_name + target.parent.mkdir(parents=True) + target.touch() + + assert find_file_recursively(tmp_path, file_name) == target + + +def test_find_file_recursively_multiple_files_exist(tmp_path): + file_name = "some_file.txt" + + target1 = tmp_path / "foo" / file_name + target1.parent.mkdir(parents=True) + target1.touch() + + target2 = tmp_path / "foo" / "bar" / file_name + target2.parent.mkdir(parents=True) + target2.touch() + + assert find_file_recursively(tmp_path, file_name) in [target1, target2] + + +def test_find_file_recursively_not_exist(tmp_path): + file_name = "some_file.txt" + assert find_file_recursively(tmp_path, file_name) is None + + +def make_dir_and_file(dir_path: Path, file_path: str | Path) -> Path: + file = dir_path / file_path + file.parent.mkdir(parents=True, exist_ok=True) + file.touch() + + return file + + +@pytest.fixture() +def temporary_dir_w_some_txt(tmp_path): + some_txt = ["a/b/c/d.txt", "1/2/3/4.txt", "e.txt", "f/g.txt", "5/6/7.txt"] + for file_path in some_txt: + make_dir_and_file(tmp_path, file_path) + return tmp_path + + +def test_remove_matched_files(temporary_dir_w_some_txt): + file_path_to_leave = "foo/bar/file_to_leave.txt" + file_to_leave = make_dir_and_file(temporary_dir_w_some_txt, file_path_to_leave) + + remove_matched_files(temporary_dir_w_some_txt, "*.txt", file_to_leave) + + assert file_to_leave.exists() + assert len(list(temporary_dir_w_some_txt.rglob("*.txt"))) == 1 + + +def test_remove_matched_files_remove_all(temporary_dir_w_some_txt): + remove_matched_files(temporary_dir_w_some_txt, "*.txt") + + assert len(list(temporary_dir_w_some_txt.rglob("*.txt"))) == 0 + + +def test_remove_matched_files_no_file_to_remove(temporary_dir_w_some_txt): + remove_matched_files(temporary_dir_w_some_txt, "*.log") + + assert len(list(temporary_dir_w_some_txt.rglob("*.txt"))) == 5 From d9f7e15f7a88fe14f2cfad2d45132ee86685698a Mon Sep 17 00:00:00 2001 From: Sungman Cho Date: Fri, 16 Feb 2024 13:05:11 +0900 Subject: [PATCH 6/6] Move Linearwarmup scheduler from base to the algo (#2924) --- src/otx/algo/schedulers/__init__.py | 4 +- src/otx/algo/schedulers/warmup_schedulers.py | 130 ++---------------- src/otx/core/model/module/base.py | 17 --- .../action/action_classification/x3d.yaml | 2 +- .../action/action_detection/x3d_fastrcnn.yaml | 2 +- .../h_label_cls/mobilenet_v3_large_light.yaml | 2 +- .../h_label_cls/otx_deit_tiny.yaml | 2 +- .../mobilenet_v3_large_light.yaml | 2 +- .../multi_class_cls/otx_deit_tiny.yaml | 2 +- .../otx_mobilenet_v3_large.yaml | 2 +- .../mobilenet_v3_large_light.yaml | 2 +- .../multi_label_cls/otx_deit_tiny.yaml | 2 +- .../recipe/detection/atss_mobilenetv2.yaml | 2 +- src/otx/recipe/detection/atss_r50_fpn.yaml | 2 +- src/otx/recipe/detection/atss_resnext101.yaml | 2 +- src/otx/recipe/detection/ssd_mobilenetv2.yaml | 2 +- src/otx/recipe/detection/yolox_l.yaml | 2 +- src/otx/recipe/detection/yolox_l_tile.yaml | 2 +- src/otx/recipe/detection/yolox_s.yaml | 2 +- src/otx/recipe/detection/yolox_s_tile.yaml | 2 +- src/otx/recipe/detection/yolox_tiny.yaml | 2 +- src/otx/recipe/detection/yolox_tiny_tile.yaml | 2 +- src/otx/recipe/detection/yolox_x.yaml | 2 +- src/otx/recipe/detection/yolox_x_tile.yaml | 2 +- .../maskrcnn_efficientnetb2b.yaml | 2 +- .../maskrcnn_efficientnetb2b_tile.yaml | 2 +- .../instance_segmentation/maskrcnn_r50.yaml | 2 +- .../maskrcnn_r50_tile.yaml | 2 +- .../instance_segmentation/maskrcnn_swint.yaml | 2 +- .../maskrcnn_efficientnetb2b.yaml | 2 +- .../rotated_detection/maskrcnn_r50.yaml | 2 +- .../semantic_segmentation/litehrnet_18.yaml | 2 +- .../semantic_segmentation/litehrnet_s.yaml | 2 +- .../semantic_segmentation/litehrnet_x.yaml | 2 +- .../semantic_segmentation/segnext_b.yaml | 2 +- .../semantic_segmentation/segnext_s.yaml | 2 +- .../semantic_segmentation/segnext_t.yaml | 2 +- tests/unit/core/model/module/test_base.py | 3 +- 38 files changed, 50 insertions(+), 172 deletions(-) diff --git a/src/otx/algo/schedulers/__init__.py b/src/otx/algo/schedulers/__init__.py index 9f1e8b57d0e..9ff8f508750 100644 --- a/src/otx/algo/schedulers/__init__.py +++ b/src/otx/algo/schedulers/__init__.py @@ -3,6 +3,6 @@ # """Custom schedulers for the OTX2.0.""" -from .warmup_schedulers import WarmupReduceLROnPlateau +from .warmup_schedulers import LinearWarmupScheduler -__all__ = ["WarmupReduceLROnPlateau"] +__all__ = ["LinearWarmupScheduler"] diff --git a/src/otx/algo/schedulers/warmup_schedulers.py b/src/otx/algo/schedulers/warmup_schedulers.py index b1244250621..ff72a2d44bc 100644 --- a/src/otx/algo/schedulers/warmup_schedulers.py +++ b/src/otx/algo/schedulers/warmup_schedulers.py @@ -4,127 +4,21 @@ """Warm-up schedulers for the OTX2.0.""" from __future__ import annotations -from typing import TYPE_CHECKING +import torch -from lightning.pytorch.cli import ReduceLROnPlateau -from torch.optim.lr_scheduler import PolynomialLR -if TYPE_CHECKING: - from torch.optim import Optimizer - - -class BaseWarmupScheduler: - """Base Warumup Scheduler class. - - It should be inherited if want to implement warmup based Custom LRScheduler. - i.e. WarmupCustomLRScheduler(BaseWarmupScheduler, ...), WarmupCosineAnnealingLR(BaseWarmupScheduler, ...) - - Args: - warmup_steps (int): The total number of the warmup steps. it could be epoch or iter. - warmup_by_epoch (bool): If True, warmup_steps represent the epoch. - - """ - - warmup_steps: int - - -class WarmupReduceLROnPlateau(BaseWarmupScheduler, ReduceLROnPlateau): - """ReduceLROnPlateau for enabling the warmup. - - Args: - optimizer (Optimizer): Wrapped optimizer. - warmup_steps (int): The total number of the warmup steps. it could be epoch or iter. - monitor (str): The name of monitoring value. - mode (str): One of `min`, `max`. In `min` mode, lr will - be reduced when the quantity monitored has stopped - decreasing; in `max` mode it will be reduced when the - quantity monitored has stopped increasing. Default: 'min'. - factor (float): Factor by which the learning rate will be - reduced. new_lr = lr * factor. Default: 0.1. - patience (int): Number of epochs with no improvement after - which learning rate will be reduced. For example, if - `patience = 2`, then we will ignore the first 2 epochs - with no improvement, and will only decrease the LR after the - 3rd epoch if the loss still hasn't improved then. - Default: 10. - threshold (float): Threshold for measuring the new optimum, - to only focus on significant changes. Default: 1e-4. - threshold_mode (str): One of `rel`, `abs`. In `rel` mode, - dynamic_threshold = best * ( 1 + threshold ) in 'max' - mode or best * ( 1 - threshold ) in `min` mode. - In `abs` mode, dynamic_threshold = best + threshold in - `max` mode or best - threshold in `min` mode. Default: 'rel'. - cooldown (int): Number of epochs to wait before resuming - normal operation after lr has been reduced. Default: 0. - min_lr (float or list): A scalar or a list of scalars. A - lower bound on the learning rate of all param groups - or each group respectively. Default: 0. - eps (float): Minimal decay applied to lr. If the difference - between new and old lr is smaller than eps, the update is - ignored. Default: 1e-8. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - """ - - def __init__( - self, - optimizer: Optimizer, - warmup_steps: int, - monitor: str, - mode: str = "min", - factor: float = 0.1, - patience: int = 10, - threshold: float = 1e-4, - threshold_mode: str = "rel", - cooldown: int = 0, - min_lr: float = 0, - eps: float = 1e-8, - verbose: bool = False, - ): - self.warmup_steps = warmup_steps - super().__init__( - optimizer, - monitor, - mode, - factor, - patience, - threshold, - threshold_mode, - cooldown, - min_lr, - eps, - verbose, - ) - - -class WarmupPolynomialLR(BaseWarmupScheduler, PolynomialLR): - """PolynomialLR for enabling the warmup. - - Args: - optimizer (Optimizer): Wrapped optimizer. - warmup_steps (int): The total number of the warmup steps. it could be epoch or iter. - total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. - power (int): The power of the polynomial. Default: 1.0. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - - """ +class LinearWarmupScheduler(torch.optim.lr_scheduler.LambdaLR): + """Linear Warmup scheduler.""" def __init__( self, - optimizer: Optimizer, - warmup_steps: int, - total_iters: int = 5, - power: float = 1.0, - last_epoch: int = -1, - verbose: bool = False, + optimizer: torch.optim.Optimizer, + num_warmup_steps: int = 1000, + interval: str = "step", ): - self.warmup_steps = warmup_steps - super().__init__( - optimizer, - total_iters, - power, - last_epoch, - verbose, - ) + if not num_warmup_steps > 0: + msg = f"num_warmup_steps should be > 0, got {num_warmup_steps}" + raise ValueError(msg) + self.num_warmup_steps = num_warmup_steps + self.interval = interval + super().__init__(optimizer, lambda step: min(step / num_warmup_steps, 1.0)) diff --git a/src/otx/core/model/module/base.py b/src/otx/core/model/module/base.py index 20edfc02795..7077a2a4e9d 100644 --- a/src/otx/core/model/module/base.py +++ b/src/otx/core/model/module/base.py @@ -26,23 +26,6 @@ from otx.core.data.dataset.base import LabelInfo -class LinearWarmupScheduler(torch.optim.lr_scheduler.LambdaLR): - """Linear Warmup scheduler.""" - - def __init__( - self, - optimizer: torch.optim.Optimizer, - num_warmup_steps: int = 1000, - interval: str = "step", - ): - if not num_warmup_steps > 0: - msg = f"num_warmup_steps should be > 0, got {num_warmup_steps}" - raise ValueError(msg) - self.num_warmup_steps = num_warmup_steps - self.interval = interval - super().__init__(optimizer, lambda step: min(step / num_warmup_steps, 1.0)) - - class OTXLitModule(LightningModule): """Base class for the lightning module used in OTX.""" diff --git a/src/otx/recipe/action/action_classification/x3d.yaml b/src/otx/recipe/action/action_classification/x3d.yaml index e43d0b2c1e9..c3063b50c59 100644 --- a/src/otx/recipe/action/action_classification/x3d.yaml +++ b/src/otx/recipe/action/action_classification/x3d.yaml @@ -10,7 +10,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/action/action_detection/x3d_fastrcnn.yaml b/src/otx/recipe/action/action_detection/x3d_fastrcnn.yaml index 36c51f4247e..3d153625491 100644 --- a/src/otx/recipe/action/action_detection/x3d_fastrcnn.yaml +++ b/src/otx/recipe/action/action_detection/x3d_fastrcnn.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.00001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/classification/h_label_cls/mobilenet_v3_large_light.yaml b/src/otx/recipe/classification/h_label_cls/mobilenet_v3_large_light.yaml index 12f731da739..5d46245bc1c 100644 --- a/src/otx/recipe/classification/h_label_cls/mobilenet_v3_large_light.yaml +++ b/src/otx/recipe/classification/h_label_cls/mobilenet_v3_large_light.yaml @@ -13,7 +13,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 10 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/classification/h_label_cls/otx_deit_tiny.yaml b/src/otx/recipe/classification/h_label_cls/otx_deit_tiny.yaml index a6d2e62b6a3..85ba8a71105 100644 --- a/src/otx/recipe/classification/h_label_cls/otx_deit_tiny.yaml +++ b/src/otx/recipe/classification/h_label_cls/otx_deit_tiny.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.05 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 10 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/classification/multi_class_cls/mobilenet_v3_large_light.yaml b/src/otx/recipe/classification/multi_class_cls/mobilenet_v3_large_light.yaml index 080cc830be7..ccf36885d67 100644 --- a/src/otx/recipe/classification/multi_class_cls/mobilenet_v3_large_light.yaml +++ b/src/otx/recipe/classification/multi_class_cls/mobilenet_v3_large_light.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 10 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/classification/multi_class_cls/otx_deit_tiny.yaml b/src/otx/recipe/classification/multi_class_cls/otx_deit_tiny.yaml index da0c5522854..525a3eaf795 100644 --- a/src/otx/recipe/classification/multi_class_cls/otx_deit_tiny.yaml +++ b/src/otx/recipe/classification/multi_class_cls/otx_deit_tiny.yaml @@ -10,7 +10,7 @@ optimizer: weight_decay: 0.05 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 10 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/classification/multi_class_cls/otx_mobilenet_v3_large.yaml b/src/otx/recipe/classification/multi_class_cls/otx_mobilenet_v3_large.yaml index 7058f87da0e..ed7c1a805fd 100644 --- a/src/otx/recipe/classification/multi_class_cls/otx_mobilenet_v3_large.yaml +++ b/src/otx/recipe/classification/multi_class_cls/otx_mobilenet_v3_large.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 10 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/classification/multi_label_cls/mobilenet_v3_large_light.yaml b/src/otx/recipe/classification/multi_label_cls/mobilenet_v3_large_light.yaml index 5f9f82ae0f8..a86c5e50dfe 100644 --- a/src/otx/recipe/classification/multi_label_cls/mobilenet_v3_large_light.yaml +++ b/src/otx/recipe/classification/multi_label_cls/mobilenet_v3_large_light.yaml @@ -11,7 +11,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 10 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/classification/multi_label_cls/otx_deit_tiny.yaml b/src/otx/recipe/classification/multi_label_cls/otx_deit_tiny.yaml index 9e66ece0fbf..37b3b652c89 100644 --- a/src/otx/recipe/classification/multi_label_cls/otx_deit_tiny.yaml +++ b/src/otx/recipe/classification/multi_label_cls/otx_deit_tiny.yaml @@ -10,7 +10,7 @@ optimizer: weight_decay: 0.05 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 10 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/atss_mobilenetv2.yaml b/src/otx/recipe/detection/atss_mobilenetv2.yaml index 69d1cd52c7d..0589627669e 100644 --- a/src/otx/recipe/detection/atss_mobilenetv2.yaml +++ b/src/otx/recipe/detection/atss_mobilenetv2.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/atss_r50_fpn.yaml b/src/otx/recipe/detection/atss_r50_fpn.yaml index 43a20e292f4..604f6abbcbe 100644 --- a/src/otx/recipe/detection/atss_r50_fpn.yaml +++ b/src/otx/recipe/detection/atss_r50_fpn.yaml @@ -10,7 +10,7 @@ optimizer: weight_decay: 0.0 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/atss_resnext101.yaml b/src/otx/recipe/detection/atss_resnext101.yaml index bb0a7b939f9..d4680fa4a54 100644 --- a/src/otx/recipe/detection/atss_resnext101.yaml +++ b/src/otx/recipe/detection/atss_resnext101.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/ssd_mobilenetv2.yaml b/src/otx/recipe/detection/ssd_mobilenetv2.yaml index 09b10bc4eea..fbf7bada443 100644 --- a/src/otx/recipe/detection/ssd_mobilenetv2.yaml +++ b/src/otx/recipe/detection/ssd_mobilenetv2.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/yolox_l.yaml b/src/otx/recipe/detection/yolox_l.yaml index 690f7bfd4f4..e545a04bdb1 100644 --- a/src/otx/recipe/detection/yolox_l.yaml +++ b/src/otx/recipe/detection/yolox_l.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/yolox_l_tile.yaml b/src/otx/recipe/detection/yolox_l_tile.yaml index 3bc2727e805..73561ae19d5 100644 --- a/src/otx/recipe/detection/yolox_l_tile.yaml +++ b/src/otx/recipe/detection/yolox_l_tile.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/yolox_s.yaml b/src/otx/recipe/detection/yolox_s.yaml index 0bf3a268446..d1004f2df44 100644 --- a/src/otx/recipe/detection/yolox_s.yaml +++ b/src/otx/recipe/detection/yolox_s.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/yolox_s_tile.yaml b/src/otx/recipe/detection/yolox_s_tile.yaml index f9806972601..36fd609cf89 100644 --- a/src/otx/recipe/detection/yolox_s_tile.yaml +++ b/src/otx/recipe/detection/yolox_s_tile.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/yolox_tiny.yaml b/src/otx/recipe/detection/yolox_tiny.yaml index 9997a0022e8..7913d070da9 100644 --- a/src/otx/recipe/detection/yolox_tiny.yaml +++ b/src/otx/recipe/detection/yolox_tiny.yaml @@ -11,7 +11,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/yolox_tiny_tile.yaml b/src/otx/recipe/detection/yolox_tiny_tile.yaml index 2e305cbb9c3..5124eb0aa9b 100644 --- a/src/otx/recipe/detection/yolox_tiny_tile.yaml +++ b/src/otx/recipe/detection/yolox_tiny_tile.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/yolox_x.yaml b/src/otx/recipe/detection/yolox_x.yaml index 68162a0164d..c8dccc74853 100644 --- a/src/otx/recipe/detection/yolox_x.yaml +++ b/src/otx/recipe/detection/yolox_x.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/detection/yolox_x_tile.yaml b/src/otx/recipe/detection/yolox_x_tile.yaml index 5b23cc530fe..cc5d75d1590 100644 --- a/src/otx/recipe/detection/yolox_x_tile.yaml +++ b/src/otx/recipe/detection/yolox_x_tile.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.0001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 3 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b.yaml index bdba41b089f..2cc304dfa6f 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml index b495600617c..cdfe2047164 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_r50.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_r50.yaml index 2ea4a57884f..a46d4cd80df 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_r50.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_r50.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml index 3b9c6e5e9b0..827b509fb6d 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_swint.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_swint.yaml index 8806f46b905..a495be1227f 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_swint.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_swint.yaml @@ -10,7 +10,7 @@ optimizer: weight_decay: 0.05 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml b/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml index b5f5a9a8cfe..429017549d5 100644 --- a/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml +++ b/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml b/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml index 415e11b66fa..53af59fb3fb 100644 --- a/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml +++ b/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml @@ -12,7 +12,7 @@ optimizer: weight_decay: 0.001 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml b/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml index 99995b831da..8a59e129f30 100644 --- a/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml +++ b/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml @@ -14,7 +14,7 @@ optimizer: weight_decay: 0.0 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml b/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml index fdf93f5734c..66da17fe4a2 100644 --- a/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml +++ b/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml @@ -14,7 +14,7 @@ optimizer: weight_decay: 0.0 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml b/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml index 3df08bb4eb3..266f06d1925 100644 --- a/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml +++ b/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml @@ -14,7 +14,7 @@ optimizer: weight_decay: 0.0 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 100 - class_path: lightning.pytorch.cli.ReduceLROnPlateau diff --git a/src/otx/recipe/semantic_segmentation/segnext_b.yaml b/src/otx/recipe/semantic_segmentation/segnext_b.yaml index f2330302006..5b6b58ea0cd 100644 --- a/src/otx/recipe/semantic_segmentation/segnext_b.yaml +++ b/src/otx/recipe/semantic_segmentation/segnext_b.yaml @@ -14,7 +14,7 @@ optimizer: weight_decay: 0.01 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 20 - class_path: torch.optim.lr_scheduler.PolynomialLR diff --git a/src/otx/recipe/semantic_segmentation/segnext_s.yaml b/src/otx/recipe/semantic_segmentation/segnext_s.yaml index 7f814f34119..e93d1a9cdc4 100644 --- a/src/otx/recipe/semantic_segmentation/segnext_s.yaml +++ b/src/otx/recipe/semantic_segmentation/segnext_s.yaml @@ -14,7 +14,7 @@ optimizer: weight_decay: 0.01 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 20 - class_path: torch.optim.lr_scheduler.PolynomialLR diff --git a/src/otx/recipe/semantic_segmentation/segnext_t.yaml b/src/otx/recipe/semantic_segmentation/segnext_t.yaml index 3de98141813..7c7b18c64d9 100644 --- a/src/otx/recipe/semantic_segmentation/segnext_t.yaml +++ b/src/otx/recipe/semantic_segmentation/segnext_t.yaml @@ -14,7 +14,7 @@ optimizer: weight_decay: 0.01 scheduler: - - class_path: otx.core.model.module.base.LinearWarmupScheduler + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler init_args: num_warmup_steps: 20 - class_path: torch.optim.lr_scheduler.PolynomialLR diff --git a/tests/unit/core/model/module/test_base.py b/tests/unit/core/model/module/test_base.py index bf9961da165..f4262b77156 100644 --- a/tests/unit/core/model/module/test_base.py +++ b/tests/unit/core/model/module/test_base.py @@ -10,8 +10,9 @@ import pytest from lightning.pytorch.cli import ReduceLROnPlateau from lightning.pytorch.trainer import Trainer +from otx.algo.schedulers.warmup_schedulers import LinearWarmupScheduler from otx.core.model.entity.base import OTXModel -from otx.core.model.module.base import LinearWarmupScheduler, OTXLitModule +from otx.core.model.module.base import OTXLitModule from torch.optim import Optimizer