diff --git a/.github/workflows/code_scan.yaml b/.github/workflows/code_scan.yaml index 1473bc09aa5..7f78a37603a 100644 --- a/.github/workflows/code_scan.yaml +++ b/.github/workflows/code_scan.yaml @@ -20,7 +20,7 @@ jobs: - name: Checkout code uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Set up Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install dependencies @@ -45,7 +45,7 @@ jobs: - name: Checkout repository uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Set up Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install dependencies diff --git a/.github/workflows/daily.yaml b/.github/workflows/daily.yaml index e645af781e0..6f57b0dcf00 100644 --- a/.github/workflows/daily.yaml +++ b/.github/workflows/daily.yaml @@ -32,7 +32,7 @@ jobs: - name: Checkout repository uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Install Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install tox diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 80e38e9faea..11e518a0c7a 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -18,7 +18,7 @@ jobs: - name: Checkout repository uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Set up Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install tox diff --git a/.github/workflows/docs_stable.yaml b/.github/workflows/docs_stable.yaml index be60899b206..387e37c6fb0 100644 --- a/.github/workflows/docs_stable.yaml +++ b/.github/workflows/docs_stable.yaml @@ -18,7 +18,7 @@ jobs: with: fetch-depth: 0 # otherwise, you will failed to push refs to dest repo - name: Set up Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install tox diff --git a/.github/workflows/perf_benchmark.yaml b/.github/workflows/perf_benchmark.yaml index d7d13bb7e68..da90217fe72 100644 --- a/.github/workflows/perf_benchmark.yaml +++ b/.github/workflows/perf_benchmark.yaml @@ -121,7 +121,7 @@ jobs: - name: Checkout repository uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Install Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install tox diff --git a/.github/workflows/pre_merge.yaml b/.github/workflows/pre_merge.yaml index baa477ff0ca..6bd5b633223 100644 --- a/.github/workflows/pre_merge.yaml +++ b/.github/workflows/pre_merge.yaml @@ -26,7 +26,7 @@ jobs: - name: Checkout repository uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Set up Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install tox @@ -56,7 +56,7 @@ jobs: - name: Checkout repository uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Install Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: ${{ matrix.python-version }} - name: Install tox @@ -105,7 +105,7 @@ jobs: - name: Checkout repository uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Install Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install tox diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index cfa0e832522..eb98a70d979 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -29,7 +29,7 @@ jobs: - name: Checkout uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Set up Python 3.10 - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install pypa/build diff --git a/.github/workflows/publish_internal.yaml b/.github/workflows/publish_internal.yaml index f487b175daf..d273565c0a2 100644 --- a/.github/workflows/publish_internal.yaml +++ b/.github/workflows/publish_internal.yaml @@ -27,7 +27,7 @@ jobs: - name: Checkout uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Set up Python 3.10 - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install pypa/build @@ -50,7 +50,7 @@ jobs: - name: Checkout uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Set up Python - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: python-version: "3.10" - name: Install dependencies diff --git a/CHANGELOG.md b/CHANGELOG.md index 044503b8ef0..526eb49b651 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,13 +6,137 @@ All notable changes to this project will be documented in this file. ### New features -- Add zero-shot visual prompting (, , ) -- Add support for the training and validation on the XPU devices (https://github.com/openvinotoolkit/training_extensions/pull/3058) +### Enhancements + +## \[1.6.0\] + +### New features + +- Changed supported Python version range (>=3.9, <=3.11) + () +- Support MMDetection COCO format + () +- Develop JsonSectionPageMapper in Rust API + () +- Add Filtering via User-Provided Python Functions + (, ) +- Remove supporting MacOS platform + () +- Support Kaggle image data (`KaggleImageCsvBase`, `KaggleImageTxtBase`, `KaggleImageMaskBase`, `KaggleVocBase`, `KaggleYoloBase`) + () +- Add `__getitem__()` for random accessing with O(1) time complexity + () +- Add Data-aware Anchor Generator + () +- Support bounding box import within Kaggle extractors and add `KaggleCocoBase` + () + +### Enhancements + +- Optimize Python import to make CLI entrypoint faster + () +- Add ImageColorScale context manager + () +- Enhance visualizer to toggle plot title visibility + () +- Enhance Datumaro data format detect() to be memory-bounded and performant + () +- Change RoIImage and MosaicImage to have np.uint8 dtype as default + () +- Enable image backend and color channel format to be selectable + () +- Boost up `CityscapesBase` and `KaggleImageMaskBase` by dropping `np.unique` + () +- Enhance RISE algortihm for explainable AI + () +- Enhance explore unit test to use real dataset from ImageNet + () +- Fix each method of the comparator to be used separately + () +- Bump ONNX version to 1.16.0 + () +- Print the color channel format (RGB) for datum stats command + () +- Add ignore_index argument to Mask.as_class_mask() and Mask.as_instance_mask() + () + +### Bug fixes + +- Fix wrong example of Datumaro dataset creation in document + () +- Fix wrong command to install datumaro from github + (, ) +- Update document to correct wrong `datum project import` command and add filtering example to filter out items containing annotations. + () +- Fix label compare of distance method + () +- Fix Datumaro visualizer's import errors after introducing lazy import + () +- Fix broken link to supported formats in readme + () +- Fix Kinetics data format to have media data + () +- Handling undefined labels at the annotation statistics + () +- Add unit test for item rename + () +- Fix a bug in the previous behavior when importing nested datasets in the project + () +- Fix Kaggle importer when adding duplicated labels + () +- Fix input tensor shape in model interpreter for OpenVINO 2023.3 + () +- Add default value for target in prune cli + () +- Remove deprecated MediaManager + () +- Fix explore command without project + () +- Fix enable COCO to import only bboxes + () +- Fix resize transform for RleMask annotation +- () +- Fix import YOLO variants from extractor when `urls` is not specified + () + +## \[1.5.2\] ### Enhancements -- Upgrade OpenVINO to 2023.3 () -- Automate performance benchmark () +- Add memory bounded datumaro data format detect to release 1.5.1 + () +- Bump version string to 1.5.2 + () +- Remove Protobuf version limitation (<4) + () + +## \[1.5.1\] + +### Enhancements + +- Enhance Datumaro data format stream importer performance + () +- Change image default dtype from float32 to uint8 + () +- Add comparison level-up doc + () +- Add ImportError to catch GitPython import error + () + +### Bug fixes + +- Modify the draw function in the visualizer not to raise an error for unsupported annotation types. + () +- Correct explore path in the related document. + () +- Fix errata in the voc document. Color values in the labelmap.txt should be separated by commas, not colons. + () +- Fix hyperlink errors in the document + (, ) +- Fix memory unbounded Arrow data format export/import + () +- Update CVAT format doc to bypass warning + () ## \[v1.5.0\] diff --git a/docs/source/guide/release_notes/index.rst b/docs/source/guide/release_notes/index.rst index a1653700ac9..9df0d147087 100644 --- a/docs/source/guide/release_notes/index.rst +++ b/docs/source/guide/release_notes/index.rst @@ -8,6 +8,79 @@ Releases v2.0.0 (1Q24) ------------- +v1.6.0 (2024.04) +---------------- + +New features +^^^^^^^^^^^^ +- Changed supported Python version range (>=3.9, <=3.11) +- Support MMDetection COCO format +- Develop JsonSectionPageMapper in Rust API +- Add Filtering via User-Provided Python Functions +- Remove supporting MacOS platform +- Support Kaggle image data (`KaggleImageCsvBase`, `KaggleImageTxtBase`, `KaggleImageMaskBase`, `KaggleVocBase`, `KaggleYoloBase`) +- Add `__getitem__()` for random accessing with O(1) time complexity +- Add Data-aware Anchor Generator +- Support bounding box import within Kaggle extractors and add `KaggleCocoBase` + +Enhancements +^^^^^^^^^^^^ +- Optimize Python import to make CLI entrypoint faster +- Add ImageColorScale context manager +- Enhance visualizer to toggle plot title visibility +- Enhance Datumaro data format detect() to be memory-bounded and performant +- Change RoIImage and MosaicImage to have np.uint8 dtype as default +- Enable image backend and color channel format to be selectable +- Boost up `CityscapesBase` and `KaggleImageMaskBase` by dropping `np.unique` +- Enhance RISE algortihm for explainable AI +- Enhance explore unit test to use real dataset from ImageNet +- Fix each method of the comparator to be used separately + +Bug fixes +^^^^^^^^^ +- Fix wrong example of Datumaro dataset creation in document +- Fix wrong command to install datumaro from github +- Update document to correct wrong `datum project import` command and add filtering example to filter out items containing annotations. +- Fix label compare of distance method +- Fix Datumaro visualizer's import errors after introducing lazy import +- Fix broken link to supported formats in readme +- Fix Kinetics data format to have media data +- Handling undefined labels at the annotation statistics +- Add unit test for item rename +- Fix a bug in the previous behavior when importing nested datasets in the project +- Fix Kaggle importer when adding duplicated labels +- Fix input tensor shape in model interpreter for OpenVINO 2023.3 +- Add default value for target in prune cli +- Remove deprecated MediaManager +- Fix explore command without project + +v1.5.2 (2024.01) +---------------- + +Enhancements +^^^^^^^^^^^^ + +- Add memory bounded datumaro data format detect +- Remove Protobuf version limitation (<4) + +v1.5.1 (2023.11) +---------------- + +Enhancements +^^^^^^^^^^^^ +- Enhance Datumaro data format stream importer performance +- Change image default dtype from float32 to uint8 +- Add comparison level-up doc +- Add ImportError to catch GitPython import error + +Bug fixes +^^^^^^^^^ +- Modify the draw function in the visualizer not to raise an error for unsupported annotation types. +- Correct explore path in the related document. +- Fix errata in the voc document. Color values in the labelmap.txt should be separated by commas, not colons. +- Fix hyperlink errors in the document. +- Fix memory unbounded Arrow data format export/import. +- Update CVAT format doc to bypass warning. v1.5.0 (4Q23) ------------- diff --git a/for_developers/regression_test/requirements.txt b/for_developers/regression_test/requirements.txt index 382c01b6338..a13e11ab567 100644 --- a/for_developers/regression_test/requirements.txt +++ b/for_developers/regression_test/requirements.txt @@ -314,9 +314,9 @@ gunicorn==21.2.0 \ --hash=sha256:3213aa5e8c24949e792bcacfc176fef362e7aac80b76c56f6b5122bf350722f0 \ --hash=sha256:88ec8bff1d634f98e61b9f65bc4bf3cd918a90806c6f5c48bc5603849ec81033 # via mlflow -idna==3.6 \ - --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ - --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 # via requests importlib-metadata==6.11.0 \ --hash=sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443 \ @@ -1023,9 +1023,9 @@ sqlalchemy==2.0.28 \ # via # alembic # mlflow -sqlparse==0.4.4 \ - --hash=sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3 \ - --hash=sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c +sqlparse==0.5.0 \ + --hash=sha256:714d0a4932c059d16189f58ef5411ec2287a4360f17cdd0edd2d09d4c5087c93 \ + --hash=sha256:c204494cd97479d0e39f28c93d46c0b2d5959c7b9ab904762ea6c7af211c8663 # via mlflow threadpoolctl==3.3.0 \ --hash=sha256:5dac632b4fa2d43f42130267929af3ba01399ef4bd1882918e92dbc30365d30c \ diff --git a/pyproject.toml b/pyproject.toml index a0141001072..45d428f5e38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -288,6 +288,9 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" # minimum target version target-version = "py38" +# Enumerate all fixed violations. +show-fixes = true + [tool.ruff.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 20 diff --git a/src/otx/__init__.py b/src/otx/__init__.py index 1a8804b3693..d821d0b8458 100644 --- a/src/otx/__init__.py +++ b/src/otx/__init__.py @@ -3,7 +3,7 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -__version__ = "2.0.0" +__version__ = "2.1.0rc0" from otx.core.types import * # noqa: F403 diff --git a/src/otx/algo/__init__.py b/src/otx/algo/__init__.py index 968d579e5f7..29312f92f25 100644 --- a/src/otx/algo/__init__.py +++ b/src/otx/algo/__init__.py @@ -3,6 +3,24 @@ # """Module for OTX custom algorithms, e.g., model, losses, hook, etc...""" -from . import action_classification, classification, detection, segmentation, visual_prompting +from . import ( + accelerators, + action_classification, + classification, + detection, + plugins, + segmentation, + strategies, + visual_prompting, +) -__all__ = ["action_classification", "classification", "detection", "segmentation", "visual_prompting"] +__all__ = [ + "action_classification", + "classification", + "detection", + "segmentation", + "visual_prompting", + "strategies", + "accelerators", + "plugins", +] diff --git a/src/otx/algo/accelerators/__init__.py b/src/otx/algo/accelerators/__init__.py new file mode 100644 index 00000000000..5fc4b9d9d1d --- /dev/null +++ b/src/otx/algo/accelerators/__init__.py @@ -0,0 +1,8 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Lightning accelerator for XPU device.""" + +from .xpu import XPUAccelerator + +__all__ = ["XPUAccelerator"] diff --git a/src/otx/algo/accelerators/xpu.py b/src/otx/algo/accelerators/xpu.py new file mode 100644 index 00000000000..f5969336ab4 --- /dev/null +++ b/src/otx/algo/accelerators/xpu.py @@ -0,0 +1,88 @@ +"""Lightning accelerator for XPU device.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from typing import Any, Union + +import numpy as np +import torch +from lightning.pytorch.accelerators import AcceleratorRegistry +from lightning.pytorch.accelerators.accelerator import Accelerator +from mmcv.ops.nms import NMSop +from mmcv.ops.roi_align import RoIAlign +from mmengine.structures import instance_data + +from otx.algo.detection.utils import monkey_patched_nms, monkey_patched_roi_align +from otx.utils.utils import is_xpu_available + + +class XPUAccelerator(Accelerator): + """Support for a XPU, optimized for large-scale machine learning.""" + + accelerator_name = "xpu" + + def setup_device(self, device: torch.device) -> None: + """Sets up the specified device.""" + if device.type != "xpu": + msg = f"Device should be xpu, got {device} instead" + raise RuntimeError(msg) + + torch.xpu.set_device(device) + self.patch_packages_xpu() + + @staticmethod + def parse_devices(devices: str | list | torch.device) -> list: + """Parses devices for multi-GPU training.""" + if isinstance(devices, list): + return devices + return [devices] + + @staticmethod + def get_parallel_devices(devices: list) -> list[torch.device]: + """Generates a list of parrallel devices.""" + return [torch.device("xpu", idx) for idx in devices] + + @staticmethod + def auto_device_count() -> int: + """Returns number of XPU devices available.""" + return torch.xpu.device_count() + + @staticmethod + def is_available() -> bool: + """Checks if XPU available.""" + return is_xpu_available() + + def get_device_stats(self, device: str | torch.device) -> dict[str, Any]: + """Returns XPU devices stats.""" + return {} + + def teardown(self) -> None: + """Cleans-up XPU-related resources.""" + self.revert_packages_xpu() + + def patch_packages_xpu(self) -> None: + """Patch packages when xpu is available.""" + # patch instance_data from mmengie + long_type_tensor = Union[torch.LongTensor, torch.xpu.LongTensor] + bool_type_tensor = Union[torch.BoolTensor, torch.xpu.BoolTensor] + instance_data.IndexType = Union[str, slice, int, list, long_type_tensor, bool_type_tensor, np.ndarray] + + # patch nms and roi_align + self._nms_op_forward = NMSop.forward + self._roi_align_forward = RoIAlign.forward + NMSop.forward = monkey_patched_nms + RoIAlign.forward = monkey_patched_roi_align + + def revert_packages_xpu(self) -> None: + """Revert packages when xpu is available.""" + NMSop.forward = self._nms_op_forward + RoIAlign.forward = self._roi_align_forward + + +AcceleratorRegistry.register( + XPUAccelerator.accelerator_name, + XPUAccelerator, + description="Accelerator supports XPU devices", +) diff --git a/src/otx/algo/classification/deit_tiny.py b/src/otx/algo/classification/deit_tiny.py index 653a724a0f2..b78567f70de 100644 --- a/src/otx/algo/classification/deit_tiny.py +++ b/src/otx/algo/classification/deit_tiny.py @@ -10,7 +10,7 @@ import torch from mmpretrain.models.utils import resize_pos_embed -from otx.algo.hooks.recording_forward_hook import ViTReciproCAMHook +from otx.algo.explain.explain_algo import ViTReciproCAM from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable @@ -33,7 +33,7 @@ class ForwardExplainMixInForDeit(ExplainableMixInMMPretrainModel): - """Deit model which can attach a XAI hook.""" + """Deit model which can attach a XAI (Explainable AI) branch.""" @torch.no_grad() def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor: @@ -133,7 +133,7 @@ def _forward_explain_image_classifier( def get_explain_fn(self) -> Callable: """Returns explain function.""" - explainer = ViTReciproCAMHook( + explainer = ViTReciproCAM( self.head_forward_fn, num_classes=self.num_classes, ) diff --git a/src/otx/algo/classification/efficientnet_b0.py b/src/otx/algo/classification/efficientnet_b0.py index 488149dce24..29497a4c791 100644 --- a/src/otx/algo/classification/efficientnet_b0.py +++ b/src/otx/algo/classification/efficientnet_b0.py @@ -101,6 +101,9 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix) + def _reset_prediction_layer(self, num_classes: int) -> None: + return + class EfficientNetB0ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrainMultilabelClsModel): """EfficientNetB0 Model for multi-class classification task.""" diff --git a/src/otx/algo/classification/mobilenet_v3_large.py b/src/otx/algo/classification/mobilenet_v3_large.py index 7ef27adfe23..766d83378fc 100644 --- a/src/otx/algo/classification/mobilenet_v3_large.py +++ b/src/otx/algo/classification/mobilenet_v3_large.py @@ -4,10 +4,12 @@ """MobileNetV3 model implementation.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.classification import ( @@ -18,6 +20,7 @@ from otx.core.model.utils.mmpretrain import ExplainableMixInMMPretrainModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import HLabelInfo +from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -48,11 +51,21 @@ def __init__( ) @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - parent_parameters = super()._export_parameters - parent_parameters.update({"via_onnx": True}) - return parent_parameters + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + mean, std = get_mean_std_from_data_processing(self.config) + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=True, # TODO(someone): Check if this model can be exported directly with OV > 2024.0 + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" @@ -83,11 +96,21 @@ def __init__( ) @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - parent_parameters = super()._export_parameters - parent_parameters.update({"via_onnx": True}) - return parent_parameters + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + mean, std = get_mean_std_from_data_processing(self.config) + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=True, # NOTE: This should be done via onnx + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" @@ -116,11 +139,21 @@ def __init__( ) @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - parent_parameters = super()._export_parameters - parent_parameters.update({"via_onnx": True}) - return parent_parameters + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + mean, std = get_mean_std_from_data_processing(self.config) + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=True, # NOTE: This should be done via onnx + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" diff --git a/src/otx/algo/classification/otx_dino_v2.py b/src/otx/algo/classification/otx_dino_v2.py index adabe2ebf2e..63f5c82f97c 100644 --- a/src/otx/algo/classification/otx_dino_v2.py +++ b/src/otx/algo/classification/otx_dino_v2.py @@ -142,29 +142,21 @@ def _customize_outputs( labels=labels, ) - @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - export_params: dict[str, Any] = {} - - export_params["resize_mode"] = "standard" - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - export_params["via_onnx"] = False - export_params["input_size"] = (1, 3, 224, 224) - export_params["onnx_export_configuration"] = None - export_params["mean"] = [123.675, 116.28, 103.53] - export_params["std"] = [58.395, 57.12, 57.375] - - parent_parameters = super()._export_parameters - parent_parameters.update(export_params) - - return parent_parameters - @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter(**self._export_parameters) + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, 224, 224), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) @property def _optimization_config(self) -> dict[str, Any]: diff --git a/src/otx/algo/classification/torchvision_model.py b/src/otx/algo/classification/torchvision_model.py index 796160f3cb5..210fea7def4 100644 --- a/src/otx/algo/classification/torchvision_model.py +++ b/src/otx/algo/classification/torchvision_model.py @@ -11,7 +11,7 @@ from torch import nn from torchvision.models import get_model, get_model_weights -from otx.algo.hooks.recording_forward_hook import ReciproCAMHook +from otx.algo.explain.explain_algo import ReciproCAM from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.classification import MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity from otx.core.exporter.base import OTXModelExporter @@ -139,7 +139,7 @@ def __init__( self.softmax = nn.Softmax(dim=-1) self.loss = loss - self.explainer = ReciproCAMHook( + self.explainer = ReciproCAM( self._head_forward_fn, num_classes=num_classes, optimize_gap=True, @@ -285,25 +285,18 @@ def _customize_outputs( @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter(**self._export_parameters) - - @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - export_params: dict[str, Any] = {} - export_params["input_size"] = (1, 3, 224, 224) - export_params["resize_mode"] = "standard" - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - export_params["via_onnx"] = False - export_params["onnx_export_configuration"] = None - export_params["mean"] = [123.675, 116.28, 103.53] - export_params["std"] = [58.395, 57.12, 57.375] - - parameters = super()._export_parameters - parameters.update(export_params) - - return parameters + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, 224, 224), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: """Model forward explain function.""" diff --git a/src/otx/algo/detection/atss.py b/src/otx/algo/detection/atss.py index 3478a903a48..91a4a1b5bbe 100644 --- a/src/otx/algo/detection/atss.py +++ b/src/otx/algo/detection/atss.py @@ -5,14 +5,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from copy import deepcopy +from typing import TYPE_CHECKING, Literal from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.mmdeploy import MMdeployExporter from otx.core.metrics.mean_ap import MeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.detection import MMDetCompatibleModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -46,16 +50,27 @@ def __init__( self.tile_image_size = self.image_size @property - def _export_parameters(self) -> dict[str, Any]: - """Parameters for an exporter.""" - export_params = super()._export_parameters - export_params["deploy_cfg"] = "otx.algo.detection.mmdeploy.atss" - export_params["input_size"] = self.image_size - export_params["resize_mode"] = "standard" - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - - return export_params + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + if self.image_size is None: + raise ValueError(self.image_size) + + mean, std = get_mean_std_from_data_processing(self.config) + + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.detection.mmdeploy.atss", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" diff --git a/src/otx/algo/detection/heads/anchor_head.py b/src/otx/algo/detection/heads/anchor_head.py index ab3ec9e9d8c..1f56adb373c 100644 --- a/src/otx/algo/detection/heads/anchor_head.py +++ b/src/otx/algo/detection/heads/anchor_head.py @@ -9,19 +9,17 @@ from typing import TYPE_CHECKING import torch -from mmdet.models.task_modules.prior_generators import anchor_inside_flags -from mmdet.models.utils import images_to_levels, multi_apply, unmap from mmdet.registry import MODELS, TASK_UTILS -from mmdet.structures.bbox import BaseBoxes, cat_boxes, get_box_tensor from mmengine.structures import InstanceData from torch import Tensor, nn from otx.algo.detection.heads.base_head import BaseDenseHead from otx.algo.detection.heads.base_sampler import PseudoSampler from otx.algo.detection.heads.custom_anchor_generator import AnchorGenerator +from otx.algo.detection.utils.utils import anchor_inside_flags, images_to_levels, multi_apply, unmap if TYPE_CHECKING: - from mmdet.utils import InstanceList, OptConfigType, OptInstanceList, OptMultiConfig + from mmengine import ConfigDict # This class and its supporting functions below lightly adapted from the mmdet AnchorHead available at: @@ -56,11 +54,11 @@ def __init__( bbox_coder: dict, loss_cls: dict, loss_bbox: dict, + train_cfg: ConfigDict | dict, feat_channels: int = 256, reg_decoded_bbox: bool = False, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - init_cfg: OptMultiConfig = None, + test_cfg: ConfigDict | dict | None = None, + init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict] | None = None, ) -> None: super().__init__(init_cfg=init_cfg) self.in_channels = in_channels @@ -142,7 +140,7 @@ def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]: bbox_pred = self.conv_reg(x) return cls_score, bbox_pred - def forward(self, x: tuple[Tensor]) -> tuple[list[Tensor], list[Tensor]]: + def forward(self, x: tuple[Tensor]) -> tuple: """Forward features from the upstream network. Args: @@ -199,7 +197,7 @@ def get_anchors( def _get_targets_single( self, - flat_anchors: Tensor | BaseBoxes, + flat_anchors: Tensor, valid_flags: Tensor, gt_instances: InstanceData, img_meta: dict, @@ -209,7 +207,7 @@ def _get_targets_single( """Compute regression and classification targets for anchors in a single image. Args: - flat_anchors (Tensor or :obj:`BaseBoxes`): Multi-level anchors + flat_anchors (Tensor): Multi-level anchors of the image, which are concatenated into a single tensor or box type of shape (num_anchors, 4) valid_flags (Tensor): Multi level valid flags of the image, @@ -277,7 +275,6 @@ def _get_targets_single( pos_bbox_targets = self.bbox_coder.encode(sampling_result.pos_priors, sampling_result.pos_gt_bboxes) else: pos_bbox_targets = sampling_result.pos_gt_bboxes - pos_bbox_targets = get_box_tensor(pos_bbox_targets) bbox_targets[pos_inds, :] = pos_bbox_targets bbox_weights[pos_inds, :] = 1.0 @@ -303,9 +300,9 @@ def get_targets( self, anchor_list: list[list[Tensor]], valid_flag_list: list[list[Tensor]], - batch_gt_instances: InstanceList, + batch_gt_instances: list[InstanceData], batch_img_metas: list[dict], - batch_gt_instances_ignore: OptInstanceList = None, + batch_gt_instances_ignore: list[InstanceData] | None = None, unmap_outputs: bool = True, ) -> tuple: """Compute regression and classification targets for anchors in multiple images. @@ -364,7 +361,7 @@ def get_targets( concat_anchor_list = [] concat_valid_flag_list = [] for i in range(num_imgs): - concat_anchor_list.append(cat_boxes(anchor_list[i])) + concat_anchor_list.append(torch.cat(anchor_list[i])) concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) # compute targets for each image @@ -455,7 +452,6 @@ def loss_by_feat_single( # decodes the already encoded coordinates to absolute format. anchors = anchors.reshape(-1, anchors.size(-1)) bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) - bbox_pred = get_box_tensor(bbox_pred) loss_bbox = self.loss_bbox(bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor) return loss_cls, loss_bbox @@ -463,9 +459,9 @@ def loss_by_feat( self, cls_scores: list[Tensor], bbox_preds: list[Tensor], - batch_gt_instances: InstanceList, + batch_gt_instances: list[InstanceData], batch_img_metas: list[dict], - batch_gt_instances_ignore: OptInstanceList = None, + batch_gt_instances_ignore: list[InstanceData] | None = None, ) -> dict: """Calculate the loss based on the features extracted by the detection head. @@ -504,7 +500,7 @@ def loss_by_feat( # anchor number of multi levels num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] # concat all level anchors and flags to a single tensor - concat_anchor_list = [cat_boxes(anchor) for anchor in anchor_list] + concat_anchor_list = [torch.cat(anchor) for anchor in anchor_list] all_anchor_list = images_to_levels(concat_anchor_list, num_level_anchors) losses_cls, losses_bbox = multi_apply( diff --git a/src/otx/algo/detection/heads/base_head.py b/src/otx/algo/detection/heads/base_head.py index b513d90c4b0..b50ddcb0236 100644 --- a/src/otx/algo/detection/heads/base_head.py +++ b/src/otx/algo/detection/heads/base_head.py @@ -11,16 +11,14 @@ import torch from mmcv.ops import batched_nms -from mmdet.models.utils import filter_scores_and_topk, select_single_mlvl, unpack_gt_instances -from mmdet.structures.bbox import cat_boxes, get_box_tensor, get_box_wh, scale_boxes from mmengine.model import constant_init from mmengine.structures import InstanceData from torch import Tensor, nn +from otx.algo.detection.utils.utils import filter_scores_and_topk, select_single_mlvl, unpack_gt_instances + if TYPE_CHECKING: - from mmdet.structures import SampleList - from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig - from mmengine.config import ConfigDict + from mmengine import ConfigDict # This class and its supporting functions below lightly adapted from the mmdet BaseDenseHead available at: @@ -62,7 +60,7 @@ class BaseDenseHead(nn.Module): loss_and_predict(): forward() -> loss_by_feat() -> predict_by_feat() """ - def __init__(self, init_cfg: OptMultiConfig = None) -> None: + def __init__(self, init_cfg: ConfigDict | list[ConfigDict] | dict | list[dict] | None = None) -> None: super().__init__() self._is_init = False @@ -82,7 +80,7 @@ def init_weights(self) -> None: if hasattr(m, "conv_offset"): constant_init(m.conv_offset, 0) - def get_positive_infos(self) -> InstanceList: + def get_positive_infos(self) -> list[InstanceData] | None: """Get positive information from sampling results. Returns: @@ -105,7 +103,7 @@ def get_positive_infos(self) -> InstanceList: positive_infos.append(pos_info) return positive_infos - def loss(self, x: tuple[Tensor], batch_data_samples: SampleList) -> dict: + def loss(self, x: tuple[Tensor], batch_data_samples: list[InstanceData]) -> dict: """Perform forward propagation and loss calculation of the detection head. Args: @@ -131,18 +129,18 @@ def loss_by_feat( self, cls_scores: list[Tensor], bbox_preds: list[Tensor], - batch_gt_instances: InstanceList, + batch_gt_instances: list[InstanceData], batch_img_metas: list[dict], - batch_gt_instances_ignore: OptInstanceList = None, + batch_gt_instances_ignore: list[InstanceData] | None = None, ) -> dict: """Calculate the loss based on the features extracted by the detection head.""" def loss_and_predict( self, x: tuple[Tensor], - batch_data_samples: SampleList, + batch_data_samples: list[InstanceData], proposal_cfg: ConfigDict | None = None, - ) -> tuple[dict, InstanceList]: + ) -> tuple[dict, list[InstanceData]]: """Perform forward propagation of the head, then calculate loss and predictions. Args: @@ -172,7 +170,12 @@ def loss_and_predict( predictions = self.predict_by_feat(cls_scores, bbox_preds, batch_img_metas=batch_img_metas, cfg=proposal_cfg) return losses, predictions - def predict(self, x: tuple[Tensor], batch_data_samples: SampleList, rescale: bool = False) -> InstanceList: + def predict( + self, + x: tuple[Tensor], + batch_data_samples: list[InstanceData], + rescale: bool = False, + ) -> list[InstanceData]: """Perform forward propagation of the detection head and predict detection results. Args: @@ -203,7 +206,7 @@ def predict_by_feat( cfg: ConfigDict | None = None, rescale: bool = False, with_nms: bool = True, - ) -> InstanceList: + ) -> list[InstanceData]: """Transform a batch of output features extracted from the head into bbox results. Note: When score_factors is not None, the cls_scores are @@ -241,8 +244,6 @@ def predict_by_feat( - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ - with_score_factors = score_factors is not None - num_levels = len(cls_scores) featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] @@ -258,7 +259,7 @@ def predict_by_feat( img_meta = batch_img_metas[img_id] cls_score_list = select_single_mlvl(cls_scores, img_id, detach=True) bbox_pred_list = select_single_mlvl(bbox_preds, img_id, detach=True) - if with_score_factors: + if score_factors is not None: score_factor_list = select_single_mlvl(score_factors, img_id, detach=True) else: score_factor_list = [None for _ in range(num_levels)] @@ -369,8 +370,13 @@ def _predict_by_feat_single( # `nms_pre` than before. score_thr = cfg.get("score_thr", 0) - results = filter_scores_and_topk(scores, score_thr, nms_pre, {"bbox_pred": bbox_pred, "priors": priors}) - scores, labels, keep_idxs, filtered_results = results + filtered_results: dict + scores, labels, keep_idxs, filtered_results = filter_scores_and_topk( # type: ignore[assignment] + scores, + score_thr, + nms_pre, + {"bbox_pred": bbox_pred, "priors": priors}, + ) bbox_pred = filtered_results["bbox_pred"] # noqa: PLW2901 priors = filtered_results["priors"] # noqa: PLW2901 @@ -387,7 +393,7 @@ def _predict_by_feat_single( mlvl_score_factors.append(score_factor) bbox_pred = torch.cat(mlvl_bbox_preds) - priors = cat_boxes(mlvl_valid_priors) + priors = torch.cat(mlvl_valid_priors) bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) results = InstanceData() @@ -437,7 +443,9 @@ def _bbox_post_process( """ if rescale: scale_factor = [1 / s for s in img_meta["scale_factor"]] - results.bboxes = scale_boxes(results.bboxes, scale_factor) + results.bboxes = results.bboxes * results.bboxes.new_tensor(scale_factor).repeat( + (1, int(results.bboxes.size(-1) / 2)), + ) if hasattr(results, "score_factors"): score_factors = results.pop("score_factors") @@ -445,13 +453,14 @@ def _bbox_post_process( # filter small size bboxes if cfg.get("min_bbox_size", -1) >= 0: - w, h = get_box_wh(results.bboxes) + w = results.bboxes[:, 2] - results.bboxes[:, 0] + h = results.bboxes[:, 3] - results.bboxes[:, 1] valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) if not valid_mask.all(): results = results[valid_mask] if with_nms and results.bboxes.numel() > 0: - bboxes = get_box_tensor(results.bboxes) + bboxes = results.bboxes det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, results.labels, cfg.nms) results = results[keep_idxs] # some nms would reweight the score, such as softnms diff --git a/src/otx/algo/detection/heads/base_sampler.py b/src/otx/algo/detection/heads/base_sampler.py index 4fe04e131c4..1c75c310466 100644 --- a/src/otx/algo/detection/heads/base_sampler.py +++ b/src/otx/algo/detection/heads/base_sampler.py @@ -4,11 +4,10 @@ from abc import ABCMeta, abstractmethod import torch -from mmdet.models.task_modules.assigners import AssignResult -from mmdet.models.task_modules.samplers.sampling_result import SamplingResult -from mmdet.structures.bbox import BaseBoxes, cat_boxes from mmengine.structures import InstanceData +from otx.algo.detection.utils.structures import AssignResult, SamplingResult + class BaseSampler(metaclass=ABCMeta): """Base class of samplers. @@ -72,26 +71,6 @@ def sample( Returns: :obj:`SamplingResult`: Sampling result. - - Example: - >>> from mmengine.structures import InstanceData - >>> from mmdet.models.task_modules.samplers import RandomSampler, - >>> from mmdet.models.task_modules.assigners import AssignResult - >>> from mmdet.models.task_modules.samplers. - ... sampling_result import ensure_rng, random_boxes - >>> rng = ensure_rng(None) - >>> assign_result = AssignResult.random(rng=rng) - >>> pred_instances = InstanceData() - >>> pred_instances.priors = random_boxes(assign_result.num_preds, - ... rng=rng) - >>> gt_instances = InstanceData() - >>> gt_instances.bboxes = random_boxes(assign_result.num_gts, - ... rng=rng) - >>> gt_instances.labels = torch.randint( - ... 0, 5, (assign_result.num_gts,), dtype=torch.long) - >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, - >>> add_gt_as_proposals=False) - >>> self = self.sample(assign_result, pred_instances, gt_instances) """ gt_bboxes = gt_instances.bboxes priors = pred_instances.priors @@ -101,13 +80,8 @@ def sample( gt_flags = priors.new_zeros((priors.shape[0],), dtype=torch.uint8) if self.add_gt_as_proposals and len(gt_bboxes) > 0: - # When `gt_bboxes` and `priors` are all box type, convert - # `gt_bboxes` type to `priors` type. - if isinstance(gt_bboxes, BaseBoxes) and isinstance(priors, BaseBoxes): - gt_bboxes_ = gt_bboxes.convert_to(type(priors)) - else: - gt_bboxes_ = gt_bboxes - priors = cat_boxes([gt_bboxes_, priors], dim=0) + gt_bboxes_ = gt_bboxes + priors = torch.cat([gt_bboxes_, priors], dim=0) assign_result.add_gt_(gt_labels) gt_ones = priors.new_ones(gt_bboxes_.shape[0], dtype=torch.uint8) gt_flags = torch.cat([gt_ones, gt_flags]) diff --git a/src/otx/algo/detection/heads/custom_anchor_generator.py b/src/otx/algo/detection/heads/custom_anchor_generator.py index 8d706477a4c..bff9aa7d59b 100644 --- a/src/otx/algo/detection/heads/custom_anchor_generator.py +++ b/src/otx/algo/detection/heads/custom_anchor_generator.py @@ -10,7 +10,6 @@ import numpy as np import torch from mmdet.registry import TASK_UTILS -from mmdet.structures.bbox import HorizontalBoxes from torch.nn.modules.utils import _pair @@ -44,8 +43,6 @@ class AnchorGenerator: float is given, they will be used to shift the centers of anchors. center_offset (float): The offset of center in proportion to anchors' width and height. By default it is 0 in V2.0. - use_box_type (bool): Whether to warp anchors with the box type data - structure. Defaults to False. Examples: >>> from mmdet.models.task_modules. @@ -78,7 +75,6 @@ def __init__( scales_per_octave: int | None = None, centers: list[tuple[float, float]] | None = None, center_offset: float = 0.0, - use_box_type: bool = False, ) -> None: # check center and center_offset if center_offset != 0 and centers is None: @@ -112,7 +108,6 @@ def __init__( self.centers = centers self.center_offset = center_offset self.base_anchors = self.gen_base_anchors() - self.use_box_type = use_box_type @property def num_base_anchors(self) -> list[int]: @@ -278,12 +273,9 @@ def single_level_grid_priors( # shifted anchors (K, A, 4), reshape to (K*A, 4) all_anchors = base_anchors[None, :, :] + shifts[:, None, :] - all_anchors = all_anchors.view(-1, 4) # first A rows correspond to A anchors of (0, 0) in feature map, # then (0, 1), (0, 2), ... - if self.use_box_type: - all_anchors = HorizontalBoxes(all_anchors) - return all_anchors + return all_anchors.view(-1, 4) def sparse_priors( self, @@ -506,7 +498,6 @@ def __init__( self.center_offset = 0 self.gen_base_anchors() - self.use_box_type = False def gen_base_anchors(self) -> None: # type: ignore[override] """Generate base anchor for SSD.""" diff --git a/src/otx/algo/detection/heads/custom_ssd_head.py b/src/otx/algo/detection/heads/custom_ssd_head.py index 22fb788af4c..fe0ccc4bf1a 100644 --- a/src/otx/algo/detection/heads/custom_ssd_head.py +++ b/src/otx/algo/detection/heads/custom_ssd_head.py @@ -7,9 +7,6 @@ from typing import TYPE_CHECKING import torch -from mmdet.models.losses import smooth_l1_loss -from mmdet.models.utils import multi_apply -from mmdet.registry import MODELS from torch import Tensor, nn from otx.algo.detection.heads.anchor_head import AnchorHead @@ -17,10 +14,12 @@ from otx.algo.detection.heads.custom_anchor_generator import SSDAnchorGeneratorClustered from otx.algo.detection.heads.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder from otx.algo.detection.heads.max_iou_assigner import MaxIoUAssigner +from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss +from otx.algo.detection.losses.weighted_loss import smooth_l1_loss +from otx.algo.detection.utils.utils import multi_apply if TYPE_CHECKING: - from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList - from mmengine.config import Config + from mmengine.config import ConfigDict, InstanceData # This class and its supporting functions below lightly adapted from the mmdet SSDHead available at: @@ -39,12 +38,6 @@ class SSDHead(AnchorHead): > 0. Defaults to 256. use_depthwise (bool): Whether to use DepthwiseSeparableConv. Defaults to False. - conv_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct - and config conv layer. Defaults to None. - norm_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct - and config norm layer. Defaults to None. - act_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct - and config activation layer. Defaults to None. anchor_generator (:obj:`ConfigDict` or dict): Config dict for anchor generator. bbox_coder (:obj:`ConfigDict` or dict): Config of bounding box coder. @@ -63,21 +56,18 @@ class SSDHead(AnchorHead): def __init__( self, - anchor_generator: ConfigType, - bbox_coder: ConfigType, - init_cfg: MultiConfig, - act_cfg: ConfigType, + anchor_generator: ConfigDict | dict, + bbox_coder: ConfigDict | dict, + init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict], + act_cfg: ConfigDict | dict, + train_cfg: ConfigDict | dict, num_classes: int = 80, in_channels: tuple[int, ...] = (512, 1024, 512, 256, 256, 256), stacked_convs: int = 0, feat_channels: int = 256, use_depthwise: bool = False, - conv_cfg: ConfigType | None = None, - norm_cfg: ConfigType | None = None, reg_decoded_bbox: bool = False, - train_cfg: ConfigType | None = None, - test_cfg: ConfigType | None = None, - loss_cls: Config | dict | None = None, + test_cfg: ConfigDict | dict | None = None, ) -> None: super(AnchorHead, self).__init__(init_cfg=init_cfg) self.num_classes = num_classes @@ -85,9 +75,7 @@ def __init__( self.stacked_convs = stacked_convs self.feat_channels = feat_channels self.use_depthwise = use_depthwise - self.conv_cfg = conv_cfg - self.norm_cfg = norm_cfg - self.act_cfg = act_cfg + self.act_cfg = act_cfg # TODO(Jaeguk): act_cfg will be deprecated after implementing export. self.cls_out_channels = num_classes + 1 # add background class anchor_generator.pop("type") @@ -98,14 +86,7 @@ def __init__( # heads but a list of int in SSDHead self.num_base_priors = self.prior_generator.num_base_priors - if loss_cls is None: - loss_cls = { - "type": "CrossEntropyLoss", - "use_sigmoid": False, - "reduction": "none", - "loss_weight": 1.0, - } - self.loss_cls = MODELS.build(loss_cls) + self.loss_cls = CrossEntropyLoss(use_sigmoid=False, reduction="none", loss_weight=1.0) self._init_layers() @@ -218,9 +199,9 @@ def loss_by_feat( self, cls_scores: list[Tensor], bbox_preds: list[Tensor], - batch_gt_instances: InstanceList, + batch_gt_instances: list[InstanceData], batch_img_metas: list[dict], - batch_gt_instances_ignore: OptInstanceList = None, + batch_gt_instances_ignore: list[InstanceData] | None = None, ) -> dict[str, list[Tensor]]: """Compute losses of the head. @@ -298,11 +279,9 @@ def _init_layers(self) -> None: self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() - activation_config = self.act_cfg.copy() - activation_config.setdefault("inplace", True) for in_channel, num_base_priors in zip(self.in_channels, self.num_base_priors): if self.use_depthwise: - activation_layer = MODELS.build(activation_config) + activation_layer = nn.ReLU(inplace=True) self.reg_convs.append( nn.Sequential( diff --git a/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py b/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py index 126b5de4eef..a46e838e06e 100644 --- a/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py +++ b/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py @@ -6,7 +6,6 @@ import numpy as np import torch -from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor from torch import Tensor @@ -44,6 +43,7 @@ def __init__( ctr_clamp: int = 32, ) -> None: self.encode_size = encode_size + # TODO(Jaeguk): use_box_type should be deprecated. self.use_box_type = use_box_type self.means = target_means self.stds = target_stds @@ -51,33 +51,31 @@ def __init__( self.add_ctr_clamp = add_ctr_clamp self.ctr_clamp = ctr_clamp - def encode(self, bboxes: Tensor | BaseBoxes, gt_bboxes: Tensor | BaseBoxes) -> Tensor: + def encode(self, bboxes: Tensor, gt_bboxes: Tensor) -> Tensor: """Get box regression transformation deltas that can be used to transform the bboxes into the gt_bboxes. Args: - bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes, + bboxes (torch.Tensor): Source boxes, e.g., object proposals. - gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the + gt_bboxes (torch.Tensor): Target of the transformation, e.g., ground-truth boxes. Returns: torch.Tensor: Box transformation deltas """ - bboxes = get_box_tensor(bboxes) - gt_bboxes = get_box_tensor(gt_bboxes) return bbox2delta(bboxes, gt_bboxes, self.means, self.stds) def decode( self, - bboxes: Tensor | BaseBoxes, + bboxes: Tensor, pred_bboxes: Tensor, max_shape: tuple[int, ...] | Tensor | tuple[tuple[int, ...], ...] | None = None, wh_ratio_clip: float = 16 / 1000, - ) -> Tensor | BaseBoxes: + ) -> Tensor: """Apply transformation `pred_bboxes` to `boxes`. Args: - bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes. Shape + bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4) pred_bboxes (Tensor): Encoded offsets with respect to each roi. Has shape (B, N, num_classes * 4) or (B, N, 4) or @@ -92,10 +90,9 @@ def decode( width and height. Returns: - Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. + torch.Tensor: Decoded boxes. """ - bboxes = get_box_tensor(bboxes) - decoded_bboxes = delta2bbox( + return delta2bbox( bboxes, pred_bboxes, self.means, @@ -107,10 +104,6 @@ def decode( self.ctr_clamp, ) - if self.use_box_type: - decoded_bboxes = HorizontalBoxes(decoded_bboxes) - return decoded_bboxes - def bbox2delta( proposals: Tensor, diff --git a/src/otx/algo/detection/heads/iou2d_calculator.py b/src/otx/algo/detection/heads/iou2d_calculator.py index fd46b436b5d..bad8a5ea094 100644 --- a/src/otx/algo/detection/heads/iou2d_calculator.py +++ b/src/otx/algo/detection/heads/iou2d_calculator.py @@ -5,7 +5,8 @@ from __future__ import annotations import torch -from mmdet.structures.bbox import BaseBoxes, bbox_overlaps, get_box_tensor + +from otx.algo.detection.utils.bbox_overlaps import bbox_overlaps # This class and its supporting functions below lightly adapted from the mmdet BboxOverlaps2D available at: @@ -19,18 +20,18 @@ def __init__(self, scale: float = 1.0, dtype: str | None = None): def __call__( self, - bboxes1: torch.Tensor | BaseBoxes, - bboxes2: torch.Tensor | BaseBoxes, + bboxes1: torch.Tensor, + bboxes2: torch.Tensor, mode: str = "iou", is_aligned: bool = False, ) -> torch.Tensor: """Calculate IoU between 2D bboxes. Args: - bboxes1 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) + bboxes1 (Tensor): bboxes have shape (m, 4) in format, or shape (m, 5) in format. - bboxes2 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) + bboxes2 (Tensor): bboxes have shape (m, 4) in format, shape (m, 5) in format, or be empty. If ``is_aligned `` is ``True``, then m and n must be equal. @@ -43,8 +44,6 @@ def __call__( Returns: Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) """ - bboxes1 = get_box_tensor(bboxes1) - bboxes2 = get_box_tensor(bboxes2) if bboxes2.size(-1) == 5: bboxes2 = bboxes2[..., :4] if bboxes1.size(-1) == 5: diff --git a/src/otx/algo/detection/heads/max_iou_assigner.py b/src/otx/algo/detection/heads/max_iou_assigner.py index 7bcbb0353ab..f95f44585a1 100644 --- a/src/otx/algo/detection/heads/max_iou_assigner.py +++ b/src/otx/algo/detection/heads/max_iou_assigner.py @@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Callable import torch -from mmdet.models.task_modules.assigners.assign_result import AssignResult from torch import Tensor from otx.algo.detection.heads.iou2d_calculator import BboxOverlaps2D +from otx.algo.detection.utils.structures import AssignResult if TYPE_CHECKING: from mmengine.structures import InstanceData diff --git a/src/otx/algo/detection/losses/cross_entropy_loss.py b/src/otx/algo/detection/losses/cross_entropy_loss.py new file mode 100644 index 00000000000..81c3be1a1b1 --- /dev/null +++ b/src/otx/algo/detection/losses/cross_entropy_loss.py @@ -0,0 +1,272 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Base Cross Entropy Loss implementation from mmdet.""" + +from __future__ import annotations + +import torch +from torch import nn + +from otx.algo.detection.losses.weighted_loss import weight_reduce_loss + + +# All of the methods and classes below come from mmdet, and are slightly modified. +# https://github.com/open-mmlab/mmdetection/blob/ecac3a77becc63f23d9f6980b2a36f86acd00a8a/mmdet/models/losses/cross_entropy_loss.py +def cross_entropy( + pred: torch.Tensor, + label: torch.Tensor, + weight: torch.Tensor | None = None, + reduction: str = "mean", + avg_factor: int | None = None, + class_weight: list[float] | None = None, + ignore_index: int = -100, + avg_non_ignore: bool = False, +) -> torch.Tensor: + """Calculate the CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. + Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + + Returns: + torch.Tensor: The calculated loss + """ + loss = nn.functional.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index) + + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 + if (avg_factor is None) and avg_non_ignore and reduction == "mean": + avg_factor = label.numel() - (label == ignore_index).sum().item() + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + return weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + +def _expand_onehot_labels( + labels: torch.Tensor, + label_weights: torch.Tensor, + label_channels: int, + ignore_index: int, +) -> tuple[torch.Tensor, ...]: + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_full((labels.size(0), label_channels), 0) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask & (labels < label_channels), as_tuple=False) + + if inds.numel() > 0: + bin_labels[inds, labels[inds]] = 1 + + valid_mask = valid_mask.view(-1, 1).expand(labels.size(0), label_channels).float() + bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels) + bin_label_weights *= valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy( + pred: torch.Tensor, + label: torch.Tensor, + weight: torch.Tensor | None = None, + reduction: str = "mean", + avg_factor: int | None = None, + class_weight: list[float] | None = None, + ignore_index: int = -100, + avg_non_ignore: bool = False, +) -> torch.Tensor: + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1) or (N, ). + When the shape of pred is (N, 1), label will be expanded to + one-hot format, and when the shape of pred is (N, ), label + will not be expanded to one-hot format. + label (torch.Tensor): The learning label of the prediction, + with shape (N, ). + weight (torch.Tensor, None): Sample-wise loss weight. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. + Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + + Returns: + torch.Tensor: The calculated loss. + """ + if pred.dim() != label.dim(): + label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.size(-1), ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + # The inplace writing method will have a mismatched broadcast + # shape error if the weight and valid_mask dimensions + # are inconsistent such as (B,N,1) and (B,N,C). + weight = weight * valid_mask if weight is not None else valid_mask + + # average loss over non-ignored elements + if (avg_factor is None) and avg_non_ignore and reduction == "mean": + avg_factor = valid_mask.sum().item() + + # weighted element-wise losses + weight = weight.float() + loss = nn.functional.binary_cross_entropy_with_logits( + pred, + label.float(), + pos_weight=class_weight, + reduction="none", + ) + # do the reduction for the weighted loss + return weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor) + + +def mask_cross_entropy( + pred: torch.Tensor, + target: torch.Tensor, + label: torch.Tensor, + class_weight: list[float] | None = None, + **kwargs, # noqa: ARG001 +) -> torch.Tensor: + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C, *), C is the + number of classes. The trailing * indicates arbitrary shape. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + class_weight (list[float], None): The weight for each class. + + Returns: + torch.Tensor: The calculated loss + + Example: + >>> N, C = 3, 11 + >>> H, W = 2, 2 + >>> pred = torch.randn(N, C, H, W) * 1000 + >>> target = torch.rand(N, H, W) + >>> label = torch.randint(0, C, size=(N,)) + >>> reduction = 'mean' + >>> avg_factor = None + >>> class_weights = None + >>> loss = mask_cross_entropy(pred, target, label, reduction, + >>> avg_factor, class_weights) + >>> assert loss.shape == (1,) + """ + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return nn.functional.binary_cross_entropy_with_logits( + pred_slice, + target, + weight=class_weight, + reduction="mean", + )[None] + + +class CrossEntropyLoss(nn.Module): + """Base Cross Entropy Loss implementation from mmdet.""" + + def __init__( + self, + use_sigmoid: bool = False, + use_mask: bool = False, + reduction: str = "mean", + class_weight: list[float] | None = None, + loss_weight: float = 1.0, + avg_non_ignore: bool = False, + ): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float], optional): Weight of each class. + Defaults to None. + loss_weight (float): Weight of the loss. Defaults to 1.0. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + """ + super().__init__() + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + self.avg_non_ignore = avg_non_ignore + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy # type: ignore[assignment] + else: + self.cls_criterion = cross_entropy + + def extra_repr(self) -> str: + """Extra repr.""" + return f"avg_non_ignore={self.avg_non_ignore}" + + def forward( + self, + cls_score: torch.Tensor, + label: torch.Tensor, + weight: torch.Tensor | None = None, + avg_factor: int | None = None, + reduction_override: str | None = None, + ignore_index: int = -100, + **kwargs, + ) -> torch.Tensor: + """Forward function. + + Args: + cls_score (torch.Tensor): The prediction. + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, None): Sample-wise loss weight. + avg_factor (int, None): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, None): The method used to reduce the + loss. Options are "none", "mean" and "sum". + ignore_index (int): The label index to be ignored. + Default: -100. + + Returns: + torch.Tensor: The calculated loss. + """ + reduction = reduction_override if reduction_override else self.reduction + + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight, device=cls_score.device) + else: + class_weight = None + return self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + ignore_index=ignore_index, + avg_non_ignore=self.avg_non_ignore, + **kwargs, + ) diff --git a/src/otx/algo/detection/losses/weighted_loss.py b/src/otx/algo/detection/losses/weighted_loss.py new file mode 100644 index 00000000000..69fe1d2696d --- /dev/null +++ b/src/otx/algo/detection/losses/weighted_loss.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Weighted loss from mmdet.""" + +from __future__ import annotations + +import functools +from typing import Callable + +import torch +from torch import Tensor +from torch.nn import functional + + +def reduce_loss(loss: Tensor, reduction: str) -> Tensor: + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = functional._Reduction.get_enum(reduction) # noqa: SLF001 + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + if reduction_enum == 1: + return loss.mean() + if reduction_enum == 2: + return loss.sum() + msg = f"reduction_enum: {reduction_enum} is invalid" + raise ValueError(msg) + + +def weight_reduce_loss( + loss: Tensor, + weight: Tensor | None = None, + reduction: str = "mean", + avg_factor: float | None = None, +) -> Tensor: + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor | None): Element-wise weights. + Defaults to None. + reduction (str): Same as built-in losses of PyTorch. + Defaults to 'mean'. + avg_factor (float | None): Average factor when + computing the mean of losses. Defaults to None. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + # if reduction is mean, then average the loss by avg_factor + elif reduction == "mean": + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != "none": + msg = "avg_factor can not be used with reduction='sum'" + raise ValueError(msg) + return loss + + +def weighted_loss(loss_func: Callable) -> Callable: + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper( + pred: Tensor, + target: Tensor, + weight: Tensor | None = None, + reduction: str = "mean", + avg_factor: int | None = None, + **kwargs, + ) -> Tensor: + """Wrapper for weighted loss. + + Args: + pred (Tensor): The prediction. + target (Tensor): Target bboxes. + weight (Tensor | None): The weight of loss for each + prediction. Defaults to None. + reduction (str): Options are "none", "mean" and "sum". + Defaults to 'mean'. + avg_factor (int | None): Average factor that is used + to average the loss. Defaults to None. + + Returns: + Tensor: Loss tensor. + """ + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + return weight_reduce_loss(loss, weight, reduction, avg_factor) + + return wrapper + + +@weighted_loss +def smooth_l1_loss(pred: Tensor, target: Tensor, beta: float = 1.0) -> Tensor: + """Smooth L1 loss. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + beta (float): The threshold in the piecewise function. + Defaults to 1.0. + + Returns: + Tensor: Calculated loss + """ + if target.numel() == 0: + return pred.sum() * 0 + + diff = torch.abs(pred - target) + return torch.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta) diff --git a/src/otx/algo/detection/mmconfigs/ssd_mobilenetv2.yaml b/src/otx/algo/detection/mmconfigs/ssd_mobilenetv2.yaml index 93b40f0df85..a6685ee4cac 100644 --- a/src/otx/algo/detection/mmconfigs/ssd_mobilenetv2.yaml +++ b/src/otx/algo/detection/mmconfigs/ssd_mobilenetv2.yaml @@ -50,8 +50,6 @@ bbox_head: - 96 - 320 use_depthwise: true - norm_cfg: - type: BN act_cfg: type: ReLU init_cfg: diff --git a/src/otx/algo/detection/ops/__init__.py b/src/otx/algo/detection/ops/__init__.py new file mode 100644 index 00000000000..8bed33ebe94 --- /dev/null +++ b/src/otx/algo/detection/ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Custom operation implementations for detection task.""" diff --git a/src/otx/algo/detection/ops/nms.py b/src/otx/algo/detection/ops/nms.py new file mode 100644 index 00000000000..968b0ec3f1a --- /dev/null +++ b/src/otx/algo/detection/ops/nms.py @@ -0,0 +1,221 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""NMS implementations for detection task.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch +from torch import Tensor +from torchvision.ops.boxes import nms as torch_nms + + +# This class is from NMSop in mmcv and slightly modified +# https://github.com/open-mmlab/mmcv/blob/265531fa9fe9e071c7d80df549d680ed257d9a16/mmcv/ops/nms.py +class NMSop(torch.autograd.Function): + """NMS operation.""" + + @staticmethod + def forward( + ctx: Any, # noqa: ARG004, ANN401 + bboxes: Tensor, + scores: Tensor, + iou_threshold: float, + offset: int, # noqa: ARG004 + score_threshold: float, + max_num: int, + ) -> Tensor: + """Forward function.""" + is_filtering_by_score = score_threshold > 0 + if is_filtering_by_score: + valid_mask = scores > score_threshold + bboxes, scores = bboxes[valid_mask], scores[valid_mask] + valid_inds = torch.nonzero(valid_mask, as_tuple=False).squeeze(dim=1) + inds = torch_nms(bboxes, scores, iou_threshold) + + if max_num > 0: + inds = inds[:max_num] + if is_filtering_by_score: + inds = valid_inds[inds] + return inds + + +# This method is from nms in mmcv +# https://github.com/open-mmlab/mmcv/blob/265531fa9fe9e071c7d80df549d680ed257d9a16/mmcv/ops/nms.py +def nms( + boxes: Tensor | np.ndarray, + scores: Tensor | np.ndarray, + iou_threshold: float, + offset: int = 0, + score_threshold: float = 0, + max_num: int = -1, +) -> tuple[Tensor | np.ndarray, Tensor | np.ndarray]: + """Dispatch to either CPU or GPU NMS implementations. + + The input can be either torch tensor or numpy array. GPU NMS will be used + if the input is gpu tensor, otherwise CPU NMS + will be used. The returned type will always be the same as inputs. + + Arguments: + boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4). + scores (torch.Tensor or np.ndarray): scores in shape (N, ). + iou_threshold (float): IoU threshold for NMS. + offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset). + score_threshold (float): score threshold for NMS. + max_num (int): maximum number of boxes after NMS. + + Returns: + tuple: kept dets (boxes and scores) and indice, which always have + the same data type as the input. + + Example: + >>> boxes = np.array([[49.1, 32.4, 51.0, 35.9], + >>> [49.3, 32.9, 51.0, 35.3], + >>> [49.2, 31.8, 51.0, 35.4], + >>> [35.1, 11.5, 39.1, 15.7], + >>> [35.6, 11.8, 39.3, 14.2], + >>> [35.3, 11.5, 39.9, 14.5], + >>> [35.2, 11.7, 39.7, 15.7]], dtype=np.float32) + >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.5, 0.4, 0.3],\ + dtype=np.float32) + >>> iou_threshold = 0.6 + >>> dets, inds = nms(boxes, scores, iou_threshold) + >>> assert len(inds) == len(dets) == 3 + """ + is_numpy = False + if isinstance(boxes, np.ndarray): + is_numpy = True + boxes = torch.from_numpy(boxes) + if isinstance(scores, np.ndarray): + scores = torch.from_numpy(scores) + + inds = NMSop.apply(boxes, scores, iou_threshold, offset, score_threshold, max_num) + dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1) + if is_numpy: + dets = dets.cpu().numpy() + inds = inds.cpu().numpy() + return dets, inds + + +# This method is from batched_nms in mmcv +# https://github.com/open-mmlab/mmcv/blob/265531fa9fe9e071c7d80df549d680ed257d9a16/mmcv/ops/nms.py +def batched_nms( + boxes: Tensor, + scores: Tensor, + idxs: Tensor, + nms_cfg: dict | None = None, + class_agnostic: bool = False, +) -> tuple[Tensor, Tensor]: + r"""Performs non-maximum suppression in a batched fashion. + + Modified from `torchvision/ops/boxes.py#L39 + `_. + In order to perform NMS independently per class, we add an offset to all + the boxes. The offset is dependent only on the class idx, and is large + enough so that boxes from different classes do not overlap. + + Note: + In v1.4.1 and later, ``batched_nms`` supports skipping the NMS and + returns sorted raw results when `nms_cfg` is None. + + Args: + boxes (torch.Tensor): boxes in shape (N, 4) or (N, 5). + scores (torch.Tensor): scores in shape (N, ). + idxs (torch.Tensor): each index value correspond to a bbox cluster, + and NMS will not be applied between elements of different idxs, + shape (N, ). + nms_cfg (dict | optional): Supports skipping the nms when `nms_cfg` + is None, otherwise it should specify nms type and other + parameters like `iou_thr`. Possible keys includes the following. + + - iou_threshold (float): IoU threshold used for NMS. + - split_thr (float): threshold number of boxes. In some cases the + number of boxes is large (e.g., 200k). To avoid OOM during + training, the users could set `split_thr` to a small value. + If the number of boxes is greater than the threshold, it will + perform NMS on each group of boxes separately and sequentially. + Defaults to 10000. + class_agnostic (bool): if true, nms is class agnostic, + i.e. IoU thresholding happens over all boxes, + regardless of the predicted class. Defaults to False. + + Returns: + tuple: kept dets and indice. + + - boxes (Tensor): Bboxes with score after nms, has shape + (num_bboxes, 5). last dimension 5 arrange as + (x1, y1, x2, y2, score) + - keep (Tensor): The indices of remaining boxes in input + boxes. + """ + # skip nms when nms_cfg is None + if nms_cfg is None: + scores, inds = scores.sort(descending=True) + boxes = boxes[inds] + return torch.cat([boxes, scores[:, None]], -1), inds + + nms_cfg_ = nms_cfg.copy() + class_agnostic = nms_cfg_.pop("class_agnostic", class_agnostic) + if class_agnostic: + boxes_for_nms = boxes + # When using rotated boxes, only apply offsets on center. + elif boxes.size(-1) == 5: + # Strictly, the maximum coordinates of the rotating box + # (x,y,w,h,a) should be calculated by polygon coordinates. + # But the conversion from rotated box to polygon will + # slow down the speed. + # So we use max(x,y) + max(w,h) as max coordinate + # which is larger than polygon max coordinate + # max(x1, y1, x2, y2,x3, y3, x4, y4) + max_coordinate = boxes[..., :2].max() + boxes[..., 2:4].max() + offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) + boxes_ctr_for_nms = boxes[..., :2] + offsets[:, None] + boxes_for_nms = torch.cat([boxes_ctr_for_nms, boxes[..., 2:5]], dim=-1) + else: + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) + boxes_for_nms = boxes + offsets[:, None] + + nms_op = nms_cfg_.pop("type", "nms") + if isinstance(nms_op, str): + nms_op = eval(nms_op) # noqa: S307, PGH001 + + split_thr = nms_cfg_.pop("split_thr", 10000) + # Won't split to multiple nms nodes when exporting to onnx + if boxes_for_nms.shape[0] < split_thr: + dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_) + boxes = boxes[keep] + + # This assumes `dets` has arbitrary dimensions where + # the last dimension is score. + # Currently it supports bounding boxes [x1, y1, x2, y2, score] or + # rotated boxes [cx, cy, w, h, angle_radian, score]. + + scores = dets[:, -1] + else: + max_num = nms_cfg_.pop("max_num", -1) + total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) + # Some type of nms would reweight the score, such as SoftNMS + scores_after_nms = scores.new_zeros(scores.size()) + for idx in torch.unique(idxs): + mask = (idxs == idx).nonzero(as_tuple=False).view(-1) + dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_) + total_mask[mask[keep]] = True + scores_after_nms[mask[keep]] = dets[:, -1] + keep = total_mask.nonzero(as_tuple=False).view(-1) + + scores, inds = scores_after_nms[keep].sort(descending=True) + keep = keep[inds] + boxes = boxes[keep] + + if max_num > 0: + keep = keep[:max_num] + boxes = boxes[:max_num] + scores = scores[:max_num] + + boxes = torch.cat([boxes, scores[:, None]], -1) + return boxes, keep diff --git a/src/otx/algo/detection/rtmdet.py b/src/otx/algo/detection/rtmdet.py index d3531dcf88e..2965c32c22d 100644 --- a/src/otx/algo/detection/rtmdet.py +++ b/src/otx/algo/detection/rtmdet.py @@ -5,14 +5,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from copy import deepcopy +from typing import TYPE_CHECKING, Literal from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.mmdeploy import MMdeployExporter from otx.core.metrics.mean_ap import MeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.detection import MMDetCompatibleModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -46,16 +50,27 @@ def __init__( self.tile_image_size = self.image_size @property - def _export_parameters(self) -> dict[str, Any]: - """Parameters for an exporter.""" - export_params = super()._export_parameters - export_params["deploy_cfg"] = "otx.algo.detection.mmdeploy.rtmdet" - export_params["input_size"] = self.image_size - export_params["resize_mode"] = "fit_to_window_letterbox" - export_params["pad_value"] = 114 - export_params["swap_rgb"] = False - - return export_params + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + if self.image_size is None: + raise ValueError(self.image_size) + + mean, std = get_mean_std_from_data_processing(self.config) + + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.detection.mmdeploy.rtmdet", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="fit_to_window_letterbox", + pad_value=114, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" diff --git a/src/otx/algo/detection/ssd.py b/src/otx/algo/detection/ssd.py index 7912e9744b1..ec7c74b4296 100644 --- a/src/otx/algo/detection/ssd.py +++ b/src/otx/algo/detection/ssd.py @@ -18,18 +18,21 @@ from otx.algo.detection.heads.custom_ssd_head import SSDHead from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.mmdeploy import MMdeployExporter from otx.core.metrics.mean_ap import MeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.detection import MMDetCompatibleModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.utils.build import modify_num_classes from otx.core.utils.config import convert_conf_to_mmconfig_dict +from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: import torch from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable - from mmdet.structures import DetDataSample, OptSampleList, SampleList - from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig + from mmengine import ConfigDict + from mmengine.structures import InstanceData from omegaconf import DictConfig from torch import Tensor, device @@ -48,12 +51,12 @@ class SingleStageDetector(nn.Module): def __init__( self, - backbone: ConfigType, - bbox_head: OptConfigType = None, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - data_preprocessor: OptConfigType = None, - init_cfg: OptMultiConfig = None, + backbone: ConfigDict | dict, + bbox_head: ConfigDict | dict, + train_cfg: ConfigDict | dict | None = None, + test_cfg: ConfigDict | dict | None = None, + data_preprocessor: ConfigDict | dict | None = None, + init_cfg: ConfigDict | list[ConfigDict] | dict | list[dict] = None, ) -> None: super().__init__() self._is_init = False @@ -153,9 +156,9 @@ def init_weights(self) -> None: def forward( self, inputs: torch.Tensor, - data_samples: OptSampleList = None, + data_samples: list[InstanceData], mode: str = "tensor", - ) -> dict[str, torch.Tensor] | list[DetDataSample] | tuple[torch.Tensor] | torch.Tensor: + ) -> dict[str, torch.Tensor] | list[InstanceData] | tuple[torch.Tensor] | torch.Tensor: """The unified entry for a forward process in both training and test. The method should accept three modes: "tensor", "predict" and "loss": @@ -163,7 +166,7 @@ def forward( - "tensor": Forward the whole network and return tensor or tuple of tensor without any post-processing, same as a common nn.Module. - "predict": Forward and return the predictions, which are fully - processed to a list of :obj:`DetDataSample`. + processed to a list of :obj:`InstanceData`. - "loss": Forward and return a dict of losses according to the given inputs and data samples. @@ -173,7 +176,7 @@ def forward( Args: inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general. - data_samples (list[:obj:`DetDataSample`], optional): A batch of + data_samples (list[:obj:`InstanceData`], optional): A batch of data samples that contain annotations and predictions. Defaults to None. mode (str): Return what kind of value. Defaults to 'tensor'. @@ -182,7 +185,7 @@ def forward( The return type depends on ``mode``. - If ``mode="tensor"``, return a tensor or a tuple of tensor. - - If ``mode="predict"``, return a list of :obj:`DetDataSample`. + - If ``mode="predict"``, return a list of :obj:`InstanceData`. - If ``mode="loss"``, return a dict of tensor. """ if mode == "loss": @@ -198,14 +201,14 @@ def forward( def loss( self, batch_inputs: Tensor, - batch_data_samples: SampleList, + batch_data_samples: list[InstanceData], ) -> dict | list: """Calculate losses from a batch of inputs and data samples. Args: batch_inputs (Tensor): Input images of shape (N, C, H, W). These should usually be mean centered and std scaled. - batch_data_samples (list[:obj:`DetDataSample`]): The batch + batch_data_samples (list[:obj:`InstanceData`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. @@ -215,20 +218,25 @@ def loss( x = self.extract_feat(batch_inputs) return self.bbox_head.loss(x, batch_data_samples) - def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: bool = True) -> SampleList: + def predict( + self, + batch_inputs: Tensor, + batch_data_samples: list[InstanceData], + rescale: bool = True, + ) -> list[InstanceData]: """Predict results from a batch of inputs and data samples with post-processing. Args: batch_inputs (Tensor): Inputs with shape (N, C, H, W). - batch_data_samples (List[:obj:`DetDataSample`]): The Data + batch_data_samples (List[:obj:`InstanceData`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool): Whether to rescale the results. Defaults to True. Returns: - list[:obj:`DetDataSample`]: Detection results of the - input images. Each DetDataSample usually contain + list[:obj:`InstanceData`]: Detection results of the + input images. Each InstanceData usually contain 'pred_instances'. And the ``pred_instances`` usually contains following keys. @@ -246,13 +254,13 @@ def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: def _forward( self, batch_inputs: Tensor, - batch_data_samples: OptSampleList = None, + batch_data_samples: list[InstanceData] | None = None, ) -> tuple[list[Tensor], list[Tensor]]: """Network forward process. Args: batch_inputs (Tensor): Inputs with shape (N, C, H, W). - batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + batch_data_samples (list[:obj:`InstanceData`]): Each item contains the meta information of each image and corresponding annotations. @@ -277,18 +285,22 @@ def extract_feat(self, batch_inputs: Tensor) -> tuple[Tensor]: x = self.neck(x) return x - def add_pred_to_datasample(self, data_samples: SampleList, results_list: InstanceList) -> SampleList: - """Add predictions to `DetDataSample`. + def add_pred_to_datasample( + self, + data_samples: list[InstanceData], + results_list: list[InstanceData], + ) -> list[InstanceData]: + """Add predictions to `InstanceData`. Args: - data_samples (list[:obj:`DetDataSample`], optional): A batch of + data_samples (list[:obj:`InstanceData`], optional): A batch of data samples that contain annotations and predictions. results_list (list[:obj:`InstanceData`]): Detection results of each image. Returns: - list[:obj:`DetDataSample`]: Detection results of the - input images. Each DetDataSample usually contain + list[:obj:`InstanceData`]: Detection results of the + input images. Each InstanceData usually contain 'pred_instances'. And the ``pred_instances`` usually contains following keys. @@ -355,7 +367,6 @@ def __init__( ) self.image_size = (1, 3, 864, 864) self.tile_image_size = self.image_size - self._register_load_state_dict_pre_hook(self._set_anchors_hook) def _create_model(self) -> nn.Module: from mmdet.models.data_preprocessors import ( @@ -407,6 +418,10 @@ def setup(self, stage: str) -> None: anchor_generator.widths = new_anchors[0] anchor_generator.heights = new_anchors[1] anchor_generator.gen_base_anchors() + self.hparams["ssd_anchors"] = { + "heights": anchor_generator.heights, + "widths": anchor_generator.widths, + } def _get_new_anchors(self, dataset: OTXDataset, anchor_generator: SSDAnchorGeneratorClustered) -> tuple | None: """Get new anchors for SSD from OTXDataset.""" @@ -518,19 +533,6 @@ def get_classification_layers( classification_layers[prefix + key] = {"use_bg": use_bg, "num_anchors": num_anchors} return classification_layers - def state_dict(self, *args, **kwargs) -> dict[str, Any]: - """Return state dictionary of model entity with anchor information. - - Returns: - A dictionary containing SSD state. - - """ - state_dict = super().state_dict(*args, **kwargs) - anchor_generator = self.model.bbox_head.anchor_generator - anchors = {"heights": anchor_generator.heights, "widths": anchor_generator.widths} - state_dict["model.model.anchors"] = anchors - return state_dict - def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs) -> None: """Modify input state_dict according to class name matching before weight loading.""" model2ckpt = self.map_class_names(self.model_classes, self.ckpt_classes) @@ -563,26 +565,38 @@ def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix: state_dict[prefix + param_name] = model_param @property - def _export_parameters(self) -> dict[str, Any]: - """Parameters for an exporter.""" - export_params = super()._export_parameters - export_params["deploy_cfg"] = "otx.algo.detection.mmdeploy.ssd_mobilenetv2" - export_params["input_size"] = self.image_size - export_params["resize_mode"] = "standard" - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - - return export_params - - def _set_anchors_hook(self, state_dict: dict[str, Any], *args, **kwargs) -> None: - """Pre hook for pop anchor statistics from checkpoint state_dict.""" - anchors = state_dict.pop("model.model.anchors", None) - if anchors is not None: + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + if self.image_size is None: + raise ValueError(self.image_size) + + mean, std = get_mean_std_from_data_processing(self.config) + + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.detection.mmdeploy.ssd_mobilenetv2", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) + + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on load checkpoint.""" + if (hparams := checkpoint.get("hyper_parameters")) and (anchors := hparams.get("ssd_anchors", None)): anchor_generator = self.model.bbox_head.anchor_generator anchor_generator.widths = anchors["widths"] anchor_generator.heights = anchors["heights"] anchor_generator.gen_base_anchors() + return super().on_load_checkpoint(checkpoint) + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_ssd_ckpt(state_dict, add_prefix) diff --git a/src/otx/algo/detection/utils/__init__.py b/src/otx/algo/detection/utils/__init__.py new file mode 100644 index 00000000000..2ab46a64ac4 --- /dev/null +++ b/src/otx/algo/detection/utils/__init__.py @@ -0,0 +1,8 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""utils for detection task.""" + +from .mmcv_patched_ops import monkey_patched_nms, monkey_patched_roi_align + +__all__ = ["monkey_patched_nms", "monkey_patched_roi_align"] diff --git a/src/otx/algo/detection/utils/bbox_overlaps.py b/src/otx/algo/detection/utils/bbox_overlaps.py new file mode 100644 index 00000000000..d250e3102c1 --- /dev/null +++ b/src/otx/algo/detection/utils/bbox_overlaps.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Overlap between bboxes calculation function.""" + +from __future__ import annotations + +import torch + + +def fp16_clamp(x: torch.Tensor, min_value: int | None = None, max_value: int | None = None) -> torch.Tensor: + """Clamp for cpu float16, tensor fp16 has no clamp implementation.""" + if not x.is_cuda and x.dtype == torch.float16: + return x.float().clamp(min_value, max_value).half() + + return x.clamp(min_value, max_value) + + +def bbox_overlaps( + bboxes1: torch.Tensor, + bboxes2: torch.Tensor, + mode: str = "iou", + is_aligned: bool = False, + eps: float = 1e-6, +) -> torch.Tensor: + """Calculate overlap between two set of bboxes. + + FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889 + Note: + Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou', + there are some new generated variable when calculating IOU + using bbox_overlaps function: + + 1) is_aligned is False + area1: M x 1 + area2: N x 1 + lt: M x N x 2 + rb: M x N x 2 + wh: M x N x 2 + overlap: M x N x 1 + union: M x N x 1 + ious: M x N x 1 + + Total memory: + S = (9 x N x M + N + M) * 4 Byte, + + When using FP16, we can reduce: + R = (9 x N x M + N + M) * 4 / 2 Byte + R large than (N + M) * 4 * 2 is always true when N and M >= 1. + Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2, + N + 1 < 3 * N, when N or M is 1. + + Given M = 40 (ground truth), N = 400000 (three anchor boxes + in per grid, FPN, R-CNNs), + R = 275 MB (one times) + + A special case (dense detection), M = 512 (ground truth), + R = 3516 MB = 3.43 GB + + When the batch size is B, reduce: + B x R + + Therefore, CUDA memory runs out frequently. + + Experiments on GeForce RTX 2080Ti (11019 MiB): + + | dtype | M | N | Use | Real | Ideal | + |:----:|:----:|:----:|:----:|:----:|:----:| + | FP32 | 512 | 400000 | 8020 MiB | -- | -- | + | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB | + | FP32 | 40 | 400000 | 1540 MiB | -- | -- | + | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB | + + 2) is_aligned is True + area1: N x 1 + area2: N x 1 + lt: N x 2 + rb: N x 2 + wh: N x 2 + overlap: N x 1 + union: N x 1 + ious: N x 1 + + Total memory: + S = 11 x N * 4 Byte + + When using FP16, we can reduce: + R = 11 x N * 4 / 2 Byte + + So do the 'giou' (large than 'iou'). + + Time-wise, FP16 is generally faster than FP32. + + When gpu_assign_thr is not -1, it takes more time on cpu + but not reduce memory. + There, we can reduce half the memory and keep the speed. + + If ``is_aligned`` is ``False``, then calculate the overlaps between each + bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned + pair of bboxes1 and bboxes2. + + Args: + bboxes1 (Tensor): shape (B, m, 4) in format or empty. + bboxes2 (Tensor): shape (B, n, 4) in format or empty. + B indicates the batch dim, in shape (B1, B2, ..., Bn). + If ``is_aligned`` is ``True``, then m and n must be equal. + mode (str): "iou" (intersection over union), "iof" (intersection over + foreground) or "giou" (generalized intersection over union). + Default "iou". + is_aligned (bool, optional): If True, then m and n must be equal. + Default False. + eps (float, optional): A value added to the denominator for numerical + stability. Default 1e-6. + + Returns: + Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) + + Example: + >>> bboxes1 = torch.FloatTensor([ + >>> [0, 0, 10, 10], + >>> [10, 10, 20, 20], + >>> [32, 32, 38, 42], + >>> ]) + >>> bboxes2 = torch.FloatTensor([ + >>> [0, 0, 10, 20], + >>> [0, 10, 10, 19], + >>> [10, 10, 20, 20], + >>> ]) + >>> overlaps = bbox_overlaps(bboxes1, bboxes2) + >>> assert overlaps.shape == (3, 3) + >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True) + >>> assert overlaps.shape == (3, ) + + Example: + >>> empty = torch.empty(0, 4) + >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]]) + >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) + >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) + >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) + """ + batch_shape = bboxes1.shape[:-2] + + rows = bboxes1.size(-2) + cols = bboxes2.size(-2) + + if rows * cols == 0: + if is_aligned: + return bboxes1.new((*batch_shape, rows)) + return bboxes1.new((*batch_shape, rows, cols)) + + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) + + if is_aligned: + lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2] + rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2] + + wh = fp16_clamp(rb - lt, min_value=0) + overlap = wh[..., 0] * wh[..., 1] + + union = area1 + area2 - overlap if mode in ["iou", "giou"] else area1 + if mode == "giou": + enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2]) + enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:]) + else: + lt = torch.max(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) # [B, rows, cols, 2] + rb = torch.min(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] + + wh = fp16_clamp(rb - lt, min_value=0) + overlap = wh[..., 0] * wh[..., 1] + + union = area1[..., None] + area2[..., None, :] - overlap if mode in ["iou", "giou"] else area1[..., None] + if mode == "giou": + enclosed_lt = torch.min(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) + enclosed_rb = torch.max(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) + + eps = union.new_tensor([eps]) + union = torch.max(union, eps) + ious = overlap / union + if mode in ["iou", "iof"]: + return ious + # calculate gious + enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min_value=0) + enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] + enclose_area = torch.max(enclose_area, eps) + return ious - (enclose_area - union) / enclose_area diff --git a/src/otx/algo/detection/utils/mmcv_patched_ops.py b/src/otx/algo/detection/utils/mmcv_patched_ops.py new file mode 100644 index 00000000000..ec3a884232d --- /dev/null +++ b/src/otx/algo/detection/utils/mmcv_patched_ops.py @@ -0,0 +1,73 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""utils for detection task.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from mmcv.utils import ext_loader +from torchvision.ops import nms as tv_nms +from torchvision.ops import roi_align as tv_roi_align + +if TYPE_CHECKING: + from mmcv.ops.nms import NMSop + from mmcv.ops.roi_align import RoIAlign + +ext_module = ext_loader.load_ext("_ext", ["nms", "softnms", "nms_match", "nms_rotated", "nms_quadri"]) + + +def monkey_patched_nms( + ctx: NMSop, + bboxes: torch.Tensor, + scores: torch.Tensor, + iou_threshold: float, + offset: float, + score_threshold: float, + max_num: int, +) -> torch.Tensor: + """Runs MMCVs NMS with torchvision.nms, or forces NMS from MMCV to run on CPU.""" + _ = ctx + is_filtering_by_score = score_threshold > 0 + if is_filtering_by_score: + valid_mask = scores > score_threshold + bboxes, scores = bboxes[valid_mask], scores[valid_mask] + valid_inds = torch.nonzero(valid_mask, as_tuple=False).squeeze(dim=1) + + if bboxes.dtype == torch.bfloat16: + bboxes = bboxes.to(torch.float32) + if scores.dtype == torch.bfloat16: + scores = scores.to(torch.float32) + + if offset == 0: + inds = tv_nms(bboxes, scores, float(iou_threshold)) + else: + device = bboxes.device + bboxes = bboxes.to("cpu") + scores = scores.to("cpu") + inds = ext_module.nms(bboxes, scores, iou_threshold=float(iou_threshold), offset=offset) + bboxes = bboxes.to(device) + scores = scores.to(device) + + if max_num > 0: + inds = inds[:max_num] + if is_filtering_by_score: + inds = valid_inds[inds] + return inds + + +def monkey_patched_roi_align(self: RoIAlign, _input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor: + """Replaces MMCVs roi align with the one from torchvision. + + Args: + self: patched instance + _input: NCHW images + rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy. + """ + if "aligned" in tv_roi_align.__code__.co_varnames: + return tv_roi_align(_input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned) + if self.aligned: + rois -= rois.new_tensor([0.0] + [0.5 / self.spatial_scale] * 4) + return tv_roi_align(_input, rois, self.output_size, self.spatial_scale, self.sampling_ratio) diff --git a/src/otx/algo/detection/utils/structures.py b/src/otx/algo/detection/utils/structures.py new file mode 100644 index 00000000000..5a3b57c011a --- /dev/null +++ b/src/otx/algo/detection/utils/structures.py @@ -0,0 +1,172 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Data structures for detection task.""" + +from __future__ import annotations + +from typing import Any + +import torch +from torch import Tensor + + +class AssignResult: + """Stores assignments between predicted and truth boxes. + + Attributes: + num_gts (int): the number of truth boxes considered when computing this + assignment + gt_inds (Tensor): for each predicted box indicates the 1-based + index of the assigned truth box. 0 means unassigned and -1 means + ignore. + max_overlaps (Tensor): the iou between the predicted box and its + assigned truth box. + labels (Tensor): If specified, for each predicted box + indicates the category label of the assigned truth box. + """ + + def __init__(self, num_gts: int, gt_inds: Tensor, max_overlaps: Tensor, labels: Tensor) -> None: + self.num_gts = num_gts + self.gt_inds = gt_inds + self.max_overlaps = max_overlaps + self.labels = labels + # Interface for possible user-defined properties + self._extra_properties: dict[str, Any] = {} + + @property + def num_preds(self) -> int: + """int: the number of predictions in this assignment.""" + return len(self.gt_inds) + + def set_extra_property(self, key: str, value: Any) -> None: # noqa: ANN401 + """Set user-defined new property.""" + self._extra_properties[key] = value + + def get_extra_property(self, key: str) -> Any: # noqa: ANN401 + """Get user-defined property.""" + return self._extra_properties.get(key, None) + + @property + def info(self) -> dict: + """Return a dictionary of info about the object.""" + basic_info = { + "num_gts": self.num_gts, + "num_preds": self.num_preds, + "gt_inds": self.gt_inds, + "max_overlaps": self.max_overlaps, + "labels": self.labels, + } + basic_info.update(self._extra_properties) + return basic_info + + def add_gt_(self, gt_labels: Tensor) -> None: + """Add ground truth as assigned results. + + Args: + gt_labels (torch.Tensor): Labels of gt boxes + """ + self_inds = torch.arange(1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) + self.gt_inds = torch.cat([self_inds, self.gt_inds]) + + self.max_overlaps = torch.cat([self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps]) + + self.labels = torch.cat([gt_labels, self.labels]) + + +class SamplingResult: + """Bbox sampling result. + + Args: + pos_inds (Tensor): Indices of positive samples. + neg_inds (Tensor): Indices of negative samples. + priors (Tensor): The priors can be anchors or points, + or the bboxes predicted by the previous stage. + gt_bboxes (Tensor): Ground truth of bboxes. + assign_result (:obj:`AssignResult`): Assigning results. + gt_flags (Tensor): The Ground truth flags. + avg_factor_with_neg (bool): If True, ``avg_factor`` equal to + the number of total priors; Otherwise, it is the number of + positive priors. Defaults to True. + """ + + def __init__( + self, + pos_inds: Tensor, + neg_inds: Tensor, + priors: Tensor, + gt_bboxes: Tensor, + assign_result: AssignResult, + gt_flags: Tensor, + avg_factor_with_neg: bool = True, + ) -> None: + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.num_pos = max(pos_inds.numel(), 1) + self.num_neg = max(neg_inds.numel(), 1) + self.avg_factor_with_neg = avg_factor_with_neg + self.avg_factor = self.num_pos + self.num_neg if avg_factor_with_neg else self.num_pos + self.pos_priors = priors[pos_inds] + self.neg_priors = priors[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + self.pos_gt_labels = assign_result.labels[pos_inds] + box_dim = 4 + if gt_bboxes.numel() == 0: + self.pos_gt_bboxes = gt_bboxes.view(-1, box_dim) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, box_dim) + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long()] + + @property + def priors(self) -> Tensor: + """torch.Tensor: concatenated positive and negative priors.""" + return torch.cat([self.pos_priors, self.neg_priors]) + + @property + def bboxes(self) -> Tensor: + """torch.Tensor: concatenated positive and negative boxes.""" + return self.priors + + @property + def pos_bboxes(self) -> Tensor: + """Return positive box pairs.""" + return self.pos_priors + + @property + def neg_bboxes(self) -> Tensor: + """Return negative box pairs.""" + return self.neg_priors + + def to(self, device: str | torch.device) -> SamplingResult: + """Change the device of the data inplace. + + Example: + >>> self = SamplingResult.random() + >>> print(f'self = {self.to(None)}') + >>> # xdoctest: +REQUIRES(--gpu) + >>> print(f'self = {self.to(0)}') + """ + _dict = self.__dict__ + for key, value in _dict.items(): + if isinstance(value, torch.Tensor): + _dict[key] = value.to(device) + return self + + @property + def info(self) -> dict: + """Returns a dictionary of info about the object.""" + return { + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_priors": self.pos_priors, + "neg_priors": self.neg_priors, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, + "num_pos": self.num_pos, + "num_neg": self.num_neg, + "avg_factor": self.avg_factor, + } diff --git a/src/otx/algo/detection/utils/utils.py b/src/otx/algo/detection/utils/utils.py new file mode 100644 index 00000000000..5a869cde8e6 --- /dev/null +++ b/src/otx/algo/detection/utils/utils.py @@ -0,0 +1,212 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Utils for otx detection algo.""" + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Callable + +import torch +from torch import Tensor + +if TYPE_CHECKING: + from mmengine.structures import InstanceData + + +# Methods below come from mmdet.utils and slightly modified. +# https://github.com/open-mmlab/mmdetection/blob/3.x/mmdet/models/utils/misc.py +def multi_apply(func: Callable, *args, **kwargs) -> tuple: + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) # type: ignore[call-overload] + return tuple(map(list, zip(*map_results))) + + +def anchor_inside_flags( + flat_anchors: Tensor, + valid_flags: Tensor, + img_shape: tuple[int, ...], + allowed_border: int = 0, +) -> Tensor: + """Check whether the anchors are inside the border. + + Args: + flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4). + valid_flags (torch.Tensor): An existing valid flags of anchors. + img_shape (tuple(int)): Shape of current image. + allowed_border (int): The border to allow the valid anchor. + Defaults to 0. + + Returns: + torch.Tensor: Flags indicating whether the anchors are inside a \ + valid range. + """ + img_h, img_w = img_shape[:2] + if allowed_border >= 0: + inside_flags = ( + valid_flags + & (flat_anchors[:, 0] >= -allowed_border) + & (flat_anchors[:, 1] >= -allowed_border) + & (flat_anchors[:, 2] < img_w + allowed_border) + & (flat_anchors[:, 3] < img_h + allowed_border) + ) + else: + inside_flags = valid_flags + return inside_flags + + +def images_to_levels(target: list[Tensor], num_levels: list[int]) -> list[Tensor]: + """Convert targets by image to targets by feature level. + + [target_img0, target_img1] -> [target_level0, target_level1, ...] + """ + stacked_target = torch.stack(target, 0) + level_targets = [] + start = 0 + for n in num_levels: + end = start + n + # level_targets.append(target[:, start:end].squeeze(0)) + level_targets.append(stacked_target[:, start:end]) + start = end + return level_targets + + +def unmap(data: Tensor, count: int, inds: Tensor, fill: int = 0) -> Tensor: + """Unmap a subset of item (data) back to the original set of items (of size count).""" + if data.dim() == 1: + ret = data.new_full((count,), fill) + ret[inds.type(torch.bool)] = data + else: + new_size = (count,) + data.size()[1:] + ret = data.new_full(new_size, fill) + ret[inds.type(torch.bool), :] = data + return ret + + +def filter_scores_and_topk( + scores: Tensor, + score_thr: float, + topk: int, + results: dict | list | Tensor | None = None, +) -> tuple[Tensor, Tensor, Tensor, dict | list | Tensor | None]: + """Filter results using score threshold and topk candidates. + + Args: + scores (Tensor): The scores, shape (num_bboxes, K). + score_thr (float): The score filter threshold. + topk (int): The number of topk candidates. + results (dict or list or Tensor, Optional): The results to + which the filtering rule is to be applied. The shape + of each item is (num_bboxes, N). + + Returns: + tuple: Filtered results + - scores (Tensor): The scores after being filtered, \ + shape (num_bboxes_filtered, ). + - labels (Tensor): The class labels, shape \ + (num_bboxes_filtered, ). + - anchor_idxs (Tensor): The anchor indexes, shape \ + (num_bboxes_filtered, ). + - filtered_results (dict or list or Tensor, Optional): \ + The filtered results. The shape of each item is \ + (num_bboxes_filtered, N). + """ + valid_mask = scores > score_thr + scores = scores[valid_mask] + valid_idxs = torch.nonzero(valid_mask) + + num_topk = min(topk, valid_idxs.size(0)) + # torch.sort is actually faster than .topk (at least on GPUs) + scores, idxs = scores.sort(descending=True) + scores = scores[:num_topk] + topk_idxs = valid_idxs[idxs[:num_topk]] + keep_idxs, labels = topk_idxs.unbind(dim=1) + + filtered_results: dict | list | Tensor | None = None + if results is not None: + if isinstance(results, dict): + filtered_results = {k: v[keep_idxs] for k, v in results.items()} + elif isinstance(results, list): + filtered_results = [result[keep_idxs] for result in results] + elif isinstance(results, torch.Tensor): + filtered_results = results[keep_idxs] + else: + msg = f"Only supports dict or list or Tensor, but get {type(results)}." + raise NotImplementedError(msg) + return scores, labels, keep_idxs, filtered_results + + +def select_single_mlvl(mlvl_tensors: list[Tensor], batch_id: int, detach: bool = True) -> list[Tensor]: + """Extract a multi-scale single image tensor from a multi-scale batch tensor based on batch index. + + Note: The default value of detach is True, because the proposal gradient + needs to be detached during the training of the two-stage model. E.g + Cascade Mask R-CNN. + + Args: + mlvl_tensors (list[Tensor]): Batch tensor for all scale levels, + each is a 4D-tensor. + batch_id (int): Batch index. + detach (bool): Whether detach gradient. Default True. + + Returns: + list[Tensor]: Multi-scale single image tensor. + """ + num_levels = len(mlvl_tensors) + + if detach: + mlvl_tensor_list = [mlvl_tensors[i][batch_id].detach() for i in range(num_levels)] + else: + mlvl_tensor_list = [mlvl_tensors[i][batch_id] for i in range(num_levels)] + return mlvl_tensor_list + + +def unpack_gt_instances(batch_data_samples: list[InstanceData]) -> tuple: + """Unpack gt_instances, gt_instances_ignore and img_metas based on batch_data_samples. + + Args: + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple: + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + - batch_gt_instances_ignore (list[:obj:`InstanceData`]): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + - batch_img_metas (list[dict]): Meta information of each image, + e.g., image size, scaling factor, etc. + """ + batch_gt_instances = [] + batch_gt_instances_ignore = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + if "ignored_instances" in data_sample: + batch_gt_instances_ignore.append(data_sample.ignored_instances) + else: + batch_gt_instances_ignore.append(None) + + return batch_gt_instances, batch_gt_instances_ignore, batch_img_metas diff --git a/src/otx/algo/detection/yolox.py b/src/otx/algo/detection/yolox.py index 02e9ca6ff9a..15b56b2bbae 100644 --- a/src/otx/algo/detection/yolox.py +++ b/src/otx/algo/detection/yolox.py @@ -5,14 +5,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from copy import deepcopy +from typing import TYPE_CHECKING, Literal from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.mmdeploy import MMdeployExporter from otx.core.metrics.mean_ap import MeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.detection import MMDetCompatibleModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -46,16 +50,27 @@ def __init__( self.tile_image_size = self.image_size @property - def _export_parameters(self) -> dict[str, Any]: - """Parameters for an exporter.""" - export_params = super()._export_parameters - export_params["resize_mode"] = "fit_to_window_letterbox" - export_params["pad_value"] = 114 - export_params["swap_rgb"] = True - export_params["input_size"] = self.image_size - export_params["deploy_cfg"] = "otx.algo.detection.mmdeploy.yolox" - - return export_params + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + if self.image_size is None: + raise ValueError(self.image_size) + + mean, std = get_mean_std_from_data_processing(self.config) + + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.detection.mmdeploy.yolox", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="fit_to_window_letterbox", + pad_value=114, + swap_rgb=True, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" @@ -87,13 +102,24 @@ def __init__( self.tile_image_size = self.image_size @property - def _export_parameters(self) -> dict[str, Any]: - """Parameters for an exporter.""" - export_params = super()._export_parameters - export_params["resize_mode"] = "fit_to_window_letterbox" - export_params["pad_value"] = 114 - export_params["swap_rgb"] = False - export_params["input_size"] = self.image_size - export_params["deploy_cfg"] = "otx.algo.detection.mmdeploy.yolox_tiny" - - return export_params + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + if self.image_size is None: + raise ValueError(self.image_size) + + mean, std = get_mean_std_from_data_processing(self.config) + + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.detection.mmdeploy.yolox_tiny", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="fit_to_window_letterbox", + pad_value=114, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) diff --git a/src/otx/algo/explain/__init__.py b/src/otx/algo/explain/__init__.py new file mode 100644 index 00000000000..dc7d46c3365 --- /dev/null +++ b/src/otx/algo/explain/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Module for OTX XAI algorithms.""" diff --git a/src/otx/algo/hooks/recording_forward_hook.py b/src/otx/algo/explain/explain_algo.py similarity index 81% rename from src/otx/algo/hooks/recording_forward_hook.py rename to src/otx/algo/explain/explain_algo.py index ae51d69f8c5..54ab8bccc2a 100644 --- a/src/otx/algo/hooks/recording_forward_hook.py +++ b/src/otx/algo/explain/explain_algo.py @@ -1,20 +1,19 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # -"""Hooks for recording/updating model internal activations.""" +"""Algorithms for calculcalating XAI branch for Explainable AI.""" from __future__ import annotations from typing import TYPE_CHECKING, Callable -import numpy as np import torch from otx.core.types.explain import FeatureMapType if TYPE_CHECKING: + import numpy as np from mmengine.structures.instance_data import InstanceData - from torch.utils.hooks import RemovableHandle HeadForwardFn = Callable[[FeatureMapType], torch.Tensor] ExplainerForwardFn = HeadForwardFn @@ -34,7 +33,7 @@ def get_feature_vector(feature_map: FeatureMapType) -> torch.Tensor: return torch.nn.functional.adaptive_avg_pool2d(feature_map, (1, 1)).flatten(start_dim=1) -class BaseRecordingForwardHook: +class BaseExplainAlgo: """While registered with the designated PyTorch module, this class caches feature vector during forward pass. Args: @@ -43,19 +42,8 @@ class BaseRecordingForwardHook: def __init__(self, head_forward_fn: HeadForwardFn | None = None, normalize: bool = True) -> None: self._head_forward_fn = head_forward_fn - self.handle: RemovableHandle | None = None - self._records: list[torch.Tensor] = [] self._norm_saliency_maps = normalize - @property - def records(self) -> list[torch.Tensor]: - """Return records.""" - return self._records - - def reset(self) -> None: - """Clear all history of records.""" - self._records.clear() - def func(self, feature_map: torch.Tensor, fpn_idx: int = -1) -> torch.Tensor: """This method get the feature vector or saliency map from the output of the module. @@ -69,25 +57,6 @@ def func(self, feature_map: torch.Tensor, fpn_idx: int = -1) -> torch.Tensor: """ raise NotImplementedError - def recording_forward( - self, - _: torch.nn.Module, - x: torch.Tensor, - output: torch.Tensor, - ) -> None: # pylint: disable=unused-argument - """Record the XAI result during executing model forward function.""" - tensors = self.func(output) - if isinstance(tensors, torch.Tensor): - tensors_np = tensors.detach().cpu().numpy() - elif isinstance(tensors, np.ndarray): - tensors_np = tensors - else: - self._torch_to_numpy_from_list(tensors) - tensors_np = tensors - - for tensor in tensors_np: - self._records.append(tensor) - def _predict_from_feature_map(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): if self._head_forward_fn: @@ -96,14 +65,6 @@ def _predict_from_feature_map(self, x: torch.Tensor) -> torch.Tensor: x = torch.tensor(x) return x - def _torch_to_numpy_from_list(self, tensor_list: list[torch.Tensor | None]) -> None: - for i in range(len(tensor_list)): - tensor = tensor_list[i] - if isinstance(tensor, list): - self._torch_to_numpy_from_list(tensor) - elif isinstance(tensor, torch.Tensor): - tensor_list[i] = tensor.detach().cpu().numpy() - @staticmethod def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor: """Normalize saliency maps.""" @@ -116,18 +77,8 @@ def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor: return saliency_map.to(torch.uint8) -class ActivationMapHook(BaseRecordingForwardHook): - """ActivationMapHook. Mean of the feature map along the channel dimension.""" - - @classmethod - def create_and_register_hook( - cls, - backbone: torch.nn.Module, - ) -> BaseRecordingForwardHook: - """Create this object and register it to the module forward hook.""" - hook = cls() - hook.handle = backbone.register_forward_hook(hook.recording_forward) - return hook +class ActivationMap(BaseExplainAlgo): + """ActivationMap. Mean of the feature map along the channel dimension.""" def func(self, feature_map: FeatureMapType, fpn_idx: int = -1) -> torch.Tensor: """Generate the saliency map by average feature maps then normalizing to (0, 255).""" @@ -144,7 +95,7 @@ def func(self, feature_map: FeatureMapType, fpn_idx: int = -1) -> torch.Tensor: return activation_map.reshape((batch_size, h, w)) -class ReciproCAMHook(BaseRecordingForwardHook): +class ReciproCAM(BaseExplainAlgo): """Implementation of Recipro-CAM for class-wise saliency map. Recipro-CAM: gradient-free reciprocal class activation map (https://arxiv.org/pdf/2209.14074.pdf) @@ -161,23 +112,6 @@ def __init__( self._num_classes = num_classes self._optimize_gap = optimize_gap - @classmethod - def create_and_register_hook( - cls, - backbone: torch.nn.Module, - head_forward_fn: HeadForwardFn, - num_classes: int, - optimize_gap: bool, - ) -> BaseRecordingForwardHook: - """Create this object and register it to the module forward hook.""" - hook = cls( - head_forward_fn, - num_classes=num_classes, - optimize_gap=optimize_gap, - ) - hook.handle = backbone.register_forward_hook(hook.recording_forward) - return hook - def func(self, feature_map: FeatureMapType, fpn_idx: int = -1) -> torch.Tensor: """Generate the class-wise saliency maps using Recipro-CAM and then normalizing to (0, 255). @@ -226,7 +160,7 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w: return mosaic_feature_map -class ViTReciproCAMHook(BaseRecordingForwardHook): +class ViTReciproCAM(BaseExplainAlgo): """Implementation of ViTRecipro-CAM for class-wise saliency map for transformer-based classifiers. Args: @@ -251,21 +185,6 @@ def __init__( self._use_gaussian = use_gaussian self._cls_token = cls_token - @classmethod - def create_and_register_hook( - cls, - target_layernorm: torch.nn.Module, - head_forward_fn: HeadForwardFn, - num_classes: int, - ) -> BaseRecordingForwardHook: - """Create this object and register it to the module forward hook.""" - hook = cls( - head_forward_fn, - num_classes=num_classes, - ) - hook.handle = target_layernorm.register_forward_hook(hook.recording_forward) - return hook - def func(self, feature_map: torch.Tensor, _: int = -1) -> torch.Tensor: """Generate the class-wise saliency maps using ViTRecipro-CAM and then normalizing to (0, 255). @@ -328,8 +247,8 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor) -> torch.Tensor: return mosaic_feature_map -class DetClassProbabilityMapHook(BaseRecordingForwardHook): - """Saliency map hook for object detection models.""" +class DetClassProbabilityMap(BaseExplainAlgo): + """Saliency map generation algo for object detection models.""" def __init__( self, @@ -392,8 +311,8 @@ def func( return saliency_map.reshape((batch_size, self._num_classes, height, width)) -class MaskRCNNRecordingForwardHook(BaseRecordingForwardHook): - """Dummy saliency map hook for Mask R-CNN model.""" +class MaskRCNNExplainAlgo(BaseExplainAlgo): + """Dummy saliency map algo for Mask R-CNN model.""" def __init__(self, num_classes: int) -> None: super().__init__() @@ -422,19 +341,22 @@ def func( @classmethod def average_and_normalize( cls, - pred: InstanceData, + pred: InstanceData | dict[str, torch.Tensor], num_classes: int, ) -> np.array: """Average and normalize masks in prediction per-class. Args: - preds (InstanceData): Predictions of Instance Segmentation model. + preds (InstanceData | dict): Predictions of Instance Segmentation model. num_classes (int): Num classes that model can predict. Returns: np.array: Class-wise Saliency Maps. One saliency map per each class - [class_id, H, W] """ - masks, scores, labels = (pred.masks, pred.scores, pred.labels) + if isinstance(pred, dict): + masks, scores, labels = pred["masks"], pred["scores"], pred["labels"] + else: + masks, scores, labels = (pred.masks, pred.scores, pred.labels) _, height, width = masks.shape saliency_map = torch.zeros((num_classes, height, width), dtype=torch.float32, device=labels.device) diff --git a/src/otx/algo/hooks/__init__.py b/src/otx/algo/hooks/__init__.py deleted file mode 100644 index de1331bdce2..00000000000 --- a/src/otx/algo/hooks/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# -"""Module for OTX custom hooks.""" diff --git a/src/otx/algo/instance_segmentation/maskrcnn.py b/src/otx/algo/instance_segmentation/maskrcnn.py index 8df321b8f0d..a25e4f86634 100644 --- a/src/otx/algo/instance_segmentation/maskrcnn.py +++ b/src/otx/algo/instance_segmentation/maskrcnn.py @@ -5,14 +5,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from copy import deepcopy +from typing import TYPE_CHECKING, Literal from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.mmdeploy import MMdeployExporter from otx.core.metrics.mean_ap import MaskRLEMeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.instance_segmentation import MMDetInstanceSegCompatibleModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -46,16 +50,27 @@ def __init__( self.tile_image_size = (1, 3, 512, 512) @property - def _export_parameters(self) -> dict[str, Any]: - """Parameters for an exporter.""" - export_params = super()._export_parameters - export_params["deploy_cfg"] = "otx.algo.instance_segmentation.mmdeploy.maskrcnn" - export_params["input_size"] = self.image_size - export_params["resize_mode"] = "standard" # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - - return export_params + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + if self.image_size is None: + raise ValueError(self.image_size) + + mean, std = get_mean_std_from_data_processing(self.config) + + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving + pad_value=0, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" @@ -87,13 +102,24 @@ def __init__( self.tile_image_size = (1, 3, 512, 512) @property - def _export_parameters(self) -> dict[str, Any]: - """Parameters for an exporter.""" - export_params = super()._export_parameters - export_params["deploy_cfg"] = "otx.algo.instance_segmentation.mmdeploy.maskrcnn_swint" - export_params["input_size"] = self.image_size - export_params["resize_mode"] = "standard" # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - - return export_params + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + if self.image_size is None: + raise ValueError(self.image_size) + + mean, std = get_mean_std_from_data_processing(self.config) + + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn_swint", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving + pad_value=0, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) diff --git a/src/otx/algo/instance_segmentation/rtmdet_inst.py b/src/otx/algo/instance_segmentation/rtmdet_inst.py index d54eab3091b..7a751bfe3a0 100644 --- a/src/otx/algo/instance_segmentation/rtmdet_inst.py +++ b/src/otx/algo/instance_segmentation/rtmdet_inst.py @@ -5,13 +5,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from copy import deepcopy +from typing import TYPE_CHECKING, Literal from otx.algo.utils.mmconfig import read_mmconfig +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.mmdeploy import MMdeployExporter from otx.core.metrics.mean_ap import MaskRLEMeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.instance_segmentation import MMDetInstanceSegCompatibleModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -45,13 +49,24 @@ def __init__( self.tile_image_size = self.image_size @property - def _export_parameters(self) -> dict[str, Any]: - """Parameters for an exporter.""" - export_params = super()._export_parameters - export_params["deploy_cfg"] = "otx.algo.instance_segmentation.mmdeploy.rtmdet_inst" - export_params["input_size"] = self.image_size - export_params["resize_mode"] = "fit_to_window_letterbox" - export_params["pad_value"] = 114 - export_params["swap_rgb"] = False - - return export_params + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + if self.image_size is None: + raise ValueError(self.image_size) + + mean, std = get_mean_std_from_data_processing(self.config) + + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.instance_segmentation.mmdeploy.rtmdet_inst", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="fit_to_window_letterbox", + pad_value=114, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) diff --git a/src/otx/algo/plugins/__init__.py b/src/otx/algo/plugins/__init__.py new file mode 100644 index 00000000000..91be640aea6 --- /dev/null +++ b/src/otx/algo/plugins/__init__.py @@ -0,0 +1,8 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Plugin for mixed-precision training on XPU.""" + +from .xpu_precision import MixedPrecisionXPUPlugin + +__all__ = ["MixedPrecisionXPUPlugin"] diff --git a/src/otx/algo/plugins/xpu_precision.py b/src/otx/algo/plugins/xpu_precision.py new file mode 100644 index 00000000000..fb2a08eb182 --- /dev/null +++ b/src/otx/algo/plugins/xpu_precision.py @@ -0,0 +1,117 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Plugin for mixed-precision training on XPU.""" + +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Callable, Generator + +import torch +from lightning.pytorch.plugins.precision.precision import Precision +from lightning.pytorch.utilities import GradClipAlgorithmType +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from torch import Tensor +from torch.optim import LBFGS, Optimizer + +if TYPE_CHECKING: + import lightning.pytorch as pl + from lightning_fabric.utilities.types import Optimizable + + +class MixedPrecisionXPUPlugin(Precision): + """Plugin for Automatic Mixed Precision (AMP) training with ``torch.xpu.autocast``. + + Args: + scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. + """ + + def __init__(self, scaler: torch.cuda.amp.GradScaler | None = None) -> None: + self.scaler = scaler + + def pre_backward(self, tensor: Tensor, module: pl.LightningModule) -> Tensor: + """Apply grad scaler before backward.""" + if self.scaler is not None: + tensor = self.scaler.scale(tensor) + return super().pre_backward(tensor, module) + + def optimizer_step( # type: ignore[override] + self, + optimizer: Optimizable, + model: pl.LightningModule, + closure: Callable, + **kwargs: dict, + ) -> None | dict: + """Make an optimizer step using scaler if it was passed.""" + if self.scaler is None: + # skip scaler logic, as bfloat16 does not require scaler + return super().optimizer_step( + optimizer, + model=model, + closure=closure, + **kwargs, + ) + if isinstance(optimizer, LBFGS): + msg = "Native AMP and the LBFGS optimizer are not compatible." + raise MisconfigurationException( + msg, + ) + closure_result = closure() + + if not _optimizer_handles_unscaling(optimizer): + # Unscaling needs to be performed here in case we are going to apply gradient clipping. + # Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam). + # Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook. + self.scaler.unscale_(optimizer) + + self._after_closure(model, optimizer) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found + step_output = self.scaler.step(optimizer, **kwargs) + self.scaler.update() + return step_output + return closure_result + + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: int | float = 0.0, + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + ) -> None: + """Handle grad clipping with scaler.""" + if clip_val > 0 and _optimizer_handles_unscaling(optimizer): + msg = f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping" + " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" + raise RuntimeError( + msg, + ) + super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) + + @contextmanager + def forward_context(self) -> Generator[None, None, None]: + """Enable autocast context.""" + with torch.xpu.autocast(True): + yield + + def state_dict(self) -> dict[str, Any]: + """Returns state dict of the plugin.""" + if self.scaler is not None: + return self.scaler.state_dict() + return {} + + def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: + """Loads state dict to the plugin.""" + if self.scaler is not None: + self.scaler.load_state_dict(state_dict) + + +def _optimizer_handles_unscaling(optimizer: torch.optim.Optimizer) -> bool: + """Determines if a PyTorch optimizer handles unscaling gradients in the step method ratherthan through the scaler. + + Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return + value will only be reliable for built-in PyTorch optimizers. + """ + return getattr(optimizer, "_step_supports_amp_scaling", False) diff --git a/src/otx/algo/segmentation/dino_v2_seg.py b/src/otx/algo/segmentation/dino_v2_seg.py index 257ff2dd9cc..cdbee2987bd 100644 --- a/src/otx/algo/segmentation/dino_v2_seg.py +++ b/src/otx/algo/segmentation/dino_v2_seg.py @@ -7,10 +7,13 @@ from typing import TYPE_CHECKING, Any from otx.algo.utils.mmconfig import read_mmconfig +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics.dice import SegmCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.segmentation import MMSegCompatibleModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -39,13 +42,25 @@ def __init__( metric=metric, torch_compile=torch_compile, ) + self.image_size = (1, 3, 560, 560) @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - parent_parameters = super()._export_parameters - parent_parameters["input_size"] = (1, 3, 560, 560) - return parent_parameters + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + mean, std = get_mean_std_from_data_processing(self.config) + + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=None, + ) @property def _optimization_config(self) -> dict[str, Any]: diff --git a/src/otx/algo/segmentation/litehrnet.py b/src/otx/algo/segmentation/litehrnet.py index 0d9f3e019b5..1116ae74c26 100644 --- a/src/otx/algo/segmentation/litehrnet.py +++ b/src/otx/algo/segmentation/litehrnet.py @@ -11,10 +11,13 @@ from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics.dice import SegmCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.segmentation import MMSegCompatibleModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -46,17 +49,22 @@ def __init__( ) @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - parent_parameters = super()._export_parameters - parent_parameters.update( - { - "onnx_export_configuration": {"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK}, - "via_onnx": True, - }, - ) + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + mean, std = get_mean_std_from_data_processing(self.config) - return parent_parameters + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=True, + onnx_export_configuration={"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK}, + output_names=None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" diff --git a/src/otx/algo/segmentation/segnext.py b/src/otx/algo/segmentation/segnext.py index 26ff0152612..03c2afaaffa 100644 --- a/src/otx/algo/segmentation/segnext.py +++ b/src/otx/algo/segmentation/segnext.py @@ -8,10 +8,13 @@ from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics.dice import SegmCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.segmentation import MMSegCompatibleModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -42,6 +45,24 @@ def __init__( torch_compile=torch_compile, ) + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + mean, std = get_mean_std_from_data_processing(self.config) + + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=None, + ) + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_seg_segnext_ckpt(state_dict, add_prefix) diff --git a/src/otx/algo/strategies/__init__.py b/src/otx/algo/strategies/__init__.py new file mode 100644 index 00000000000..392a1b82b22 --- /dev/null +++ b/src/otx/algo/strategies/__init__.py @@ -0,0 +1,8 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Lightning strategy for single XPU device.""" + +from .xpu_single import SingleXPUStrategy + +__all__ = ["SingleXPUStrategy"] diff --git a/src/otx/algo/strategies/xpu_single.py b/src/otx/algo/strategies/xpu_single.py new file mode 100644 index 00000000000..4b9501dd36f --- /dev/null +++ b/src/otx/algo/strategies/xpu_single.py @@ -0,0 +1,73 @@ +"""Lightning strategy for single XPU device.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from lightning.pytorch.strategies import StrategyRegistry +from lightning.pytorch.strategies.single_device import SingleDeviceStrategy +from lightning.pytorch.utilities.exceptions import MisconfigurationException + +from otx.utils.utils import is_xpu_available + +if TYPE_CHECKING: + import lightning.pytorch as pl + from lightning.pytorch.plugins.precision import PrecisionPlugin + from lightning_fabric.plugins import CheckpointIO + from lightning_fabric.utilities.types import _DEVICE + + +class SingleXPUStrategy(SingleDeviceStrategy): + """Strategy for training on single XPU device.""" + + strategy_name = "xpu_single" + + def __init__( + self, + device: _DEVICE = "xpu:0", + accelerator: pl.accelerators.Accelerator | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: PrecisionPlugin | None = None, + ): + if not is_xpu_available(): + msg = "`SingleXPUStrategy` requires XPU devices to run" + raise MisconfigurationException(msg) + + super().__init__( + accelerator=accelerator, + device=device, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) + + @property + def is_distributed(self) -> bool: + """Returns true if the strategy supports distributed training.""" + return False + + def setup_optimizers(self, trainer: pl.Trainer) -> None: + """Sets up optimizers.""" + super().setup_optimizers(trainer) + if len(self.optimizers) > 1: # type: ignore[has-type] + msg = "XPU strategy doesn't support multiple optimizers" + raise RuntimeError(msg) + if trainer.task != "SEMANTIC_SEGMENTATION": + if len(self.optimizers) == 1: # type: ignore[has-type] + model, optimizer = torch.xpu.optimize(trainer.model, optimizer=self.optimizers[0]) # type: ignore[has-type] + self.optimizers = [optimizer] + self.model = model + else: # for inference + trainer.model.eval() + self.model = torch.xpu.optimize(trainer.model) + + +StrategyRegistry.register( + SingleXPUStrategy.strategy_name, + SingleXPUStrategy, + description="Strategy that enables training on single XPU", +) diff --git a/src/otx/cli/install.py b/src/otx/cli/install.py index 7d0fd49d45f..ae9454089dd 100644 --- a/src/otx/cli/install.py +++ b/src/otx/cli/install.py @@ -64,10 +64,15 @@ def add_install_parser(subcommands_action: _ActionSubCommands) -> None: help="Do not install PyTorch. Choose this option if you already install PyTorch.", action="store_true", ) + subcommands_action.add_subcommand("install", parser, help="Install OTX requirements.") -def otx_install(option: str | None = None, verbose: bool = False, do_not_install_torch: bool = False) -> int: +def otx_install( + option: str | None = None, + verbose: bool = False, + do_not_install_torch: bool = False, +) -> int: """Install OTX requirements. Args: diff --git a/src/otx/core/config/hpo.py b/src/otx/core/config/hpo.py index 87efdcda849..ff763e443a9 100644 --- a/src/otx/core/config/hpo.py +++ b/src/otx/core/config/hpo.py @@ -19,14 +19,13 @@ class HpoConfig: save_path: str | None = None mode: Literal["max", "min"] = "max" num_trials: int | None = None - num_workers: int = 1 + num_workers: int = torch.cuda.device_count() if torch.cuda.is_available() else 1 expected_time_ratio: int | float | None = 4 maximum_resource: int | float | None = None - subset_ratio: float | int | None = None - min_subset_size: int = 500 prior_hyper_parameters: dict | list[dict] | None = None acceptable_additional_time_ratio: float | int = 1.0 minimum_resource: int | float | None = None reduction_factor: int = 3 asynchronous_bracket: bool = True asynchronous_sha: bool = torch.cuda.device_count() != 1 + metric_name: str | None = None diff --git a/src/otx/core/data/dataset/anomaly/dataset.py b/src/otx/core/data/dataset/anomaly/dataset.py index 1d9d6bc443a..8a7a90d5a69 100644 --- a/src/otx/core/data/dataset/anomaly/dataset.py +++ b/src/otx/core/data/dataset/anomaly/dataset.py @@ -26,6 +26,7 @@ from otx.core.data.entity.base import ImageInfo from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER, MemCacheHandlerBase from otx.core.types.image import ImageColorChannel +from otx.core.types.label import LabelInfo from otx.core.types.task import OTXTaskType @@ -53,6 +54,7 @@ def __init__( image_color_channel, stack_images, ) + self.label_info = LabelInfo(label_names=["Normal", "Anomaly"], label_groups=[["Normal", "Anomaly"]]) def _get_item_impl( self, diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index 4ece94158e4..775546b51ed 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -26,6 +26,7 @@ ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingDataEntity, ) +from otx.core.types.label import NullLabelInfo from otx.core.utils.mask_util import polygon_to_bitmap from .base import OTXDataset, Transforms @@ -61,6 +62,8 @@ def __init__( # if using only point prompt self.prob = 0.0 + self.label_info = NullLabelInfo() + def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name) img = item.media_as(dmImage) @@ -189,6 +192,8 @@ def __init__( # if using only point prompt self.prob = 0.0 + self.label_info = NullLabelInfo() + def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None: item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name) img = item.media_as(dmImage) diff --git a/src/otx/core/data/entity/tile.py b/src/otx/core/data/entity/tile.py index e407328ed53..911cd8b70db 100644 --- a/src/otx/core/data/entity/tile.py +++ b/src/otx/core/data/entity/tile.py @@ -6,7 +6,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Sequence, TypeVar +from typing import TYPE_CHECKING, Generic, Sequence from otx.core.types.task import OTXTaskType @@ -20,12 +20,6 @@ from torchvision import tv_tensors -T_OTXTileBatchDataEntity = TypeVar( - "T_OTXTileBatchDataEntity", - bound="OTXTileBatchDataEntity", -) - - @dataclass class TileDataEntity(Generic[T_OTXDataEntity]): """Base data entity for tile task. @@ -66,6 +60,9 @@ def task(self) -> OTXTaskType: return OTXTaskType.DETECTION +TileAttrDictList = list[dict[str, int | str]] + + @dataclass class OTXTileBatchDataEntity(Generic[T_OTXBatchDataEntity]): """Base batch data entity for tile task. @@ -82,10 +79,10 @@ class OTXTileBatchDataEntity(Generic[T_OTXBatchDataEntity]): batch_size: int batch_tiles: list[list[tv_tensors.Image]] batch_tile_img_infos: list[list[ImageInfo]] - batch_tile_attr_list: list[list[dict[str, int | str]]] + batch_tile_attr_list: list[TileAttrDictList] imgs_info: list[ImageInfo] - def unbind(self) -> list[T_OTXBatchDataEntity]: + def unbind(self) -> list[tuple[TileAttrDictList, T_OTXBatchDataEntity]]: """Unbind batch data entity.""" raise NotImplementedError @@ -102,7 +99,7 @@ class TileBatchDetDataEntity(OTXTileBatchDataEntity): bboxes: list[tv_tensors.BoundingBoxes] labels: list[LongTensor] - def unbind(self) -> list[tuple[list[dict[str, int | str]], DetBatchDataEntity]]: + def unbind(self) -> list[tuple[TileAttrDictList, DetBatchDataEntity]]: """Unbind batch data entity for detection task.""" tiles = [tile for tiles in self.batch_tiles for tile in tiles] tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] @@ -194,7 +191,7 @@ class TileBatchInstSegDataEntity(OTXTileBatchDataEntity): masks: list[tv_tensors.Mask] polygons: list[list[Polygon]] - def unbind(self) -> list[tuple[list[dict[str, int | str]], InstanceSegBatchDataEntity]]: + def unbind(self) -> list[tuple[TileAttrDictList, InstanceSegBatchDataEntity]]: """Unbind batch data entity for instance segmentation task.""" tiles = [tile for tiles in self.batch_tiles for tile in tiles] tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] diff --git a/src/otx/core/exporter/base.py b/src/otx/core/exporter/base.py index 5224cf53c5e..dcefc5fd902 100644 --- a/src/otx/core/exporter/base.py +++ b/src/otx/core/exporter/base.py @@ -16,7 +16,7 @@ from openvino.model_api.models import Model from otx.core.exporter.exportable_code import demo -from otx.core.types.export import OTXExportFormatType +from otx.core.types.export import OTXExportFormatType, TaskLevelExportParameters from otx.core.types.precision import OTXPrecisionType if TYPE_CHECKING: @@ -29,6 +29,8 @@ class OTXModelExporter: """Base class for the model exporters used in OTX. Args: + task_level_export_parameters (TaskLevelExportParameters): Collection of export parameters + which can be defined at a task level. input_size (tuple[int, ...]): Input shape. mean (tuple[float, float, float], optional): Mean values of 3 channels. Defaults to (0.0, 0.0, 0.0). std (tuple[float, float, float], optional): Std values of 3 channels. Defaults to (1.0, 1.0, 1.0). @@ -38,20 +40,19 @@ class OTXModelExporter: "fit_to_window_letterbox" resizes images and pads images to fit the size. Defaults to "standard". pad_value (int, optional): Padding value. Defaults to 0. swap_rgb (bool, optional): Whether to convert the image from BGR to RGB Defaults to False. - metadata (dict[tuple[str, str],str] | None, optional): metadata to embed to the exported model. output_names (list[str] | None, optional): Names for model's outputs, which would be embedded into resulting model. """ def __init__( self, + task_level_export_parameters: TaskLevelExportParameters, input_size: tuple[int, ...], mean: tuple[float, float, float] = (0.0, 0.0, 0.0), std: tuple[float, float, float] = (1.0, 1.0, 1.0), resize_mode: Literal["crop", "standard", "fit_to_window", "fit_to_window_letterbox"] = "standard", pad_value: int = 0, swap_rgb: bool = False, - metadata: dict[tuple[str, str], str] | None = None, output_names: list[str] | None = None, ) -> None: self.input_size = input_size @@ -60,9 +61,17 @@ def __init__( self.resize_mode = resize_mode self.pad_value = pad_value self.swap_rgb = swap_rgb - self.metadata = metadata + self.task_level_export_parameters = task_level_export_parameters self.output_names = output_names + @property + def metadata(self) -> dict[tuple[str, str], str]: + """Collection of metadata to be stored in OpenVINO Intermediate Representation or ONNX. + + This metadata is mainly used to support ModelAPI. + """ + return self.task_level_export_parameters.to_metadata() + def export( self, model: torch.nn.Module, @@ -197,7 +206,6 @@ def to_exportable_code( arch.write(str(path_to_model), Path("model") / "model.xml") arch.write(path_to_model.with_suffix(".bin"), Path("model") / "model.bin") - arch.writestr( str(Path("model") / "config.json"), json.dumps(parameters, ensure_ascii=False, indent=4), diff --git a/src/otx/core/exporter/exportable_code/demo/demo_package/__init__.py b/src/otx/core/exporter/exportable_code/demo/demo_package/__init__.py index 3afe9c9f203..604e5969747 100644 --- a/src/otx/core/exporter/exportable_code/demo/demo_package/__init__.py +++ b/src/otx/core/exporter/exportable_code/demo/demo_package/__init__.py @@ -6,10 +6,22 @@ from .executors import AsyncExecutor, SyncExecutor from .model_wrapper import ModelWrapper from .utils import create_visualizer +from .visualizers import ( + BaseVisualizer, + ClassificationVisualizer, + InstanceSegmentationVisualizer, + ObjectDetectionVisualizer, + SemanticSegmentationVisualizer, +) __all__ = [ "SyncExecutor", "AsyncExecutor", "create_visualizer", "ModelWrapper", + "BaseVisualizer", + "ClassificationVisualizer", + "SemanticSegmentationVisualizer", + "InstanceSegmentationVisualizer", + "ObjectDetectionVisualizer", ] diff --git a/src/otx/core/exporter/exportable_code/demo/demo_package/executors/asynchronous.py b/src/otx/core/exporter/exportable_code/demo/demo_package/executors/asynchronous.py index fc0cd328131..4549922ef5e 100644 --- a/src/otx/core/exporter/exportable_code/demo/demo_package/executors/asynchronous.py +++ b/src/otx/core/exporter/exportable_code/demo/demo_package/executors/asynchronous.py @@ -14,6 +14,7 @@ import numpy as np from demo_package.model_wrapper import ModelWrapper + from demo_package.streamer import get_streamer from demo_package.visualizers import BaseVisualizer, dump_frames diff --git a/src/otx/core/exporter/exportable_code/demo/demo_package/executors/synchronous.py b/src/otx/core/exporter/exportable_code/demo/demo_package/executors/synchronous.py index ea280841aad..06fd9035f8d 100644 --- a/src/otx/core/exporter/exportable_code/demo/demo_package/executors/synchronous.py +++ b/src/otx/core/exporter/exportable_code/demo/demo_package/executors/synchronous.py @@ -12,7 +12,7 @@ from demo_package.model_wrapper import ModelWrapper from demo_package.visualizers import BaseVisualizer -from demo_package.streamer import get_streamer +from demo_package.streamer.streamer import get_streamer from demo_package.visualizers import dump_frames diff --git a/src/otx/core/exporter/exportable_code/demo/demo_package/visualizers/vis_utils.py b/src/otx/core/exporter/exportable_code/demo/demo_package/visualizers/vis_utils.py index 0edea1cbb94..b8d9662f541 100644 --- a/src/otx/core/exporter/exportable_code/demo/demo_package/visualizers/vis_utils.py +++ b/src/otx/core/exporter/exportable_code/demo/demo_package/visualizers/vis_utils.py @@ -95,7 +95,7 @@ def __init__(self, num_classes: int, rng: random.Random | None = None) -> None: Returns: None """ - if num_classes == 0: + if num_classes <= 0: msg = "ColorPalette accepts only the positive number of colors" raise ValueError(msg) if rng is None: diff --git a/src/otx/core/exporter/exportable_code/demo/demo_package/visualizers/visualizer.py b/src/otx/core/exporter/exportable_code/demo/demo_package/visualizers/visualizer.py index 764eb23cf1d..dc9ee8f90ba 100644 --- a/src/otx/core/exporter/exportable_code/demo/demo_package/visualizers/visualizer.py +++ b/src/otx/core/exporter/exportable_code/demo/demo_package/visualizers/visualizer.py @@ -126,6 +126,10 @@ def draw( Output image with annotations. """ predictions = predictions.top_labels + if not any(predictions): + log.warning("There are no predictions.") + return frame + class_label = predictions[0][1] font_scale = 0.7 label_height = cv2.getTextSize(class_label, cv2.FONT_HERSHEY_COMPLEX, font_scale, 2)[0][1] diff --git a/src/otx/core/exporter/mmdeploy.py b/src/otx/core/exporter/mmdeploy.py index 811021ce24d..fd2abe90bf3 100644 --- a/src/otx/core/exporter/mmdeploy.py +++ b/src/otx/core/exporter/mmdeploy.py @@ -21,6 +21,7 @@ from mmengine.registry.default_scope import DefaultScope from otx.core.exporter.base import OTXModelExporter +from otx.core.types.export import TaskLevelExportParameters from otx.core.types.precision import OTXPrecisionType from otx.core.utils.config import convert_conf_to_mmconfig_dict, to_tuple @@ -38,6 +39,8 @@ class MMdeployExporter(OTXModelExporter): model_cfg (DictConfig): Model config for mm framework. deploy_cfg (str | MMConfig): Deployment config module path or MMEngine Config object. test_pipeline (list[dict]): A pipeline for test dataset. + task_level_export_parameters (TaskLevelExportParameters): Collection of export parameters + which can be defined at a task level. input_size (tuple[int, ...]): Input shape. mean (tuple[float, float, float], optional): Mean values of 3 channels. Defaults to (0.0, 0.0, 0.0). std (tuple[float, float, float], optional): Std values of 3 channels. Defaults to (1.0, 1.0, 1.0). @@ -57,17 +60,26 @@ def __init__( model_cfg: DictConfig, deploy_cfg: str | MMConfig, test_pipeline: list[dict], + task_level_export_parameters: TaskLevelExportParameters, input_size: tuple[int, ...], mean: tuple[float, float, float] = (0.0, 0.0, 0.0), std: tuple[float, float, float] = (1.0, 1.0, 1.0), resize_mode: Literal["crop", "standard", "fit_to_window", "fit_to_window_letterbox"] = "standard", pad_value: int = 0, swap_rgb: bool = False, - metadata: dict[tuple[str, str], str] | None = None, max_num_detections: int = 0, output_names: list[str] | None = None, ) -> None: - super().__init__(input_size, mean, std, resize_mode, pad_value, swap_rgb, metadata, output_names) + super().__init__( + task_level_export_parameters=task_level_export_parameters, + input_size=input_size, + mean=mean, + std=std, + resize_mode=resize_mode, + pad_value=pad_value, + swap_rgb=swap_rgb, + output_names=output_names, + ) self._model_builder = model_builder model_cfg = convert_conf_to_mmconfig_dict(model_cfg, "list") self._model_cfg = MMConfig({"model": model_cfg, "test_pipeline": list(map(to_tuple, test_pipeline))}) diff --git a/src/otx/core/exporter/native.py b/src/otx/core/exporter/native.py index aba87dbabed..27a7e60eee2 100644 --- a/src/otx/core/exporter/native.py +++ b/src/otx/core/exporter/native.py @@ -15,6 +15,7 @@ import torch from otx.core.exporter.base import OTXModelExporter +from otx.core.types.export import TaskLevelExportParameters from otx.core.types.precision import OTXPrecisionType @@ -23,18 +24,27 @@ class OTXNativeModelExporter(OTXModelExporter): def __init__( self, + task_level_export_parameters: TaskLevelExportParameters, input_size: tuple[int, ...], mean: tuple[float, float, float] = (0.0, 0.0, 0.0), std: tuple[float, float, float] = (1.0, 1.0, 1.0), resize_mode: Literal["crop", "standard", "fit_to_window", "fit_to_window_letterbox"] = "standard", pad_value: int = 0, swap_rgb: bool = False, - metadata: dict[tuple[str, str], str] | None = None, via_onnx: bool = False, onnx_export_configuration: dict[str, Any] | None = None, output_names: list[str] | None = None, ) -> None: - super().__init__(input_size, mean, std, resize_mode, pad_value, swap_rgb, metadata, output_names) + super().__init__( + task_level_export_parameters=task_level_export_parameters, + input_size=input_size, + mean=mean, + std=std, + resize_mode=resize_mode, + pad_value=pad_value, + swap_rgb=swap_rgb, + output_names=output_names, + ) self.via_onnx = via_onnx self.onnx_export_configuration = onnx_export_configuration if onnx_export_configuration is not None else {} if output_names is not None: diff --git a/src/otx/core/metrics/fmeasure.py b/src/otx/core/metrics/fmeasure.py index 56bcb9853b7..e761eccefdd 100644 --- a/src/otx/core/metrics/fmeasure.py +++ b/src/otx/core/metrics/fmeasure.py @@ -657,7 +657,7 @@ def __init__( self._f_measure_per_nms: dict | None = None self._best_confidence_threshold: float | None = None self._best_nms_threshold: float | None = None - self._f_measure = 0.0 + self._f_measure = float("-inf") self.reset() diff --git a/src/otx/core/model/action_classification.py b/src/otx/core/model/action_classification.py index b5b4c8c0b64..28b37a50061 100644 --- a/src/otx/core/model/action_classification.py +++ b/src/otx/core/model/action_classification.py @@ -12,12 +12,12 @@ from otx.core.data.entity.action_classification import ActionClsBatchDataEntity, ActionClsBatchPredEntity from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput from otx.core.metrics.accuracy import MultiClassClsMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.types.export import TaskLevelExportParameters from otx.core.utils.config import inplace_num_classes from otx.core.utils.utils import get_mean_std_from_data_processing @@ -31,13 +31,7 @@ from otx.core.metrics import MetricCallable -class OTXActionClsModel( - OTXModel[ - ActionClsBatchDataEntity, - ActionClsBatchPredEntity, - T_OTXTileBatchDataEntity, - ], -): +class OTXActionClsModel(OTXModel[ActionClsBatchDataEntity, ActionClsBatchPredEntity]): """Base class for the action classification models used in OTX.""" def __init__( @@ -57,16 +51,12 @@ def __init__( ) @property - def _export_parameters(self) -> dict[str, Any]: + def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" - parameters = super()._export_parameters - parameters["metadata"].update( - { - ("model_info", "model_type"): "Action Classification", - ("model_info", "task_type"): "action classification", - }, + return super()._export_parameters.wrap( + model_type="Action Classification", + task_type="action classification", ) - return parameters def _convert_pred_entity_to_compute_metric( self, @@ -175,24 +165,23 @@ def _customize_outputs( labels=labels, ) - @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - export_params = super()._export_parameters - export_params.update(get_mean_std_from_data_processing(self.config)) - export_params["resize_mode"] = "standard" - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - export_params["via_onnx"] = False - export_params["input_size"] = self.image_size - export_params["onnx_export_configuration"] = None - - return export_params - @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter(**self._export_parameters) + mean, std = get_mean_std_from_data_processing(self.config) + + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=None, + ) class OVActionClsModel( diff --git a/src/otx/core/model/action_detection.py b/src/otx/core/model/action_detection.py index 7529f93602c..b978110ba50 100644 --- a/src/otx/core/model/action_detection.py +++ b/src/otx/core/model/action_detection.py @@ -11,7 +11,6 @@ from otx.core.data.entity.action_detection import ActionDetBatchDataEntity, ActionDetBatchPredEntity from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.metrics import MetricInput from otx.core.metrics.mean_ap import MeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel @@ -26,13 +25,7 @@ from otx.core.metrics import MetricCallable -class OTXActionDetModel( - OTXModel[ - ActionDetBatchDataEntity, - ActionDetBatchPredEntity, - T_OTXTileBatchDataEntity, - ], -): +class OTXActionDetModel(OTXModel[ActionDetBatchDataEntity, ActionDetBatchPredEntity]): """Base class for the action detection models used in OTX.""" def __init__( diff --git a/src/otx/core/model/anomaly.py b/src/otx/core/model/anomaly.py index fb1b080a42e..edad71489ae 100644 --- a/src/otx/core/model/anomaly.py +++ b/src/otx/core/model/anomaly.py @@ -27,14 +27,12 @@ AnomalySegmentationDataBatch, ) from otx.core.exporter.base import OTXModelExporter -from otx.core.types.export import OTXExportFormatType -from otx.core.types.label import LabelInfo +from otx.core.types.export import OTXExportFormatType, TaskLevelExportParameters +from otx.core.types.label import LabelInfo, NullLabelInfo from otx.core.types.precision import OTXPrecisionType from otx.core.types.task import OTXTaskType if TYPE_CHECKING: - from collections import OrderedDict - from anomalib.metrics import AnomalibMetricCollection from anomalib.metrics.threshold import BaseThreshold from lightning.pytorch import Trainer @@ -65,25 +63,38 @@ def __init__( normalization_scale: float = 1.0, ) -> None: self.orig_height, self.orig_width = image_shape - metadata = { - ("model_info", "image_threshold"): image_threshold, - ("model_info", "pixel_threshold"): pixel_threshold, - ("model_info", "normalization_scale"): normalization_scale, - ("model_info", "orig_height"): image_shape[0], - ("model_info", "orig_width"): image_shape[1], - ("model_info", "image_shape"): image_shape, - ("model_info", "labels"): "Normal Anomaly", - ("model_info", "model_type"): "AnomalyDetection", - ("model_info", "task"): task.value, - } + self.image_threshold = image_threshold + self.pixel_threshold = pixel_threshold + self.task = task + self.normalization_scale = normalization_scale + super().__init__( + task_level_export_parameters=TaskLevelExportParameters( + model_type="anomaly", + task_type="anomaly", + label_info=NullLabelInfo(), + optimization_config={}, + ), input_size=(1, 3, *image_shape), mean=mean_values, std=scale_values, swap_rgb=False, # default value. Ideally, modelAPI should pass RGB inputs after the pre-processing step - metadata=metadata, ) + @property + def metadata(self) -> dict[tuple[str, str], str | float | int | tuple[int, int]]: # type: ignore[override] + return { + ("model_info", "image_threshold"): self.image_threshold, + ("model_info", "pixel_threshold"): self.pixel_threshold, + ("model_info", "normalization_scale"): self.normalization_scale, + ("model_info", "orig_height"): self.orig_height, + ("model_info", "orig_width"): self.orig_width, + ("model_info", "image_shape"): (self.orig_height, self.orig_width), + ("model_info", "labels"): "Normal Anomaly", + ("model_info", "model_type"): "AnomalyDetection", + ("model_info", "task"): self.task.value, + } + def to_openvino( self, model: nn.Module, @@ -146,6 +157,22 @@ def __init__(self) -> None: self.image_metrics: AnomalibMetricCollection self.pixel_metrics: AnomalibMetricCollection + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on saving checkpoint.""" + super().on_save_checkpoint(checkpoint) # type: ignore[misc] + + attrs = ["_task_type", "_input_size", "mean_values", "scale_values", "image_threshold", "pixel_threshold"] + + checkpoint["anomaly"] = {key: getattr(self, key, None) for key in attrs} + + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on loading checkpoint.""" + super().on_load_checkpoint(checkpoint) # type: ignore[misc] + + if anomaly_attrs := checkpoint.get("anomaly", None): + for key, value in anomaly_attrs.items(): + setattr(self, key, value) + @property def input_size(self) -> tuple[int, int]: """Returns the input size of the model. @@ -225,7 +252,7 @@ def trainable_model(self) -> str | None: def setup(self, stage: str | None = None) -> None: """Setup the model.""" super().setup(stage) # type: ignore[misc] - if hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "config"): + if stage == "fit" and hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "config"): if hasattr(self.trainer.datamodule.config, "test_subset"): self._extract_mean_scale_from_transforms(self.trainer.datamodule.config.test_subset.transforms) elif hasattr(self.trainer.datamodule.config, "val_subset"): @@ -314,24 +341,6 @@ def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[torch. return optimizer(params=params) return super().configure_optimizers() # type: ignore[misc] - def state_dict(self) -> dict[str, Any]: - """Return state dictionary of model entity with meta information. - - Returns: - A dictionary containing datamodule state. - - """ - state_dict = super().state_dict() # type: ignore[misc] - # This is defined in OTXModel - state_dict["label_info"] = self.label_info # type: ignore[attr-defined] - return state_dict - - def load_state_dict(self, ckpt: OrderedDict[str, Any], *args, **kwargs) -> None: - """Pass the checkpoint to the anomaly model.""" - ckpt = ckpt.get("state_dict", ckpt) - ckpt.pop("label_info", None) # [TODO](ashwinvaidya17): Revisit this method when OTXModel is the lightning model - return super().load_state_dict(ckpt, *args, **kwargs) # type: ignore[misc] - def forward( self, inputs: AnomalyModelInputs, diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index 5b781209dea..aa0c56fbbb1 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -3,6 +3,8 @@ # """Class definition for base model entity used in OTX.""" +# mypy: disable-error-code="arg-type" + from __future__ import annotations import contextlib @@ -25,19 +27,21 @@ from torch.optim.sgd import SGD from torchmetrics import Metric, MetricCollection +from otx import __version__ +from otx.core.config.data import TileConfig from otx.core.data.entity.base import ( OTXBatchLossEntity, T_OTXBatchDataEntity, T_OTXBatchPredEntity, ) -from otx.core.data.entity.tile import OTXTileBatchDataEntity, T_OTXTileBatchDataEntity +from otx.core.data.entity.tile import OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput, NullMetricCallable from otx.core.optimizer.callable import OptimizerCallableSupportHPO from otx.core.schedulers import LRSchedulerListCallable, PicklableLRSchedulerCallable from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler -from otx.core.types.export import OTXExportFormatType +from otx.core.types.export import OTXExportFormatType, TaskLevelExportParameters from otx.core.types.label import LabelInfo, NullLabelInfo from otx.core.types.precision import OTXPrecisionType from otx.core.utils.build import get_default_num_async_infer_requests @@ -77,7 +81,7 @@ def _default_scheduler_callable( DefaultSchedulerCallable = _default_scheduler_callable -class OTXModel(LightningModule, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXTileBatchDataEntity]): +class OTXModel(LightningModule, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]): """Base class for the models used in OTX. Args: @@ -111,6 +115,8 @@ def __init__( self.torch_compile = torch_compile self._explain_mode = False + self._tile_config: TileConfig | None = None + # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt self.save_hyperparameters(logger=False, ignore=["model", "optimizer", "scheduler", "metric"]) @@ -334,16 +340,54 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa self.log(log_metric_name, value, sync_dist=True, prog_bar=True) - def state_dict(self) -> dict[str, Any]: - """Return state dictionary of model entity with meta information. + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on saving checkpoint.""" + super().on_save_checkpoint(checkpoint) - Returns: - A dictionary containing datamodule state. + checkpoint["label_info"] = self.label_info + checkpoint["otx_version"] = __version__ - """ - state_dict = super().state_dict() - state_dict["label_info"] = self.label_info - return state_dict + if self._tile_config: + checkpoint["tile_config"] = self._tile_config + + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on loading checkpoint.""" + super().on_load_checkpoint(checkpoint) + + if ckpt_label_info := checkpoint.get("label_info", None): + self._label_info = ckpt_label_info + + if ckpt_tile_config := checkpoint.get("tile_config", None): + self._tile_config = ckpt_tile_config + + def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -> None: + """Load state dict incrementally.""" + ckpt_label_info: LabelInfo | None = ckpt.get("label_info", None) + + if ckpt_label_info is None: + msg = "Checkpoint should have `label_info`." + raise ValueError(msg, ckpt_label_info) + + if ckpt_label_info != self.label_info: + msg = ( + "Load model state dictionary incrementally: " + f"Label info from checkpoint: {ckpt_label_info} -> " + f"Label info from training data: {self.label_info}" + ) + logger.info(msg) + self.register_load_state_dict_pre_hook( + self.label_info.label_names, + ckpt_label_info.label_names, + ) + + # Model weights + state_dict: dict[str, Any] = ckpt.get("state_dict", None) + + if ckpt_label_info is None: + msg = "Checkpoint should have `state_dict`." + raise ValueError(msg, ckpt_label_info) + + self.load_state_dict(state_dict, *args, **kwargs) def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: """Load state dictionary from checkpoint state dictionary. @@ -362,23 +406,6 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: else: state_dict = ckpt - ckpt_label_info = state_dict.pop("label_info", None) - - if ckpt_label_info and self.label_info is None: - msg = ( - "`state_dict` to load has `label_info`, but the current model has no `label_info`. " - "It is recommended to set proper `label_info` for the incremental learning case." - ) - warnings.warn(msg, stacklevel=2) - if ckpt_label_info and self.label_info and ckpt_label_info != self.label_info: - logger.warning( - f"Data classes from checkpoint: {ckpt_label_info.label_names} -> " - f"Data classes from training data: {self.label_info.label_names}", - ) - self.register_load_state_dict_pre_hook( - self.label_info.label_names, - ckpt_label_info.label_names, - ) return super().load_state_dict(state_dict, *args, **kwargs) def load_from_otx_v1_ckpt(self, ckpt: dict[str, Any]) -> dict: @@ -488,7 +515,7 @@ def _restore_model_forward(self) -> None: def forward_tiles( self, - inputs: T_OTXTileBatchDataEntity, + inputs: OTXTileBatchDataEntity[T_OTXBatchDataEntity], ) -> T_OTXBatchPredEntity | OTXBatchLossEntity: """Model forward function for tile task.""" raise NotImplementedError @@ -604,36 +631,41 @@ def _exporter(self) -> OTXModelExporter: raise NotImplementedError(msg) @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation. + def _export_parameters(self) -> TaskLevelExportParameters: + """Defines export parameters sharable at a task level. + + To export OTXModel which is compatible with ModelAPI, + you should define an appropriate export parameters for each task. + This property is usually defined at the task level classes defined in `otx.core.model.*`. + Please refer to `TaskLevelExportParameters` for more details. - To export OTXModel, you should define an appropriate parameters." - "This is used in the constructor of `self._exporter`. " - "For example, `self._exporter = SomeExporter(**self.export_parameters)`. " - "Please refer to `otx.core.exporter.*` for detailed examples." Returns: - dict[str, Any]: parameters of exporter. - """ - parameters: dict[str, Any] = {} - all_labels = "" - all_label_ids = "" - for lbl in self.label_info.label_names: - all_labels += lbl.replace(" ", "_") + " " - all_label_ids += lbl.replace(" ", "_") + " " - - # not every model requires ptq_config - optimization_config = self._optimization_config - parameters["metadata"] = { - ("model_info", "labels"): all_labels.strip(), - ("model_info", "label_ids"): all_label_ids.strip(), - ("model_info", "optimization_config"): json.dumps(optimization_config), - ("model_info", "label_info"): self.label_info.to_json(), - } + Collection of exporter parameters that can be defined at a task level. - if self.explain_mode: - parameters["output_names"] = ["logits", "feature_vector", "saliency_map"] + Examples: + This example shows how this property is used at the new model development - return parameters + ```python + + class MyDetectionModel(OTXDetectionModel): + ... + + @property + def _exporter(self) -> OTXModelExporter: + # `self._export_parameters` defined at `OTXDetectionModel` + # You can redefine it `MyDetectionModel` if you need + return OTXModelExporter( + task_level_export_parameters=self._export_parameters, + ... + ) + ``` + """ + return TaskLevelExportParameters( + model_type="null", + task_type="null", + label_info=self.label_info, + optimization_config=self._optimization_config, + ) def _reset_prediction_layer(self, num_classes: int) -> None: """Reset its prediction layer with a given number of classes. @@ -691,6 +723,20 @@ def patch_optimizer_and_scheduler_for_hpo(self) -> None: if not isinstance(self.scheduler_callable, PicklableLRSchedulerCallable): self.scheduler_callable = PicklableLRSchedulerCallable(self.scheduler_callable) + @property + def tile_config(self) -> TileConfig: + """Get tiling configurations.""" + if self._tile_config is None: + msg = "This task type does not support tiling." + raise RuntimeError(msg) + + return self._tile_config + + @tile_config.setter + def tile_config(self, tile_config: TileConfig) -> None: + """Set tiling configurations.""" + self._tile_config = tile_config + class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]): """Base class for the OpenVINO model. diff --git a/src/otx/core/model/classification.py b/src/otx/core/model/classification.py index e9f3c36e70b..d724d76ff16 100644 --- a/src/otx/core/model/classification.py +++ b/src/otx/core/model/classification.py @@ -5,7 +5,6 @@ from __future__ import annotations -import json from typing import TYPE_CHECKING, Any import numpy as np @@ -21,7 +20,6 @@ MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, ) -from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput @@ -32,6 +30,7 @@ ) from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.types.export import TaskLevelExportParameters from otx.core.types.label import HLabelInfo from otx.core.utils.config import inplace_num_classes from otx.core.utils.utils import get_mean_std_from_data_processing @@ -46,13 +45,7 @@ from otx.core.metrics import MetricCallable -class OTXMulticlassClsModel( - OTXModel[ - MulticlassClsBatchDataEntity, - MulticlassClsBatchPredEntity, - T_OTXTileBatchDataEntity, - ], -): +class OTXMulticlassClsModel(OTXModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity]): """Base class for the classification models used in OTX.""" def __init__( @@ -72,18 +65,14 @@ def __init__( ) @property - def _export_parameters(self) -> dict[str, Any]: + def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" - parameters = super()._export_parameters - parameters["metadata"].update( - { - ("model_info", "model_type"): "Classification", - ("model_info", "task_type"): "classification", - ("model_info", "multilabel"): str(False), - ("model_info", "hierarchical"): str(False), - }, + return super()._export_parameters.wrap( + model_type="Classification", + task_type="classification", + multilabel=False, + hierarchical=False, ) - return parameters def _convert_pred_entity_to_compute_metric( self, @@ -226,34 +215,26 @@ def _customize_outputs( @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter(**self._export_parameters) - - @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - export_params = super()._export_parameters - export_params.update(get_mean_std_from_data_processing(self.config)) - export_params["resize_mode"] = "standard" - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - export_params["via_onnx"] = False - export_params["input_size"] = self.image_size - export_params["onnx_export_configuration"] = None - - return export_params + mean, std = get_mean_std_from_data_processing(self.config) + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) ### NOTE, currently, although we've made the separate Multi-cls, Multi-label classes ### It'll be integrated after H-label classification integration with more advanced design. -class OTXMultilabelClsModel( - OTXModel[ - MultilabelClsBatchDataEntity, - MultilabelClsBatchPredEntity, - T_OTXTileBatchDataEntity, - ], -): +class OTXMultilabelClsModel(OTXModel[MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity]): """Multi-label classification models used in OTX.""" def __init__( @@ -273,19 +254,15 @@ def __init__( ) @property - def _export_parameters(self) -> dict[str, Any]: + def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" - parameters = super()._export_parameters - parameters["metadata"].update( - { - ("model_info", "model_type"): "Classification", - ("model_info", "task_type"): "classification", - ("model_info", "multilabel"): str(True), - ("model_info", "hierarchical"): str(False), - ("model_info", "confidence_threshold"): str(0.5), - }, + return super()._export_parameters.wrap( + model_type="Classification", + task_type="classification", + multilabel=True, + hierarchical=False, + confidence_threshold=0.5, ) - return parameters def _convert_pred_entity_to_compute_metric( self, @@ -428,30 +405,22 @@ def _customize_outputs( @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter(**self._export_parameters) + mean, std = get_mean_std_from_data_processing(self.config) + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) - @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - export_params = super()._export_parameters - export_params.update(get_mean_std_from_data_processing(self.config)) - export_params["resize_mode"] = "standard" - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - export_params["via_onnx"] = False - export_params["input_size"] = self.image_size - export_params["onnx_export_configuration"] = None - - return export_params - - -class OTXHlabelClsModel( - OTXModel[ - HlabelClsBatchDataEntity, - HlabelClsBatchPredEntity, - T_OTXTileBatchDataEntity, - ], -): + +class OTXHlabelClsModel(OTXModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity]): """H-label classification models used in OTX.""" def __init__( @@ -473,29 +442,15 @@ def __init__( self._label_info = hlabel_info @property - def _export_parameters(self) -> dict[str, Any]: + def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" - parameters = super()._export_parameters - hierarchical_config: dict = {} - - label_info: HLabelInfo = self.label_info # type: ignore[assignment] - hierarchical_config["cls_heads_info"] = label_info.as_dict() - hierarchical_config["label_tree_edges"] = label_info.label_tree_edges - - parameters["metadata"].update( - { - ("model_info", "model_type"): "Classification", - ("model_info", "task_type"): "classification", - ("model_info", "multilabel"): str(False), - ("model_info", "hierarchical"): str(True), - ("model_info", "confidence_threshold"): str(0.5), - ("model_info", "hierarchical_config"): json.dumps(hierarchical_config), - # NOTE: There is currently too many channels for label related metadata. - # This should be clean up afterwards in ModelAPI side. - ("model_info", "label_info"): json.dumps(label_info.as_dict()), - }, + return super()._export_parameters.wrap( + model_type="Classification", + task_type="classification", + multilabel=False, + hierarchical=True, + confidence_threshold=0.5, ) - return parameters def _convert_pred_entity_to_compute_metric( self, @@ -653,21 +608,19 @@ def _customize_outputs( @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter(**self._export_parameters) - - @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - export_params = super()._export_parameters - export_params.update(get_mean_std_from_data_processing(self.config)) - export_params["resize_mode"] = "standard" - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - export_params["via_onnx"] = False - export_params["input_size"] = self.image_size - export_params["onnx_export_configuration"] = None - - return export_params + mean, std = get_mean_std_from_data_processing(self.config) + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) class OVMulticlassClassificationModel( diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index ab4c6cc94b0..53e594810aa 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -5,7 +5,6 @@ from __future__ import annotations -import copy import logging as log import types from typing import TYPE_CHECKING, Any, Callable, Literal @@ -18,15 +17,14 @@ from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity -from otx.core.data.entity.tile import TileBatchDetDataEntity -from otx.core.exporter.base import OTXModelExporter -from otx.core.metrics import MetricInput +from otx.core.data.entity.tile import OTXTileBatchDataEntity +from otx.core.metrics import MetricCallable, MetricInput from otx.core.metrics.mean_ap import MeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.types.export import TaskLevelExportParameters from otx.core.utils.config import inplace_num_classes from otx.core.utils.tile_merge import DetectionTileMerge -from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -38,10 +36,8 @@ from torch import nn from torchmetrics import Metric - from otx.core.metrics import MetricCallable - -class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity, TileBatchDetDataEntity]): +class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity]): """Base class for the detection models used in OTX.""" def __init__( @@ -59,9 +55,9 @@ def __init__( metric=metric, torch_compile=torch_compile, ) - self.tile_config = TileConfig() + self._tile_config = TileConfig() - def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: + def forward_tiles(self, inputs: OTXTileBatchDataEntity[DetBatchDataEntity]) -> DetBatchPredEntity: """Unpack detection tiles. Args: @@ -74,11 +70,11 @@ def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: tile_attrs: list[list[dict[str, int | str]]] = [] merger = DetectionTileMerge( inputs.imgs_info, - self.tile_config.iou_threshold, - self.tile_config.max_num_instances, + self.num_classes, + self.tile_config, ) for batch_tile_attrs, batch_tile_input in inputs.unbind(): - output = self.forward(batch_tile_input) + output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input) if isinstance(output, OTXBatchLossEntity): msg = "Loss output is not supported for tile merging" raise TypeError(msg) @@ -86,7 +82,7 @@ def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: tile_attrs.append(batch_tile_attrs) pred_entities = merger.merge(tile_preds, tile_attrs) - return DetBatchPredEntity( + pred_entity = DetBatchPredEntity( batch_size=inputs.batch_size, images=[pred_entity.image for pred_entity in pred_entities], imgs_info=[pred_entity.img_info for pred_entity in pred_entities], @@ -94,31 +90,22 @@ def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: bboxes=[pred_entity.bboxes for pred_entity in pred_entities], labels=[pred_entity.labels for pred_entity in pred_entities], ) + if self.explain_mode: + pred_entity.saliency_map = [pred_entity.saliency_map for pred_entity in pred_entities] + pred_entity.feature_vector = [pred_entity.feature_vector for pred_entity in pred_entities] + + return pred_entity @property - def _export_parameters(self) -> dict[str, Any]: + def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" - parameters = super()._export_parameters - parameters["metadata"].update( - { - ("model_info", "model_type"): "ssd", - ("model_info", "task_type"): "detection", - ("model_info", "confidence_threshold"): str( - self.hparams.get("best_confidence_threshold", 0.0), - ), # it was able to be set in OTX 1.X - ("model_info", "iou_threshold"): str(0.5), - }, + return super()._export_parameters.wrap( + model_type="ssd", + task_type="detection", + confidence_threshold=self.hparams.get("best_confidence_threshold", 0.0), + iou_threshold=0.5, + tile_config=self.tile_config if self.tile_config.enable_tiler else None, ) - if self.tile_config.enable_tiler: - parameters["metadata"].update( - { - ("model_info", "tile_size"): str(self.tile_config.tile_size[0]), - ("model_info", "tiles_overlap"): str(self.tile_config.overlap), - ("model_info", "max_pred_number"): str(self.tile_config.max_num_instances), - }, - ) - - return parameters def _convert_pred_entity_to_compute_metric( self, @@ -183,14 +170,35 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa class ExplainableOTXDetModel(OTXDetectionModel): - """OTX detection model which can attach a XAI hook.""" + """OTX detection model which can attach a XAI (Explainable AI) branch.""" - def forward_explain( + def __init__( self, - inputs: DetBatchDataEntity, - ) -> DetBatchPredEntity: + num_classes: int, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MeanAPCallable, + torch_compile: bool = False, + ) -> None: + super().__init__( + num_classes=num_classes, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + ) + + from otx.algo.explain.explain_algo import get_feature_vector + + self.model.feature_vector_fn = get_feature_vector + self.model.explain_fn = self.get_explain_fn() + + def forward_explain(self, inputs: DetBatchDataEntity) -> DetBatchPredEntity: """Model forward function.""" - from otx.algo.hooks.recording_forward_hook import get_feature_vector + from otx.algo.explain.explain_algo import get_feature_vector + + if isinstance(inputs, OTXTileBatchDataEntity): + return self.forward_tiles(inputs) self.model.feature_vector_fn = get_feature_vector self.model.explain_fn = self.get_explain_fn() @@ -255,11 +263,11 @@ def _forward_explain_detection( def get_explain_fn(self) -> Callable: """Returns explain function.""" from otx.algo.detection.heads.custom_ssd_head import SSDHead - from otx.algo.hooks.recording_forward_hook import DetClassProbabilityMapHook + from otx.algo.explain.explain_algo import DetClassProbabilityMap # SSD-like heads also have background class background_class = isinstance(self.model.bbox_head, SSDHead) - explainer = DetClassProbabilityMapHook( + explainer = DetClassProbabilityMap( num_classes=self.num_classes + background_class, num_anchors=self.get_num_anchors(), ) @@ -302,13 +310,6 @@ def get_num_anchors(self) -> list[int]: return [1] * 10 - @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - parameters = super()._export_parameters - parameters["output_names"] = ["feature_vector", "saliency_map"] if self.explain_mode else None - return parameters - class MMDetCompatibleModel(ExplainableOTXDetModel): """Detection model compatible for MMDet. @@ -339,21 +340,6 @@ def __init__( torch_compile=torch_compile, ) - @property - def _export_parameters(self) -> dict[str, Any]: - """Parameters for an exporter.""" - if self.image_size is None: - error_msg = "self.image_size shouldn't be None to use mmdeploy." - raise ValueError(error_msg) - - export_params = super()._export_parameters - export_params.update(get_mean_std_from_data_processing(self.config)) - export_params["model_builder"] = self._create_model - export_params["model_cfg"] = copy.copy(self.config) - export_params["test_pipeline"] = self._make_fake_test_pipeline() - - return export_params - def _create_model(self) -> nn.Module: from .utils.mmdet import create_model @@ -483,13 +469,6 @@ def _customize_outputs( labels=labels, ) - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - from otx.core.exporter.mmdeploy import MMdeployExporter - - return MMdeployExporter(**self._export_parameters) - class OVDetectionModel(OVModel[DetBatchDataEntity, DetBatchPredEntity]): """Object detection model compatible for OpenVINO IR inference. diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index 572ba044af8..121c6ca585e 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -7,7 +7,6 @@ import logging as log import types -from copy import copy from typing import TYPE_CHECKING, Any, Callable, Literal import numpy as np @@ -17,24 +16,24 @@ from openvino.model_api.tilers import InstanceSegmentationTiler from torchvision import tv_tensors +from otx.algo.explain.explain_algo import get_feature_vector from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity -from otx.core.data.entity.tile import TileBatchInstSegDataEntity -from otx.core.exporter.base import OTXModelExporter +from otx.core.data.entity.tile import OTXTileBatchDataEntity from otx.core.metrics import MetricInput from otx.core.metrics.mean_ap import MaskRLEMeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.types.export import TaskLevelExportParameters from otx.core.utils.config import inplace_num_classes from otx.core.utils.mask_util import encode_rle, polygon_to_rle from otx.core.utils.tile_merge import InstanceSegTileMerge -from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from mmdet.models.data_preprocessors import DetDataPreprocessor - from mmdet.models.detectors.base import TwoStageDetector + from mmdet.models.detectors.two_stage import TwoStageDetector from mmdet.structures import OptSampleList from omegaconf import DictConfig from openvino.model_api.models.utils import InstanceSegmentationResult @@ -44,13 +43,7 @@ from otx.core.metrics import MetricCallable -class OTXInstanceSegModel( - OTXModel[ - InstanceSegBatchDataEntity, - InstanceSegBatchPredEntity, - TileBatchInstSegDataEntity, - ], -): +class OTXInstanceSegModel(OTXModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity]): """Base class for the Instance Segmentation models used in OTX.""" def __init__( @@ -68,9 +61,9 @@ def __init__( metric=metric, torch_compile=torch_compile, ) - self.tile_config = TileConfig() + self._tile_config = TileConfig() - def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchPredEntity: + def forward_tiles(self, inputs: OTXTileBatchDataEntity[InstanceSegBatchDataEntity]) -> InstanceSegBatchPredEntity: """Unpack instance segmentation tiles. Args: @@ -83,11 +76,11 @@ def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchP tile_attrs: list[list[dict[str, int | str]]] = [] merger = InstanceSegTileMerge( inputs.imgs_info, - self.tile_config.iou_threshold, - self.tile_config.max_num_instances, + self.num_classes, + self.tile_config, ) for batch_tile_attrs, batch_tile_input in inputs.unbind(): - output = self.forward(batch_tile_input) + output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input) if isinstance(output, OTXBatchLossEntity): msg = "Loss output is not supported for tile merging" raise TypeError(msg) @@ -95,7 +88,7 @@ def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchP tile_attrs.append(batch_tile_attrs) pred_entities = merger.merge(tile_preds, tile_attrs) - return InstanceSegBatchPredEntity( + pred_entity = InstanceSegBatchPredEntity( batch_size=inputs.batch_size, images=[pred_entity.image for pred_entity in pred_entities], imgs_info=[pred_entity.img_info for pred_entity in pred_entities], @@ -105,43 +98,23 @@ def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchP masks=[pred_entity.masks for pred_entity in pred_entities], polygons=[pred_entity.polygons for pred_entity in pred_entities], ) + if self.explain_mode: + pred_entity.saliency_map = [pred_entity.saliency_map for pred_entity in pred_entities] + pred_entity.feature_vector = [pred_entity.feature_vector for pred_entity in pred_entities] + + return pred_entity @property - def _export_parameters(self) -> dict[str, Any]: + def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" - parameters = super()._export_parameters - parameters["metadata"].update( - { - ("model_info", "model_type"): "MaskRCNN", - ("model_info", "task_type"): "instance_segmentation", - ("model_info", "confidence_threshold"): str( - self.hparams.get("best_confidence_threshold", 0.0), - ), # it was able to be set in OTX 1.X - ("model_info", "iou_threshold"): str(0.5), - }, + return super()._export_parameters.wrap( + model_type="MaskRCNN", + task_type="instance_segmentation", + confidence_threshold=self.hparams.get("best_confidence_threshold", 0.0), + iou_threshold=0.5, + tile_config=self.tile_config if self.tile_config.enable_tiler else None, ) - # Instance segmentation needs to add empty label - all_labels = "otx_empty_lbl " - all_label_ids = "None " - for lbl in self.label_info.label_names: - all_labels += lbl.replace(" ", "_") + " " - all_label_ids += lbl.replace(" ", "_") + " " - - parameters["metadata"][("model_info", "labels")] = all_labels.strip() - parameters["metadata"][("model_info", "label_ids")] = all_label_ids.strip() - - if self.tile_config.enable_tiler: - parameters["metadata"].update( - { - ("model_info", "tile_size"): str(self.tile_config.tile_size[0]), - ("model_info", "tiles_overlap"): str(self.tile_config.overlap), - ("model_info", "max_pred_number"): str(self.tile_config.max_num_instances), - }, - ) - - return parameters - def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: """Load state_dict from checkpoint. @@ -233,14 +206,33 @@ def _convert_pred_entity_to_compute_metric( class ExplainableOTXInstanceSegModel(OTXInstanceSegModel): - """OTX Instance Segmentation model which can attach a XAI hook.""" + """OTX Instance Segmentation model which can attach a XAI (Explainable AI) branch.""" - def forward_explain( + def __init__( self, - inputs: InstanceSegBatchDataEntity, - ) -> InstanceSegBatchPredEntity: + num_classes: int, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MaskRLEMeanAPCallable, + torch_compile: bool = False, + ) -> None: + super().__init__( + num_classes=num_classes, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + ) + + from otx.algo.explain.explain_algo import get_feature_vector + + self.model.feature_vector_fn = get_feature_vector + self.model.explain_fn = self.get_explain_fn() + + def forward_explain(self, inputs: InstanceSegBatchDataEntity) -> InstanceSegBatchPredEntity: """Model forward function.""" - from otx.algo.hooks.recording_forward_hook import get_feature_vector + if isinstance(inputs, OTXTileBatchDataEntity): + return self.forward_tiles(inputs) self.model.feature_vector_fn = get_feature_vector self.model.explain_fn = self.get_explain_fn() @@ -299,9 +291,9 @@ def _forward_explain_inst_seg( def get_explain_fn(self) -> Callable: """Returns explain function.""" - from otx.algo.hooks.recording_forward_hook import MaskRCNNRecordingForwardHook + from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo - explainer = MaskRCNNRecordingForwardHook(num_classes=self.num_classes) + explainer = MaskRCNNExplainAlgo(num_classes=self.num_classes) return explainer.func def _reset_model_forward(self) -> None: @@ -330,13 +322,6 @@ def _restore_model_forward(self) -> None: self.model.forward = func_type(self.original_model_forward, self.model) self.original_model_forward = None - @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - parameters = super()._export_parameters - parameters["output_names"] = ["feature_vector", "saliency_map"] if self.explain_mode else None - return parameters - class MMDetInstanceSegCompatibleModel(ExplainableOTXInstanceSegModel): """Instance Segmentation model compatible for MMDet.""" @@ -362,21 +347,6 @@ def __init__( torch_compile=torch_compile, ) - @property - def _export_parameters(self) -> dict[str, Any]: - """Parameters for an exporter.""" - if self.image_size is None: - error_msg = "self.image_size shouldn't be None to use mmdeploy." - raise ValueError(error_msg) - - export_params = super()._export_parameters - export_params.update(get_mean_std_from_data_processing(self.config)) - export_params["model_builder"] = self._create_model - export_params["model_cfg"] = copy(self.config) - export_params["test_pipeline"] = self._make_fake_test_pipeline() - - return export_params - def _create_model(self) -> nn.Module: from .utils.mmdet import create_model @@ -531,13 +501,6 @@ def _customize_outputs( labels=labels, ) - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - from otx.core.exporter.mmdeploy import MMdeployExporter - - return MMdeployExporter(**self._export_parameters) - class OVInstanceSegmentationModel( OVModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity], diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index 93abbb5a105..e5a1f140ecf 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -5,7 +5,6 @@ from __future__ import annotations -import copy import json from typing import TYPE_CHECKING, Any @@ -13,16 +12,13 @@ from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity -from otx.core.data.entity.tile import T_OTXTileBatchDataEntity -from otx.core.exporter.base import OTXModelExporter -from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput from otx.core.metrics.dice import SegmCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.types.export import TaskLevelExportParameters from otx.core.types.label import SegLabelInfo from otx.core.utils.config import inplace_num_classes -from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -34,7 +30,7 @@ from otx.core.metrics import MetricCallable -class OTXSegmentationModel(OTXModel[SegBatchDataEntity, SegBatchPredEntity, T_OTXTileBatchDataEntity]): +class OTXSegmentationModel(OTXModel[SegBatchDataEntity, SegBatchPredEntity]): """Base class for the detection models used in OTX.""" def __init__( @@ -54,23 +50,15 @@ def __init__( ) @property - def _export_parameters(self) -> dict[str, Any]: + def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" - parameters = super()._export_parameters - hierarchical_config: dict = {} - hierarchical_config["cls_heads_info"] = {} - hierarchical_config["label_tree_edges"] = [] - - parameters["metadata"].update( - { - ("model_info", "model_type"): "Segmentation", - ("model_info", "task_type"): "segmentation", - ("model_info", "return_soft_prediction"): str(True), - ("model_info", "soft_threshold"): str(0.5), - ("model_info", "blur_strength"): str(-1), - }, + return super()._export_parameters.wrap( + model_type="Segmentation", + task_type="segmentation", + return_soft_prediction=True, + soft_threshold=0.5, + blur_strength=-1, ) - return parameters def _convert_pred_entity_to_compute_metric( self, @@ -176,20 +164,6 @@ def _customize_outputs( raise TypeError(output) masks.append(output.pred_sem_seg.data) - if hasattr(self, "explain_hook"): - hook_records = self.explain_hook.records - explain_results = copy.deepcopy(hook_records[-len(outputs) :]) - - return SegBatchPredEntity( - batch_size=len(outputs), - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=[], - masks=masks, - saliency_map=explain_results, - feature_vector=[], - ) - return SegBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -198,25 +172,6 @@ def _customize_outputs( masks=masks, ) - @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - export_params = super()._export_parameters - export_params.update(get_mean_std_from_data_processing(self.config)) - export_params["resize_mode"] = "standard" - export_params["pad_value"] = 0 - export_params["swap_rgb"] = False - export_params["via_onnx"] = False - export_params["input_size"] = self.image_size - export_params["onnx_export_configuration"] = None - - return export_params - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter(**self._export_parameters) - class OVSegmentationModel(OVModel[SegBatchDataEntity, SegBatchPredEntity]): """Semantic segmentation model compatible for OpenVINO IR inference. diff --git a/src/otx/core/model/utils/mmpretrain.py b/src/otx/core/model/utils/mmpretrain.py index ff25374352b..0b946cd71ac 100644 --- a/src/otx/core/model/utils/mmpretrain.py +++ b/src/otx/core/model/utils/mmpretrain.py @@ -12,7 +12,7 @@ from mmpretrain.models.utils import ClsDataPreprocessor as _ClsDataPreprocessor from mmpretrain.registry import MODELS -from otx.algo.hooks.recording_forward_hook import get_feature_vector +from otx.algo.explain.explain_algo import ReciproCAM, get_feature_vector from otx.core.data.entity.base import T_OTXBatchDataEntity, T_OTXBatchPredEntity from otx.core.utils.build import build_mm_model, get_classification_layers @@ -134,9 +134,7 @@ def get_explain_fn(self) -> Callable: Note: Can be redefined at the model's level. """ - from otx.algo.hooks.recording_forward_hook import ReciproCAMHook - - explainer = ReciproCAMHook( + explainer = ReciproCAM( self.head_forward_fn, num_classes=self.num_classes, optimize_gap=self.has_gap, diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index bcb1d9608bf..081cc916554 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -26,7 +26,6 @@ from torchvision import tv_tensors from otx.core.data.entity.base import OTXBatchLossEntity, Points -from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.data.entity.visual_prompting import ( VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity, @@ -39,6 +38,7 @@ from otx.core.metrics.visual_prompting import VisualPromptingMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel from otx.core.schedulers import LRSchedulerListCallable +from otx.core.types.export import TaskLevelExportParameters from otx.core.types.label import LabelInfo, NullLabelInfo from otx.core.utils.mask_util import polygon_to_bitmap @@ -168,9 +168,7 @@ def _inference_step_for_zero_shot( ) -class OTXVisualPromptingModel( - OTXModel[VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity, T_OTXTileBatchDataEntity], -): +class OTXVisualPromptingModel(OTXModel[VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity]): """Base class for the visual prompting models used in OTX.""" def __init__( @@ -188,27 +186,27 @@ def __init__( metric=metric, torch_compile=torch_compile, ) + self._label_info = NullLabelInfo() @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" - return OTXVisualPromptingModelExporter(via_onnx=True, **self._export_parameters) + return OTXVisualPromptingModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, self.model.image_size, self.model.image_size), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="fit_to_window", + via_onnx=True, + ) @property - def _export_parameters(self) -> dict[str, Any]: + def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" - export_params = super()._export_parameters - export_params["metadata"].update( - { - ("model_info", "model_type"): "Visual_Prompting", - ("model_info", "task_type"): "visual_prompting", - }, + return super()._export_parameters.wrap( + model_type="Visual_Prompting", + task_type="visual_prompting", ) - export_params["input_size"] = (1, 3, self.model.image_size, self.model.image_size) - export_params["resize_mode"] = "fit_to_window" - export_params["mean"] = (123.675, 116.28, 103.53) - export_params["std"] = (58.395, 57.12, 57.375) - return export_params @property def _optimization_config(self) -> dict[str, Any]: @@ -275,11 +273,7 @@ def _set_label_info(self, _: LabelInfo | list[str]) -> None: class OTXZeroShotVisualPromptingModel( - OTXModel[ - ZeroShotVisualPromptingBatchDataEntity, - ZeroShotVisualPromptingBatchPredEntity, - T_OTXTileBatchDataEntity, - ], + OTXModel[ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity], ): """Base class for the visual prompting models used in OTX.""" @@ -298,27 +292,27 @@ def __init__( metric=metric, torch_compile=torch_compile, ) + self._label_info = NullLabelInfo() @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" - return OTXVisualPromptingModelExporter(via_onnx=True, **self._export_parameters) + return OTXVisualPromptingModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, self.model.image_size, self.model.image_size), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="fit_to_window", + via_onnx=True, + ) @property - def _export_parameters(self) -> dict[str, Any]: + def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" - export_params = super()._export_parameters - export_params["metadata"].update( - { - ("model_info", "model_type"): "Visual_Prompting", - ("model_info", "task_type"): "visual_prompting", - }, + return super()._export_parameters.wrap( + model_type="Visual_Prompting", + task_type="visual_prompting", ) - export_params["input_size"] = (1, 3, self.model.image_size, self.model.image_size) - export_params["resize_mode"] = "fit_to_window" - export_params["mean"] = (123.675, 116.28, 103.53) - export_params["std"] = (58.395, 57.12, 57.375) - return export_params @property def _optimization_config(self) -> dict[str, Any]: diff --git a/src/otx/core/types/device.py b/src/otx/core/types/device.py index 0d11e0393f7..bd87a5721df 100644 --- a/src/otx/core/types/device.py +++ b/src/otx/core/types/device.py @@ -10,7 +10,7 @@ class DeviceType(str, Enum): """OTX Device type definition.""" - # ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto") + # ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "xpu", "auto") auto = "auto" gpu = "gpu" @@ -19,3 +19,4 @@ class DeviceType(str, Enum): ipu = "ipu" hpu = "hpu" mps = "mps" + xpu = "xpu" diff --git a/src/otx/core/types/export.py b/src/otx/core/types/export.py index af439c6ff96..9291769dccc 100644 --- a/src/otx/core/types/export.py +++ b/src/otx/core/types/export.py @@ -5,8 +5,13 @@ from __future__ import annotations +import json +from dataclasses import dataclass, fields from enum import Enum +from otx.core.config.data import TileConfig +from otx.core.types.label import HLabelInfo, LabelInfo + class OTXExportFormatType(str, Enum): """OTX export format type definition.""" @@ -14,3 +19,143 @@ class OTXExportFormatType(str, Enum): ONNX = "ONNX" OPENVINO = "OPENVINO" EXPORTABLE_CODE = "EXPORTABLE_CODE" + + +@dataclass(frozen=True) +class TaskLevelExportParameters: + """Collection of export parameters which can be defined at a task level. + + Attributes: + model_type (str): Model type field used in ModelAPI. + task_type (str): Task type field used in ModelAPI. + label_info (LabelInfo): OTX label info metadata. + It will be parsed into a format compatible with ModelAPI. + optimization_config (dict): Configurations for NNCF PTQ model optimization. + multilabel (bool | None): Whether it is multilabel or not. + Only specified for the classification task. + hierarchical (bool | None): Whether it is hierarchical or not. + Only specified for the classification task. + confidence_threshold (float | None): Confidence threshold for model prediction probability. + It is used only for classification tasks, detection and instance segmentation tasks. + iou_threshold (float | None): The Intersection over Union (IoU) threshold + for Non-Maximum Suppression (NMS) post-processing. + It is used only for models in detection and instance segmentation tasks. + return_soft_prediction (bool | None): Whether to return soft prediction. + It is used only for semantic segmentation tasks. + soft_threshold (float | None): Minimum class confidence for each pixel. + The higher the value, the more strict the segmentation is (usually set to 0.5). + Only specified for semantic segmentation tasks. + blur_strength (int | None): The higher the value, the smoother the + segmentation output will be, but less accurate. + Only specified for semantic segmentation tasks. + tile_config (TileConfig | None): Configuration for tiling models + If None, the model is not trained with tiling. + """ + + # Common + model_type: str + task_type: str + label_info: LabelInfo + optimization_config: dict + + # (Optional) Classification tasks + multilabel: bool | None = None + hierarchical: bool | None = None + + # (Optional) Classification tasks, detection and instance segmentation task + confidence_threshold: float | None = None + + # (Optional) Detection and instance segmentation task + iou_threshold: float | None = None + + # (Optional) Semantic segmentation task + return_soft_prediction: bool | None = None + soft_threshold: float | None = None + blur_strength: int | None = None + + # (Optional) Tasks with tiling + tile_config: TileConfig | None = None + + def wrap(self, **kwargs_to_update) -> TaskLevelExportParameters: + """Create a new instance by wrapping it with the given keyword arguments. + + Args: + kwargs_to_update (dict): Keyword arguments to update. + + Returns: + TaskLevelExportParameters: A new instance with updated attributes. + """ + updated_kwargs = {field.name: getattr(self, field.name) for field in fields(self)} + updated_kwargs.update(kwargs_to_update) + return TaskLevelExportParameters(**updated_kwargs) + + def to_metadata(self) -> dict[tuple[str, str], str]: + """Convert this dataclass to dictionary format compatible with ModelAPI. + + Returns: + dict[tuple[str, str], str]: It will be directly delivered to + OpenVINO IR's `rt_info` or ONNX metadata slot. + """ + if self.task_type == "instance_segmentation": + # Instance segmentation needs to add empty label + all_labels = "otx_empty_lbl " + all_label_ids = "None " + for lbl in self.label_info.label_names: + all_labels += lbl.replace(" ", "_") + " " + all_label_ids += lbl.replace(" ", "_") + " " + else: + all_labels = "" + all_label_ids = "" + for lbl in self.label_info.label_names: + all_labels += lbl.replace(" ", "_") + " " + all_label_ids += lbl.replace(" ", "_") + " " + + metadata = { + # Common + ("model_info", "model_type"): self.model_type, + ("model_info", "task_type"): self.task_type, + ("model_info", "label_info"): self.label_info.to_json(), + ("model_info", "labels"): all_labels.strip(), + ("model_info", "label_ids"): all_label_ids.strip(), + ("model_info", "optimization_config"): json.dumps(self.optimization_config), + } + + if isinstance(self.label_info, HLabelInfo): + metadata[("model_info", "hierarchical_config")] = json.dumps( + { + "cls_heads_info": self.label_info.as_dict(), + "label_tree_edges": self.label_info.label_tree_edges, + }, + ) + + if self.multilabel is not None: + metadata[("model_info", "multilabel")] = str(self.multilabel) + + if self.hierarchical is not None: + metadata[("model_info", "hierarchical")] = str(self.hierarchical) + + if self.confidence_threshold is not None: + metadata[("model_info", "confidence_threshold")] = str(self.confidence_threshold) + + if self.iou_threshold is not None: + metadata[("model_info", "iou_threshold")] = str(self.iou_threshold) + + if self.return_soft_prediction is not None: + metadata[("model_info", "return_soft_prediction")] = str(self.return_soft_prediction) + + if self.soft_threshold is not None: + metadata[("model_info", "soft_threshold")] = str(self.soft_threshold) + + if self.blur_strength is not None: + metadata[("model_info", "blur_strength")] = str(self.blur_strength) + + if self.tile_config is not None: + metadata.update( + { + ("model_info", "tile_size"): str(self.tile_config.tile_size[0]), + ("model_info", "tiles_overlap"): str(self.tile_config.overlap), + ("model_info", "max_pred_number"): str(self.tile_config.max_num_instances), + }, + ) + + return metadata diff --git a/src/otx/core/utils/tile_merge.py b/src/otx/core/utils/tile_merge.py index 41fe707cbd3..a99cb5d24aa 100644 --- a/src/otx/core/utils/tile_merge.py +++ b/src/otx/core/utils/tile_merge.py @@ -9,10 +9,13 @@ from collections import defaultdict from typing import Generic +import cv2 +import numpy as np import torch from torchvision import tv_tensors from torchvision.ops import batched_nms +from otx.core.config.data import TileConfig from otx.core.data.entity.base import ImageInfo, T_OTXBatchPredEntity, T_OTXDataEntity from otx.core.data.entity.detection import DetBatchPredEntity, DetPredEntity from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity, InstanceSegPredEntity @@ -31,20 +34,28 @@ class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]): def __init__( self, img_infos: list[ImageInfo], - iou_threshold: float = 0.45, - max_num_instances: int = 500, + num_classes: int, + tile_config: TileConfig, ) -> None: self.img_infos = img_infos - self.iou_threshold = iou_threshold - self.max_num_instances = max_num_instances + self.num_classes = num_classes + self.tile_size = tile_config.tile_size + self.iou_threshold = tile_config.iou_threshold + self.max_num_instances = tile_config.max_num_instances @abstractmethod - def _merge_entities(self, img_info: ImageInfo, entities: list[T_OTXDataEntity]) -> T_OTXDataEntity: + def _merge_entities( + self, + img_info: ImageInfo, + entities: list[T_OTXDataEntity], + explain_mode: bool = False, + ) -> T_OTXDataEntity: """Merge tile predictions to one single full-size prediction data entity. Args: img_info (ImageInfo): Image information about the original image before tiling. entities (list[T_OTXDataEntity]): List of tile prediction entities. + explain_mode (bool): Whether or not tiles have explain features. Default: False. Returns: T_OTXDataEntity: Merged prediction entity. @@ -102,14 +113,20 @@ def merge( """ entities_to_merge = defaultdict(list) img_ids = [] + explain_mode = len(batch_tile_preds[0].feature_vector) > 0 for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs): - for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores in zip( + batch_size = tile_preds.batch_size + saliency_maps = tile_preds.saliency_map if explain_mode else [[] for _ in range(batch_size)] + feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(batch_size)] + for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_s_map, tile_f_vect in zip( tile_attrs, tile_preds.imgs_info, tile_preds.bboxes, tile_preds.labels, tile_preds.scores, + saliency_maps, + feature_vectors, ): offset_x, offset_y, _, _ = tile_attr["roi"] tile_bboxes[:, 0::2] += offset_x @@ -120,26 +137,36 @@ def merge( img_ids.append(tile_id) tile_img_info.padding = tile_attr["roi"] - entities_to_merge[tile_id].append( - DetPredEntity( - image=torch.empty(tile_img_info.ori_shape), - img_info=tile_img_info, - bboxes=tile_bboxes, - labels=tile_labels, - score=tile_scores, - ), + det_pred_entity = DetPredEntity( + image=torch.empty(tile_img_info.ori_shape), + img_info=tile_img_info, + bboxes=tile_bboxes, + labels=tile_labels, + score=tile_scores, ) + + if explain_mode: + det_pred_entity.feature_vector = tile_f_vect + det_pred_entity.saliency_map = tile_s_map + entities_to_merge[tile_id].append(det_pred_entity) + return [ - self._merge_entities(image_info, entities_to_merge[img_id]) + self._merge_entities(image_info, entities_to_merge[img_id], explain_mode) for img_id, image_info in zip(img_ids, self.img_infos) ] - def _merge_entities(self, img_info: ImageInfo, entities: list[DetPredEntity]) -> DetPredEntity: + def _merge_entities( + self, + img_info: ImageInfo, + entities: list[DetPredEntity], + explain_mode: bool = False, + ) -> DetPredEntity: """Merge tile predictions to one single prediction. Args: img_info (ImageInfo): Image information about the original image before tiling. entities (list[DetPredEntity]): List of tile prediction entities. + explain_mode (bool): Whether or not tiles have explain features. Default: False. Returns: DetPredEntity: Merged prediction entity. @@ -147,6 +174,9 @@ def _merge_entities(self, img_info: ImageInfo, entities: list[DetPredEntity]) -> bboxes: list | torch.Tensor = [] labels: list | torch.Tensor = [] scores: list | torch.Tensor = [] + feature_vectors = [] + saliency_maps = [] + tiles_coords = [] img_size = img_info.ori_shape for tile_entity in entities: num_preds = len(tile_entity.bboxes) @@ -154,29 +184,110 @@ def _merge_entities(self, img_info: ImageInfo, entities: list[DetPredEntity]) -> bboxes.extend(tile_entity.bboxes) labels.extend(tile_entity.labels) scores.extend(tile_entity.score) + if explain_mode: + tiles_coords.append(tile_entity.img_info.padding) + feature_vectors.append(tile_entity.feature_vector) + saliency_maps.append(tile_entity.saliency_map) bboxes = torch.stack(bboxes) if len(bboxes) > 0 else torch.empty((0, 4), device=img_info.device) labels = torch.stack(labels) if len(labels) > 0 else torch.empty((0,), device=img_info.device) scores = torch.stack(scores) if len(scores) > 0 else torch.empty((0,), device=img_info.device) - bboxes, labels, scores, _ = self.nms_postprocess( - bboxes, - scores, - labels, - ) + bboxes, labels, scores, _ = self.nms_postprocess(bboxes, scores, labels) - return DetPredEntity( + det_pred_entity = DetPredEntity( image=torch.empty(img_size), img_info=img_info, score=scores, - bboxes=tv_tensors.BoundingBoxes( - bboxes, - canvas_size=img_size, - format="XYXY", - ), + bboxes=tv_tensors.BoundingBoxes(bboxes, canvas_size=img_size, format="XYXY"), labels=labels, ) + if explain_mode: + merged_vector = np.mean(feature_vectors, axis=0) + merged_saliency_map = self._merge_saliency_maps(saliency_maps, img_size, tiles_coords) + det_pred_entity.feature_vector = merged_vector + det_pred_entity.saliency_map = merged_saliency_map + + return det_pred_entity + + def _merge_saliency_maps( + self, + saliency_maps: list[np.array], + shape: tuple[int, int], + tiles_coords: list[tuple[int, int, int, int]], + ) -> np.ndarray: + """Merging saliency maps from each tile for PyTorch implementation. + + OV implementation is on ModelAPI side. Unlike ModelAPI implementation, + it doesn't have the first tile with resized untiled image. + + Args: + saliency_maps: list of saliency maps, shape of each map is (Nc, H, W) + shape: shape of the original image + tiles_coords: coordinates of tiles + + Returns: + Merged saliency map with shape (Nc, H, W) + """ + if len(saliency_maps) == 1: + return saliency_maps[0] + + if len(saliency_maps[0].shape) == 1: + return np.ndarray([]) + + num_classes = saliency_maps[0].shape[0] + map_h, map_w = saliency_maps[0].shape[1:] + + image_h, image_w = shape + ratio = map_h / min(image_h, self.tile_size[0]), map_w / min(image_w, self.tile_size[1]) + + image_map_h = int(image_h * ratio[0]) + image_map_w = int(image_w * ratio[1]) + merged_map = np.zeros((num_classes, image_map_h, image_map_w)) + + for i, saliency_map in enumerate(saliency_maps): + for class_idx in range(num_classes): + cls_map = saliency_map[class_idx] + + x_1, y_1, map_w, map_h = tiles_coords[i] + x_2, y_2 = x_1 + map_w, y_1 + map_h + + y_1, x_1 = int(y_1 * ratio[0]), int(x_1 * ratio[1]) + y_2, x_2 = int(y_2 * ratio[0]), int(x_2 * ratio[1]) + + map_h, map_w = cls_map.shape + + if (map_h > y_2 - y_1 > 0) and (map_w > x_2 - x_1 > 0): + cls_map = cv2.resize(cls_map, (x_2 - x_1, y_2 - y_1)) + + map_h, map_w = y_2 - y_1, x_2 - x_1 + + for hi, wi in [(h_, w_) for h_ in range(map_h) for w_ in range(map_w)]: + map_pixel = cls_map[hi, wi] + merged_pixel = merged_map[class_idx][y_1 + hi, x_1 + wi] + if merged_pixel != 0: + merged_map[class_idx][y_1 + hi, x_1 + wi] = 0.5 * (map_pixel + merged_pixel) + else: + merged_map[class_idx][y_1 + hi, x_1 + wi] = map_pixel + + for class_idx in range(num_classes): + merged_map[class_idx] = _non_linear_normalization(merged_map[class_idx]) + + return merged_map.astype(np.uint8) + + +def _non_linear_normalization(saliency_map: np.ndarray) -> np.ndarray: + """Use non-linear normalization y=x**1.5 for 2D saliency maps.""" + min_soft_score = np.min(saliency_map) + # Make merged_map distribution positive to perform non-linear normalization y=x**1.5 + saliency_map = (saliency_map - min_soft_score) ** 1.5 + + max_soft_score = np.max(saliency_map) + saliency_map = 255.0 / (max_soft_score + 1e-12) * saliency_map + + return np.floor(saliency_map) + class InstanceSegTileMerge(TileMerge): """Instance segmentation tile merge.""" @@ -195,15 +306,18 @@ def merge( """ entities_to_merge = defaultdict(list) img_ids = [] + explain_mode = len(batch_tile_preds[0].feature_vector) > 0 for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs): - for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_masks in zip( + feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(tile_preds.batch_size)] + for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_masks, tile_f_vect in zip( tile_attrs, tile_preds.imgs_info, tile_preds.bboxes, tile_preds.labels, tile_preds.scores, tile_preds.masks, + feature_vectors, ): keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0 keep_indices = keep_indices.nonzero(as_tuple=True)[0] @@ -221,24 +335,32 @@ def merge( img_ids.append(tile_id) tile_img_info.padding = tile_attr["roi"] - entities_to_merge[tile_id].append( - InstanceSegPredEntity( - image=torch.empty(tile_img_info.ori_shape), - img_info=tile_img_info, - bboxes=_bboxes, - labels=_labels, - score=_scores, - masks=_masks.to_sparse(), - polygons=[], - ), + inst_seg_pred_entity = InstanceSegPredEntity( + image=torch.empty(tile_img_info.ori_shape), + img_info=tile_img_info, + bboxes=_bboxes, + labels=_labels, + score=_scores, + masks=_masks.to_sparse(), + polygons=[], ) + if explain_mode: + inst_seg_pred_entity.feature_vector = tile_f_vect + inst_seg_pred_entity.saliency_map = [] + entities_to_merge[tile_id].append(inst_seg_pred_entity) + return [ - self._merge_entities(image_info, entities_to_merge[img_id]) + self._merge_entities(image_info, entities_to_merge[img_id], explain_mode) for img_id, image_info in zip(img_ids, self.img_infos) ] - def _merge_entities(self, img_info: ImageInfo, entities: list[InstanceSegPredEntity]) -> InstanceSegPredEntity: + def _merge_entities( + self, + img_info: ImageInfo, + entities: list[InstanceSegPredEntity], + explain_mode: bool = False, + ) -> InstanceSegPredEntity: """Merge tile predictions to one single prediction. Args: @@ -252,6 +374,7 @@ def _merge_entities(self, img_info: ImageInfo, entities: list[InstanceSegPredEnt labels: list | torch.Tensor = [] scores: list | torch.Tensor = [] masks: list | torch.Tensor = [] + feature_vectors = [] img_size = img_info.ori_shape for tile_entity in entities: num_preds = len(tile_entity.bboxes) @@ -268,6 +391,8 @@ def _merge_entities(self, img_info: ImageInfo, entities: list[InstanceSegPredEnt masks.extend( torch.sparse_coo_tensor(mask_indices, mask_values, (num_preds, *img_size)), ) + if explain_mode: + feature_vectors.append(tile_entity.feature_vector) bboxes = torch.stack(bboxes) if len(bboxes) > 0 else torch.empty((0, 4), device=img_info.device) labels = torch.stack(labels) if len(labels) > 0 else torch.empty((0,), device=img_info.device) @@ -275,16 +400,41 @@ def _merge_entities(self, img_info: ImageInfo, entities: list[InstanceSegPredEnt masks = masks if len(masks) > 0 else torch.empty((0, *img_size)) bboxes, labels, scores, masks = self.nms_postprocess(bboxes, scores, labels, masks) - return InstanceSegPredEntity( + + inst_seg_pred_entity = InstanceSegPredEntity( image=torch.empty(img_size), img_info=img_info, score=scores, - bboxes=tv_tensors.BoundingBoxes( - bboxes, - canvas_size=img_size, - format="XYXY", - ), + bboxes=tv_tensors.BoundingBoxes(bboxes, canvas_size=img_size, format="XYXY"), labels=labels, masks=tv_tensors.Mask(masks, dtype=bool), polygons=[], ) + + if explain_mode: + merged_vector = np.mean(feature_vectors, axis=0) + merged_saliency_map = self.get_saliency_maps_from_masks(labels, scores, masks, self.num_classes) + inst_seg_pred_entity.feature_vector = merged_vector + inst_seg_pred_entity.saliency_map = merged_saliency_map + + return inst_seg_pred_entity + + def get_saliency_maps_from_masks( + self, + labels: torch.Tensor, + scores: torch.Tensor, + masks: None | torch.Tensor, + num_classes: int, + ) -> np.ndarray: + """Average and normalize predicted masks in per-class. + + Returns: + np.array: Class-wise Saliency Maps. One saliency map per each class - [class_id, H, W] + """ + from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo + + if masks is None: + return np.ndarray([]) + + pred = {"labels": labels, "scores": scores, "masks": masks} + return MaskRCNNExplainAlgo.average_and_normalize(pred, num_classes) diff --git a/src/otx/core/utils/utils.py b/src/otx/core/utils/utils.py index 006279e5cbb..fe93a6bfd8e 100644 --- a/src/otx/core/utils/utils.py +++ b/src/otx/core/utils/utils.py @@ -7,7 +7,7 @@ from collections import defaultdict from multiprocessing import cpu_count -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import torch from datumaro.components.annotation import AnnotationType, LabelCategories @@ -41,19 +41,25 @@ def is_ckpt_for_finetuning(ckpt: dict) -> bool: return "state_dict" in ckpt -def get_mean_std_from_data_processing(config: DictConfig) -> dict[str, Any]: +def get_mean_std_from_data_processing( + config: DictConfig, +) -> tuple[tuple[float, float, float], tuple[float, float, float]]: """Get mean and std value from data_processing. Args: config (DictConfig): MM framework model config. Returns: - dict[str, Any]: Dictionary with mean and std value. + tuple[tuple[float, float, float], tuple[float, float, float]]: + Tuple of mean and std values. + + Examples: + >>> mean, std = get_mean_std_from_data_processing(config) """ - return { - "mean": config["data_preprocessor"]["mean"], - "std": config["data_preprocessor"]["std"], - } + return ( + config["data_preprocessor"]["mean"], + config["data_preprocessor"]["std"], + ) def get_adaptive_num_workers(num_dataloader: int = 1) -> int | None: diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index f1e636c60ba..ae4294e71fd 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -16,6 +16,7 @@ import torch from lightning import Trainer, seed_everything +from otx.algo.plugins import MixedPrecisionXPUPlugin from otx.core.config.device import DeviceConfig from otx.core.config.explain import ExplainConfig from otx.core.config.hpo import HpoConfig @@ -27,6 +28,7 @@ from otx.core.types.precision import OTXPrecisionType from otx.core.types.task import OTXTaskType from otx.core.utils.cache import TrainerArgumentsCache +from otx.utils.utils import is_xpu_available from .hpo import execute_hpo, update_hyper_parameter from .utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, AutoConfigurator @@ -179,7 +181,8 @@ def train( resume: bool = False, metric: MetricCallable | None = None, run_hpo: bool = False, - hpo_config: HpoConfig | None = None, + hpo_config: HpoConfig = HpoConfig(), # noqa: B008 https://github.com/omni-us/jsonargparse/issues/423 + checkpoint: PathLike | None = None, **kwargs, ) -> dict[str, Any]: """Trains the model using the provided LightningModule and OTXDataModule. @@ -199,6 +202,7 @@ def train( metric callable. It will temporarilly change the evaluation metric for the validation and test. run_hpo (bool, optional): If True, optimizer hyper parameters before training a model. hpo_config (HpoConfig | None, optional): Configuration for HPO. + checkpoint (PathLike | None, optional): Path to the checkpoint file. Defaults to None. **kwargs: Additional keyword arguments for pl.Trainer configuration. Returns: @@ -234,14 +238,14 @@ def train( otx train --data_root --config ``` """ + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + if run_hpo: - if hpo_config is None: - hpo_config = HpoConfig() best_config, best_trial_weight = execute_hpo(engine=self, **locals()) if best_config is not None: update_hyper_parameter(self, best_config) if best_trial_weight is not None: - self.checkpoint = best_trial_weight + checkpoint = best_trial_weight resume = True if seed is not None: @@ -258,7 +262,7 @@ def train( ) fit_kwargs: dict[str, Any] = {} - # NOTE Model's label info should be converted datamodule's label info before ckpt loading + # NOTE: Model's label info should be converted datamodule's label info before ckpt loading # This is due to smart weight loading check label name as well as number of classes. if self.model.label_info != self.datamodule.label_info: # TODO (vinnamki): Revisit label_info logic to make it cleaner @@ -269,12 +273,17 @@ def train( logging.warning(msg) self.model.label_info = self.datamodule.label_info - if resume: - fit_kwargs["ckpt_path"] = self.checkpoint - elif self.checkpoint is not None: - loaded_checkpoint = torch.load(self.checkpoint) - # loaded checkpoint have keys (OTX1.5): model, config, labels, input_size, VERSION - self.model.load_state_dict(loaded_checkpoint) + if resume and checkpoint: + # NOTE: If both `resume` and `checkpoint` are provided, + # load the entire model state from the checkpoint using the pl.Trainer's API. + fit_kwargs["ckpt_path"] = checkpoint + elif not resume and checkpoint: + # NOTE: If `resume` is not enabled but `checkpoint` is provided, + # load the model state from the checkpoint incrementally. + # This means only the model weights are loaded. If there is a mismatch in label_info, + # perform incremental weight loading for the model's classification layer. + ckpt = torch.load(checkpoint) + self.model.load_state_dict_incrementally(ckpt) with override_metric_callable(model=self.model, new_metric_callable=metric) as model: self.trainer.fit( @@ -333,20 +342,6 @@ def test( otx test --config --checkpoint ``` """ - # NOTE Model's label info should be converted datamodule's label info before ckpt loading - # This is due to smart weight loading check label name as well as number of classes. - if self.model.label_info != self.datamodule.label_info: - # TODO (vinnamki): Revisit label_info logic to make it cleaner - msg = ( - "Model label_info is not equal to the Datamodule label_info. " - f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" - ) - logging.warning(msg) - self.model.label_info = self.datamodule.label_info - - # TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test - # raise ValueError() - model = self.model checkpoint = checkpoint if checkpoint is not None else self.checkpoint datamodule = datamodule if datamodule is not None else self.datamodule @@ -366,8 +361,18 @@ def test( # NOTE, trainer.test takes only lightning based checkpoint. # So, it can't take the OTX1.x checkpoint. if checkpoint is not None and not is_ir_ckpt: - loaded_checkpoint = torch.load(checkpoint) - model.load_state_dict(loaded_checkpoint) + model_cls = self.model.__class__ + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint) + + if model.label_info != self.datamodule.label_info: + msg = ( + "To launch a test pipeline, the label information should be same " + "between the training and testing datasets. " + "Please check whether you use the same dataset: " + f"model.label_info={model.label_info}, " + f"datamodule.label_info={self.datamodule.label_info}" + ) + raise ValueError(msg) self._build_trainer(**kwargs) @@ -423,20 +428,6 @@ def predict( """ from otx.algo.utils.xai_utils import process_saliency_maps_in_pred_entity - # NOTE Model's label info should be converted datamodule's label info before ckpt loading - # This is due to smart weight loading check label name as well as number of classes. - if self.model.label_info != self.datamodule.label_info: - # TODO (vinnamki): Revisit label_info logic to make it cleaner - msg = ( - "Model label_info is not equal to the Datamodule label_info. " - f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" - ) - logging.warning(msg) - self.model.label_info = self.datamodule.label_info - - # TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test - # raise ValueError() - model = self.model checkpoint = checkpoint if checkpoint is not None else self.checkpoint @@ -451,8 +442,18 @@ def predict( datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") if checkpoint is not None and not is_ir_ckpt: - loaded_checkpoint = torch.load(checkpoint) - model.load_state_dict(loaded_checkpoint) + model_cls = self.model.__class__ + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint) + + if model.label_info != self.datamodule.label_info: + msg = ( + "To launch a predict pipeline, the label information should be same " + "between the training and testing datasets. " + "Please check whether you use the same dataset: " + f"model.label_info={model.label_info}, " + f"datamodule.label_info={self.datamodule.label_info}" + ) + raise ValueError(msg) self._build_trainer(**kwargs) @@ -516,11 +517,12 @@ def export( --checkpoint --export_precision FP16 --export_format ONNX ``` """ - ckpt_path = str(checkpoint) if checkpoint is not None else self.checkpoint - if ckpt_path is None: + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + + if checkpoint is None: msg = "To make export, checkpoint must be specified." raise RuntimeError(msg) - is_ir_ckpt = Path(ckpt_path).suffix in [".xml"] + is_ir_ckpt = Path(checkpoint).suffix in [".xml"] if is_ir_ckpt and export_format != OTXExportFormatType.EXPORTABLE_CODE: msg = ( @@ -538,10 +540,9 @@ def export( ) if not is_ir_ckpt: + model_cls = self.model.__class__ + self.model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, map_location="cpu") self.model.eval() - loaded_checkpoint = torch.load(ckpt_path) - self.model.label_info = loaded_checkpoint["state_dict"]["label_info"] - self.model.load_state_dict(loaded_checkpoint) self.model.explain_mode = explain exported_model_path = self.model.export( @@ -679,11 +680,19 @@ def explain( datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) - model.label_info = datamodule.label_info - if checkpoint is not None and not is_ir_ckpt: - loaded_checkpoint = torch.load(checkpoint) - model.load_state_dict(loaded_checkpoint) + model_cls = model.__class__ + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint) + + if model.label_info != self.datamodule.label_info: + msg = ( + "To launch a explain pipeline, the label information should be same " + "between the training and testing datasets. " + "Please check whether you use the same dataset: " + f"model.label_info={model.label_info}, " + f"datamodule.label_info={self.datamodule.label_info}" + ) + raise ValueError(msg) model.explain_mode = True @@ -726,7 +735,7 @@ def from_config( Defaults to None. If work_dir is None, use the work_dir from the configuration file. kwargs: Arguments that can override the engine's arguments. - Returns:s + Returns: Engine: An instance of the Engine class. Example: @@ -765,10 +774,24 @@ def from_config( ) warn(msg, stacklevel=1) + if (datamodule := instantiated_config.get("data")) is None: + msg = "Cannot instantiate datamodule from config." + raise ValueError(msg) + if not isinstance(datamodule, OTXDataModule): + raise TypeError(datamodule) + + if (model := instantiated_config.get("model")) is None: + msg = "Cannot instantiate model from config." + raise ValueError(msg) + if not isinstance(model, OTXModel): + raise TypeError(model) + + model.label_info = datamodule.label_info + return cls( work_dir=instantiated_config.get("work_dir", work_dir), - datamodule=instantiated_config.get("data"), - model=instantiated_config.get("model"), + datamodule=datamodule, + model=model, **engine_kwargs, ) @@ -856,6 +879,8 @@ def device(self) -> DeviceConfig: @device.setter def device(self, device: DeviceType) -> None: + if is_xpu_available() and device == DeviceType.auto: + device = DeviceType.xpu self._device = DeviceConfig(accelerator=device) self._cache.update(accelerator=self._device.accelerator, devices=self._device.devices) self._cache.is_trainer_args_identical = False @@ -878,6 +903,14 @@ def _build_trainer(self, **kwargs) -> None: """Instantiate the trainer based on the model parameters.""" if self._cache.requires_update(**kwargs) or self._trainer is None: self._cache.update(**kwargs) + # set up xpu device + if self._device.accelerator == DeviceType.xpu: + self._cache.update(strategy="xpu_single") + # add plugin for Automatic Mixed Precision on XPU + if self._cache.args.get("precision", 32) == 16: + self._cache.update(plugins=[MixedPrecisionXPUPlugin()]) + self._cache.args["precision"] = None + kwargs = self._cache.args self._trainer = Trainer(**kwargs) self._cache.is_trainer_args_identical = True diff --git a/src/otx/engine/hpo/hpo_api.py b/src/otx/engine/hpo/hpo_api.py index 60a8a0e637c..c974e19b998 100644 --- a/src/otx/engine/hpo/hpo_api.py +++ b/src/otx/engine/hpo/hpo_api.py @@ -22,9 +22,10 @@ from otx.utils.utils import get_decimal_point, get_using_dot_delimited_key, remove_matched_files from .hpo_trial import run_hpo_trial -from .utils import find_trial_file, get_best_hpo_weight, get_hpo_weight_dir +from .utils import find_trial_file, get_best_hpo_weight, get_callable_args_name, get_hpo_weight_dir, get_metric if TYPE_CHECKING: + from lightning import Callback from lightning.pytorch.cli import OptimizerCallable from otx.engine.engine import Engine @@ -34,16 +35,17 @@ AVAILABLE_HP_NAME_MAP = { "data.config.train_subset.batch_size": "datamodule.config.train_subset.batch_size", - "optimizer": "optimizer.keywords", - "scheduler": "scheduler.keywords", + "optimizer": "optimizer_callable.optimizer_kwargs", + # "scheduler": "scheduler.keywords", NOTE need to revisit after SchedulerCallableSupportHPO is implemted } def execute_hpo( engine: Engine, max_epochs: int, - hpo_config: HpoConfig | None = None, + hpo_config: HpoConfig, progress_update_callback: Callable[[int | float], None] | None = None, + callbacks: list[Callback] | Callback | None = None, **train_args, ) -> tuple[dict[str, Any] | None, Path | None]: """Execute HPO. @@ -51,9 +53,10 @@ def execute_hpo( Args: engine (Engine): engine instnace. max_epochs (int): max epochs to train. - hpo_config (HpoConfig | None, optional): Configuration for HPO. + hpo_config (HpoConfig): Configuration for HPO. progress_update_callback (Callable[[int | float], None] | None, optional): callback to update progress. If it's given, it's called with progress every second. Defaults to None. + callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None. Returns: tuple[dict[str, Any] | None, Path | None]: @@ -61,18 +64,22 @@ def execute_hpo( return None. """ if engine.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: # type: ignore[has-type] - logger.warning("Zero shot visual prompting task doesn't support HPO.") - return None, None + msg = "Zero shot visual prompting task doesn't support HPO." + raise RuntimeError(msg) + if "anomaly.padim" in str(type(engine.model)).lower(): + msg = "Padim doesn't need HPO. HPO is skipped." + raise RuntimeError(msg) engine.model.patch_optimizer_and_scheduler_for_hpo() hpo_workdir = Path(engine.work_dir) / "hpo" hpo_workdir.mkdir(exist_ok=True) hpo_configurator = HPOConfigurator( - engine, - max_epochs, - hpo_workdir, - hpo_config, + engine=engine, + max_epochs=max_epochs, + hpo_config=hpo_config, + hpo_workdir=hpo_workdir, + callbacks=callbacks, ) if (hpo_algo := hpo_configurator.get_hpo_algo()) is None: logger.warning("HPO is skipped.") @@ -88,9 +95,12 @@ def execute_hpo( hpo_workdir=hpo_workdir, engine=engine, max_epochs=max_epochs, + callbacks=callbacks, + metric_name=hpo_config.metric_name, **_adjust_train_args(train_args), ), "gpu" if torch.cuda.is_available() else "cpu", + num_parallel_trial=hpo_configurator.hpo_config["num_workers"], ) best_trial = hpo_algo.get_best_config() @@ -113,21 +123,24 @@ class HPOConfigurator: Args: engine (Engine): engine instance. - max_epoch (int): max epochs to train. + max_epochs (int): max epochs to train. + hpo_config (HpoConfig): Configuration for HPO. hpo_workdir (Path | None, optional): HPO work directory. Defaults to None. - hpo_config (HpoConfig | None, optional): Configuration for HPO. + callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None. """ def __init__( self, engine: Engine, - max_epoch: int, + max_epochs: int, + hpo_config: HpoConfig, hpo_workdir: Path | None = None, - hpo_config: HpoConfig | None = None, + callbacks: list[Callback] | Callback | None = None, ) -> None: self._engine = engine - self._max_epoch = max_epoch + self._max_epochs = max_epochs self._hpo_workdir = hpo_workdir if hpo_workdir is not None else Path(engine.work_dir) / "hpo" + self._callbacks = callbacks self.hpo_config: dict[str, Any] = hpo_config # type: ignore[assignment] @property @@ -136,19 +149,40 @@ def hpo_config(self) -> dict[str, Any]: return self._hpo_config @hpo_config.setter - def hpo_config(self, hpo_config: HpoConfig | None) -> None: - train_dataset_size = len(self._engine.datamodule.train_dataloader()) + def hpo_config(self, hpo_config: HpoConfig) -> None: + train_dataset_size = len( + self._engine.datamodule.subsets[self._engine.datamodule.config.train_subset.subset_name], + ) + + if hpo_config.metric_name is None: + if self._callbacks is None: + msg = ( + "HPOConfigurator can't find the metric because callback doesn't exist. " + "Please set hpo_config.metric_name." + ) + raise RuntimeError(msg) + hpo_config.metric_name = get_metric(self._callbacks) + + if "loss" in hpo_config.metric_name and hpo_config.mode == "max": + logger.warning( + f"Because metric for HPO is {hpo_config.metric_name}, hpo_config.mode is changed from max to min.", + ) + hpo_config.mode = "min" self._hpo_config: dict[str, Any] = { # default setting "save_path": str(self._hpo_workdir), - "num_full_iterations": self._max_epoch, + "num_full_iterations": self._max_epochs, "full_dataset_size": train_dataset_size, } - if hpo_config is not None: - self._hpo_config.update( - {key: val for key, val in dataclasses.asdict(hpo_config).items() if val is not None}, - ) + hb_arg_names = get_callable_args_name(HyperBand) + self._hpo_config.update( + { + key: val + for key, val in dataclasses.asdict(hpo_config).items() + if val is not None and key in hb_arg_names + }, + ) if "search_space" not in self._hpo_config: self._hpo_config["search_space"] = self._get_default_search_space() diff --git a/src/otx/engine/hpo/hpo_trial.py b/src/otx/engine/hpo/hpo_trial.py index 6b519af504c..a4e0f8ac435 100644 --- a/src/otx/engine/hpo/hpo_trial.py +++ b/src/otx/engine/hpo/hpo_trial.py @@ -16,7 +16,7 @@ from otx.hpo import TrialStatus from otx.utils.utils import find_file_recursively, remove_matched_files, set_using_dot_delimited_key -from .utils import find_trial_file, get_best_hpo_weight, get_hpo_weight_dir +from .utils import find_trial_file, get_best_hpo_weight, get_hpo_weight_dir, get_metric if TYPE_CHECKING: from lightning import LightningModule, Trainer @@ -51,6 +51,7 @@ def run_hpo_trial( hpo_workdir: Path, engine: Engine, callbacks: list[Callback] | Callback | None = None, + metric_name: str | None = None, **train_args, ) -> None: """Run HPO trial. After it's done, best weight and last weight are saved for later use. @@ -61,6 +62,8 @@ def run_hpo_trial( hpo_workdir (Path): HPO work directory. engine (Engine): engine instance. callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None. + metric_name (str | None, optional): + metric name to determine trial performance. If it's None, get it from ModelCheckpoint callback. train_args: Arugments for 'engine.train'. """ trial_id = hp_config["id"] @@ -69,10 +72,10 @@ def run_hpo_trial( _set_trial_hyper_parameter(hp_config["configuration"], engine, train_args) if (checkpoint := _find_last_weight(hpo_weight_dir)) is not None: - engine.checkpoint = checkpoint + train_args["checkpoint"] = checkpoint train_args["resume"] = True - callbacks = _register_hpo_callback(report_func, callbacks) + callbacks = _register_hpo_callback(report_func, callbacks, metric_name) _set_to_validate_every_epoch(callbacks, train_args) with TemporaryDirectory(prefix="OTX-HPO-") as temp_dir: @@ -93,23 +96,19 @@ def _find_last_weight(weight_dir: Path) -> Path | None: return find_file_recursively(weight_dir, "last.ckpt") -def _register_hpo_callback(report_func: Callable, callbacks: list[Callback] | Callback | None) -> list[Callback]: +def _register_hpo_callback( + report_func: Callable, + callbacks: list[Callback] | Callback | None = None, + metric_name: str | None = None, +) -> list[Callback]: if isinstance(callbacks, Callback): callbacks = [callbacks] elif callbacks is None: callbacks = [] - callbacks.append(HPOCallback(report_func, _get_metric(callbacks))) + callbacks.append(HPOCallback(report_func, get_metric(callbacks) if metric_name is None else metric_name)) return callbacks -def _get_metric(callbacks: list[Callback]) -> str: - for callback in callbacks: - if isinstance(callback, ModelCheckpoint): - return callback.monitor - error_msg = "Failed to find a metric. There is no ModelCheckpoint in callback list." - raise RuntimeError(error_msg) - - def _set_to_validate_every_epoch(callbacks: list[Callback], train_args: dict[str, Any]) -> None: for callback in callbacks: if isinstance(callback, AdaptiveTrainScheduling): diff --git a/src/otx/engine/hpo/utils.py b/src/otx/engine/hpo/utils.py index b2c43846f8d..28586406ca4 100644 --- a/src/otx/engine/hpo/utils.py +++ b/src/otx/engine/hpo/utils.py @@ -5,14 +5,19 @@ from __future__ import annotations +import inspect import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable + +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from otx.utils.utils import find_file_recursively if TYPE_CHECKING: from pathlib import Path + from lightning import Callback + def find_trial_file(hpo_workdir: Path, trial_id: str) -> Path | None: """Find a trial file which store trial record. @@ -78,3 +83,37 @@ def get_hpo_weight_dir(hpo_workdir: Path, trial_id: str) -> Path: if not hpo_weight_dir.exists(): hpo_weight_dir.mkdir(parents=True) return hpo_weight_dir + + +def get_callable_args_name(module: Callable) -> list[str]: + """Get arguments name list from callable. + + Args: + module (Callable): callable to get arguments name from. + + Returns: + list[str]: arguments name list. + """ + return list(inspect.signature(module).parameters) + + +def get_metric(callbacks: list[Callback] | Callback) -> str: + """Find a metric name from ModelCheckpoint callback. + + Args: + callbacks (list[Callback] | Callback): Callback list. + + Raises: + RuntimeError: If ModelCheckpoint doesn't exist, the error is raised. + + Returns: + str: metric name. + """ + if not isinstance(callbacks, list): + callbacks = [callbacks] + + for callback in callbacks: + if isinstance(callback, ModelCheckpoint): + return callback.monitor + msg = "Failed to find a metric. There is no ModelCheckpoint in callback list." + raise RuntimeError(msg) diff --git a/src/otx/hpo/hpo_base.py b/src/otx/hpo/hpo_base.py index dfb0e412b62..1d9948a7f3c 100644 --- a/src/otx/hpo/hpo_base.py +++ b/src/otx/hpo/hpo_base.py @@ -42,10 +42,6 @@ class HpoBase(ABC): HPO use time about exepected_time_ratio * train time after HPO times. maximum_resource (int | float | None, optional): Maximum resource to use for training each trial. - subset_ratio (float | int | None, optional): ratio to how many train dataset to use for each trial. - The lower value is, the faster the speed is. - But If it's too low, HPO can be unstable. - min_subset_size (int, optional) : Minimum size of subset. Default value is 500. resume (bool, optional): resume flag decide to use previous HPO results. If HPO completed, you can just use optimized hyper parameters. If HPO stopped in middle, you can resume in middle. @@ -66,8 +62,6 @@ def __init__( full_dataset_size: int = 0, expected_time_ratio: int | float | None = None, maximum_resource: int | float | None = None, - subset_ratio: float | int | None = None, - min_subset_size: int = 500, resume: bool = False, prior_hyper_parameters: dict | list[dict] | None = None, acceptable_additional_time_ratio: float | int = 1.0, @@ -81,11 +75,6 @@ def __init__( if num_trials is not None: check_positive(num_trials, "num_trials") check_positive(num_workers, "num_workers") - if subset_ratio is not None and not 0 < subset_ratio <= 1: - error_msg = ( - f"subset_ratio should be greater than 0 and lesser than or equal to 1. Your value is {subset_ratio}" - ) - raise ValueError(error_msg) if save_path is None: save_path = tempfile.mkdtemp(prefix="OTX-hpo-") @@ -98,8 +87,6 @@ def __init__( self.full_dataset_size = full_dataset_size self.expected_time_ratio = expected_time_ratio self.maximum_resource: int | float | None = maximum_resource - self.subset_ratio = subset_ratio - self.min_subset_size = min_subset_size self.resume = resume self.hpo_status: dict = {} self.acceptable_additional_time_ratio = acceptable_additional_time_ratio diff --git a/src/otx/hpo/hpo_runner.py b/src/otx/hpo/hpo_runner.py index f43064aeeed..03a745a324f 100644 --- a/src/otx/hpo/hpo_runner.py +++ b/src/otx/hpo/hpo_runner.py @@ -48,8 +48,6 @@ class HpoLoop: It's used for CPUResourceManager. Defaults to None. num_gpu_for_single_trial (int | None, optional): How many GPUs are used for a single trial. It's used for GPUResourceManager. Defaults to None. - available_gpu (str | None, optional): How many GPUs are available. It's used for GPUResourceManager. - Defaults to None. """ def __init__( @@ -59,7 +57,6 @@ def __init__( resource_type: Literal["gpu", "cpu"] = "gpu", num_parallel_trial: int | None = None, num_gpu_for_single_trial: int | None = None, - available_gpu: str | None = None, ) -> None: self._hpo_algo = hpo_algo self._train_func = train_func @@ -71,7 +68,6 @@ def __init__( resource_type, num_parallel_trial, num_gpu_for_single_trial, - available_gpu, ) self._main_pid = os.getpid() @@ -245,7 +241,6 @@ def run_hpo_loop( resource_type: Literal["gpu", "cpu"] = "gpu", num_parallel_trial: int | None = None, num_gpu_for_single_trial: int | None = None, - available_gpu: str | None = None, ) -> None: """Run the HPO loop. @@ -258,8 +253,6 @@ def run_hpo_loop( It's used for CPUResourceManager. Defaults to None. num_gpu_for_single_trial (int | None, optional): How many GPUs are used for a single trial. It's used for GPUResourceManager. Defaults to None. - available_gpu (str | None, optional): How many GPUs are available. It's used for GPUResourceManager. - Defaults to None. """ - hpo_loop = HpoLoop(hpo_algo, train_func, resource_type, num_parallel_trial, num_gpu_for_single_trial, available_gpu) + hpo_loop = HpoLoop(hpo_algo, train_func, resource_type, num_parallel_trial, num_gpu_for_single_trial) hpo_loop.run() diff --git a/src/otx/hpo/hyperband.py b/src/otx/hpo/hyperband.py index 82eb46d0036..9282266f55e 100644 --- a/src/otx/hpo/hyperband.py +++ b/src/otx/hpo/hyperband.py @@ -502,6 +502,24 @@ class HyperBand(HpoBase): https://arxiv.org/abs/1810.05934 Args: + search_space (dict[str, dict[str, Any]]): hyper parameter search space to find. + save_path (str | None, optional): path where result of HPO is saved. + mode ("max" | "min", optional): One of {min, max}. Determines whether objective is + minimizing or maximizing the score. + num_trials (int | None, optional): How many training to conduct for HPO. + num_workers (int, optional): How many trains are executed in parallel. + num_full_iterations (int, optional): epoch for traninig after HPO. + full_dataset_size (int, optional): train dataset size + expected_time_ratio (int | float | None, optional): Time to use for HPO. + If HPO is configured automatically, + HPO use time about exepected_time_ratio * + train time after HPO times. + maximum_resource (int | float | None, optional): Maximum resource to use for training each trial. + resume (bool, optional): resume flag decide to use previous HPO results. + If HPO completed, you can just use optimized hyper parameters. + If HPO stopped in middle, you can resume in middle. + prior_hyper_parameters (dict | list[dict] | None, optional) = Hyper parameters to try first. + acceptable_additional_time_ratio (float | int, optional) = Decide how much additional time can be acceptable. minimum_resource (float | int | None, optional): Minimum resource to use for training a trial. Defaults to None. reduction_factor (int, optional): Decicdes how many trials to promote to next rung. Only top 1 / reduction_factor of rung trials can be promoted. Defaults to 3. @@ -514,13 +532,37 @@ class HyperBand(HpoBase): def __init__( self, + search_space: dict[str, dict[str, Any]], + save_path: str | None = None, + mode: Literal["max", "min"] = "max", + num_trials: int | None = None, + num_workers: int = 1, + num_full_iterations: int | float = 1, + full_dataset_size: int = 0, + expected_time_ratio: int | float | None = None, + maximum_resource: int | float | None = None, + resume: bool = False, + prior_hyper_parameters: dict | list[dict] | None = None, + acceptable_additional_time_ratio: float | int = 1.0, minimum_resource: int | float | None = None, reduction_factor: int = 3, asynchronous_sha: bool = True, asynchronous_bracket: bool = False, - **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__( + search_space, + save_path, + mode, + num_trials, + num_workers, + num_full_iterations, + full_dataset_size, + expected_time_ratio, + maximum_resource, + resume, + prior_hyper_parameters, + acceptable_additional_time_ratio, + ) if minimum_resource is not None: check_positive(minimum_resource, "minimum_resource") @@ -673,7 +715,7 @@ def _get_random_hyper_parameter(self, num_samples: int) -> list[AshaTrial]: def _make_trial(self, hyper_parameter: dict) -> AshaTrial: trial_id = self._get_new_trial_id() - trial = AshaTrial(trial_id, hyper_parameter, self._get_train_environment()) + trial = AshaTrial(trial_id, hyper_parameter) self._trials[trial_id] = trial return trial @@ -682,9 +724,6 @@ def _get_new_trial_id(self) -> str: self._next_trial_id += 1 return str(trial_id) - def _get_train_environment(self) -> dict: - return {"subset_ratio": self.subset_ratio} - def get_next_sample(self) -> AshaTrial | None: """Get next trial to train. diff --git a/src/otx/hpo/resource_manager.py b/src/otx/hpo/resource_manager.py index 653da358b97..5321714b4dd 100644 --- a/src/otx/hpo/resource_manager.py +++ b/src/otx/hpo/resource_manager.py @@ -93,26 +93,27 @@ class GPUResourceManager(BaseResourceManager): Args: num_gpu_for_single_trial (int, optional): How many GPUs is used for a single trial. Defaults to 1. - available_gpu (str | None, optional): How many GPUs are available. Defaults to None. + num_parallel_trial (int, optional): How many trials to run in parallel. Defaults to 4. """ - def __init__(self, num_gpu_for_single_trial: int = 1, available_gpu: str | None = None) -> None: + def __init__(self, num_gpu_for_single_trial: int = 1, num_parallel_trial: int | None = None) -> None: check_positive(num_gpu_for_single_trial, "num_gpu_for_single_trial") + if num_parallel_trial is not None: + check_positive(num_parallel_trial, "num_parallel_trial") self._num_gpu_for_single_trial = num_gpu_for_single_trial - self._available_gpu = self._set_available_gpu(available_gpu) + self._available_gpu = self._set_available_gpu(num_parallel_trial) self._usage_status: dict[Any, list] = {} - def _set_available_gpu(self, available_gpu: str | None = None) -> list[int]: - if available_gpu is None: - cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") - if cuda_visible_devices is not None: - available_gpu_arr = self._transform_gpu_format_from_string_to_arr(cuda_visible_devices) - else: - num_gpus = torch.cuda.device_count() - available_gpu_arr = list(range(num_gpus)) + def _set_available_gpu(self, num_parallel_trial: int | None = None) -> list[int]: + cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + if cuda_visible_devices is not None: + available_gpu_arr = self._transform_gpu_format_from_string_to_arr(cuda_visible_devices) else: - available_gpu_arr = self._transform_gpu_format_from_string_to_arr(available_gpu) + num_gpus = torch.cuda.device_count() + available_gpu_arr = list(range(num_gpus)) + if num_parallel_trial is not None: + available_gpu_arr = available_gpu_arr[:num_parallel_trial] return available_gpu_arr @@ -168,7 +169,6 @@ def get_resource_manager( resource_type: Literal["gpu", "cpu"], num_parallel_trial: int | None = None, num_gpu_for_single_trial: int | None = None, - available_gpu: str | None = None, ) -> BaseResourceManager: """Get an appropriate resource manager depending on current environment. @@ -179,8 +179,6 @@ def get_resource_manager( Defaults to None. num_gpu_for_single_trial (int | None, optional): How many GPUs is used for a single trial. It's used for GPUResourceManager. Defaults to None. - available_gpu (str | None, optional): How many GPUs are available. It's used for GPUResourceManager. - Defaults to None. Raises: ValueError: If resource_type is neither 'gpu' nor 'cpu', then raise an error. @@ -197,7 +195,7 @@ def get_resource_manager( args = _remove_none_from_dict(args) return CPUResourceManager(**args) # type: ignore[arg-type] if resource_type == "gpu": - args = {"num_gpu_for_single_trial": num_gpu_for_single_trial, "available_gpu": available_gpu} # type: ignore[dict-item] + args = {"num_gpu_for_single_trial": num_gpu_for_single_trial, "num_parallel_trial": num_parallel_trial} # type: ignore[dict-item] args = _remove_none_from_dict(args) return GPUResourceManager(**args) # type: ignore[arg-type] error_msg = f"Available resource type is cpu, gpu. Your value is {resource_type}." diff --git a/src/otx/utils/utils.py b/src/otx/utils/utils.py index 89cf03a2c79..797274aa634 100644 --- a/src/otx/utils/utils.py +++ b/src/otx/utils/utils.py @@ -8,10 +8,19 @@ from decimal import Decimal from typing import TYPE_CHECKING, Any +import torch + if TYPE_CHECKING: from pathlib import Path +XPU_AVAILABLE = None +try: + import intel_extension_for_pytorch # noqa: F401 +except ImportError: + XPU_AVAILABLE = False + + def get_using_dot_delimited_key(key: str, target: Any) -> Any: # noqa: ANN401 """Get values of attribute in target object using dot delimited key. @@ -114,3 +123,11 @@ def remove_matched_files(directory: Path, pattern: str, file_to_leave: Path | No for weight in directory.rglob(pattern): if weight != file_to_leave: weight.unlink() + + +def is_xpu_available() -> bool: + """Checks if XPU device is available.""" + global XPU_AVAILABLE # noqa: PLW0603 + if XPU_AVAILABLE is None: + XPU_AVAILABLE = hasattr(torch, "xpu") and torch.xpu.is_available() + return XPU_AVAILABLE diff --git a/tests/conftest.py b/tests/conftest.py index b54e5c6a73f..3b9717b53ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,10 +4,21 @@ import pytest import torch +from datumaro import Polygon +from mmdet.structures import DetDataSample +from mmengine.structures import InstanceData from otx.core.data.entity.base import ImageInfo +from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity, DetDataEntity +from otx.core.data.entity.instance_segmentation import ( + InstanceSegBatchDataEntity, + InstanceSegBatchPredEntity, + InstanceSegDataEntity, +) from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity, SegDataEntity from otx.core.data.mem_cache import MemCacheHandlerSingleton from otx.core.types.task import OTXTaskType +from torch import LongTensor +from torchvision import tv_tensors from torchvision.tv_tensors import Image, Mask @@ -112,6 +123,93 @@ def pytest_addoption(parser: pytest.Parser): ) +@pytest.fixture(scope="session") +def fxt_data_sample() -> list[DetDataSample]: + data_sample = DetDataSample( + metainfo={ + "img_shape": (480, 480), + "ori_shape": (480, 480), + "scale_factor": (1.0, 1.0), + "pad_shape": (480, 480), + "ignored_labels": [], + }, + gt_instances=InstanceData( + bboxes=torch.Tensor([[0.0, 0.0, 240, 240], [240, 240, 480, 480]]), + labels=torch.LongTensor([0, 1]), + ), + ) + return [data_sample] + + +@pytest.fixture(scope="session") +def fxt_det_data_entity() -> tuple[tuple, DetDataEntity, DetBatchDataEntity]: + img_size = (64, 64) + fake_image = torch.zeros(size=(3, *img_size), dtype=torch.uint8).numpy() + fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) + fake_bboxes = tv_tensors.BoundingBoxes(data=torch.Tensor([0, 0, 5, 5]), format="xyxy", canvas_size=(10, 10)) + fake_labels = LongTensor([1]) + # define data entity + single_data_entity = DetDataEntity(fake_image, fake_image_info, fake_bboxes, fake_labels) + batch_data_entity = DetBatchDataEntity( + batch_size=1, + images=[Image(data=torch.from_numpy(fake_image))], + imgs_info=[fake_image_info], + bboxes=[fake_bboxes], + labels=[fake_labels], + ) + batch_pred_data_entity = DetBatchPredEntity( + batch_size=1, + images=[Image(data=torch.from_numpy(fake_image))], + imgs_info=[fake_image_info], + bboxes=[fake_bboxes], + labels=[fake_labels], + scores=[], + ) + + return single_data_entity, batch_pred_data_entity, batch_data_entity + + +@pytest.fixture(scope="session") +def fxt_inst_seg_data_entity() -> tuple[tuple, InstanceSegDataEntity, InstanceSegBatchDataEntity]: + img_size = (64, 64) + fake_image = torch.zeros(size=(3, *img_size), dtype=torch.uint8).numpy() + fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) + fake_bboxes = tv_tensors.BoundingBoxes(data=torch.Tensor([0, 0, 5, 5]), format="xyxy", canvas_size=(10, 10)) + fake_labels = LongTensor([1]) + fake_masks = Mask(torch.randint(low=0, high=255, size=(1, *img_size), dtype=torch.uint8)) + fake_polygons = [Polygon(points=[1, 1, 2, 2, 3, 3, 4, 4])] + # define data entity + single_data_entity = InstanceSegDataEntity( + image=fake_image, + img_info=fake_image_info, + bboxes=fake_bboxes, + masks=fake_masks, + labels=fake_labels, + polygons=fake_polygons, + ) + batch_data_entity = InstanceSegBatchDataEntity( + batch_size=1, + images=[Image(data=torch.from_numpy(fake_image))], + imgs_info=[fake_image_info], + bboxes=[fake_bboxes], + labels=[fake_labels], + masks=[fake_masks], + polygons=[fake_polygons], + ) + batch_pred_data_entity = InstanceSegBatchPredEntity( + batch_size=1, + images=[Image(data=torch.from_numpy(fake_image))], + imgs_info=[fake_image_info], + bboxes=[fake_bboxes], + labels=[fake_labels], + masks=[fake_masks], + scores=[], + polygons=[fake_polygons], + ) + + return single_data_entity, batch_pred_data_entity, batch_data_entity + + @pytest.fixture(scope="session") def fxt_seg_data_entity() -> tuple[tuple, SegDataEntity, SegBatchDataEntity]: img_size = (32, 32) @@ -142,7 +240,7 @@ def fxt_seg_data_entity() -> tuple[tuple, SegDataEntity, SegBatchDataEntity]: @pytest.fixture(autouse=True) -def fxt_clean_up_mem_cache() -> None: +def fxt_clean_up_mem_cache(): """Clean up the mem-cache instance at the end of the test. It is required for everyone who tests model training pipeline. @@ -153,7 +251,7 @@ def fxt_clean_up_mem_cache() -> None: # TODO(Jaeguk): Add cpu param when OTX can run integration test parallelly for each task. -@pytest.fixture(params=[pytest.param("gpu", marks=pytest.mark.gpu)]) +@pytest.fixture(scope="module", params=[pytest.param("gpu", marks=pytest.mark.gpu)]) def fxt_accelerator(request: pytest.FixtureRequest) -> str: return request.param diff --git a/tests/e2e/cli/test_cli.py b/tests/e2e/cli/test_cli.py index 18c6bda4d28..d65908fab33 100644 --- a/tests/e2e/cli/test_cli.py +++ b/tests/e2e/cli/test_cli.py @@ -7,6 +7,7 @@ import numpy as np import pytest import yaml +from otx.core.types.task import OTXTaskType from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK from tests.e2e.cli.utils import run_main @@ -327,38 +328,38 @@ def test_otx_explain_e2e_cli( reference_sal_vals = { # Classification "multi_label_cls_efficientnet_v2_light": ( - np.array([66, 97, 84, 33, 42, 79, 0], dtype=np.uint8), - "Slide6_class_0_saliency_map.png", + np.array([201, 209, 196, 158, 157, 119, 77], dtype=np.uint8), + "American_Crow_0031_25433_class_0_saliency_map.png", ), "h_label_cls_efficientnet_v2_light": ( - np.array([152, 193, 144, 132, 149, 204, 217], dtype=np.uint8), - "092_class_5_saliency_map.png", + np.array([102, 141, 134, 79, 66, 92, 84], dtype=np.uint8), + "108_class_4_saliency_map.png", ), # Detection "detection_yolox_tiny": ( - np.array([111, 163, 141, 141, 146, 147, 158, 169, 184, 193], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", + np.array([182, 194, 187, 179, 188, 206, 215, 207, 177, 130], dtype=np.uint8), + "img_371_jpg_rf_a893e0bdc6fda0ba1b2a7f07d56cec23_class_0_saliency_map.png", ), "detection_ssd_mobilenetv2": ( - np.array([135, 80, 74, 34, 27, 32, 47, 42, 32, 34], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", + np.array([118, 188, 241, 213, 160, 120, 86, 94, 111, 138], dtype=np.uint8), + "img_371_jpg_rf_a893e0bdc6fda0ba1b2a7f07d56cec23_class_0_saliency_map.png", ), "detection_atss_mobilenetv2": ( - np.array([22, 62, 64, 0, 27, 60, 59, 53, 37, 45], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", + np.array([29, 39, 55, 69, 80, 88, 92, 86, 100, 88], dtype=np.uint8), + "img_371_jpg_rf_a893e0bdc6fda0ba1b2a7f07d56cec23_class_0_saliency_map.png", ), # Instance Segmentation "instance_segmentation_maskrcnn_efficientnetb2b": ( - np.array([54, 54, 54, 54, 0, 0, 0, 54, 0, 0], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", + np.array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=np.uint8), + "CDY_2018_class_0_saliency_map.png", ), } test_case_name = task + "_" + model_name if test_case_name in reference_sal_vals: actual_sal_vals = cv2.imread(str(latest_dir / "saliency_maps" / reference_sal_vals[test_case_name][1])) if test_case_name == "instance_segmentation_maskrcnn_efficientnetb2b": - # Take corner values due to map sparsity of InstSeg - actual_sal_vals = (actual_sal_vals[-10:, -1, -1]).astype(np.uint16) + # Take lower corner values due to map sparsity of InstSeg + actual_sal_vals = (actual_sal_vals[-10:, -1, 0]).astype(np.uint16) else: actual_sal_vals = (actual_sal_vals[:10, 0, 0]).astype(np.uint16) ref_sal_vals = reference_sal_vals[test_case_name][0] @@ -456,6 +457,17 @@ def test_otx_hpo_e2e_cli( """ if task not in DEFAULT_CONFIG_PER_TASK: pytest.skip(f"Task {task} is not supported in the auto-configuration.") + if task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + pytest.skip("ZERO_SHOT_VISUAL_PROMPTING doesn't support HPO.") + + # Need to change model to stfpm because default anomaly model is 'padim' which doesn't support HPO + model_cfg = [] + if task in { + OTXTaskType.ANOMALY_CLASSIFICATION, + OTXTaskType.ANOMALY_DETECTION, + OTXTaskType.ANOMALY_SEGMENTATION, + }: + model_cfg = ["--config", str(DEFAULT_CONFIG_PER_TASK[task].parent / "stfpm.yaml")] task = task.lower() tmp_path_hpo = tmp_path / f"otx_hpo_{task}" @@ -464,6 +476,7 @@ def test_otx_hpo_e2e_cli( command_cfg = [ "otx", "train", + *model_cfg, "--task", task.upper(), "--data_root", @@ -485,10 +498,6 @@ def test_otx_hpo_e2e_cli( run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) - # zero_shot_visual_prompting doesn't support HPO. Check just there is no error. - if task in ("zero_shot_visual_prompting"): - return - latest_dir = max( (p for p in tmp_path_hpo.iterdir() if p.is_dir() and p.name != ".latest"), key=lambda p: p.stat().st_mtime, diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index ac61a5d3c7d..dd42c4fd41a 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -16,7 +16,7 @@ @pytest.fixture(scope="session") def fxt_ci_data_root() -> Path: - data_root = Path(os.environ.get("CI_DATA_ROOT", "/home/validation/data/v2")) + data_root = Path(os.environ.get("CI_DATA_ROOT", "/home/validation/data")) if not Path.is_dir(data_root): msg = f"cannot find {data_root}" raise FileNotFoundError(msg) @@ -87,22 +87,22 @@ def pytest_configure(config): @pytest.fixture() def fxt_target_dataset_per_task(fxt_ci_data_root) -> dict: return { - "multi_class_cls": Path(fxt_ci_data_root / "multiclass_classification/multiclass_CUB_small/1"), - "multi_label_cls": Path(fxt_ci_data_root / "multilabel_classification/multilabel_CUB_small/1"), - "h_label_cls": Path(fxt_ci_data_root / "hlabel_classification/hlabel_CUB_small/1"), - "detection": Path(fxt_ci_data_root / "detection/pothole_small/1"), - "rotated_detection": Path(fxt_ci_data_root / "detection/pothole_small/1"), - "instance_segmentation": Path(fxt_ci_data_root / "instance_seg/wgisd_small/1"), - "semantic_segmentation": Path(fxt_ci_data_root / "semantic_seg/kvasir_small/1"), - "action_classification": Path(fxt_ci_data_root / "action/action_classification/ucf_kinetics_5percent_small"), - "action_detection": Path(fxt_ci_data_root / "action/action_detection/UCF101_ava_5percent"), - "visual_prompting": Path(fxt_ci_data_root / "visual_prompting/wgisd_small/1"), + "multi_class_cls": Path(fxt_ci_data_root / "v2/multiclass_classification/multiclass_CUB_small/1"), + "multi_label_cls": Path(fxt_ci_data_root / "v2/multilabel_classification/multilabel_CUB_small/1"), + "h_label_cls": Path(fxt_ci_data_root / "v2/hlabel_classification/hlabel_CUB_small/1"), + "detection": Path(fxt_ci_data_root / "v2/detection/pothole_small/1"), + "rotated_detection": Path(fxt_ci_data_root / "v2/detection/pothole_small/1"), + "instance_segmentation": Path(fxt_ci_data_root / "v2/instance_seg/wgisd_small/1"), + "semantic_segmentation": Path(fxt_ci_data_root / "v2/semantic_seg/kvasir_small/1"), + "action_classification": Path(fxt_ci_data_root / "v2/action/action_classification/ucf_kinetics_5percent_small"), + "action_detection": Path(fxt_ci_data_root / "v2/action/action_detection/UCF101_ava_5percent"), + "visual_prompting": Path(fxt_ci_data_root / "v2/visual_prompting/wgisd_small/1"), "zero_shot_visual_prompting": Path( - fxt_ci_data_root / "zero_shot_visual_prompting/coco_car_person_medium_datumaro", + fxt_ci_data_root / "v2/zero_shot_visual_prompting/coco_car_person_medium", ), - "anomaly_classification": Path(fxt_ci_data_root / "anomaly/mvtec/bottle_small/1"), - "anomaly_detection": Path(fxt_ci_data_root / "anomaly/mvtec/hazelnut_large"), - "anomaly_segmentation": Path(fxt_ci_data_root / "anomaly/mvtec/hazelnut_large"), + "anomaly_classification": Path(fxt_ci_data_root / "v2/anomaly/mvtec/bottle_small/1"), + "anomaly_detection": Path(fxt_ci_data_root / "v2/anomaly/mvtec/hazelnut_large"), + "anomaly_segmentation": Path(fxt_ci_data_root / "v2/anomaly/mvtec/hazelnut_large"), } @@ -122,10 +122,7 @@ def fxt_cli_override_command_per_task() -> dict: "3", ], "visual_prompting": [], - "zero_shot_visual_prompting": [ - "--data.config.data_format", - "datumaro", - ], + "zero_shot_visual_prompting": [], "anomaly_classification": [], "anomaly_detection": [], "anomaly_segmentation": [], diff --git a/tests/integration/api/test_engine_api.py b/tests/integration/api/test_engine_api.py index d7af44d4db6..593a11c58fd 100644 --- a/tests/integration/api/test_engine_api.py +++ b/tests/integration/api/test_engine_api.py @@ -8,6 +8,7 @@ import pytest from openvino.model_api.tilers import Tiler from otx.algo.classification.efficientnet_b0 import EfficientNetB0ForMulticlassCls +from otx.core.config.hpo import HpoConfig from otx.core.data.module import OTXDataModule from otx.core.model.base import OTXModel from otx.core.types.task import OTXTaskType @@ -156,24 +157,9 @@ def test_engine_from_tile_recipe( assert engine.datamodule.config.tile_config.overlap == ov_model.model.tiles_overlap -REASON = """ -Traceback (most recent call last): - File "/home/vinnamki/miniconda3/envs/otx-v2/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap - self.run() - File "/home/vinnamki/miniconda3/envs/otx-v2/lib/python3.11/multiprocessing/process.py", line 108, in run - self._target(*self._args, **self._kwargs) - File "/home/vinnamki/otx/training_extensions/src/otx/hpo/hpo_runner.py", line 200, in _run_train - train_func(hp_config, report_func) - File "/home/vinnamki/otx/training_extensions/src/otx/engine/hpo/hpo_trial.py", line 75, in run_hpo_trial - callbacks = _register_hpo_callback(report_func, callbacks) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/home/vinnamki/otx/training_extensions/src/otx/engine/hpo/hpo_trial.py", line 101, in _register_hpo_callback - callbacks.append(HPOCallback(report_func, _get_metric(callbacks))) - ^^^^^^^^^^^^^^^^^^^^^^ - File "/home/vinnamki/otx/training_extensions/src/otx/engine/hpo/hpo_trial.py", line 110, in _get_metric - raise RuntimeError(error_msg) -RuntimeError: Failed to find a metric. There is no ModelCheckpoint in callback list. -""" +METRIC_NAME = { + OTXTaskType.MULTI_CLASS_CLS: "val/accuracy", +} @pytest.mark.parametrize("task", pytest.TASK_LIST) @@ -182,9 +168,12 @@ def test_otx_hpo( tmp_path: Path, fxt_target_dataset_per_task: dict, ) -> None: - pytest.xfail(reason=REASON) + if task not in METRIC_NAME: + reason = f"test_otx_hpo for {task} isn't prepared yet." + pytest.xfail(reason=reason) - model = EfficientNetB0ForMulticlassCls(num_classes=3) + model = EfficientNetB0ForMulticlassCls(num_classes=2) + hpo_config = HpoConfig(metric_name=METRIC_NAME[task], expected_time_ratio=2, num_workers=1) work_dir = str(tmp_path) engine = Engine( data_root=fxt_target_dataset_per_task[task.lower()], @@ -192,4 +181,4 @@ def test_otx_hpo( work_dir=work_dir, model=model, ) - engine.train(run_hpo=True) + engine.train(max_epochs=1, run_hpo=True, hpo_config=hpo_config) diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index bfb17020f0f..07acab85b16 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -4,7 +4,7 @@ from pathlib import Path -import numpy as np +import cv2 import pytest import yaml from otx.core.types.task import OTXTaskType @@ -13,34 +13,19 @@ from tests.utils import run_main -@pytest.mark.parametrize( - "recipe", - pytest.RECIPE_LIST, +@pytest.fixture( + params=pytest.RECIPE_LIST, ids=lambda x: "/".join(Path(x).parts[-2:]), ) -def test_otx_e2e( - recipe: str, - tmp_path: Path, +def fxt_trained_model( fxt_accelerator: str, fxt_target_dataset_per_task: dict, fxt_cli_override_command_per_task: dict, fxt_open_subprocess: bool, -) -> None: - """ - Test OTX CLI e2e commands. - - - 'otx train' with 2 epochs training - - 'otx test' with output checkpoint from 'otx train' - - 'otx export' with output checkpoint from 'otx train' - - 'otx test' with the exported to ONNX/IR model - - Args: - recipe (str): The recipe to use for training. (eg. 'classification/otx_mobilenet_v3_large.yaml') - tmp_path (Path): The temporary path for storing the training outputs. - - Returns: - None - """ + request: pytest.FixtureRequest, + tmp_path, +): + recipe = request.param task = recipe.split("/")[-2] model_name = recipe.split("/")[-1].split(".")[0] @@ -64,6 +49,34 @@ def test_otx_e2e( run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) + return recipe, task, model_name, tmp_path_train + + +def test_otx_e2e( + fxt_trained_model, + fxt_accelerator: str, + fxt_target_dataset_per_task: dict, + fxt_cli_override_command_per_task: dict, + fxt_open_subprocess: bool, + tmp_path: Path, +) -> None: + """ + Test OTX CLI e2e commands. + + - 'otx train' with 2 epochs training + - 'otx test' with output checkpoint from 'otx train' + - 'otx export' with output checkpoint from 'otx train' + - 'otx test' with the exported to ONNX/IR model + + Args: + recipe (str): The recipe to use for training. (eg. 'classification/otx_mobilenet_v3_large.yaml') + tmp_path (Path): The temporary path for storing the training outputs. + + Returns: + None + """ + recipe, task, model_name, tmp_path_train = fxt_trained_model + outputs_dir = tmp_path_train / "outputs" latest_dir = max( (p for p in outputs_dir.iterdir() if p.is_dir() and p.name != ".latest"), @@ -79,9 +92,8 @@ def test_otx_e2e( assert "data" in train_output_config assert "engine" in train_output_config assert (latest_dir / "csv").exists() - assert (latest_dir / "checkpoints").exists() - ckpt_files = list((latest_dir / "checkpoints").glob(pattern="epoch_*.ckpt")) - assert len(ckpt_files) > 0 + ckpt_file = latest_dir / "best_checkpoint.ckpt" + assert ckpt_file.exists() # 2) otx test tmp_path_test = tmp_path / f"otx_test_{model_name}" @@ -98,7 +110,7 @@ def test_otx_e2e( fxt_accelerator, *fxt_cli_override_command_per_task[task], "--checkpoint", - str(ckpt_files[-1]), + str(ckpt_file), ] run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) @@ -154,7 +166,7 @@ def test_otx_e2e( str(tmp_path_test / "outputs" / fmt), *overrides, "--checkpoint", - str(ckpt_files[-1]), + str(ckpt_file), "--export_format", f"{fmt}", ] @@ -232,7 +244,7 @@ def test_otx_e2e( str(tmp_path_test / "outputs" / fmt), *fxt_cli_override_command_per_task[task], "--checkpoint", - str(ckpt_files[-1]), + str(ckpt_file), "--export_format", f"{fmt}", "--explain", @@ -250,18 +262,13 @@ def test_otx_e2e( assert (fmt_latest_dir / f"{format_to_file[fmt]}").exists() -@pytest.mark.parametrize( - "recipe", - pytest.RECIPE_LIST, - ids=lambda x: "/".join(Path(x).parts[-2:]), -) def test_otx_explain_e2e( - recipe: str, - tmp_path: Path, + fxt_trained_model, fxt_accelerator: str, fxt_target_dataset_per_task: dict, fxt_cli_override_command_per_task: dict, fxt_open_subprocess: bool, + tmp_path: Path, ) -> None: """ Test OTX CLI explain e2e command. @@ -273,13 +280,16 @@ def test_otx_explain_e2e( Returns: None """ - if "tile" in recipe: - pytest.skip("Explain is not supported for tiling yet.") - import cv2 + recipe, task, model_name, tmp_path_train = fxt_trained_model - task = recipe.split("/")[-2] - model_name = recipe.split("/")[-1].split(".")[0] + outputs_dir = tmp_path_train / "outputs" + latest_dir = outputs_dir / ".latest" + ckpt_file = latest_dir / "train" / "best_checkpoint.ckpt" + assert ckpt_file.exists() + + if "tile" in recipe: + pytest.skip("Explain is not supported for tiling yet.") if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]): pytest.skip("Supported only for classification, detection and instance segmentation task.") @@ -302,10 +312,10 @@ def test_otx_explain_e2e( fxt_accelerator, "--seed", "0", - "--deterministic", - "True", "--dump", "True", + "--checkpoint", + str(ckpt_file), *fxt_cli_override_command_per_task[task], ] @@ -322,47 +332,6 @@ def test_otx_explain_e2e( assert sal_map.shape[0] > 0 assert sal_map.shape[1] > 0 - sal_diff_thresh = 3 - reference_sal_vals = { - # Classification - "multi_label_cls_efficientnet_v2_light": ( - np.array([66, 97, 84, 33, 42, 79, 0], dtype=np.uint8), - "Slide6_class_0_saliency_map.png", - ), - "h_label_cls_efficientnet_v2_light": ( - np.array([152, 193, 144, 132, 149, 204, 217], dtype=np.uint8), - "092_class_5_saliency_map.png", - ), - # Detection - "detection_yolox_tiny": ( - np.array([111, 163, 141, 141, 146, 147, 158, 169, 184, 193], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", - ), - "detection_ssd_mobilenetv2": ( - np.array([135, 80, 74, 34, 27, 32, 47, 42, 32, 34], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", - ), - "detection_atss_mobilenetv2": ( - np.array([22, 62, 64, 0, 27, 60, 59, 53, 37, 45], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", - ), - # Instance Segmentation - "instance_segmentation_maskrcnn_efficientnetb2b": ( - np.array([54, 54, 54, 54, 0, 0, 0, 54, 0, 0], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", - ), - } - test_case_name = task + "_" + model_name - if test_case_name in reference_sal_vals: - actual_sal_vals = cv2.imread(str(latest_dir / "saliency_map" / reference_sal_vals[test_case_name][1])) - if test_case_name == "instance_segmentation_maskrcnn_efficientnetb2b": - # Take corner values due to map sparsity of InstSeg - actual_sal_vals = (actual_sal_vals[-10:, -1, -1]).astype(np.uint16) - else: - actual_sal_vals = (actual_sal_vals[:10, 0, 0]).astype(np.uint16) - ref_sal_vals = reference_sal_vals[test_case_name][0] - assert np.max(np.abs(actual_sal_vals - ref_sal_vals) <= sal_diff_thresh) - # @pytest.mark.skipif(len(pytest.RECIPE_OV_LIST) < 1, reason="No OV recipe found.") @pytest.mark.parametrize( @@ -436,20 +405,6 @@ def test_otx_ov_test( assert len(metric_result) > 0 -REASON = ''' -self = - - def finalize(self) -> None: - """Set done as True.""" - if not self.score: - error_msg = f"Trial{self.id} didn't report any score but tries to be done." -> raise RuntimeError(error_msg) -E RuntimeError: Trial0 didn't report any score but tries to be done. - -src/otx/hpo/hpo_base.py:274: RuntimeError -''' - - @pytest.mark.parametrize("task", pytest.TASK_LIST) def test_otx_hpo_e2e( task: OTXTaskType, @@ -471,12 +426,17 @@ def test_otx_hpo_e2e( """ if task not in DEFAULT_CONFIG_PER_TASK: pytest.skip(f"Task {task} is not supported in the auto-configuration.") + if task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + pytest.skip("ZERO_SHOT_VISUAL_PROMPTING doesn't support HPO.") + + # Need to change model to stfpm because default anomaly model is 'padim' which doesn't support HPO + model_cfg = [] if task in { OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION, }: - pytest.xfail(reason=REASON) + model_cfg = ["--config", str(DEFAULT_CONFIG_PER_TASK[task].parent / "stfpm.yaml")] task = task.lower() tmp_path_hpo = tmp_path / f"otx_hpo_{task}" @@ -485,6 +445,7 @@ def test_otx_hpo_e2e( command_cfg = [ "otx", "train", + *model_cfg, "--task", task.upper(), "--data_root", @@ -494,20 +455,18 @@ def test_otx_hpo_e2e( "--engine.device", fxt_accelerator, "--max_epochs", - "1" if task in ("zero_shot_visual_prompting") else "2", + "1", "--run_hpo", "true", "--hpo_config.expected_time_ratio", "2", + "--hpo_config.num_workers", + "1", *fxt_cli_override_command_per_task[task], ] run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) - # zero_shot_visual_prompting doesn't support HPO. Check just there is no error. - if task in ("zero_shot_visual_prompting"): - return - latest_dir = max( (p for p in tmp_path_hpo.iterdir() if p.is_dir() and p.name != ".latest"), key=lambda p: p.stat().st_mtime, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a6223ab4661..0c7d1ed56be 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -103,7 +103,7 @@ def fxt_rtmdet_tiny_config(fxt_asset_dir: Path) -> MMConfig: # [TODO]: This is a temporary approach. -@pytest.fixture() +@pytest.fixture(scope="module") def fxt_target_dataset_per_task() -> dict: return { "multi_class_cls": "tests/assets/classification_dataset", @@ -123,7 +123,7 @@ def fxt_target_dataset_per_task() -> dict: } -@pytest.fixture() +@pytest.fixture(scope="module") def fxt_cli_override_command_per_task() -> dict: return { "multi_class_cls": [], diff --git a/tests/unit/algo/accelerators/__init__.py b/tests/unit/algo/accelerators/__init__.py new file mode 100644 index 00000000000..9996ffc6523 --- /dev/null +++ b/tests/unit/algo/accelerators/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Unit tests of accelerators of OTX algo.""" diff --git a/tests/unit/algo/accelerators/test_xpu.py b/tests/unit/algo/accelerators/test_xpu.py new file mode 100644 index 00000000000..793bbe18331 --- /dev/null +++ b/tests/unit/algo/accelerators/test_xpu.py @@ -0,0 +1,62 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Test for otx.algo.accelerators.xpu""" + + +import pytest +import torch +from otx.algo.accelerators import XPUAccelerator +from otx.utils.utils import is_xpu_available + + +class TestXPUAccelerator: + @pytest.fixture() + def accelerator(self, mocker): + mock_torch = mocker.patch("otx.algo.accelerators.xpu.torch") + mocker.patch.object(XPUAccelerator, "patch_packages_xpu") + mocker.patch.object(XPUAccelerator, "teardown") + return XPUAccelerator(), mock_torch + + def test_setup_device(self, accelerator): + accelerator, mock_torch = accelerator + device = torch.device("xpu") + accelerator.setup_device(device) + assert mock_torch.xpu.set_device.called + + def test_parse_devices(self, accelerator): + accelerator, _ = accelerator + devices = [1, 2, 3] + parsed_devices = accelerator.parse_devices(devices) + assert isinstance(parsed_devices, list) + assert parsed_devices == devices + + def test_get_parallel_devices(self, accelerator, mocker): + accelerator, _ = accelerator + devices = [1, 2, 3] + parallel_devices = accelerator.get_parallel_devices(devices) + assert isinstance(parallel_devices, list) + for device in parallel_devices: + assert isinstance(device, mocker.MagicMock) + + def test_auto_device_count(self, accelerator, mocker): + accelerator, mock_torch = accelerator + count = accelerator.auto_device_count() + assert isinstance(count, mocker.MagicMock) + assert mock_torch.xpu.device_count.called + + def test_is_available(self, accelerator): + accelerator, _ = accelerator + available = accelerator.is_available() + assert isinstance(available, bool) + assert available == is_xpu_available() + + def test_get_device_stats(self, accelerator): + accelerator, _ = accelerator + device = torch.device("xpu") + stats = accelerator.get_device_stats(device) + assert isinstance(stats, dict) + + def test_teardown(self, accelerator): + accelerator, _ = accelerator + accelerator.teardown() diff --git a/tests/unit/algo/classification/test_torchvision_model.py b/tests/unit/algo/classification/test_torchvision_model.py index 295080fd40c..46da5d0c265 100644 --- a/tests/unit/algo/classification/test_torchvision_model.py +++ b/tests/unit/algo/classification/test_torchvision_model.py @@ -3,6 +3,7 @@ from otx.algo.classification.torchvision_model import OTXTVModel, TVModelWithLossComputation from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.classification import MulticlassClsBatchPredEntity +from otx.core.types.export import TaskLevelExportParameters @pytest.fixture() @@ -31,16 +32,10 @@ def test_customize_outputs(self, fxt_tv_model, fxt_multiclass_cls_batch_data_ent assert isinstance(preds, MulticlassClsBatchPredEntity) def test_export_parameters(self, fxt_tv_model): - params = fxt_tv_model._export_parameters - assert isinstance(params, dict) - assert "input_size" in params - assert "resize_mode" in params - assert "pad_value" in params - assert "swap_rgb" in params - assert "via_onnx" in params - assert "onnx_export_configuration" in params - assert "mean" in params - assert "std" in params + export_parameters = fxt_tv_model._export_parameters + assert isinstance(export_parameters, TaskLevelExportParameters) + assert export_parameters.model_type == "Classification" + assert export_parameters.task_type == "classification" @pytest.mark.parametrize("explain_mode", [True, False]) def test_predict_step(self, fxt_tv_model: OTXTVModel, fxt_multiclass_cls_batch_data_entity, explain_mode): diff --git a/tests/unit/algo/detection/conftest.py b/tests/unit/algo/detection/conftest.py new file mode 100644 index 00000000000..3d5cd06fbf1 --- /dev/null +++ b/tests/unit/algo/detection/conftest.py @@ -0,0 +1,34 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Test of custom algo modules of OTX Detection task.""" +import pytest +from otx.core.config.data import DataModuleConfig, SubsetConfig +from otx.core.data.module import OTXDataModule +from otx.core.types.task import OTXTaskType +from torchvision.transforms.v2 import Resize + + +@pytest.fixture() +def fxt_data_module(): + return OTXDataModule( + task=OTXTaskType.DETECTION, + config=DataModuleConfig( + data_format="coco_instances", + data_root="tests/assets/car_tree_bug", + train_subset=SubsetConfig( + batch_size=2, + subset_name="train", + transforms=[Resize(320)], + ), + val_subset=SubsetConfig( + batch_size=2, + subset_name="val", + transforms=[Resize(320)], + ), + test_subset=SubsetConfig( + batch_size=2, + subset_name="test", + transforms=[Resize(320)], + ), + ), + ) diff --git a/tests/unit/algo/detection/heads/test_custom_ssd_head.py b/tests/unit/algo/detection/heads/test_custom_ssd_head.py index 7d95e3f4591..e2f23e6a019 100644 --- a/tests/unit/algo/detection/heads/test_custom_ssd_head.py +++ b/tests/unit/algo/detection/heads/test_custom_ssd_head.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 """Test of CustomSSDHead.""" -from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss from otx.algo.detection.heads.custom_ssd_head import SSDHead +from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss class TestSSDHead: @@ -25,6 +25,20 @@ def test_init(self, mocker) -> None: "target_means": [0.0, 0.0, 0.0, 0.0], "target_stds": [0.1, 0.1, 0.1, 0.1], }, + train_cfg={ + "assigner": { + "type": "MaxIoUAssigner", + "pos_iou_thr": 0.4, + "neg_iou_thr": 0.4, + }, + "smoothl1_beta": 1.0, + "allowed_border": -1, + "pos_weight": -1, + "neg_pos_ratio": 3, + "debug": False, + "use_giou": False, + "use_focal": False, + }, ) assert isinstance(self.head.loss_cls, CrossEntropyLoss) diff --git a/tests/unit/algo/detection/test_atss.py b/tests/unit/algo/detection/test_atss.py index 1564a7bee15..aa1370a84e4 100644 --- a/tests/unit/algo/detection/test_atss.py +++ b/tests/unit/algo/detection/test_atss.py @@ -5,6 +5,8 @@ import pytest from otx.algo.detection.atss import ATSS from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.exporter.mmdeploy import MMdeployExporter +from otx.core.types.export import TaskLevelExportParameters class TestATSS: @@ -20,4 +22,5 @@ def test(self, model, mocker) -> None: model.load_from_otx_v1_ckpt({}) mock_load_ckpt.assert_called_once_with({}, "model.model.") - assert isinstance(model._export_parameters, dict) + assert isinstance(model._export_parameters, TaskLevelExportParameters) + assert isinstance(model._exporter, MMdeployExporter) diff --git a/tests/unit/algo/detection/test_ssd.py b/tests/unit/algo/detection/test_ssd.py index 9a21a1a570d..53466f1806e 100644 --- a/tests/unit/algo/detection/test_ssd.py +++ b/tests/unit/algo/detection/test_ssd.py @@ -2,7 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 """Test of OTX SSD architecture.""" +from pathlib import Path + import pytest +from lightning import Trainer from otx.algo.detection.ssd import SSD @@ -11,16 +14,25 @@ class TestSSD: def fxt_model(self) -> SSD: return SSD(num_classes=3, variant="mobilenetv2") - def test_save_and_load_anchors(self, fxt_model) -> None: - anchor_widths = fxt_model.model.bbox_head.anchor_generator.widths - anchor_heights = fxt_model.model.bbox_head.anchor_generator.heights - state_dict = fxt_model.state_dict() - assert anchor_widths == state_dict["model.model.anchors"]["widths"] - assert anchor_heights == state_dict["model.model.anchors"]["heights"] + @pytest.fixture() + def fxt_checkpoint(self, fxt_model, fxt_data_module, tmpdir, monkeypatch: pytest.MonkeyPatch): + trainer = Trainer(max_steps=0) + + monkeypatch.setattr(trainer.strategy, "_lightning_module", fxt_model) + monkeypatch.setattr(trainer, "datamodule", fxt_data_module) + monkeypatch.setattr(fxt_model, "_trainer", trainer) + fxt_model.setup("fit") + + fxt_model.hparams["ssd_anchors"]["widths"][0][0] = 40 + fxt_model.hparams["ssd_anchors"]["heights"][0][0] = 50 + + checkpoint_path = Path(tmpdir) / "checkpoint.ckpt" + trainer.save_checkpoint(checkpoint_path) + + return checkpoint_path - state_dict["model.model.anchors"]["widths"][0][0] = 40 - state_dict["model.model.anchors"]["heights"][0][0] = 50 + def test_save_and_load_anchors(self, fxt_checkpoint) -> None: + loaded_model = SSD.load_from_checkpoint(checkpoint_path=fxt_checkpoint) - fxt_model.load_state_dict(state_dict) - assert fxt_model.model.bbox_head.anchor_generator.widths[0][0] == 40 - assert fxt_model.model.bbox_head.anchor_generator.heights[0][0] == 50 + assert loaded_model.model.bbox_head.anchor_generator.widths[0][0] == 40 + assert loaded_model.model.bbox_head.anchor_generator.heights[0][0] == 50 diff --git a/tests/unit/algo/detection/utils/__init__.py b/tests/unit/algo/detection/utils/__init__.py new file mode 100644 index 00000000000..a3c91b9065c --- /dev/null +++ b/tests/unit/algo/detection/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Test of utils for OTX Detection task.""" diff --git a/tests/unit/algo/detection/utils/test_mmcv_patched_ops.py b/tests/unit/algo/detection/utils/test_mmcv_patched_ops.py new file mode 100644 index 00000000000..09daa1b2cab --- /dev/null +++ b/tests/unit/algo/detection/utils/test_mmcv_patched_ops.py @@ -0,0 +1,139 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Test of mmcv_patched_ops.""" + +import pytest +import torch +from mmcv.ops import nms +from otx.algo.detection.utils.mmcv_patched_ops import monkey_patched_nms + + +class TestMonkeyPatchedNMS: + @pytest.fixture() + def setup(self): + self.ctx = None + self.bboxes = torch.tensor( + [[0.324, 0.422, 0.469, 0.123], [0.324, 0.422, 0.469, 0.123], [0.314, 0.423, 0.469, 0.123]], + ) + self.scores = torch.tensor([0.9, 0.2, 0.3]) + self.iou_threshold = 0.5 + self.offset = 0 + self.score_threshold = 0 + self.max_num = 0 + + def test_case1(self, setup): + # Testing when is_filtering_by_score is False + result = monkey_patched_nms( + self.ctx, + self.bboxes, + self.scores, + self.iou_threshold, + self.offset, + self.score_threshold, + self.max_num, + ) + assert torch.equal(result, torch.tensor([0, 2, 1])) + + def test_case2(self, setup): + # Testing when is_filtering_by_score is True + self.score_threshold = 0.8 + result = monkey_patched_nms( + self.ctx, + self.bboxes, + self.scores, + self.iou_threshold, + self.offset, + self.score_threshold, + self.max_num, + ) + assert torch.equal(result, torch.tensor([0])) + + def test_case3(self, setup): + # Testing when bboxes and scores have torch.bfloat16 dtype + self.bboxes = torch.tensor( + [[0.324, 0.422, 0.469, 0.123], [0.324, 0.422, 0.469, 0.123], [0.314, 0.423, 0.469, 0.123]], + dtype=torch.bfloat16, + ) + self.scores = torch.tensor([0.9, 0.2, 0.3], dtype=torch.bfloat16) + result1 = monkey_patched_nms( + self.ctx, + self.bboxes, + self.scores, + self.iou_threshold, + self.offset, + self.score_threshold, + self.max_num, + ) + assert torch.equal(result1, torch.tensor([0, 2, 1])) + + def test_case4(self, setup): + # Testing when offset is not 0 + self.offset = 1 + result = monkey_patched_nms( + self.ctx, + self.bboxes, + self.scores, + self.iou_threshold, + self.offset, + self.score_threshold, + self.max_num, + ) + assert torch.equal(result, torch.tensor([0])) + + def test_case5(self, setup): + # Testing when max_num is greater than 0 + self.max_num = 1 + result = monkey_patched_nms( + self.ctx, + self.bboxes, + self.scores, + self.iou_threshold, + self.offset, + self.score_threshold, + self.max_num, + ) + assert torch.equal(result, torch.tensor([0])) + + def test_case6(self, setup): + # Testing that monkey_patched_nms equals mmcv nms + self.score_threshold = 0.7 + result1 = monkey_patched_nms( + self.ctx, + self.bboxes, + self.scores, + self.iou_threshold, + self.offset, + self.score_threshold, + self.max_num, + ) + result2 = nms(self.bboxes, self.scores, self.iou_threshold, score_threshold=self.score_threshold) + assert torch.equal(result1, result2[1]) + # test random bboxes and scores + bboxes = torch.rand((100, 4)) + scores = torch.rand(100) + result1 = monkey_patched_nms( + self.ctx, + bboxes, + scores, + self.iou_threshold, + self.offset, + self.score_threshold, + self.max_num, + ) + result2 = nms(bboxes, scores, self.iou_threshold, score_threshold=self.score_threshold) + assert torch.equal(result1, result2[1]) + # no score threshold + self.iou_threshold = 0.7 + self.score_threshold = 0.0 + result1 = monkey_patched_nms( + self.ctx, + bboxes, + scores, + self.iou_threshold, + self.offset, + self.score_threshold, + self.max_num, + ) + result2 = nms(bboxes, scores, self.iou_threshold) + assert torch.equal(result1, result2[1]) diff --git a/tests/unit/algo/hooks/__init__.py b/tests/unit/algo/explain/__init__.py similarity index 100% rename from tests/unit/algo/hooks/__init__.py rename to tests/unit/algo/explain/__init__.py diff --git a/tests/unit/algo/hooks/test_saliency_map_dumping.py b/tests/unit/algo/explain/test_saliency_map_dumping.py similarity index 100% rename from tests/unit/algo/hooks/test_saliency_map_dumping.py rename to tests/unit/algo/explain/test_saliency_map_dumping.py diff --git a/tests/unit/algo/hooks/test_saliency_map_processing.py b/tests/unit/algo/explain/test_saliency_map_processing.py similarity index 100% rename from tests/unit/algo/hooks/test_saliency_map_processing.py rename to tests/unit/algo/explain/test_saliency_map_processing.py diff --git a/tests/unit/algo/hooks/test_xai_hooks.py b/tests/unit/algo/explain/test_xai_algorithms.py similarity index 52% rename from tests/unit/algo/hooks/test_xai_hooks.py rename to tests/unit/algo/explain/test_xai_algorithms.py index 55d9c63f829..141b52a00de 100644 --- a/tests/unit/algo/hooks/test_xai_hooks.py +++ b/tests/unit/algo/explain/test_xai_algorithms.py @@ -2,12 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import torch from datumaro import Polygon -from otx.algo.hooks.recording_forward_hook import ( - ActivationMapHook, - DetClassProbabilityMapHook, - MaskRCNNRecordingForwardHook, - ReciproCAMHook, - ViTReciproCAMHook, +from otx.algo.explain.explain_algo import ( + ActivationMap, + DetClassProbabilityMap, + MaskRCNNExplainAlgo, + ReciproCAM, + ViTReciproCAM, ) from otx.core.data.entity.base import ImageInfo from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity @@ -16,22 +16,14 @@ def test_activationmap() -> None: - hook = ActivationMapHook() + explain_algo = ActivationMap() - assert hook.handle is None - assert hook.records == [] - assert hook._norm_saliency_maps + assert explain_algo._norm_saliency_maps feature_map = torch.zeros((1, 10, 5, 5)) - saliency_map = hook.func(feature_map) - assert saliency_map.size() == torch.Size([1, 5, 5]) - - hook.recording_forward(None, None, feature_map) - assert len(hook.records) == 1 - - hook.reset() - assert hook.records == [] + saliency_maps = explain_algo.func(feature_map) + assert saliency_maps.size() == torch.Size([1, 5, 5]) def test_reciprocam() -> None: @@ -40,26 +32,18 @@ def cls_head_forward_fn(_) -> None: num_classes = 2 optimize_gap = False - hook = ReciproCAMHook( + explain_algo = ReciproCAM( cls_head_forward_fn, num_classes=num_classes, optimize_gap=optimize_gap, ) - assert hook.handle is None - assert hook.records == [] - assert hook._norm_saliency_maps + assert explain_algo._norm_saliency_maps feature_map = torch.zeros((1, 10, 5, 5)) - saliency_map = hook.func(feature_map) - assert saliency_map.size() == torch.Size([1, 2, 5, 5]) - - hook.recording_forward(None, None, feature_map) - assert len(hook.records) == 1 - - hook.reset() - assert hook.records == [] + saliency_maps = explain_algo.func(feature_map) + assert saliency_maps.size() == torch.Size([1, 2, 5, 5]) def test_vitreciprocam() -> None: @@ -67,54 +51,42 @@ def cls_head_forward_fn(_) -> None: return torch.zeros((196, 2)) num_classes = 2 - hook = ViTReciproCAMHook( + explain_algo = ViTReciproCAM( cls_head_forward_fn, num_classes=num_classes, ) - assert hook.handle is None - assert hook.records == [] - assert hook._norm_saliency_maps + assert explain_algo._norm_saliency_maps feature_map = torch.zeros((1, 197, 192)) - saliency_map = hook.func(feature_map) - assert saliency_map.size() == torch.Size([1, 2, 14, 14]) - - hook.recording_forward(None, None, feature_map) - assert len(hook.records) == 1 - - hook.reset() - assert hook.records == [] + saliency_maps = explain_algo.func(feature_map) + assert saliency_maps.size() == torch.Size([1, 2, 14, 14]) def test_detclassprob() -> None: num_classes = 2 num_anchors = [1] * 10 - hook = DetClassProbabilityMapHook( + explain_algo = DetClassProbabilityMap( num_classes=num_classes, num_anchors=num_anchors, ) - assert hook.handle is None - assert hook.records == [] - assert hook._norm_saliency_maps + assert explain_algo._norm_saliency_maps backbone_out = torch.zeros((1, 5, 2, 2, 2)) - saliency_map = hook.func(backbone_out) - assert saliency_map.size() == torch.Size([5, 2, 2, 2]) + saliency_maps = explain_algo.func(backbone_out) + assert saliency_maps.size() == torch.Size([5, 2, 2, 2]) def test_maskrcnn() -> None: num_classes = 2 - hook = MaskRCNNRecordingForwardHook( + explain_algo = MaskRCNNExplainAlgo( num_classes=num_classes, ) - assert hook.handle is None - assert hook.records == [] - assert hook._norm_saliency_maps + assert explain_algo._norm_saliency_maps # One image, 3 masks to aggregate pred = InstanceSegBatchPredEntity( @@ -137,6 +109,6 @@ def test_maskrcnn() -> None: ) # 2 images - saliency_map = hook.func([pred, pred]) - assert len(saliency_map) == 2 - assert saliency_map[0].shape == (2, 10, 10) + saliency_maps = explain_algo.func([pred, pred]) + assert len(saliency_maps) == 2 + assert saliency_maps[0].shape == (2, 10, 10) diff --git a/tests/unit/algo/instance_segmentation/heads/test_custom_rtmdet_ins_head.py b/tests/unit/algo/instance_segmentation/heads/test_custom_rtmdet_ins_head.py index 12f4926dc91..ebee4e4cc4d 100644 --- a/tests/unit/algo/instance_segmentation/heads/test_custom_rtmdet_ins_head.py +++ b/tests/unit/algo/instance_segmentation/heads/test_custom_rtmdet_ins_head.py @@ -1,7 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import tempfile from pathlib import Path import torch @@ -68,12 +67,11 @@ def test_mask_pred(self, mocker) -> None: cfg=test_cfg, ) - def test_predict_by_feat_ov(self) -> None: - with tempfile.TemporaryDirectory() as tmpdirname: - lit_module = RTMDetInst(num_classes=1, variant="tiny") - exported_model_path = lit_module.export( - output_dir=Path(tmpdirname), - base_name="exported_model", - export_format=OTXExportFormatType.OPENVINO, - ) - Path.exists(exported_model_path) + def test_predict_by_feat_ov(self, tmpdir) -> None: + lit_module = RTMDetInst(num_classes=1, variant="tiny") + exported_model_path = lit_module.export( + output_dir=Path(tmpdir), + base_name="exported_model", + export_format=OTXExportFormatType.OPENVINO, + ) + Path.exists(exported_model_path) diff --git a/tests/unit/algo/plugins/__init__.py b/tests/unit/algo/plugins/__init__.py new file mode 100644 index 00000000000..be7fe475146 --- /dev/null +++ b/tests/unit/algo/plugins/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Unit tests of plugins of OTX algo.""" diff --git a/tests/unit/algo/plugins/test_plugins.py b/tests/unit/algo/plugins/test_plugins.py new file mode 100644 index 00000000000..a84f4ec18d6 --- /dev/null +++ b/tests/unit/algo/plugins/test_plugins.py @@ -0,0 +1,56 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Test for otx.algo.plugins.xpu_precision""" + + +import pytest +import torch +from otx.algo.plugins.xpu_precision import MixedPrecisionXPUPlugin +from torch.optim import Optimizer + + +class TestMixedPrecisionXPUPlugin: + @pytest.fixture() + def plugin(self): + return MixedPrecisionXPUPlugin() + + def test_init(self, plugin): + assert plugin.scaler is None + + def test_pre_backward(self, plugin, mocker): + tensor = torch.zeros(1) + module = mocker.MagicMock() + output = plugin.pre_backward(tensor, module) + assert output == tensor + + def test_optimizer_step_no_scaler(self, plugin, mocker): + optimizer = mocker.MagicMock(Optimizer) + model = mocker.MagicMock() + closure = mocker.MagicMock() + kwargs = {} + mock_optimizer_step = mocker.patch( + "otx.algo.plugins.xpu_precision.Precision.optimizer_step", + ) + out = plugin.optimizer_step(optimizer, model, closure, **kwargs) + assert isinstance(out, mocker.MagicMock) + mock_optimizer_step.assert_called_once() + + def test_optimizer_step_with_scaler(self, plugin, mocker): + optimizer = mocker.MagicMock(Optimizer) + model = mocker.MagicMock() + closure = mocker.MagicMock() + plugin.scaler = mocker.MagicMock() + kwargs = {} + out = plugin.optimizer_step(optimizer, model, closure, **kwargs) + assert isinstance(out, mocker.MagicMock) + + def test_clip_gradients(self, plugin, mocker): + optimizer = mocker.MagicMock(Optimizer) + clip_val = 0.1 + gradient_clip_algorithm = "norm" + mock_clip_gradients = mocker.patch( + "otx.algo.plugins.xpu_precision.Precision.clip_gradients", + ) + plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm) + mock_clip_gradients.assert_called_once() diff --git a/tests/unit/algo/segmentation/test_dino_v2_seg.py b/tests/unit/algo/segmentation/test_dino_v2_seg.py index e3f7430c126..f3fe10a55af 100644 --- a/tests/unit/algo/segmentation/test_dino_v2_seg.py +++ b/tests/unit/algo/segmentation/test_dino_v2_seg.py @@ -4,10 +4,11 @@ import pytest from otx.algo.segmentation.dino_v2_seg import DinoV2Seg +from otx.core.exporter.base import OTXModelExporter class TestDinoV2Seg: - @pytest.fixture() + @pytest.fixture(scope="class") def fxt_dino_v2_seg(self) -> DinoV2Seg: return DinoV2Seg(num_classes=10) @@ -15,11 +16,10 @@ def test_dino_v2_seg_init(self, fxt_dino_v2_seg): assert isinstance(fxt_dino_v2_seg, DinoV2Seg) assert fxt_dino_v2_seg.num_classes == 10 - def test_export_parameters(self, fxt_dino_v2_seg): - parameters = fxt_dino_v2_seg._export_parameters - assert isinstance(parameters, dict) - assert "input_size" in parameters - assert parameters["input_size"] == (1, 3, 560, 560) + def test_exporter(self, fxt_dino_v2_seg): + exporter = fxt_dino_v2_seg._exporter + assert isinstance(exporter, OTXModelExporter) + assert exporter.input_size == (1, 3, 560, 560) def test_optimization_config(self, fxt_dino_v2_seg): config = fxt_dino_v2_seg._optimization_config diff --git a/tests/unit/algo/strategies/__init__.py b/tests/unit/algo/strategies/__init__.py new file mode 100644 index 00000000000..8830174eb83 --- /dev/null +++ b/tests/unit/algo/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Unit tests of strategies of OTX algo.""" diff --git a/tests/unit/algo/strategies/test_strategies.py b/tests/unit/algo/strategies/test_strategies.py new file mode 100644 index 00000000000..0ef457351ff --- /dev/null +++ b/tests/unit/algo/strategies/test_strategies.py @@ -0,0 +1,54 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests the XPU strategy.""" + + +import pytest +import pytorch_lightning as pl +import torch +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from otx.algo.strategies.xpu_single import SingleXPUStrategy + + +class TestSingleXPUStrategy: + def test_init(self, mocker): + with pytest.raises(MisconfigurationException): + strategy = SingleXPUStrategy(device="xpu:0") + mocked_is_xpu_available = mocker.patch( + "otx.algo.strategies.xpu_single.is_xpu_available", + return_value=True, + ) + strategy = SingleXPUStrategy(device="xpu:0") + assert mocked_is_xpu_available.call_count == 1 + assert strategy._root_device.type == "xpu" + assert strategy.accelerator is None + + @pytest.fixture() + def strategy(self, mocker): + mocker.patch( + "otx.algo.strategies.xpu_single.is_xpu_available", + return_value=True, + ) + return SingleXPUStrategy(device="xpu:0", accelerator="xpu") + + def test_is_distributed(self, strategy): + assert not strategy.is_distributed + + def test_setup_optimizers(self, strategy, mocker): + from otx.algo.strategies.xpu_single import SingleDeviceStrategy + + mocker.patch("otx.algo.strategies.xpu_single.torch") + mocker.patch( + "otx.algo.strategies.xpu_single.torch.xpu.optimize", + return_value=(mocker.MagicMock(), mocker.MagicMock()), + ) + mocker.patch.object(SingleDeviceStrategy, "setup_optimizers") + trainer = pl.Trainer() + trainer.task = "CLASSIFICATION" + # Create mock optimizers and models for testing + model = torch.nn.Linear(10, 2) + strategy._optimizers = [torch.optim.Adam(model.parameters(), lr=0.001)] + strategy._model = model + strategy.setup_optimizers(trainer) + assert len(strategy.optimizers) == 1 diff --git a/tests/unit/core/data/test_tiling.py b/tests/unit/core/data/test_tiling.py index 35fccb43052..c11eddeb2c9 100644 --- a/tests/unit/core/data/test_tiling.py +++ b/tests/unit/core/data/test_tiling.py @@ -115,6 +115,122 @@ def fxt_instseg_data_config(self, fxt_mmcv_det_transform_config) -> OTXDataModul vpm_config=VisualPromptingConfig(), ) + def det_dummy_forward(self, x: DetBatchDataEntity) -> DetBatchPredEntity: + """Dummy detection forward function for testing. + + This function creates random bounding boxes for each image in the batch. + Args: + x (DetBatchDataEntity): Input batch data entity. + + Returns: + DetBatchPredEntity: Output batch prediction entity. + """ + bboxes = [] + labels = [] + scores = [] + saliency_maps = [] + feature_vectors = [] + for img_info in x.imgs_info: + img_h, img_w = img_info.ori_shape + img_bboxes = generate_random_bboxes( + image_width=img_w, + image_height=img_h, + num_boxes=100, + ) + bboxes.append( + tv_tensors.BoundingBoxes( + img_bboxes, + canvas_size=img_info.ori_shape, + format=tv_tensors.BoundingBoxFormat.XYXY, + dtype=torch.float64, + ), + ) + labels.append( + torch.LongTensor(len(img_bboxes)).random_(3), + ) + scores.append( + torch.rand(len(img_bboxes), dtype=torch.float64), + ) + if self.explain_mode: + saliency_maps.append(np.zeros((3, 7, 7))) + feature_vectors.append(np.zeros((1, 32))) + + pred_entity = DetBatchPredEntity( + batch_size=x.batch_size, + images=x.images, + imgs_info=x.imgs_info, + scores=scores, + bboxes=bboxes, + labels=labels, + ) + if self.explain_mode: + pred_entity.saliency_map = saliency_maps + pred_entity.feature_vector = feature_vectors + + return pred_entity + + def inst_seg_dummy_forward(self, x: InstanceSegBatchDataEntity) -> InstanceSegBatchPredEntity: + """Dummy instance segmantation forward function for testing. + + This function creates random bounding boxes/masks for each image in the batch. + Args: + x (InstanceSegBatchDataEntity): Input batch data entity. + + Returns: + InstanceSegBatchPredEntity: Output batch prediction entity. + """ + bboxes = [] + labels = [] + scores = [] + masks = [] + feature_vectors = [] + + for img_info in x.imgs_info: + img_h, img_w = img_info.ori_shape + img_bboxes = generate_random_bboxes( + image_width=img_w, + image_height=img_h, + num_boxes=100, + ) + bboxes.append( + tv_tensors.BoundingBoxes( + img_bboxes, + canvas_size=img_info.ori_shape, + format=tv_tensors.BoundingBoxFormat.XYXY, + dtype=torch.float64, + ), + ) + labels.append( + torch.LongTensor(len(img_bboxes)).random_(3), + ) + scores.append( + torch.rand(len(img_bboxes), dtype=torch.float64), + ) + masks.append( + tv_tensors.Mask( + torch.randint(0, 2, (len(img_bboxes), img_h, img_w)), + dtype=torch.bool, + ), + ) + if self.explain_mode: + feature_vectors.append(np.zeros((1, 32))) + + pred_entity = InstanceSegBatchPredEntity( + batch_size=x.batch_size, + images=x.images, + imgs_info=x.imgs_info, + scores=scores, + bboxes=bboxes, + labels=labels, + masks=masks, + polygons=x.polygons, + ) + if self.explain_mode: + pred_entity.saliency_map = [] + pred_entity.feature_vector = feature_vectors + + return pred_entity + def test_tile_transform(self): dataset = DmDataset.import_from("tests/assets/car_tree_bug", format="coco_instances") first_item = next(iter(dataset), None) @@ -216,124 +332,71 @@ def test_val_dataloader(self, fxt_det_data_config) -> None: assert isinstance(batch, TileBatchDetDataEntity) def test_det_tile_merge(self, fxt_det_data_config): - def dummy_forward(x: DetBatchDataEntity) -> DetBatchPredEntity: - """Dummy forward function for testing. - - This function creates random bounding boxes for each image in the batch. - Args: - x (DetBatchDataEntity): Input batch data entity. - - Returns: - DetBatchPredEntity: Output batch prediction entity. - """ - bboxes = [] - labels = [] - scores = [] - for img_info in x.imgs_info: - img_h, img_w = img_info.ori_shape - img_bboxes = generate_random_bboxes( - image_width=img_w, - image_height=img_h, - num_boxes=100, - ) - bboxes.append( - tv_tensors.BoundingBoxes( - img_bboxes, - canvas_size=img_info.ori_shape, - format=tv_tensors.BoundingBoxFormat.XYXY, - dtype=torch.float64, - ), - ) - labels.append( - torch.LongTensor(len(img_bboxes)).random_(3), - ) - scores.append( - torch.rand(len(img_bboxes), dtype=torch.float64), - ) - - return DetBatchPredEntity( - batch_size=x.batch_size, - images=x.images, - imgs_info=x.imgs_info, - scores=scores, - bboxes=bboxes, - labels=labels, - ) - model = OTXDetectionModel(num_classes=3) + # Enable tile adapter fxt_det_data_config.tile_config.enable_tiler = True tile_datamodule = OTXDataModule( task=OTXTaskType.DETECTION, config=fxt_det_data_config, ) - model.forward = dummy_forward + + self.explain_mode = False + model.forward = self.det_dummy_forward tile_datamodule.prepare_data() for batch in tile_datamodule.val_dataloader(): model.forward_tiles(batch) - def test_instseg_tile_merge(self, fxt_instseg_data_config): - def dummy_forward(x: InstanceSegBatchDataEntity) -> InstanceSegBatchPredEntity: - """Dummy forward function for testing. - - This function creates random bounding boxes/masks for each image in the batch. - Args: - x (InstanceSegBatchDataEntity): Input batch data entity. - - Returns: - InstanceSegBatchPredEntity: Output batch prediction entity. - """ - bboxes = [] - labels = [] - scores = [] - masks = [] - for img_info in x.imgs_info: - img_h, img_w = img_info.ori_shape - img_bboxes = generate_random_bboxes( - image_width=img_w, - image_height=img_h, - num_boxes=100, - ) - bboxes.append( - tv_tensors.BoundingBoxes( - img_bboxes, - canvas_size=img_info.ori_shape, - format=tv_tensors.BoundingBoxFormat.XYXY, - dtype=torch.float64, - ), - ) - labels.append( - torch.LongTensor(len(img_bboxes)).random_(3), - ) - scores.append( - torch.rand(len(img_bboxes), dtype=torch.float64), - ) - masks.append( - tv_tensors.Mask( - torch.randint(0, 2, (len(img_bboxes), img_h, img_w)), - dtype=torch.bool, - ), - ) - - return InstanceSegBatchPredEntity( - batch_size=x.batch_size, - images=x.images, - imgs_info=x.imgs_info, - scores=scores, - bboxes=bboxes, - masks=masks, - labels=labels, - polygons=x.polygons, - ) + def test_explain_det_tile_merge(self, fxt_det_data_config): + model = OTXDetectionModel(num_classes=3) + # Enable tile adapter + fxt_det_data_config.tile_config.enable_tiler = True + fxt_det_data_config.tile_config.enable_adaptive_tiling = False + tile_datamodule = OTXDataModule( + task=OTXTaskType.DETECTION, + config=fxt_det_data_config, + ) + + self.explain_mode = model.explain_mode = True + model.forward_explain = self.det_dummy_forward + + tile_datamodule.prepare_data() + for batch in tile_datamodule.val_dataloader(): + prediction = model.forward_tiles(batch) + assert prediction.saliency_map[0].ndim == 3 + self.explain_mode = False + def test_instseg_tile_merge(self, fxt_instseg_data_config): model = OTXInstanceSegModel(num_classes=3) + # Enable tile adapter fxt_instseg_data_config.tile_config.enable_tiler = True tile_datamodule = OTXDataModule( task=OTXTaskType.INSTANCE_SEGMENTATION, config=fxt_instseg_data_config, ) - model.forward = dummy_forward + + self.explain_mode = False + model.forward = self.inst_seg_dummy_forward tile_datamodule.prepare_data() for batch in tile_datamodule.val_dataloader(): model.forward_tiles(batch) + + def test_explain_instseg_tile_merge(self, fxt_instseg_data_config): + model = OTXInstanceSegModel(num_classes=3) + # Enable tile adapter + fxt_instseg_data_config.tile_config.enable_tiler = True + fxt_instseg_data_config.tile_config.enable_adaptive_tiling = False + tile_datamodule = OTXDataModule( + task=OTXTaskType.INSTANCE_SEGMENTATION, + config=fxt_instseg_data_config, + ) + + self.explain_mode = model.explain_mode = True + model.forward_explain = self.inst_seg_dummy_forward + + tile_datamodule.prepare_data() + for batch in tile_datamodule.val_dataloader(): + prediction = model.forward_tiles(batch) + assert prediction.saliency_map[0].ndim == 3 + self.explain_mode = False diff --git a/tests/unit/core/exporter/exportable_code/demo/demo_package/visualizers/test_vis_utils.py b/tests/unit/core/exporter/exportable_code/demo/demo_package/visualizers/test_vis_utils.py new file mode 100644 index 00000000000..04f0e6580c1 --- /dev/null +++ b/tests/unit/core/exporter/exportable_code/demo/demo_package/visualizers/test_vis_utils.py @@ -0,0 +1,127 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch + +import cv2 +import numpy as np +import pytest +from numpy.random import PCG64, Generator + + +@pytest.fixture(scope="module", autouse=True) +def fxt_import_module(): + global ColorPalette, get_actmap, dump_frames # noqa: PLW0603 + from otx.core.exporter.exportable_code.demo.demo_package.visualizers.vis_utils import ( + ColorPalette as _ColorPalette, + ) + from otx.core.exporter.exportable_code.demo.demo_package.visualizers.vis_utils import ( + dump_frames as _dump_frames, + ) + from otx.core.exporter.exportable_code.demo.demo_package.visualizers.vis_utils import ( + get_actmap as _get_actmap, + ) + + ColorPalette = _ColorPalette + get_actmap = _get_actmap + dump_frames = _dump_frames + + +def test_activation_map_shape(): + generator = Generator(PCG64()) + saliency_map = (generator.random((100, 100)) * 255).astype(np.uint8) + output_res = (50, 50) + result = get_actmap(saliency_map, output_res) + assert result.shape == (50, 50, 3) + + +def test_no_saved_frames(): + output = "output" + input_path = "input" + capture = MagicMock() + saved_frames = [] + dump_frames(saved_frames, output, input_path, capture) + assert not Path(output).exists() + + +@patch("otx.core.exporter.exportable_code.demo.demo_package.visualizers.vis_utils.cv2.VideoWriter_fourcc") +@patch("otx.core.exporter.exportable_code.demo.demo_package.visualizers.vis_utils.cv2.VideoWriter") +@patch("otx.core.exporter.exportable_code.demo.demo_package.visualizers.vis_utils.get_input_names_list") +def test_video_input(mock_get_input_names_list, mock_video_writer, mock_video_writer_fourcc, tmp_path): + output = str(tmp_path / "output") + input_path = "input" + capture = MagicMock(spec=cv2.VideoCapture) + capture.get_type = lambda: "VIDEO" + capture.fps = lambda: 30 + saved_frames = [MagicMock(shape=(100, 100, 3))] + filenames = ["video.mp4"] + mock_get_input_names_list.return_value = filenames + mock_video_writer_fourcc.return_value = "mp4v" + dump_frames(saved_frames, output, input_path, capture) + mock_video_writer_fourcc.assert_called_once_with(*"mp4v") + mock_video_writer.assert_called_once() + + +@patch("otx.core.exporter.exportable_code.demo.demo_package.visualizers.vis_utils.cv2.imwrite") +@patch("otx.core.exporter.exportable_code.demo.demo_package.visualizers.vis_utils.get_input_names_list") +@patch("otx.core.exporter.exportable_code.demo.demo_package.visualizers.vis_utils.cv2.cvtColor") +def test_image_input(mock_imwrite, mock_get_input_names_list, mock_cvtcolor, tmp_path): + output = str(tmp_path / "output") + input_path = "input" + capture = MagicMock(spec=cv2.VideoCapture) + capture.get_type = lambda: "IMAGE" + saved_frames = [MagicMock(), MagicMock()] + filenames = ["image1.jpeg", "image2.jpeg"] + mock_get_input_names_list.return_value = filenames + dump_frames(saved_frames, output, input_path, capture) + assert mock_cvtcolor.call_count == 2 + assert mock_imwrite.call_count == 2 + + +class TestColorPalette: + def test_colorpalette_init_with_zero_classes(self): + expected_msg = "ColorPalette accepts only the positive number of colors" + with pytest.raises(ValueError, match=expected_msg): + ColorPalette(num_classes=0) + with pytest.raises(ValueError, match=expected_msg): + ColorPalette(num_classes=-5) + + def test_colorpalette_length(self): + num_classes = 5 + palette = ColorPalette(num_classes) + assert len(palette) == num_classes + + def test_colorpalette_getitem(self): + num_classes = 3 + palette = ColorPalette(num_classes) + color = palette[1] # assuming 0-based indexing + assert isinstance(color, tuple) + assert len(color) == 3 + + def test_colorpalette_getitem_out_of_range(self): + num_classes = 3 + palette = ColorPalette(num_classes) + color = palette[num_classes + 2] # out-of-range index + assert color == palette[2] # because it should wrap around + + def test_colorpalette_to_numpy_array(self): + num_classes = 2 + palette = ColorPalette(num_classes) + np_array = palette.to_numpy_array() + assert isinstance(np_array, np.ndarray) + assert np_array.shape == (num_classes, 3) + + def test_colorpalette_hsv2rgb_known_values(self): + h, s, v = 0.5, 1, 1 # Cyan in HSV + expected_rgb = (0, 255, 255) # Cyan in RGB + assert ColorPalette.hsv2rgb(h, s, v) == expected_rgb + + def test_dist_same_color(self): + # Colors that are the same should have a distance of 0 + color = (0.5, 0.5, 0.5) + assert ColorPalette._dist(color, color) == 0 + + def test_dist_different_colors(self): + # Test distance between two different colors + color1 = (0.1, 0.2, 0.3) + color2 = (0.4, 0.5, 0.6) + expected_distance = 0.54 + assert ColorPalette._dist(color1, color2) == expected_distance diff --git a/tests/unit/core/exporter/exportable_code/demo/demo_package/visualizers/test_visualizers.py b/tests/unit/core/exporter/exportable_code/demo/demo_package/visualizers/test_visualizers.py new file mode 100644 index 00000000000..331aa1cc3e1 --- /dev/null +++ b/tests/unit/core/exporter/exportable_code/demo/demo_package/visualizers/test_visualizers.py @@ -0,0 +1,272 @@ +from unittest.mock import Mock, patch + +import numpy as np +import pytest +from numpy.random import PCG64, Generator +from openvino.model_api.models.utils import ( + ClassificationResult, + Detection, + DetectionResult, + ImageResultWithSoftPrediction, + InstanceSegmentationResult, + SegmentedObject, +) + + +@pytest.fixture(scope="module", autouse=True) +def fxt_import_module(): + global BaseVisualizer, ClassificationVisualizer, InstanceSegmentationVisualizer, ObjectDetectionVisualizer, SemanticSegmentationVisualizer # noqa: PLW0603 + from otx.core.exporter.exportable_code.demo.demo_package import ( + BaseVisualizer as _BaseVisualizer, + ) + from otx.core.exporter.exportable_code.demo.demo_package import ( + ClassificationVisualizer as _ClassificationVisualizer, + ) + from otx.core.exporter.exportable_code.demo.demo_package import ( + InstanceSegmentationVisualizer as _InstanceSegmentationVisualizer, + ) + from otx.core.exporter.exportable_code.demo.demo_package import ( + ObjectDetectionVisualizer as _ObjectDetectionVisualizer, + ) + from otx.core.exporter.exportable_code.demo.demo_package import ( + SemanticSegmentationVisualizer as _SemanticSegmentationVisualizer, + ) + + BaseVisualizer = _BaseVisualizer + ClassificationVisualizer = _ClassificationVisualizer + InstanceSegmentationVisualizer = _InstanceSegmentationVisualizer + ObjectDetectionVisualizer = _ObjectDetectionVisualizer + SemanticSegmentationVisualizer = _SemanticSegmentationVisualizer + + +class TestBaseVisualizer: + def test_init(self): + visualizer = BaseVisualizer(window_name="TestWindow", no_show=True, delay=10, output="test_output") + assert visualizer.window_name == "TestWindow" + assert visualizer.no_show is True + assert visualizer.delay == 10 + assert visualizer.output == "test_output" + + # Test show method without displaying the window + @patch("cv2.imshow") + def test_show_no_display(self, mock_imshow): + visualizer = BaseVisualizer(no_show=True) + test_image = np.zeros((100, 100, 3), dtype=np.uint8) + visualizer.show(test_image) + mock_imshow.assert_not_called() + + # Test show method with displaying the window + @patch("cv2.imshow") + def test_show_display(self, mock_imshow): + visualizer = BaseVisualizer(no_show=False) + test_image = np.zeros((100, 100, 3), dtype=np.uint8) + visualizer.show(test_image) + mock_imshow.assert_called_once_with(visualizer.window_name, test_image) + + # Test is_quit method + @patch("cv2.waitKey", return_value=ord("q")) + def test_is_quit(self, mock_waitkey): + visualizer = BaseVisualizer(no_show=False) + assert visualizer.is_quit() is True + + # Test video_delay method + @patch("time.sleep") + def test_video_delay(self, mock_sleep): + streamer = Mock() + streamer.get_type.return_value = "VIDEO" + streamer.fps.return_value = 30 + visualizer = BaseVisualizer(no_show=False) + visualizer.video_delay(0.02, streamer) + mock_sleep.assert_called_once_with(1 / 30 - 0.02) + + +class TestClassificationVisualizer: + @pytest.fixture() + def visualizer(self): + return ClassificationVisualizer(window_name="TestWindow", no_show=True, delay=10, output="test_output") + + @pytest.fixture() + def frame(self): + return np.zeros((100, 100, 3), dtype=np.uint8) + + @pytest.fixture() + def predictions(self): + return ClassificationResult( + top_labels=[(0, "cat", 0.9)], + saliency_map=None, + feature_vector=None, + raw_scores=[0.9], + ) + + def test_draw_one_prediction(self, frame, predictions, visualizer): + # test one prediction + copied_frame = frame.copy() + output = visualizer.draw(frame, predictions) + assert output.shape == (100, 100, 3) + assert np.any(output != copied_frame) + + def test_draw_multiple_predictions(self, frame, predictions, visualizer): + # test multiple predictions + copied_frame = frame.copy() + predictions.top_labels.extend([(1, "dog", 0.8), (2, "bird", 0.7)]) + output = visualizer.draw(frame, predictions) + assert output.shape == (100, 100, 3) + assert np.any(output != copied_frame) + + def test_label_overflow(self, frame, predictions, visualizer): + # test multiple predictions + copied_frame = frame.copy() + predictions.top_labels.extend([(1, "dog", 0.8), (2, "bird", 0.7), (3, "cat", 0.6)]) + output = visualizer.draw(frame, predictions) + assert output.shape == (100, 100, 3) + assert np.any(output != copied_frame) + + def test_draw_no_predictions(self, frame, visualizer): + # test no predictions + copied_frame = frame.copy() + predictions = ClassificationResult(top_labels=[()], saliency_map=None, feature_vector=None, raw_scores=[]) + output = visualizer.draw(frame, predictions) + assert output.shape == (100, 100, 3) + assert np.equal(output, copied_frame).all() + + +class TestDetectionVisualizer: + @pytest.fixture() + def visualizer(self): + return ObjectDetectionVisualizer( + labels=["Pedestrian", "Car"], + window_name="TestWindow", + no_show=True, + delay=10, + output="test_output", + ) + + def test_draw_no_predictions(self, visualizer): + frame = np.zeros((100, 100, 3), dtype=np.uint8) + predictions = DetectionResult([], saliency_map=None, feature_vector=None) + output_frame = visualizer.draw(frame, predictions) + assert np.array_equal(frame, output_frame) + + def test_draw_with_predictions(self, visualizer): + frame = np.zeros((100, 100, 3), dtype=np.uint8) + predictions = DetectionResult( + [Detection(10, 40, 30, 80, 0.7, 2, "Car")], + saliency_map=None, + feature_vector=None, + ) + copied_frame = frame.copy() + output_frame = visualizer.draw(frame, predictions) + assert np.any(output_frame != copied_frame) + + +class TestInstanceSegmentationVisualizer: + @pytest.fixture() + def rand_generator(self): + return Generator(PCG64()) + + @pytest.fixture() + def visualizer(self): + return InstanceSegmentationVisualizer( + labels=["person", "car"], + window_name="TestWindow", + no_show=True, + delay=10, + output="test_output", + ) + + def test_draw_multiple_objects(self, visualizer, rand_generator): + # Create a frame + frame = np.zeros((100, 100, 3), dtype=np.uint8) + copied_frame = frame.copy() + + # Create instance segmentation results with multiple objects + predictions = InstanceSegmentationResult( + segmentedObjects=[ + SegmentedObject( + xmin=10, + ymin=10, + xmax=30, + ymax=30, + score=0.9, + id=0, + mask=rand_generator.integers(2, size=(100, 100), dtype=np.uint8), + str_label="person", + ), + SegmentedObject( + xmin=40, + ymin=40, + xmax=60, + ymax=60, + score=0.8, + id=1, + mask=rand_generator.integers(2, size=(100, 100), dtype=np.uint8), + str_label="car", + ), + ], + saliency_map=None, + feature_vector=None, + ) + + drawn_frame = visualizer.draw(frame, predictions) + assert np.any(drawn_frame != copied_frame) + + # Assertion checks for the drawn frame + + def test_draw_no_objects(self, visualizer): + # Create a frame + frame = np.zeros((100, 100, 3), dtype=np.uint8) + copied_frame = frame.copy() + + # Create instance segmentation results with no objects + predictions = InstanceSegmentationResult(segmentedObjects=[], saliency_map=None, feature_vector=None) + + drawn_frame = visualizer.draw(frame, predictions) + assert np.array_equal(drawn_frame, copied_frame) + + +class TestSemanticSegmentationVisualizer: + @pytest.fixture() + def labels(self): + return ["background", "object1", "object2"] + + @pytest.fixture() + def visualizer(self, labels): + return SemanticSegmentationVisualizer( + labels=labels, + window_name="TestWindow", + no_show=True, + delay=10, + output="test_output", + ) + + @pytest.fixture() + def rand_generator(self): + return Generator(PCG64()) + + def test_initialization(self, visualizer): + assert isinstance(visualizer.color_palette, np.ndarray) + assert visualizer.color_map.shape == (256, 1, 3) + assert visualizer.color_map.dtype == np.uint8 + + def test_create_color_map(self, visualizer): + color_map = visualizer._create_color_map() + assert color_map.shape == (256, 1, 3) + assert color_map.dtype == np.uint8 + + def test_apply_color_map(self, visualizer, labels, rand_generator): + input_2d_mask = rand_generator.integers(0, len(labels), size=(10, 10)) + colored_mask = visualizer._apply_color_map(input_2d_mask) + assert colored_mask.shape == (10, 10, 3) + + def test_draw(self, visualizer, rand_generator): + frame = rand_generator.integers(0, 255, size=(10, 10, 3), dtype=np.uint8) + copied_frame = frame.copy() + masks = ImageResultWithSoftPrediction( + resultImage=rand_generator.integers(0, 255, size=(10, 10), dtype=np.uint8), + soft_prediction=rand_generator.random((10, 10)), + saliency_map=None, + feature_vector=None, + ) + output_image = visualizer.draw(frame, masks) + assert output_image.shape == frame.shape + assert np.any(output_image != copied_frame) diff --git a/tests/unit/core/exporter/test_base.py b/tests/unit/core/exporter/test_base.py new file mode 100644 index 00000000000..e40922362ad --- /dev/null +++ b/tests/unit/core/exporter/test_base.py @@ -0,0 +1,97 @@ +from unittest.mock import MagicMock, patch + +import pytest +from onnx import ModelProto +from onnxconverter_common import float16 +from otx.core.exporter.base import OTXExportFormatType, OTXModelExporter, OTXPrecisionType, ZipFile +from otx.core.types.export import TaskLevelExportParameters + + +class MockModelExporter(OTXModelExporter): + def to_openvino(self, model, output_dir, base_model_name, precision): + return output_dir / f"{base_model_name}.xml" + + def to_onnx(self, model, output_dir, base_model_name, precision): + return output_dir / f"{base_model_name}.onnx" + + +@pytest.fixture() +def mock_model(): + return MagicMock() + + +@pytest.fixture() +def exporter(mocker): + ZipFile.write = MagicMock() + mocker.patch("otx.core.exporter.base.json") + return MockModelExporter( + task_level_export_parameters=MagicMock(TaskLevelExportParameters), + input_size=(224, 224), + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + ) + + +class TestOTXModelExporter: + def test_to_openvino(self, mock_model, exporter, tmp_path): + output_dir = tmp_path + base_model_name = "test_model" + precision = OTXPrecisionType.FP32 + result = exporter.export(mock_model, output_dir, base_model_name, OTXExportFormatType.OPENVINO, precision) + assert result == output_dir / f"{base_model_name}.xml" + + def test_to_onnx(self, mock_model, exporter, tmp_path): + output_dir = tmp_path + base_model_name = "test_model" + precision = OTXPrecisionType.FP32 + result = exporter.export(mock_model, output_dir, base_model_name, OTXExportFormatType.ONNX, precision) + assert result == output_dir / f"{base_model_name}.onnx" + + def test_export_unsupported_format_raises(self, exporter, mock_model, tmp_path): + export_format = "unsupported_format" + with pytest.raises(ValueError, match=f"Unsupported export format: {export_format}"): + exporter.export(mock_model, tmp_path, export_format=export_format) + + def test_to_exportable_code(self, mock_model, exporter, tmp_path): + from otx.core.exporter.base import ZipFile + + ZipFile.writestr = MagicMock() + + base_model_name = "test_model" + output_dir = tmp_path / "exportable_code" + precision = OTXPrecisionType.FP32 + + with patch("builtins.open", new_callable=MagicMock): + exporter.to_openvino = MagicMock() + result = exporter.to_exportable_code(mock_model, output_dir, base_model_name, precision) + + assert result == output_dir / "exportable_code.zip" + + def test_postprocess_openvino_model(self, mock_model, exporter): + # test output names do not match exporter parameters + exporter.output_names = ["output1"] + with pytest.raises(RuntimeError): + exporter._postprocess_openvino_model(mock_model) + # test output names match exporter parameters + exporter.output_names = ["output1", "output2"] + mock_model.outputs = [] + for output_name in exporter.output_names: + output = MagicMock() + output.get_names.return_value = output_name + mock_model.outputs.append(output) + processed_model = exporter._postprocess_openvino_model(mock_model) + # Verify the processed model is returned and the names are set correctly + assert processed_model is mock_model + for output, name in zip(processed_model.outputs, exporter.output_names): + output.tensor.set_names.assert_called_once_with({name}) + + def test_embed_metadata_true_precision_fp16(self, exporter): + onnx_model = ModelProto() + exporter._embed_onnx_metadata = MagicMock(return_value=onnx_model) + convert_float_to_float16_mock = MagicMock(return_value=onnx_model) + with pytest.MonkeyPatch.context() as m: + m.setattr(float16, "convert_float_to_float16", convert_float_to_float16_mock) + result = exporter._postprocess_onnx_model(onnx_model, embed_metadata=True, precision=OTXPrecisionType.FP16) + exporter._embed_onnx_metadata.assert_called_once() + convert_float_to_float16_mock.assert_called_once_with(onnx_model) + assert result is onnx_model diff --git a/tests/unit/core/exporter/test_mmdeploy.py b/tests/unit/core/exporter/test_mmdeploy.py index 33260e02dee..8dcaa4033b4 100644 --- a/tests/unit/core/exporter/test_mmdeploy.py +++ b/tests/unit/core/exporter/test_mmdeploy.py @@ -34,6 +34,7 @@ def get_exporter( model_builder=MagicMock(), model_cfg=MagicMock(), deploy_cfg=self.DEFAULT_MMDEPLOY_CFG, + task_level_export_parameters=MagicMock(), test_pipeline=MagicMock(), input_size=(1, 3, 256, 256), max_num_detections=max_num_detections, diff --git a/tests/unit/core/exporter/test_native.py b/tests/unit/core/exporter/test_native.py new file mode 100644 index 00000000000..4c9f481e63c --- /dev/null +++ b/tests/unit/core/exporter/test_native.py @@ -0,0 +1,75 @@ +import onnx +import pytest +import torch +from otx.core.exporter.native import OTXNativeModelExporter +from otx.core.types.export import TaskLevelExportParameters +from otx.core.types.precision import OTXPrecisionType + + +class TestOTXNativeModelExporter: + @pytest.fixture() + def exporter(self, mocker): + # Create an instance of OTXNativeModelExporter with default params + return OTXNativeModelExporter( + task_level_export_parameters=mocker.MagicMock(TaskLevelExportParameters), + input_size=(3, 224, 224), + ) + + @pytest.fixture() + def dummy_model(self): + # Define a simple dummy torch model for testing + return torch.nn.Sequential( + torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(), + ) + + def test_to_openvino_export(self, exporter, dummy_model, tmp_path): + # Use tmp_path provided by pytest for temporary file creation + output_dir = tmp_path / "model_export" + output_dir.mkdir() + + # Call the to_openvino method + exported_path = exporter.to_openvino( + model=dummy_model, + output_dir=output_dir, + base_model_name="test_model", + precision=OTXPrecisionType.FP32, + ) + + # Check that the exported files exist + assert exported_path.exists() + assert (output_dir / "test_model.xml").exists() + assert (output_dir / "test_model.bin").exists() + + exporter.via_onnx = True + exported_path = exporter.to_openvino( + model=dummy_model, + output_dir=output_dir, + base_model_name="test_model", + precision=OTXPrecisionType.FP32, + ) + + assert exported_path.exists() + assert (output_dir / "test_model.xml").exists() + assert (output_dir / "test_model.bin").exists() + + def test_to_onnx_export(self, exporter, dummy_model, tmp_path): + # Use tmp_path provided by pytest for temporary file creation + output_dir = tmp_path / "onnx_export" + output_dir.mkdir() + + # Call the to_onnx method + exported_path = exporter.to_onnx( + model=dummy_model, + output_dir=output_dir, + base_model_name="test_onnx_model", + precision=OTXPrecisionType.FP32, + ) + + # Check that the exported ONNX file exists + assert exported_path.exists() + assert (output_dir / "test_onnx_model.onnx").exists() + + # Load the model to verify it's a valid ONNX file + onnx_model = onnx.load(str(exported_path)) + onnx.checker.check_model(onnx_model) diff --git a/tests/unit/core/exporter/test_visual_prompting.py b/tests/unit/core/exporter/test_visual_prompting.py index e0249ab8f16..9050a3e28d6 100644 --- a/tests/unit/core/exporter/test_visual_prompting.py +++ b/tests/unit/core/exporter/test_visual_prompting.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 """Unit tests of visual prompting exporter.""" +from unittest.mock import MagicMock + import pytest from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter from otx.core.types.export import OTXExportFormatType @@ -22,7 +24,11 @@ def forward(self, x): class TestOTXVisualPromptingModelExporter: @pytest.fixture() def otx_visual_prompting_model_exporter(self) -> OTXVisualPromptingModelExporter: - return OTXVisualPromptingModelExporter(input_size=(10, 10), via_onnx=True) + return OTXVisualPromptingModelExporter( + task_level_export_parameters=MagicMock(), + input_size=(10, 10), + via_onnx=True, + ) def test_export_openvino(self, mocker, tmpdir, otx_visual_prompting_model_exporter) -> None: """Test export for OPENVINO.""" diff --git a/tests/unit/core/model/test_base.py b/tests/unit/core/model/test_base.py index 83891686f43..558ae998691 100644 --- a/tests/unit/core/model/test_base.py +++ b/tests/unit/core/model/test_base.py @@ -34,7 +34,9 @@ def test_smart_weight_loading(self, mocker) -> None: "model.head.bias": {"stride": 1, "num_extra_classes": 0}, } current_model.label_info = ["car", "bus", "truck"] - current_model.load_state_dict(prev_state_dict) + current_model.load_state_dict_incrementally( + {"state_dict": prev_state_dict, "label_info": prev_model.label_info}, + ) curr_state_dict = current_model.state_dict() indices = torch.Tensor([0, 2]).to(torch.int32) diff --git a/tests/unit/core/model/test_detection.py b/tests/unit/core/model/test_detection.py index 2e0c7f29907..06ab2f8b52a 100644 --- a/tests/unit/core/model/test_detection.py +++ b/tests/unit/core/model/test_detection.py @@ -5,14 +5,23 @@ from __future__ import annotations +from typing import TYPE_CHECKING from unittest.mock import create_autospec import pytest +import torch +from importlib_resources import files from lightning.pytorch.cli import ReduceLROnPlateau +from omegaconf import OmegaConf +from otx.algo.explain.explain_algo import get_feature_vector from otx.core.metrics.fmeasure import FMeasureCallable -from otx.core.model.detection import OTXDetectionModel +from otx.core.model.detection import MMDetCompatibleModel, OTXDetectionModel +from otx.core.types.export import TaskLevelExportParameters from torch.optim import Optimizer +if TYPE_CHECKING: + from omegaconf.dictconfig import DictConfig + class TestOTXDetectionModel: @pytest.fixture() @@ -39,6 +48,15 @@ def mock_scheduler(self): def mock_ckpt(self, request): return request.param + @pytest.fixture() + def config(self) -> DictConfig: + cfg_path = files("otx") / "algo" / "detection" / "mmconfigs" / "yolox_tiny.yaml" + return OmegaConf.load(cfg_path) + + @pytest.fixture() + def otx_model(self, config) -> MMDetCompatibleModel: + return MMDetCompatibleModel(num_classes=1, config=config) + def test_configure_metric_with_ckpt( self, mock_optimizer, @@ -56,3 +74,62 @@ def test_configure_metric_with_ckpt( model.load_state_dict(mock_ckpt) assert model.hparams["best_confidence_threshold"] == 0.35 + + def test_create_model(self, otx_model) -> None: + mmdet_model = otx_model._create_model() + assert mmdet_model is not None + assert isinstance(mmdet_model, torch.nn.Module) + + def test_get_num_anchors(self, otx_model): + num_anchors = otx_model.get_num_anchors() + assert isinstance(num_anchors, list) + assert all(isinstance(n, int) for n in num_anchors) + + def test_get_explain_fn(self, otx_model): + otx_model.explain_mode = True + explain_fn = otx_model.get_explain_fn() + assert callable(explain_fn) + + def test_forward_explain_detection(self, otx_model, fxt_data_sample): + inputs = torch.randn(1, 3, 224, 224) + otx_model.model.feature_vector_fn = get_feature_vector + otx_model.model.explain_fn = otx_model.get_explain_fn() + result = otx_model._forward_explain_detection(otx_model.model, inputs, fxt_data_sample, mode="predict") + + assert "predictions" in result + assert "feature_vector" in result + assert "saliency_map" in result + + def test_customize_inputs(self, otx_model, fxt_det_data_entity) -> None: + output_data = otx_model._customize_inputs(fxt_det_data_entity[2]) + assert output_data is not None + assert "gt_instances" in output_data["data_samples"][-1] + assert "bboxes" in output_data["data_samples"][-1].gt_instances + assert output_data["data_samples"][-1].metainfo["pad_shape"] == output_data["inputs"].shape[-2:] + + def test_forward_explain(self, otx_model, fxt_det_data_entity): + inputs = fxt_det_data_entity[2] + otx_model.training = False + otx_model.explain_mode = True + outputs = otx_model.forward_explain(inputs) + + assert outputs.has_xai_outputs + assert outputs.feature_vector is not None + assert outputs.saliency_map is not None + + def test_reset_restore_model_forward(self, otx_model): + otx_model.explain_mode = True + initial_model_forward = otx_model.model.forward + + otx_model._reset_model_forward() + assert otx_model.original_model_forward is not None + assert str(otx_model.model.forward) != str(otx_model.original_model_forward) + + otx_model._restore_model_forward() + assert otx_model.original_model_forward is None + assert str(otx_model.model.forward) == str(initial_model_forward) + + def test_export_parameters(self, otx_model): + parameters = otx_model._export_parameters + assert isinstance(parameters, TaskLevelExportParameters) + assert parameters.task_type == "detection" diff --git a/tests/unit/core/model/test_inst_segmentation.py b/tests/unit/core/model/test_inst_segmentation.py new file mode 100644 index 00000000000..d1d9387aa2d --- /dev/null +++ b/tests/unit/core/model/test_inst_segmentation.py @@ -0,0 +1,71 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Unit tests for instance segmentation model entity.""" + +import pytest +import torch +from otx.algo.explain.explain_algo import get_feature_vector +from otx.algo.instance_segmentation.maskrcnn import MaskRCNN +from otx.core.model.instance_segmentation import MMDetInstanceSegCompatibleModel +from otx.core.types.export import TaskLevelExportParameters + + +class TestOTXInstanceSegModel: + @pytest.fixture() + def otx_model(self) -> MMDetInstanceSegCompatibleModel: + return MaskRCNN(num_classes=1, variant="efficientnetb2b") + + def test_create_model(self, otx_model) -> None: + mmdet_model = otx_model._create_model() + assert mmdet_model is not None + assert isinstance(mmdet_model, torch.nn.Module) + + def test_get_explain_fn(self, otx_model): + otx_model.explain_mode = True + explain_fn = otx_model.get_explain_fn() + assert callable(explain_fn) + + def test_forward_explain_inst_seg(self, otx_model, fxt_data_sample): + inputs = torch.randn(1, 3, 224, 224) + otx_model.model.feature_vector_fn = get_feature_vector + otx_model.model.explain_fn = otx_model.get_explain_fn() + result = otx_model._forward_explain_inst_seg(otx_model.model, inputs, fxt_data_sample, mode="predict") + + assert "predictions" in result + assert "feature_vector" in result + assert "saliency_map" in result + + def test_customize_inputs(self, otx_model, fxt_inst_seg_data_entity) -> None: + output_data = otx_model._customize_inputs(fxt_inst_seg_data_entity[2]) + assert output_data is not None + assert "gt_instances" in output_data["data_samples"][-1] + assert "masks" in output_data["data_samples"][-1].gt_instances + assert output_data["data_samples"][-1].metainfo["pad_shape"] == output_data["inputs"].shape[-2:] + + def test_forward_explain(self, otx_model, fxt_inst_seg_data_entity): + inputs = fxt_inst_seg_data_entity[2] + otx_model.training = False + otx_model.explain_mode = True + outputs = otx_model.forward_explain(inputs) + + assert outputs.has_xai_outputs + assert outputs.feature_vector is not None + assert outputs.saliency_map is not None + + def test_reset_restore_model_forward(self, otx_model): + otx_model.explain_mode = True + initial_model_forward = otx_model.model.forward + + otx_model._reset_model_forward() + assert otx_model.original_model_forward is not None + assert str(otx_model.model.forward) != str(otx_model.original_model_forward) + + otx_model._restore_model_forward() + assert otx_model.original_model_forward is None + assert str(otx_model.model.forward) == str(initial_model_forward) + + def test_export_parameters(self, otx_model): + parameters = otx_model._export_parameters + assert isinstance(parameters, TaskLevelExportParameters) + assert parameters.task_type == "instance_segmentation" diff --git a/tests/unit/core/model/test_visual_prompting.py b/tests/unit/core/model/test_visual_prompting.py index e8693454aa8..36d7f7d7a95 100644 --- a/tests/unit/core/model/test_visual_prompting.py +++ b/tests/unit/core/model/test_visual_prompting.py @@ -26,19 +26,24 @@ _inference_step, _inference_step_for_zero_shot, ) +from otx.core.types.export import TaskLevelExportParameters from torchvision import tv_tensors @pytest.fixture() def otx_visual_prompting_model(mocker) -> OTXVisualPromptingModel: mocker.patch.object(OTXVisualPromptingModel, "_create_model") - return OTXVisualPromptingModel(num_classes=1) + model = OTXVisualPromptingModel(num_classes=1) + model.model.image_size = 1024 + return model @pytest.fixture() def otx_zero_shot_visual_prompting_model(mocker) -> OTXZeroShotVisualPromptingModel: mocker.patch.object(OTXZeroShotVisualPromptingModel, "_create_model") - return OTXZeroShotVisualPromptingModel(num_classes=1) + model = OTXZeroShotVisualPromptingModel(num_classes=1) + model.model.image_size = 1024 + return model def test_inference_step(mocker, otx_visual_prompting_model, fxt_vpm_data_entity) -> None: @@ -136,18 +141,20 @@ def test_inference_step_for_zero_shot_with_more_target( class TestOTXVisualPromptingModel: def test_exporter(self, otx_visual_prompting_model) -> None: """Test _exporter.""" - assert isinstance(otx_visual_prompting_model._exporter, OTXVisualPromptingModelExporter) + exporter = otx_visual_prompting_model._exporter + assert isinstance(exporter, OTXVisualPromptingModelExporter) + assert exporter.input_size == (1, 3, 1024, 1024) + assert exporter.resize_mode == "fit_to_window" + assert exporter.mean == (123.675, 116.28, 103.53) + assert exporter.std == (58.395, 57.12, 57.375) def test_export_parameters(self, otx_visual_prompting_model) -> None: """Test _export_parameters.""" - otx_visual_prompting_model.model.image_size = 1024 - export_parameters = otx_visual_prompting_model._export_parameters - assert export_parameters["input_size"] == (1, 3, 1024, 1024) - assert export_parameters["resize_mode"] == "fit_to_window" - assert export_parameters["mean"] == (123.675, 116.28, 103.53) - assert export_parameters["std"] == (58.395, 57.12, 57.375) + assert isinstance(export_parameters, TaskLevelExportParameters) + assert export_parameters.model_type == "Visual_Prompting" + assert export_parameters.task_type == "visual_prompting" def test_optimization_config(self, otx_visual_prompting_model) -> None: """Test _optimization_config.""" @@ -175,18 +182,20 @@ def test_optimization_config(self, otx_visual_prompting_model) -> None: class TestOTXZeroShotVisualPromptingModel: def test_exporter(self, otx_zero_shot_visual_prompting_model) -> None: """Test _exporter.""" - assert isinstance(otx_zero_shot_visual_prompting_model._exporter, OTXVisualPromptingModelExporter) + exporter = otx_zero_shot_visual_prompting_model._exporter + assert isinstance(exporter, OTXVisualPromptingModelExporter) + assert exporter.input_size == (1, 3, 1024, 1024) + assert exporter.resize_mode == "fit_to_window" + assert exporter.mean == (123.675, 116.28, 103.53) + assert exporter.std == (58.395, 57.12, 57.375) def test_export_parameters(self, otx_zero_shot_visual_prompting_model) -> None: """Test _export_parameters.""" - otx_zero_shot_visual_prompting_model.model.image_size = 1024 - export_parameters = otx_zero_shot_visual_prompting_model._export_parameters - assert export_parameters["input_size"] == (1, 3, 1024, 1024) - assert export_parameters["resize_mode"] == "fit_to_window" - assert export_parameters["mean"] == (123.675, 116.28, 103.53) - assert export_parameters["std"] == (58.395, 57.12, 57.375) + assert isinstance(export_parameters, TaskLevelExportParameters) + assert export_parameters.model_type == "Visual_Prompting" + assert export_parameters.task_type == "visual_prompting" def test_optimization_config(self, otx_zero_shot_visual_prompting_model) -> None: """Test _optimization_config.""" diff --git a/tests/unit/core/types/conftest.py b/tests/unit/core/types/conftest.py new file mode 100644 index 00000000000..b03fcc719c5 --- /dev/null +++ b/tests/unit/core/types/conftest.py @@ -0,0 +1,17 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from otx.core.types.label import LabelInfo + + +@pytest.fixture( + params=[ + "fxt_multiclass_labelinfo", + "fxt_hlabel_multilabel_info", + "fxt_null_label_info", + "fxt_seg_label_info", + ], +) +def fxt_label_info(request: pytest.FixtureRequest) -> LabelInfo: + return request.getfixturevalue(request.param) diff --git a/tests/unit/core/types/test_export.py b/tests/unit/core/types/test_export.py new file mode 100644 index 00000000000..5a85f3b6fc8 --- /dev/null +++ b/tests/unit/core/types/test_export.py @@ -0,0 +1,51 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from otx.core.config.data import TileConfig +from otx.core.types.export import TaskLevelExportParameters + + +@pytest.mark.parametrize("task_type", ["instance_segmentation", "classification"]) +def test_wrap(fxt_label_info, task_type): + params = TaskLevelExportParameters( + model_type="dummy model", + task_type=task_type, + label_info=fxt_label_info, + optimization_config={}, + ) + + multilabel = False + hierarchical = False + confidence_threshold = 0.0 + iou_threshold = 0.0 + return_soft_prediction = False + soft_threshold = 0.0 + blur_strength = 0 + tile_config = TileConfig() + + params = params.wrap( + multilabel=multilabel, + hierarchical=hierarchical, + confidence_threshold=confidence_threshold, + iou_threshold=iou_threshold, + return_soft_prediction=return_soft_prediction, + soft_threshold=soft_threshold, + blur_strength=blur_strength, + tile_config=tile_config, + ) + + metadata = params.to_metadata() + + assert metadata[("model_info", "multilabel")] == str(multilabel) + assert metadata[("model_info", "hierarchical")] == str(hierarchical) + assert metadata[("model_info", "confidence_threshold")] == str(confidence_threshold) + assert metadata[("model_info", "iou_threshold")] == str(iou_threshold) + assert metadata[("model_info", "return_soft_prediction")] == str(return_soft_prediction) + assert metadata[("model_info", "soft_threshold")] == str(soft_threshold) + assert metadata[("model_info", "blur_strength")] == str(blur_strength) + + # Tile config + assert ("model_info", "tile_size") in metadata + assert ("model_info", "tiles_overlap") in metadata + assert ("model_info", "max_pred_number") in metadata diff --git a/tests/unit/core/types/test_label.py b/tests/unit/core/types/test_label.py index 9b3ab223667..16732a6402d 100644 --- a/tests/unit/core/types/test_label.py +++ b/tests/unit/core/types/test_label.py @@ -1,21 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import pytest -from otx.core.types.label import LabelInfo - - -@pytest.fixture( - params=[ - "fxt_multiclass_labelinfo", - "fxt_hlabel_multilabel_info", - "fxt_null_label_info", - "fxt_seg_label_info", - ], -) -def fxt_label_info(request: pytest.FixtureRequest) -> LabelInfo: - return request.getfixturevalue(request.param) - def test_as_json(fxt_label_info): serialized = fxt_label_info.to_json() diff --git a/tests/unit/core/utils/test_utils.py b/tests/unit/core/utils/test_utils.py index fdaaba25d24..90737defc33 100644 --- a/tests/unit/core/utils/test_utils.py +++ b/tests/unit/core/utils/test_utils.py @@ -63,9 +63,9 @@ def test_get_mean_std_from_data_processing(): "std": 0.1, }, } - result = get_mean_std_from_data_processing(config) - assert result["mean"] == 0.5 - assert result["std"] == 0.1 + mean, std = get_mean_std_from_data_processing(config) + assert mean == 0.5 + assert std == 0.1 @pytest.fixture() diff --git a/tests/unit/engine/test_engine.py b/tests/unit/engine/test_engine.py index e5fe19fc5b8..2aa8449a663 100644 --- a/tests/unit/engine/test_engine.py +++ b/tests/unit/engine/test_engine.py @@ -2,16 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path -from unittest.mock import create_autospec import pytest from otx.algo.classification.efficientnet_b0 import EfficientNetB0ForMulticlassCls from otx.algo.classification.torchvision_model import OTXTVModel from otx.core.config.device import DeviceConfig -from otx.core.model.base import OVModel +from otx.core.model.base import OTXModel, OVModel from otx.core.types.export import OTXExportFormatType +from otx.core.types.label import NullLabelInfo from otx.core.types.precision import OTXPrecisionType from otx.engine import Engine +from pytest_mock import MockerFixture @pytest.fixture() @@ -73,13 +74,30 @@ def test_training_with_override_args(self, fxt_engine, mocker) -> None: assert fxt_engine._cache.args["max_epochs"] == 100 mock_seed_everything.assert_called_once_with(1234, workers=True) - def test_training_with_checkpoint(self, fxt_engine, mocker) -> None: - mock_torch_load = mocker.patch("torch.load") - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") + @pytest.mark.parametrize("resume", [True, False]) + def test_training_with_checkpoint(self, fxt_engine, resume: bool, mocker: MockerFixture, tmpdir) -> None: + checkpoint = "path/to/checkpoint.ckpt" + + mock_trainer = mocker.patch("otx.engine.engine.Trainer") + mock_trainer.return_value.default_root_dir = Path(tmpdir) + mock_trainer_fit = mock_trainer.return_value.fit + + mock_torch_load = mocker.patch("otx.engine.engine.torch.load") + mock_load_state_dict_incrementally = mocker.patch.object(fxt_engine.model, "load_state_dict_incrementally") + + trained_checkpoint = Path(tmpdir) / "best.ckpt" + trained_checkpoint.touch() + mock_trainer.return_value.checkpoint_callback.best_model_path = trained_checkpoint + + fxt_engine.train(resume=resume, checkpoint=checkpoint) - fxt_engine.checkpoint = "path/to/checkpoint" - fxt_engine.train() - mock_torch_load.assert_called_once_with("path/to/checkpoint") + if resume: + assert mock_trainer_fit.call_args.kwargs.get("ckpt_path") == checkpoint + else: + assert "ckpt_path" not in mock_trainer_fit.call_args.kwargs + + mock_torch_load.assert_called_once() + mock_load_state_dict_incrementally.assert_called_once() def test_training_with_run_hpo(self, fxt_engine, mocker) -> None: mocker.patch("pathlib.Path.symlink_to") @@ -93,94 +111,97 @@ def test_training_with_run_hpo(self, fxt_engine, mocker) -> None: mock_update_hyper_parameter.assert_called_once_with(fxt_engine, {}) assert mock_fit.call_args[1]["ckpt_path"] == "hpo/best/checkpoint" - def test_training_with_resume(self, fxt_engine, mocker) -> None: - mocker.patch("pathlib.Path.symlink_to") - mock_fit = mocker.patch("otx.engine.engine.Trainer.fit") - - fxt_engine.checkpoint = "path/to/checkpoint" - fxt_engine.train(resume=True) - assert mock_fit.call_args[1]["ckpt_path"] == "path/to/checkpoint" - - def test_testing_after_training(self, fxt_engine, mocker) -> None: - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") + @pytest.mark.parametrize( + "checkpoint", + [ + "path/to/checkpoint.ckpt", + "path/to/checkpoint.xml", + ], + ) + def test_test(self, fxt_engine, checkpoint, mocker: MockerFixture) -> None: mock_test = mocker.patch("otx.engine.engine.Trainer.test") - mock_torch_load = mocker.patch("torch.load") + _ = mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") + mock_get_ov_model = mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") + mock_load_from_checkpoint = mocker.patch.object(fxt_engine.model.__class__, "load_from_checkpoint") - # Fetch Checkpoint - fxt_engine.checkpoint = "path/to/checkpoint" - fxt_engine.test() - mock_torch_load.assert_called_once_with("path/to/checkpoint") - mock_test.assert_called_once() + ext = Path(checkpoint).suffix - fxt_engine.test(checkpoint="path/to/new/checkpoint") - mock_torch_load.assert_called_with("path/to/new/checkpoint") + if ext == ".ckpt": + mock_model = mocker.create_autospec(OTXModel) - def test_testing_with_ov_model(self, fxt_engine, mocker) -> None: - mock_test = mocker.patch("otx.engine.engine.Trainer.test") - mock_torch_load = mocker.patch("torch.load") - mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") - mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") + mock_load_from_checkpoint.return_value = mock_model + else: + mock_model = mocker.create_autospec(OVModel) - fxt_engine.test(checkpoint="path/to/model.xml") - mock_test.assert_called_once() - mock_torch_load.assert_not_called() + mock_get_ov_model.return_value = mock_model - fxt_engine.model = create_autospec(OVModel) - fxt_engine.test(checkpoint="path/to/model.xml") + # Correct label_info from the checkpoint + mock_model.label_info = fxt_engine.datamodule.label_info + fxt_engine.test(checkpoint=checkpoint) + mock_test.assert_called_once() - def test_prediction_after_training(self, fxt_engine, mocker) -> None: - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") + mock_model.label_info = NullLabelInfo() + # Incorrect label_info from the checkpoint + with pytest.raises( + ValueError, + match="To launch a test pipeline, the label information should be same (.*)", + ): + fxt_engine.test(checkpoint=checkpoint) + + @pytest.mark.parametrize("explain", [True, False]) + @pytest.mark.parametrize( + "checkpoint", + [ + "path/to/checkpoint.ckpt", + "path/to/checkpoint.xml", + ], + ) + def test_predict(self, fxt_engine, checkpoint, explain, mocker: MockerFixture) -> None: mock_predict = mocker.patch("otx.engine.engine.Trainer.predict") - mock_torch_load = mocker.patch("torch.load") + _ = mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") + mock_get_ov_model = mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") + mock_load_from_checkpoint = mocker.patch.object(fxt_engine.model.__class__, "load_from_checkpoint") + mock_process_saliency_maps = mocker.patch("otx.algo.utils.xai_utils.process_saliency_maps_in_pred_entity") - # Fetch Checkpoint - fxt_engine.checkpoint = "path/to/checkpoint" - fxt_engine.predict() - mock_torch_load.assert_called_once_with("path/to/checkpoint") - mock_predict.assert_called_once() + ext = Path(checkpoint).suffix - fxt_engine.predict(checkpoint="path/to/new/checkpoint") - mock_torch_load.assert_called_with("path/to/new/checkpoint") + if ext == ".ckpt": + mock_model = mocker.create_autospec(OTXModel) - fxt_engine.model = create_autospec(OVModel) - fxt_engine.predict(checkpoint="path/to/model.xml") + mock_load_from_checkpoint.return_value = mock_model + else: + mock_model = mocker.create_autospec(OVModel) - def test_prediction_with_ov_model(self, fxt_engine, mocker) -> None: - mock_predict = mocker.patch("otx.engine.engine.Trainer.predict") - mock_torch_load = mocker.patch("torch.load") - mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") - mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") + mock_get_ov_model.return_value = mock_model - fxt_engine.predict(checkpoint="path/to/model.xml") + # Correct label_info from the checkpoint + mock_model.label_info = fxt_engine.datamodule.label_info + fxt_engine.predict(checkpoint=checkpoint, explain=explain) mock_predict.assert_called_once() - mock_torch_load.assert_not_called() - - def test_prediction_explain_mode(self, fxt_engine, mocker) -> None: - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") - mock_explain = mocker.patch("otx.algo.utils.xai_utils.process_saliency_maps_in_pred_entity") - mock_predict = mocker.patch("otx.engine.engine.Trainer.predict") - mock_torch_load = mocker.patch("torch.load") + assert mock_process_saliency_maps.called == explain - # Fetch Checkpoint - fxt_engine.checkpoint = "path/to/checkpoint" - fxt_engine.predict(explain=True) - mock_torch_load.assert_called_once_with("path/to/checkpoint") - mock_explain.assert_called_once() - mock_predict.assert_called_once() + mock_model.label_info = NullLabelInfo() + # Incorrect label_info from the checkpoint + with pytest.raises( + ValueError, + match="To launch a predict pipeline, the label information should be same (.*)", + ): + fxt_engine.predict(checkpoint=checkpoint) def test_exporting(self, fxt_engine, mocker) -> None: with pytest.raises(RuntimeError, match="To make export, checkpoint must be specified."): fxt_engine.export() - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") - mocker.patch("otx.engine.engine.OTXModel.label_info") mock_export = mocker.patch("otx.engine.engine.OTXModel.export") - mock_torch_load = mocker.patch("torch.load") + + mock_load_from_checkpoint = mocker.patch.object(fxt_engine.model.__class__, "load_from_checkpoint") + mock_load_from_checkpoint.return_value = fxt_engine.model # Fetch Checkpoint - fxt_engine.checkpoint = "path/to/checkpoint" + checkpoint = "path/to/checkpoint.ckpt" + fxt_engine.checkpoint = checkpoint fxt_engine.export() - mock_torch_load.assert_called_once_with("path/to/checkpoint") + mock_load_from_checkpoint.assert_called_once_with(checkpoint_path=checkpoint, map_location="cpu") mock_export.assert_called_once_with( output_dir=Path(fxt_engine.work_dir), base_name="exported_model", @@ -242,32 +263,47 @@ def test_optimizing_model(self, fxt_engine, mocker) -> None: fxt_engine.optimize(export_demo_package=True) mocker_export.assert_called_once() - def test_explain(self, fxt_engine, mocker) -> None: - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") - mock_process_explain = mocker.patch("otx.algo.utils.xai_utils.process_saliency_maps_in_pred_entity") - - mock_torch_load = mocker.patch("torch.load") + @pytest.mark.parametrize("dump", [True, False]) + @pytest.mark.parametrize( + "checkpoint", + [ + "path/to/checkpoint.ckpt", + "path/to/checkpoint.xml", + ], + ) + def test_explain(self, fxt_engine, checkpoint, dump, mocker) -> None: mock_predict = mocker.patch("otx.engine.engine.Trainer.predict") + _ = mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") + mock_get_ov_model = mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") + mock_load_from_checkpoint = mocker.patch.object(fxt_engine.model.__class__, "load_from_checkpoint") + mock_process_saliency_maps = mocker.patch("otx.algo.utils.xai_utils.process_saliency_maps_in_pred_entity") + mock_dump_saliency_maps = mocker.patch("otx.algo.utils.xai_utils.dump_saliency_maps") - fxt_engine.explain(checkpoint="path/to/checkpoint") - mock_torch_load.assert_called_once_with("path/to/checkpoint") - mock_predict.assert_called_once() - mock_process_explain.assert_called_once() + ext = Path(checkpoint).suffix - mock_dump_saliency_maps = mocker.patch("otx.algo.utils.xai_utils.dump_saliency_maps") - fxt_engine.explain(checkpoint="path/to/checkpoint", dump=True) - mock_torch_load.assert_called_with("path/to/checkpoint") - mock_predict.assert_called() - mock_process_explain.assert_called() - mock_dump_saliency_maps.assert_called_once() + if ext == ".ckpt": + mock_model = mocker.create_autospec(OTXModel) - mock_ov_pipeline = mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") - mock_ov_model = mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") - fxt_engine.explain(checkpoint="path/to/model.xml") - mock_predict.assert_called() - mock_process_explain.assert_called() - mock_ov_model.assert_called_once() - mock_ov_pipeline.assert_called_once() + mock_load_from_checkpoint.return_value = mock_model + else: + mock_model = mocker.create_autospec(OVModel) + + mock_get_ov_model.return_value = mock_model + + # Correct label_info from the checkpoint + mock_model.label_info = fxt_engine.datamodule.label_info + fxt_engine.explain(checkpoint=checkpoint, dump=dump) + mock_predict.assert_called_once() + mock_process_saliency_maps.assert_called_once() + assert mock_dump_saliency_maps.called == dump + + mock_model.label_info = NullLabelInfo() + # Incorrect label_info from the checkpoint + with pytest.raises( + ValueError, + match="To launch a explain pipeline, the label information should be same (.*)", + ): + fxt_engine.explain(checkpoint=checkpoint) def test_from_config_with_model_name(self, tmp_path) -> None: model_name = "efficientnet_b0_light" diff --git a/tests/unit/hpo/test_hyperband.py b/tests/unit/hpo/test_hyperband.py index 971f319005b..ec89817bd51 100644 --- a/tests/unit/hpo/test_hyperband.py +++ b/tests/unit/hpo/test_hyperband.py @@ -675,11 +675,6 @@ def test_get_best_config_before_train(self, hyper_band): best_config = hyper_band.get_best_config() assert best_config is None - def test_train_option_exists(self, hyper_band): - trial = hyper_band.get_next_sample() - train_config = trial.get_train_configuration() - assert "subset_ratio" in train_config["train_environment"] - def test_prior_hyper_parameters(self, good_hyperband_args): prior1 = {"hp1": 1, "hp2": 2} prior2 = {"hp1": 100, "hp2": 200} diff --git a/tests/unit/hpo/test_resource_manager.py b/tests/unit/hpo/test_resource_manager.py index f1688931750..298e80fc5dd 100644 --- a/tests/unit/hpo/test_resource_manager.py +++ b/tests/unit/hpo/test_resource_manager.py @@ -14,7 +14,7 @@ def cpu_resource_manager(): @pytest.fixture() def gpu_resource_manager(): - return GPUResourceManager(num_gpu_for_single_trial=1, available_gpu="0,1,2,3") + return GPUResourceManager(num_gpu_for_single_trial=1, num_parallel_trial=4) class TestCPUResourceManager: @@ -66,25 +66,25 @@ def setupt_test(self, mocker): mock_torch_cuda.device_count.return_value = 4 def test_init(self): - GPUResourceManager(num_gpu_for_single_trial=1, available_gpu="0,1,2") + GPUResourceManager(num_gpu_for_single_trial=1, num_parallel_trial=3) @pytest.mark.parametrize("num_gpu_for_single_trial", [-1, 0]) def test_init_not_positive_num_gpu(self, num_gpu_for_single_trial): with pytest.raises(ValueError): # noqa: PT011 GPUResourceManager(num_gpu_for_single_trial=num_gpu_for_single_trial) - @pytest.mark.parametrize("available_gpu", [",", "a,b", "0,a", ""]) - def test_init_wrong_available_gpu_value(self, available_gpu): + @pytest.mark.parametrize("num_parallel_trial", [-1, 0]) + def test_init_wrong_available_gpu_value(self, num_parallel_trial): with pytest.raises(ValueError): # noqa: PT011 - GPUResourceManager(available_gpu=available_gpu) + GPUResourceManager(num_parallel_trial=num_parallel_trial) def test_reserve_resource(self): num_gpu_for_single_trial = 2 gpu_resource_manager = GPUResourceManager( num_gpu_for_single_trial=num_gpu_for_single_trial, - available_gpu=",".join([str(val) for val in range(8)]), + num_parallel_trial=8, ) - num_gpus = len(gpu_resource_manager._available_gpu) + num_gpus = 4 max_parallel = num_gpus // num_gpu_for_single_trial for i in range(max_parallel): @@ -112,9 +112,9 @@ def test_have_available_resource(self): num_gpu_for_single_trial = 2 gpu_resource_manager = GPUResourceManager( num_gpu_for_single_trial=num_gpu_for_single_trial, - available_gpu=",".join([str(val) for val in range(8)]), + num_parallel_trial=8, ) - num_gpus = len(gpu_resource_manager._available_gpu) + num_gpus = 4 max_parallel = num_gpus // num_gpu_for_single_trial for i in range(max_parallel): @@ -133,11 +133,11 @@ def test_get_resource_manager_cpu(): def test_get_resource_manager_gpu(mocker): mocker.patch("otx.hpo.resource_manager.torch.cuda.is_available", return_value=True) num_gpu_for_single_trial = 1 - available_gpu = "0,1,2,3" + num_parallel_trial = 4 manager = get_resource_manager( resource_type="gpu", num_gpu_for_single_trial=num_gpu_for_single_trial, - available_gpu=available_gpu, + num_parallel_trial=num_parallel_trial, ) assert isinstance(manager, GPUResourceManager)