From bc887edc6adcea533e4029cfe89a232bbbd48ca3 Mon Sep 17 00:00:00 2001 From: Vinnam Kim Date: Tue, 9 Apr 2024 10:15:15 +0900 Subject: [PATCH] Refactor to implant optimizer and scheduler into model code (#3258) * Refactor to implant optimizer and scheduler into model code - Add OptimizerCallableSupportHPO and PicklableLRSchedulerCallable for HPO - Implant optimizer and scheduler into efficientnet_b0 code as a testbed Signed-off-by: Kim, Vinnam * Upgrade jsonargparse to 4.27.7 Signed-off-by: Kim, Vinnam * Fix test installation error Signed-off-by: Kim, Vinnam * Fix test errors Signed-off-by: Kim, Vinnam --------- Signed-off-by: Kim, Vinnam --- pyproject.toml | 3 + .../algo/classification/efficientnet_b0.py | 45 ++++- src/otx/cli/install.py | 13 ++ src/otx/core/model/base.py | 16 +- src/otx/core/optimizer/__init__.py | 7 + src/otx/core/optimizer/callable.py | 161 ++++++++++++++++++ src/otx/core/schedulers/__init__.py | 20 ++- src/otx/core/utils/jsonargparse.py | 99 +++++++++++ src/otx/engine/hpo/hpo_api.py | 20 +-- tests/integration/api/test_engine_api.py | 40 +++++ tests/integration/cli/test_cli.py | 71 ++------ tests/unit/cli/test_install.py | 17 +- tests/unit/core/optimizer/__init__.py | 2 + tests/unit/core/optimizer/test_callable.py | 147 ++++++++++++++++ 14 files changed, 584 insertions(+), 77 deletions(-) create mode 100644 src/otx/core/optimizer/__init__.py create mode 100644 src/otx/core/optimizer/callable.py create mode 100644 src/otx/core/utils/jsonargparse.py create mode 100644 tests/unit/core/optimizer/__init__.py create mode 100644 tests/unit/core/optimizer/test_callable.py diff --git a/pyproject.toml b/pyproject.toml index d0ca8d6ad19..e5027719a66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -274,6 +274,9 @@ exclude = [ # Ruff complains it but don't know how to fix since it literally showed no useful logs. # https://github.com/openvinotoolkit/training_extensions/actions/runs/7176557723/job/19541622452?pr=2718#step:5:170 "tests/regression/*.py", + + # Mostly borrowed from jsonargparse codebase + "src/otx/core/utils/jsonargparse.py" ] # Same as Black. diff --git a/src/otx/algo/classification/efficientnet_b0.py b/src/otx/algo/classification/efficientnet_b0.py index c0c4ba212c4..488149dce24 100644 --- a/src/otx/algo/classification/efficientnet_b0.py +++ b/src/otx/algo/classification/efficientnet_b0.py @@ -6,10 +6,12 @@ from typing import TYPE_CHECKING +import torch +from lightning.pytorch.cli import ReduceLROnPlateau + from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable -from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.classification import ( MMPretrainHlabelClsModel, MMPretrainMulticlassClsModel, @@ -31,8 +33,17 @@ class EfficientNetB0ForHLabelCls(ExplainableMixInMMPretrainModel, MMPretrainHlab def __init__( self, hlabel_info: HLabelInfo, - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + optimizer: OptimizerCallable = lambda params: torch.optim.SGD( + params=params, + lr=0.0049, + ), + scheduler: LRSchedulerCallable | LRSchedulerListCallable = lambda optimizer: ReduceLROnPlateau( + optimizer, + mode="max", + factor=0.1, + patience=1, + monitor="val/accuracy", + ), metric: MetricCallable = HLabelClsMetricCallble, torch_compile: bool = False, ) -> None: @@ -59,8 +70,19 @@ def __init__( self, num_classes: int, light: bool = False, - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + optimizer: OptimizerCallable = lambda params: torch.optim.SGD( + params=params, + lr=0.0049, + momentum=0.9, + weight_decay=0.0001, + ), + scheduler: LRSchedulerCallable | LRSchedulerListCallable = lambda optimizer: ReduceLROnPlateau( + optimizer, + mode="max", + factor=0.1, + patience=1, + monitor="val/accuracy", + ), metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False, ) -> None: @@ -86,8 +108,17 @@ class EfficientNetB0ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrain def __init__( self, num_classes: int, - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + optimizer: OptimizerCallable = lambda params: torch.optim.SGD( + params=params, + lr=0.0049, + ), + scheduler: LRSchedulerCallable | LRSchedulerListCallable = lambda optimizer: ReduceLROnPlateau( + optimizer, + mode="max", + factor=0.1, + patience=1, + monitor="val/accuracy", + ), metric: MetricCallable = MultiLabelClsMetricCallable, torch_compile: bool = False, ) -> None: diff --git a/src/otx/cli/install.py b/src/otx/cli/install.py index 37523539b87..7d0fd49d45f 100644 --- a/src/otx/cli/install.py +++ b/src/otx/cli/install.py @@ -122,6 +122,9 @@ def otx_install(option: str | None = None, verbose: bool = False, do_not_install status_code = create_command("install").main(install_args) if status_code == 0: console.log(f"Installation Complete: {install_args}") + else: + msg = "Cannot complete installation" + raise RuntimeError(msg) # https://github.com/Madoshakalaka/pipenv-setup/issues/101 os.environ["SETUPTOOLS_USE_DISTUTILS"] = "stdlib" @@ -132,6 +135,16 @@ def otx_install(option: str | None = None, verbose: bool = False, do_not_install status_code = mim_installation(mmcv_install_args) if status_code == 0: console.log(f"MMLab Installation Complete: {mmcv_install_args}") + else: + msg = "Cannot complete installation" + raise RuntimeError(msg) + + # TODO(harimkang): Remove this reinstalling after resolving conflict with anomalib==1.0.1 + # https://github.com/openvinotoolkit/training_extensions/actions/runs/8531851027/job/23372146228?pr=3258#step:5:2587 + status_code = create_command("install").main(["jsonargparse==4.27.7"]) + if status_code != 0: + msg = "Cannot install jsonargparse==4.27.7" + raise RuntimeError(msg) # Patch MMAction2 with src/otx/cli/patches/mmaction2.patch patch_mmaction2() diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index d6792052bde..8084d7aa85e 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -33,7 +33,8 @@ from otx.core.data.entity.tile import OTXTileBatchDataEntity, T_OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.metrics import MetricInput, NullMetricCallable -from otx.core.schedulers import LRSchedulerListCallable +from otx.core.optimizer.callable import OptimizerCallableSupportHPO +from otx.core.schedulers import LRSchedulerListCallable, PicklableLRSchedulerCallable from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler from otx.core.types.export import OTXExportFormatType from otx.core.types.label import LabelInfo, NullLabelInfo @@ -675,6 +676,19 @@ def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Tensor) -> return super().lr_scheduler_step(scheduler=scheduler, metric=metric) + def patch_optimizer_and_scheduler_for_hpo(self) -> None: + """Patch optimizer and scheduler for hyperparameter optimization. + + This is inplace function changing inner states (`optimizer_callable` and `scheduler_callable`). + Both will be changed to be picklable. In addition, `optimizer_callable` is changed + to make its hyperparameters gettable. + """ + if not isinstance(self.optimizer_callable, OptimizerCallableSupportHPO): + self.optimizer_callable = OptimizerCallableSupportHPO.from_callable(self.optimizer_callable) + + if not isinstance(self.scheduler_callable, PicklableLRSchedulerCallable): + self.scheduler_callable = PicklableLRSchedulerCallable(self.scheduler_callable) + class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]): """Base class for the OpenVINO model. diff --git a/src/otx/core/optimizer/__init__.py b/src/otx/core/optimizer/__init__.py new file mode 100644 index 00000000000..ab8a3ee8472 --- /dev/null +++ b/src/otx/core/optimizer/__init__.py @@ -0,0 +1,7 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Modules related to an optimizer.""" + +from otx.core.optimizer.callable import OptimizerCallableSupportHPO + +__all__ = ["OptimizerCallableSupportHPO"] diff --git a/src/otx/core/optimizer/callable.py b/src/otx/core/optimizer/callable.py new file mode 100644 index 00000000000..111e31203c0 --- /dev/null +++ b/src/otx/core/optimizer/callable.py @@ -0,0 +1,161 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Optimizer callable to support hyper-parameter optimization (HPO) algorithm.""" + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING, Any, Sequence + +from torch import nn +from torch.optim.optimizer import Optimizer + +from otx.core.utils.jsonargparse import ClassType, lazy_instance + +if TYPE_CHECKING: + from lightning.pytorch.cli import OptimizerCallable + from torch.optim.optimizer import params_t + + +class OptimizerCallableSupportHPO: + """Optimizer callable supports OTX hyper-parameter optimization (HPO) algorithm. + + Args: + optimizer_cls: Optimizer class type or string class import path. See examples for details. + optimizer_kwargs: Keyword arguments used for the initialization of the given `optimizer_cls`. + search_hparams: Sequence of optimizer hyperparameter names which can be tuned by the OTX HPO algorithm. + + Examples: + This is an example to create `MobileNetV3ForMulticlassCls` with a `SGD` optimizer and + custom configurations. + + ```python + from torch.optim import SGD + from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls + + model = MobileNetV3ForMulticlassCls( + num_classes=3, + optimizer=OptimizerCallableSupportHPO( + optimizer_cls=SGD, + optimizer_kwargs={ + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4, + }, + ), + ) + ``` + + It can be created from the string class import path such as + + ```python + from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls + + model = MobileNetV3ForMulticlassCls( + num_classes=3, + optimizer=OptimizerCallableSupportHPO( + optimizer_cls="torch.optim.SGD", + optimizer_kwargs={ + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4, + }, + ), + ) + ``` + """ + + def __init__( + self, + optimizer_cls: type[Optimizer] | str, + optimizer_kwargs: dict[str, int | float | bool], + search_hparams: Sequence[str] = ("lr",), + ): + if isinstance(optimizer_cls, str): + splited = optimizer_cls.split(".") + module_path, class_name = ".".join(splited[:-1]), splited[-1] + module = importlib.import_module(module_path) + + self.optimizer_init: type[Optimizer] = getattr(module, class_name) + self.optimizer_path = optimizer_cls + elif issubclass(optimizer_cls, Optimizer): + self.optimizer_init = optimizer_cls + self.optimizer_path = optimizer_cls.__module__ + "." + optimizer_cls.__qualname__ + else: + raise TypeError(optimizer_cls) + + for search_hparam in search_hparams: + if search_hparam not in optimizer_kwargs: + msg = ( + f"Search hyperparamter={search_hparam} should be existed in " + f"optimizer keyword arguments={optimizer_kwargs} as well." + ) + raise ValueError(msg) + + self.search_hparams = list(search_hparams) + self.optimizer_kwargs = optimizer_kwargs + self.__dict__.update(optimizer_kwargs) + + def __call__(self, params: params_t) -> Optimizer: + """Create `torch.optim.Optimizer` instance for the given parameters.""" + return self.optimizer_init(params, **self.optimizer_kwargs) + + def to_lazy_instance(self) -> ClassType: + """Return lazy instance of this class. + + Because OTX is rely on jsonargparse library, + the default value of class initialization + argument should be the lazy instance. + Please refer to https://jsonargparse.readthedocs.io/en/stable/#default-values + for more details. + + Examples: + This is an example to implement a new model with a `SGD` optimizer and + custom configurations as a default. + + ```python + class MyAwesomeMulticlassClsModel(OTXMulticlassClsModel): + def __init__( + self, + num_classes: int, + optimizer: OptimizerCallable = OptimizerCallableSupportHPO( + optimizer_cls=SGD, + optimizer_kwargs={ + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4, + }, + ).to_lazy_instance(), + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MultiClassClsMetricCallable, + torch_compile: bool = False, + ) -> None: + ... + ``` + """ + return lazy_instance( + OptimizerCallableSupportHPO, + optimizer_cls=self.optimizer_path, + optimizer_kwargs=self.optimizer_kwargs, + search_hparams=self.search_hparams, + ) + + def __reduce__(self) -> str | tuple[Any, ...]: + return self.__class__, ( + self.optimizer_path, + self.optimizer_kwargs, + self.search_hparams, + ) + + @classmethod + def from_callable(cls, func: OptimizerCallable) -> OptimizerCallableSupportHPO: + """Create this class instance from an existing optimizer callable.""" + dummy_params = [nn.Parameter()] + optimizer = func(dummy_params) + + param_group = next(iter(optimizer.param_groups)) + + return OptimizerCallableSupportHPO( + optimizer_cls=optimizer.__class__, + optimizer_kwargs={key: value for key, value in param_group.items() if key != "params"}, + ) diff --git a/src/otx/core/schedulers/__init__.py b/src/otx/core/schedulers/__init__.py index 0b54994cd6c..7f9bec87add 100644 --- a/src/otx/core/schedulers/__init__.py +++ b/src/otx/core/schedulers/__init__.py @@ -5,14 +5,18 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable +import dill from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER from lightning.pytorch.cli import ReduceLROnPlateau from torch.optim.optimizer import Optimizer from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler, LinearWarmupSchedulerCallable +if TYPE_CHECKING: + from lightning.pytorch.cli import LRSchedulerCallable + __all__ = [ "LRSchedulerListCallable", "LinearWarmupScheduler", @@ -21,3 +25,17 @@ LRSchedulerListCallable = Callable[[Optimizer], list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]] + + +class PicklableLRSchedulerCallable: + """It converts unpicklable lr scheduler callable such as lambda function to picklable.""" + + def __init__(self, scheduler_callable: LRSchedulerCallable | LRSchedulerListCallable): + self.dumped_scheduler_callable = dill.dumps(scheduler_callable) + + def __call__( + self, + optimizer: Optimizer, + ) -> _TORCH_LRSCHEDULER | ReduceLROnPlateau | list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]: + scheduler_callable = dill.loads(self.dumped_scheduler_callable) # noqa: S301 + return scheduler_callable(optimizer) diff --git a/src/otx/core/utils/jsonargparse.py b/src/otx/core/utils/jsonargparse.py new file mode 100644 index 00000000000..334a2f58298 --- /dev/null +++ b/src/otx/core/utils/jsonargparse.py @@ -0,0 +1,99 @@ +# The MIT License (MIT) + +# Copyright (c) 2024 Intel Corporation +# Copyright (c) 2019-2024, Mauricio Villegas + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# SPDX-License-Identifier: MIT +# +# NOTE: This code contains adaptations from code originally implemented in the +# jsonargparse project (https://github.com/omni-us/jsonargparse), which is +# licensed under the MIT License. However, this specific codebase is licensed +# under the Apache License 2.0. Please note this difference in licensing. +# +"""Jsonargparse functions to adapt OTX project.""" + +from functools import partial +import inspect +from typing import Any, Type +from jsonargparse._common import is_subclass +from jsonargparse._typehints import LazyInitBaseClass as _LazyInitBaseClass +from jsonargparse._util import ClassType + +__all__ = ["lazy_instance", "ClassType"] + + +class LazyInitBaseClass(_LazyInitBaseClass): + """Modifed LazyInitBaseClass to support callable classes. + + See https://github.com/omni-us/jsonargparse/issues/481 for more details. + """ + + def __init__(self, class_type: Type, lazy_kwargs: dict): + self.__pickle_slot__ = { + "class_type": class_type, + "lazy_kwargs": lazy_kwargs, + } + + self._lazy_call_method = None + + for name, member in inspect.getmembers(class_type, predicate=inspect.isfunction): + if name == "__call__": + self._lazy_call_method = partial(member, self) + + super().__init__(class_type=class_type, lazy_kwargs=lazy_kwargs) + + def __call__(self, *args, **kwargs): + if self._lazy_call_method is None: + return None + + self._lazy_init() + return self._lazy_call_method(*args, **kwargs) + + def __reduce__(self) -> str | tuple[Any, ...]: + return self.__class__, tuple(self.__pickle_slot__.values()) + + +def lazy_instance(class_type: Type[ClassType], **kwargs) -> ClassType: + """Instantiates a lazy instance of the given type. + + By lazy it is meant that the __init__ is delayed unit the first time that a + method of the instance is called. It also provides a `lazy_get_init_data` method + useful for serializing. + + Args: + class_type: The class to instantiate. + **kwargs: Any keyword arguments to use for instantiation. + """ + caller_module = inspect.getmodule(inspect.stack()[1][0]) + class_name = f"LazyInstance_{class_type.__name__}" + if hasattr(caller_module, class_name): + lazy_init_class = getattr(caller_module, class_name) + assert is_subclass(lazy_init_class, LazyInitBaseClass) and is_subclass(lazy_init_class, class_type) + else: + lazy_init_class = type( + class_name, + (LazyInitBaseClass, class_type), + {"__doc__": f"Class for lazy instances of {class_type}"}, + ) + if caller_module is not None: + lazy_init_class.__module__ = getattr(caller_module, "__name__", __name__) + setattr(caller_module, lazy_init_class.__qualname__, lazy_init_class) + return lazy_init_class(class_type, kwargs) diff --git a/src/otx/engine/hpo/hpo_api.py b/src/otx/engine/hpo/hpo_api.py index fb5a43184e3..60a8a0e637c 100644 --- a/src/otx/engine/hpo/hpo_api.py +++ b/src/otx/engine/hpo/hpo_api.py @@ -16,6 +16,7 @@ import torch from otx.core.config.hpo import HpoConfig +from otx.core.optimizer.callable import OptimizerCallableSupportHPO from otx.core.types.task import OTXTaskType from otx.hpo import HyperBand, run_hpo_loop from otx.utils.utils import get_decimal_point, get_using_dot_delimited_key, remove_matched_files @@ -63,6 +64,8 @@ def execute_hpo( logger.warning("Zero shot visual prompting task doesn't support HPO.") return None, None + engine.model.patch_optimizer_and_scheduler_for_hpo() + hpo_workdir = Path(engine.work_dir) / "hpo" hpo_workdir.mkdir(exist_ok=True) hpo_configurator = HPOConfigurator( @@ -175,12 +178,9 @@ def _get_default_search_space(self) -> dict[str, Any]: """Set learning rate and batch size as search space.""" search_space = {} - optimizer_conf = self._engine.model.optimizer_callable - - if not callable(optimizer_conf): - raise TypeError(optimizer_conf) - - search_space["model.optimizer_callable.keywords.lr"] = self._make_lr_search_space(optimizer_conf) + search_space["model.optimizer_callable.optimizer_kwargs.lr"] = self._make_lr_search_space( + self._engine.model.optimizer_callable, + ) cur_bs = self._engine.datamodule.config.train_subset.batch_size search_space["datamodule.config.train_subset.batch_size"] = { @@ -194,10 +194,10 @@ def _get_default_search_space(self) -> dict[str, Any]: @staticmethod def _make_lr_search_space(optimizer_callable: OptimizerCallable) -> dict[str, Any]: - params = [torch.nn.Parameter(torch.zeros([0]))] - optimizer = optimizer_callable(params) - param_group = next(iter(optimizer.param_groups)) - cur_lr = param_group["lr"] # type: ignore[union-attr] + if not isinstance(optimizer_callable, OptimizerCallableSupportHPO): + raise TypeError(optimizer_callable) + + cur_lr = optimizer_callable.lr # type: ignore[attr-defined] min_lr = cur_lr / 10 return { "type": "qloguniform", diff --git a/tests/integration/api/test_engine_api.py b/tests/integration/api/test_engine_api.py index 446cb500d3d..d7af44d4db6 100644 --- a/tests/integration/api/test_engine_api.py +++ b/tests/integration/api/test_engine_api.py @@ -7,6 +7,7 @@ import pytest from openvino.model_api.tilers import Tiler +from otx.algo.classification.efficientnet_b0 import EfficientNetB0ForMulticlassCls from otx.core.data.module import OTXDataModule from otx.core.model.base import OTXModel from otx.core.types.task import OTXTaskType @@ -153,3 +154,42 @@ def test_engine_from_tile_recipe( assert isinstance(ov_model.model, Tiler), "Model should be an instance of Tiler" assert engine.datamodule.config.tile_config.tile_size[0] == ov_model.model.tile_size assert engine.datamodule.config.tile_config.overlap == ov_model.model.tiles_overlap + + +REASON = """ +Traceback (most recent call last): + File "/home/vinnamki/miniconda3/envs/otx-v2/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap + self.run() + File "/home/vinnamki/miniconda3/envs/otx-v2/lib/python3.11/multiprocessing/process.py", line 108, in run + self._target(*self._args, **self._kwargs) + File "/home/vinnamki/otx/training_extensions/src/otx/hpo/hpo_runner.py", line 200, in _run_train + train_func(hp_config, report_func) + File "/home/vinnamki/otx/training_extensions/src/otx/engine/hpo/hpo_trial.py", line 75, in run_hpo_trial + callbacks = _register_hpo_callback(report_func, callbacks) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/vinnamki/otx/training_extensions/src/otx/engine/hpo/hpo_trial.py", line 101, in _register_hpo_callback + callbacks.append(HPOCallback(report_func, _get_metric(callbacks))) + ^^^^^^^^^^^^^^^^^^^^^^ + File "/home/vinnamki/otx/training_extensions/src/otx/engine/hpo/hpo_trial.py", line 110, in _get_metric + raise RuntimeError(error_msg) +RuntimeError: Failed to find a metric. There is no ModelCheckpoint in callback list. +""" + + +@pytest.mark.parametrize("task", pytest.TASK_LIST) +def test_otx_hpo( + task: OTXTaskType, + tmp_path: Path, + fxt_target_dataset_per_task: dict, +) -> None: + pytest.xfail(reason=REASON) + + model = EfficientNetB0ForMulticlassCls(num_classes=3) + work_dir = str(tmp_path) + engine = Engine( + data_root=fxt_target_dataset_per_task[task.lower()], + task=task, + work_dir=work_dir, + model=model, + ) + engine.train(run_hpo=True) diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index a5d7a63f9fa..bfb17020f0f 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -437,59 +437,16 @@ def test_otx_ov_test( REASON = ''' -tests/integration/cli/test_cli.py:507: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -tests/utils.py:18: in run_main - _run_main(command_cfg) -tests/utils.py:37: in _run_main - main() -src/otx/cli/__init__.py:17: in main - OTXCLI() -src/otx/cli/cli.py:59: in __init__ - self.run() -src/otx/cli/cli.py:521: in run - fn(**fn_kwargs) -src/otx/engine/engine.py:234: in train - best_config, best_trial_weight = execute_hpo(engine=self, **locals()) -src/otx/engine/hpo/hpo_api.py:67: in execute_hpo - hpo_configurator = HPOConfigurator( -src/otx/engine/hpo/hpo_api.py:127: in __init__ - self.hpo_config: dict[str, Any] = hpo_config # type: ignore[assignment] -src/otx/engine/hpo/hpo_api.py:168: in hpo_config - self._hpo_config["prior_hyper_parameters"] = { -src/otx/engine/hpo/hpo_api.py:169: in - hp: get_using_dot_delimited_key(hp, self._engine) -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -key = 'model.optimizer_callable.keywords.lr' -target = .partial_instance at 0x71faee3b9480> - - def get_using_dot_delimited_key(key: str, target: Any) -> Any: # noqa: ANN401 - """Get values of attribute in target object using dot delimited key. - - For example, if key is "a.b.c", then get a value of 'target.a.b.c'. - Target should be object having attributes, dictionary or list. - To get an element in a list, an integer that is the index of corresponding value can be set as a key. - - Args: - key (str): dot delimited key. - val (Any): value to set. - target (Any): target to set value to. - """ - splited_key = key.split(".") - for each_key in splited_key: - if isinstance(target, dict): - target = target[each_key] - elif isinstance(target, list): - if not each_key.isdigit(): - error_msg = f"Key should be integer but '{each_key}'." - raise ValueError(error_msg) - target = target[int(each_key)] - else: -> target = getattr(target, each_key) -E AttributeError: 'function' object has no attribute 'keywords' - -src/otx/utils/utils.py:37: AttributeError +self = + + def finalize(self) -> None: + """Set done as True.""" + if not self.score: + error_msg = f"Trial{self.id} didn't report any score but tries to be done." +> raise RuntimeError(error_msg) +E RuntimeError: Trial0 didn't report any score but tries to be done. + +src/otx/hpo/hpo_base.py:274: RuntimeError ''' @@ -514,8 +471,12 @@ def test_otx_hpo_e2e( """ if task not in DEFAULT_CONFIG_PER_TASK: pytest.skip(f"Task {task} is not supported in the auto-configuration.") - - pytest.xfail(reason=REASON) + if task in { + OTXTaskType.ANOMALY_CLASSIFICATION, + OTXTaskType.ANOMALY_DETECTION, + OTXTaskType.ANOMALY_SEGMENTATION, + }: + pytest.xfail(reason=REASON) task = task.lower() tmp_path_hpo = tmp_path / f"otx_hpo_{task}" diff --git a/tests/unit/cli/test_install.py b/tests/unit/cli/test_install.py index a55c587a8a9..0e860019a61 100644 --- a/tests/unit/cli/test_install.py +++ b/tests/unit/cli/test_install.py @@ -1,6 +1,7 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + import pytest from _pytest.monkeypatch import MonkeyPatch from jsonargparse import ArgumentParser @@ -36,7 +37,12 @@ def test_install_extra(self, mocker: MockerFixture) -> None: mock_create_command.return_value.main.return_value = 0 status_code = otx_install(option="dev") assert status_code == mock_create_command.return_value.main.return_value - argument_call_list = mock_create_command.return_value.main.call_args_list[-1][0][-1] + + argument_call_list = [] + for call_args in mock_create_command.return_value.main.call_args_list: + for arg in call_args.args: + argument_call_list += arg + assert "pytorchcv" in argument_call_list assert "openvino" not in argument_call_list assert "anomalib" not in argument_call_list @@ -49,8 +55,13 @@ def test_install_full(self, mocker: MockerFixture, monkeypatch: MonkeyPatch) -> status_code = otx_install("full") assert status_code == mock_create_command.return_value.main.return_value - mock_create_command.assert_called_once_with("install") - argument_call_list = mock_create_command.return_value.main.call_args_list[-1][0][-1] + mock_create_command.assert_called_with("install") + + argument_call_list = [] + for call_args in mock_create_command.return_value.main.call_args_list: + for arg in call_args.args: + argument_call_list += arg + assert "openvino" in argument_call_list assert "pytorchcv" in argument_call_list assert "anomalib" in argument_call_list diff --git a/tests/unit/core/optimizer/__init__.py b/tests/unit/core/optimizer/__init__.py new file mode 100644 index 00000000000..916f3a44b27 --- /dev/null +++ b/tests/unit/core/optimizer/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/core/optimizer/test_callable.py b/tests/unit/core/optimizer/test_callable.py new file mode 100644 index 00000000000..dc20ef3fdd6 --- /dev/null +++ b/tests/unit/core/optimizer/test_callable.py @@ -0,0 +1,147 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import pickle + +import pytest +from otx.core.metrics import NullMetricCallable +from otx.core.model.base import DefaultSchedulerCallable, OTXModel +from otx.core.optimizer import OptimizerCallableSupportHPO +from torch import nn +from torch.optim import SGD + + +class TestOptimizerCallableSupportHPO: + @pytest.fixture() + def fxt_params(self): + model = nn.Linear(10, 10) + return model.parameters() + + @pytest.fixture(params=["torch.optim.SGD", SGD]) + def fxt_optimizer_cls(self, request): + return request.param + + @pytest.fixture() + def fxt_invaliid_optimizer_cls(self): + class NotOptimizer: + pass + + return NotOptimizer + + def test_succeed(self, fxt_optimizer_cls, fxt_params): + optimizer_callable = OptimizerCallableSupportHPO( + optimizer_cls=fxt_optimizer_cls, + optimizer_kwargs={ + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4, + }, + ) + optimizer = optimizer_callable(fxt_params) + + assert isinstance(optimizer, SGD) + + assert optimizer_callable.lr == 0.1 + assert optimizer_callable.momentum == 0.9 + assert optimizer_callable.weight_decay == 1e-4 + + assert all(param["lr"] == 0.1 for param in fxt_params) + assert all(param["momentum"] == 0.9 for param in fxt_params) + assert all(param["weight_decay"] == 1e-4 for param in fxt_params) + + def test_failure(self, fxt_invaliid_optimizer_cls, fxt_optimizer_cls): + with pytest.raises(TypeError): + OptimizerCallableSupportHPO( + optimizer_cls=fxt_invaliid_optimizer_cls, + optimizer_kwargs={ + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4, + }, + ) + + with pytest.raises( + ValueError, + match="Search hyperparamter=(.*) should be existed in optimizer keyword arguments", + ): + OptimizerCallableSupportHPO( + optimizer_cls=fxt_optimizer_cls, + search_hparams=("lr", "non_momentum"), + optimizer_kwargs={ + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4, + }, + ) + + def test_picklable(self, fxt_optimizer_cls): + optimizer_callable = OptimizerCallableSupportHPO( + optimizer_cls=fxt_optimizer_cls, + optimizer_kwargs={ + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4, + }, + ) + + pickled = pickle.dumps(optimizer_callable) + unpickled = pickle.loads(pickled) # noqa: S301 + + assert isinstance(unpickled, OptimizerCallableSupportHPO) + assert optimizer_callable.optimizer_path == unpickled.optimizer_path + assert optimizer_callable.optimizer_kwargs == unpickled.optimizer_kwargs + assert optimizer_callable.search_hparams == unpickled.search_hparams + + def test_lazy_instance(self, fxt_optimizer_cls): + default_optimizer_callable = OptimizerCallableSupportHPO( + optimizer_cls=fxt_optimizer_cls, + optimizer_kwargs={ + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4, + }, + ).to_lazy_instance() + + class _TestOTXModel(OTXModel): + def __init__( + self, + num_classes=10, + optimizer=default_optimizer_callable, + scheduler=DefaultSchedulerCallable, + metric=NullMetricCallable, + torch_compile: bool = False, + ) -> None: + super().__init__(num_classes, optimizer, scheduler, metric, torch_compile) + + def _create_model(self) -> nn.Module: + return nn.Linear(10, self.num_classes) + + model = _TestOTXModel() + optimizers, _ = model.configure_optimizers() + optimizer = next(iter(optimizers)) + + assert isinstance(optimizer, SGD) + + def test_lazy_instance_picklable(self, fxt_optimizer_cls, fxt_params): + lazy_instance = OptimizerCallableSupportHPO( + optimizer_cls=fxt_optimizer_cls, + optimizer_kwargs={ + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4, + }, + ).to_lazy_instance() + + pickled = pickle.dumps(lazy_instance) + unpickled = pickle.loads(pickled) # noqa: S301 + + optimizer = unpickled(fxt_params) + + assert isinstance(optimizer, SGD) + + assert unpickled.lr == 0.1 + assert unpickled.momentum == 0.9 + assert unpickled.weight_decay == 1e-4 + + assert all(param["lr"] == 0.1 for param in fxt_params) + assert all(param["momentum"] == 0.9 for param in fxt_params) + assert all(param["weight_decay"] == 1e-4 for param in fxt_params)