diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index 07d9a1d88f5..a8ac0f507ff 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -38,8 +38,12 @@ from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput, NullMetricCallable from otx.core.optimizer.callable import OptimizerCallableSupportHPO -from otx.core.schedulers import LRSchedulerListCallable, PicklableLRSchedulerCallable -from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler +from otx.core.schedulers import ( + LinearWarmupScheduler, + LinearWarmupSchedulerCallable, + LRSchedulerListCallable, + SchedulerCallableSupportHPO, +) from otx.core.types.export import OTXExportFormatType, TaskLevelExportParameters from otx.core.types.label import LabelInfo, NullLabelInfo from otx.core.types.precision import OTXPrecisionType @@ -730,8 +734,11 @@ def patch_optimizer_and_scheduler_for_hpo(self) -> None: if not isinstance(self.optimizer_callable, OptimizerCallableSupportHPO): self.optimizer_callable = OptimizerCallableSupportHPO.from_callable(self.optimizer_callable) - if not isinstance(self.scheduler_callable, PicklableLRSchedulerCallable): - self.scheduler_callable = PicklableLRSchedulerCallable(self.scheduler_callable) + if not isinstance(self.scheduler_callable, SchedulerCallableSupportHPO) and not isinstance( + self.scheduler_callable, + LinearWarmupSchedulerCallable, # LinearWarmupSchedulerCallable natively supports HPO + ): + self.scheduler_callable = SchedulerCallableSupportHPO.from_callable(self.scheduler_callable) @property def tile_config(self) -> TileConfig: diff --git a/src/otx/core/optimizer/callable.py b/src/otx/core/optimizer/callable.py index 111e31203c0..4dc590d22da 100644 --- a/src/otx/core/optimizer/callable.py +++ b/src/otx/core/optimizer/callable.py @@ -5,7 +5,7 @@ from __future__ import annotations import importlib -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any from torch import nn from torch.optim.optimizer import Optimizer @@ -23,7 +23,6 @@ class OptimizerCallableSupportHPO: Args: optimizer_cls: Optimizer class type or string class import path. See examples for details. optimizer_kwargs: Keyword arguments used for the initialization of the given `optimizer_cls`. - search_hparams: Sequence of optimizer hyperparameter names which can be tuned by the OTX HPO algorithm. Examples: This is an example to create `MobileNetV3ForMulticlassCls` with a `SGD` optimizer and @@ -69,7 +68,6 @@ def __init__( self, optimizer_cls: type[Optimizer] | str, optimizer_kwargs: dict[str, int | float | bool], - search_hparams: Sequence[str] = ("lr",), ): if isinstance(optimizer_cls, str): splited = optimizer_cls.split(".") @@ -84,15 +82,6 @@ def __init__( else: raise TypeError(optimizer_cls) - for search_hparam in search_hparams: - if search_hparam not in optimizer_kwargs: - msg = ( - f"Search hyperparamter={search_hparam} should be existed in " - f"optimizer keyword arguments={optimizer_kwargs} as well." - ) - raise ValueError(msg) - - self.search_hparams = list(search_hparams) self.optimizer_kwargs = optimizer_kwargs self.__dict__.update(optimizer_kwargs) @@ -137,14 +126,12 @@ def __init__( OptimizerCallableSupportHPO, optimizer_cls=self.optimizer_path, optimizer_kwargs=self.optimizer_kwargs, - search_hparams=self.search_hparams, ) def __reduce__(self) -> str | tuple[Any, ...]: return self.__class__, ( self.optimizer_path, self.optimizer_kwargs, - self.search_hparams, ) @classmethod diff --git a/src/otx/core/schedulers/__init__.py b/src/otx/core/schedulers/__init__.py index 7f9bec87add..ed74aaf62cf 100644 --- a/src/otx/core/schedulers/__init__.py +++ b/src/otx/core/schedulers/__init__.py @@ -5,37 +5,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import Callable -import dill from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER from lightning.pytorch.cli import ReduceLROnPlateau from torch.optim.optimizer import Optimizer +from otx.core.schedulers.callable import SchedulerCallableSupportHPO from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler, LinearWarmupSchedulerCallable -if TYPE_CHECKING: - from lightning.pytorch.cli import LRSchedulerCallable - __all__ = [ "LRSchedulerListCallable", "LinearWarmupScheduler", "LinearWarmupSchedulerCallable", + "SchedulerCallableSupportHPO", ] LRSchedulerListCallable = Callable[[Optimizer], list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]] - - -class PicklableLRSchedulerCallable: - """It converts unpicklable lr scheduler callable such as lambda function to picklable.""" - - def __init__(self, scheduler_callable: LRSchedulerCallable | LRSchedulerListCallable): - self.dumped_scheduler_callable = dill.dumps(scheduler_callable) - - def __call__( - self, - optimizer: Optimizer, - ) -> _TORCH_LRSCHEDULER | ReduceLROnPlateau | list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]: - scheduler_callable = dill.loads(self.dumped_scheduler_callable) # noqa: S301 - return scheduler_callable(optimizer) diff --git a/src/otx/core/schedulers/callable.py b/src/otx/core/schedulers/callable.py new file mode 100644 index 00000000000..5667d40a9f1 --- /dev/null +++ b/src/otx/core/schedulers/callable.py @@ -0,0 +1,160 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Scheduler callable to support hyper-parameter optimization (HPO) algorithm.""" + +from __future__ import annotations + +import importlib +import inspect +from typing import TYPE_CHECKING, Any + +from lightning.pytorch.cli import ReduceLROnPlateau +from torch import nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.optim.lr_scheduler import ReduceLROnPlateau as TorchReduceLROnPlateau + +from otx.core.utils.jsonargparse import ClassType, lazy_instance + +if TYPE_CHECKING: + from lightning.pytorch.cli import LRSchedulerCallable + + +class SchedulerCallableSupportHPO: + """LR scheduler callable supports OTX hyper-parameter optimization (HPO) algorithm. + + Args: + scheduler_cls: `LRScheduler` class type or string class import path. See examples for details. + scheduler_kwargs: Keyword arguments used for the initialization of the given `scheduler_cls`. + + Examples: + This is an example to create `MobileNetV3ForMulticlassCls` with a `StepLR` lr scheduler and + custom configurations. + + ```python + from torch.optim.lr_scheduler import StepLR + from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls + + model = MobileNetV3ForMulticlassCls( + num_classes=3, + scheduler=SchedulerCallableSupportHPO( + scheduler_cls=StepLR, + scheduler_kwargs={ + "step_size": 10, + "gamma": 0.5, + }, + ), + ) + ``` + + It can be created from the string class import path such as + + ```python + from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls + + model = MobileNetV3ForMulticlassCls( + num_classes=3, + optimizer=SchedulerCallableSupportHPO( + scheduler_cls="torch.optim.lr_scheduler.StepLR", + scheduler_kwargs={ + "step_size": 10, + "gamma": 0.5, + }, + ), + ) + ``` + """ + + def __init__( + self, + scheduler_cls: type[LRScheduler] | str, + scheduler_kwargs: dict[str, int | float | bool | str], + ): + if isinstance(scheduler_cls, str): + splited = scheduler_cls.split(".") + module_path, class_name = ".".join(splited[:-1]), splited[-1] + module = importlib.import_module(module_path) + + self.scheduler_init: type[LRScheduler] = getattr(module, class_name) + self.scheduler_path = scheduler_cls + elif issubclass(scheduler_cls, LRScheduler | ReduceLROnPlateau): + self.scheduler_init = scheduler_cls + self.scheduler_path = scheduler_cls.__module__ + "." + scheduler_cls.__qualname__ + else: + raise TypeError(scheduler_cls) + + self.scheduler_kwargs = scheduler_kwargs + self.__dict__.update(scheduler_kwargs) + + def __call__(self, optimizer: Optimizer) -> LRScheduler: + """Create `torch.optim.LRScheduler` instance for the given parameters.""" + return self.scheduler_init(optimizer, **self.scheduler_kwargs) + + def to_lazy_instance(self) -> ClassType: + """Return lazy instance of this class. + + Because OTX is rely on jsonargparse library, + the default value of class initialization + argument should be the lazy instance. + Please refer to https://jsonargparse.readthedocs.io/en/stable/#default-values + for more details. + + Examples: + This is an example to implement a new model with a `StepLR` scheduler and + custom configurations as a default. + + ```python + class MyAwesomeMulticlassClsModel(OTXMulticlassClsModel): + def __init__( + self, + num_classes: int, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = SchedulerCallableSupportHPO( + scheduler_cls=StepLR, + scheduler_kwargs={ + "step_size": 10, + "gamma": 0.5, + }, + ).to_lazy_instance(), + metric: MetricCallable = MultiClassClsMetricCallable, + torch_compile: bool = False, + ) -> None: + ... + ``` + """ + return lazy_instance( + SchedulerCallableSupportHPO, + scheduler_cls=self.scheduler_path, + scheduler_kwargs=self.scheduler_kwargs, + ) + + def __reduce__(self) -> str | tuple[Any, ...]: + return self.__class__, ( + self.scheduler_path, + self.scheduler_kwargs, + ) + + @classmethod + def from_callable(cls, func: LRSchedulerCallable) -> SchedulerCallableSupportHPO: + """Create this class instance from an existing optimizer callable.""" + dummy_params = [nn.Parameter()] + optimizer = Optimizer(dummy_params, {"lr": 1.0}) + scheduler = func(optimizer) + + allow_names = set(inspect.signature(scheduler.__class__).parameters) + + if isinstance(scheduler, ReduceLROnPlateau): + # NOTE: Other arguments except "monitor", such as "patience" + # are not included in the signature of ReduceLROnPlateau.__init__() + allow_names.update(key for key in inspect.signature(TorchReduceLROnPlateau).parameters) + + block_names = {"optimizer", "last_epoch"} + + scheduler_kwargs = { + key: value for key, value in scheduler.state_dict().items() if key in allow_names and key not in block_names + } + + return SchedulerCallableSupportHPO( + scheduler_cls=scheduler.__class__, + scheduler_kwargs=scheduler_kwargs, + ) diff --git a/src/otx/core/schedulers/warmup_schedulers.py b/src/otx/core/schedulers/warmup_schedulers.py index 3367def1a15..6de763bb52b 100644 --- a/src/otx/core/schedulers/warmup_schedulers.py +++ b/src/otx/core/schedulers/warmup_schedulers.py @@ -8,6 +8,8 @@ from torch.optim.lr_scheduler import LambdaLR, LRScheduler +from otx.core.schedulers.callable import SchedulerCallableSupportHPO + if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, ReduceLROnPlateau from torch.optim.optimizer import Optimizer @@ -65,7 +67,7 @@ def __init__( warmup_interval: Literal["step", "epoch"] = "step", monitor: str | None = None, ): - self.main_scheduler_callable = main_scheduler_callable + self.main_scheduler_callable = SchedulerCallableSupportHPO.from_callable(main_scheduler_callable) self.num_warmup_steps = num_warmup_steps self.warmup_interval = warmup_interval self.monitor = monitor diff --git a/tests/unit/core/optimizer/test_callable.py b/tests/unit/core/optimizer/test_callable.py index dc20ef3fdd6..661504c003d 100644 --- a/tests/unit/core/optimizer/test_callable.py +++ b/tests/unit/core/optimizer/test_callable.py @@ -48,30 +48,21 @@ def test_succeed(self, fxt_optimizer_cls, fxt_params): assert all(param["momentum"] == 0.9 for param in fxt_params) assert all(param["weight_decay"] == 1e-4 for param in fxt_params) - def test_failure(self, fxt_invaliid_optimizer_cls, fxt_optimizer_cls): - with pytest.raises(TypeError): - OptimizerCallableSupportHPO( - optimizer_cls=fxt_invaliid_optimizer_cls, - optimizer_kwargs={ - "lr": 0.1, - "momentum": 0.9, - "weight_decay": 1e-4, - }, - ) - - with pytest.raises( - ValueError, - match="Search hyperparamter=(.*) should be existed in optimizer keyword arguments", - ): - OptimizerCallableSupportHPO( - optimizer_cls=fxt_optimizer_cls, - search_hparams=("lr", "non_momentum"), - optimizer_kwargs={ - "lr": 0.1, - "momentum": 0.9, - "weight_decay": 1e-4, - }, - ) + def test_from_callable(self, fxt_params): + optimizer_callable = OptimizerCallableSupportHPO.from_callable( + func=lambda params: SGD(params, lr=0.1, momentum=0.9, weight_decay=1e-4), + ) + optimizer = optimizer_callable(fxt_params) + + assert isinstance(optimizer, SGD) + + assert optimizer_callable.lr == 0.1 + assert optimizer_callable.momentum == 0.9 + assert optimizer_callable.weight_decay == 1e-4 + + assert all(param["lr"] == 0.1 for param in fxt_params) + assert all(param["momentum"] == 0.9 for param in fxt_params) + assert all(param["weight_decay"] == 1e-4 for param in fxt_params) def test_picklable(self, fxt_optimizer_cls): optimizer_callable = OptimizerCallableSupportHPO( @@ -89,7 +80,6 @@ def test_picklable(self, fxt_optimizer_cls): assert isinstance(unpickled, OptimizerCallableSupportHPO) assert optimizer_callable.optimizer_path == unpickled.optimizer_path assert optimizer_callable.optimizer_kwargs == unpickled.optimizer_kwargs - assert optimizer_callable.search_hparams == unpickled.search_hparams def test_lazy_instance(self, fxt_optimizer_cls): default_optimizer_callable = OptimizerCallableSupportHPO( diff --git a/tests/unit/core/schedulers/test_callable.py b/tests/unit/core/schedulers/test_callable.py new file mode 100644 index 00000000000..0106d014f15 --- /dev/null +++ b/tests/unit/core/schedulers/test_callable.py @@ -0,0 +1,123 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import pickle + +import pytest +from lightning.pytorch.cli import ReduceLROnPlateau +from otx.core.metrics import NullMetricCallable +from otx.core.model.base import DefaultOptimizerCallable, OTXModel +from otx.core.schedulers import SchedulerCallableSupportHPO +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR + + +class TestSchedulerCallableSupportHPO: + @pytest.fixture() + def fxt_optimizer(self): + model = nn.Linear(10, 10) + return SGD(model.parameters(), lr=1.0) + + @pytest.fixture( + params=[ + (StepLR, {"step_size": 10, "gamma": 0.5}), + (CosineAnnealingLR, {"T_max": 10, "eta_min": 0.5}), + (ReduceLROnPlateau, {"monitor": "my_metric", "patience": 10}), + ], + ids=lambda param: param[0].__qualname__, + ) + def fxt_scheduler_cls_and_kwargs(self, request): + scheduler_cls, scheduler_kwargs = request.param + return scheduler_cls, scheduler_kwargs + + def test_succeed(self, fxt_scheduler_cls_and_kwargs, fxt_optimizer): + scheduler_cls, scheduler_kwargs = fxt_scheduler_cls_and_kwargs + scheduler_callable = SchedulerCallableSupportHPO( + scheduler_cls=scheduler_cls, + scheduler_kwargs=scheduler_kwargs, + ) + scheduler = scheduler_callable(fxt_optimizer) + + assert isinstance(scheduler, scheduler_cls) + + for key, value in scheduler_kwargs.items(): + assert getattr(scheduler_callable, key) == value + assert scheduler_callable.scheduler_kwargs.get(key) == value + + assert scheduler.state_dict().get(key) == value + + def test_from_callable(self, fxt_scheduler_cls_and_kwargs, fxt_optimizer): + scheduler_cls, scheduler_kwargs = fxt_scheduler_cls_and_kwargs + scheduler_callable = SchedulerCallableSupportHPO.from_callable( + func=lambda optimizer: scheduler_cls(optimizer, **scheduler_kwargs), + ) + scheduler = scheduler_callable(fxt_optimizer) + + assert isinstance(scheduler, scheduler_cls) + + for key, value in scheduler_kwargs.items(): + assert getattr(scheduler_callable, key) == value + assert scheduler_callable.scheduler_kwargs.get(key) == value + + assert scheduler.state_dict().get(key) == value + + def test_picklable(self, fxt_scheduler_cls_and_kwargs, fxt_optimizer): + scheduler_cls, scheduler_kwargs = fxt_scheduler_cls_and_kwargs + scheduler_callable = SchedulerCallableSupportHPO( + scheduler_cls=scheduler_cls, + scheduler_kwargs=scheduler_kwargs, + ) + + pickled = pickle.dumps(scheduler_callable) + unpickled = pickle.loads(pickled) # noqa: S301 + + scheduler = unpickled(fxt_optimizer) + + assert isinstance(scheduler, scheduler_cls) + + for key, value in scheduler_kwargs.items(): + assert scheduler.state_dict().get(key) == value + + def test_lazy_instance(self, fxt_scheduler_cls_and_kwargs): + scheduler_cls, scheduler_kwargs = fxt_scheduler_cls_and_kwargs + default_scheduler_callable = SchedulerCallableSupportHPO( + scheduler_cls=scheduler_cls, + scheduler_kwargs=scheduler_kwargs, + ).to_lazy_instance() + + class _TestOTXModel(OTXModel): + def __init__( + self, + num_classes=10, + optimizer=DefaultOptimizerCallable, + scheduler=default_scheduler_callable, + metric=NullMetricCallable, + torch_compile: bool = False, + ) -> None: + super().__init__(num_classes, optimizer, scheduler, metric, torch_compile) + + def _create_model(self) -> nn.Module: + return nn.Linear(10, self.num_classes) + + model = _TestOTXModel() + _, scheduler_configs = model.configure_optimizers() + scheduler = next(iter(scheduler_configs))["scheduler"] + + assert isinstance(scheduler, scheduler_cls) + + def test_lazy_instance_picklable(self, fxt_scheduler_cls_and_kwargs, fxt_optimizer): + scheduler_cls, scheduler_kwargs = fxt_scheduler_cls_and_kwargs + lazy_instance = SchedulerCallableSupportHPO( + scheduler_cls=scheduler_cls, + scheduler_kwargs=scheduler_kwargs, + ).to_lazy_instance() + + pickled = pickle.dumps(lazy_instance) + unpickled = pickle.loads(pickled) # noqa: S301 + + scheduler = unpickled(fxt_optimizer) + + assert isinstance(scheduler, scheduler_cls) + + for key, value in scheduler_kwargs.items(): + assert scheduler.state_dict().get(key) == value diff --git a/tests/unit/core/schedulers/test_warmup_schedulers.py b/tests/unit/core/schedulers/test_warmup_schedulers.py index 345e28fd0cf..e36a8caad36 100644 --- a/tests/unit/core/schedulers/test_warmup_schedulers.py +++ b/tests/unit/core/schedulers/test_warmup_schedulers.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +from lightning.pytorch.cli import ReduceLROnPlateau from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler, LinearWarmupSchedulerCallable -from pytest_mock import MockerFixture from torch import nn -from torch.optim.lr_scheduler import LRScheduler +from torch.optim.lr_scheduler import StepLR from torch.optim.sgd import SGD @@ -28,21 +28,19 @@ def test_activation(self, fxt_optimizer): class TestLinearWarmupSchedulerCallable: - def test_num_warmup_steps(self, fxt_optimizer, mocker: MockerFixture): - mock_main_scheduler = mocker.create_autospec(spec=LRScheduler) - + def test_num_warmup_steps(self, fxt_optimizer): # No linear warmup scheduler because num_warmup_steps = 0 by default scheduler_callable = LinearWarmupSchedulerCallable( - main_scheduler_callable=lambda _: mock_main_scheduler, + main_scheduler_callable=lambda optimizer: StepLR(optimizer, step_size=10, gamma=0.5), ) schedulers = scheduler_callable(fxt_optimizer) assert len(schedulers) == 1 - assert schedulers == [mock_main_scheduler] + assert isinstance(schedulers[0], StepLR) # linear warmup scheduler exists because num_warmup_steps > 0 scheduler_callable = LinearWarmupSchedulerCallable( - main_scheduler_callable=lambda _: mock_main_scheduler, + main_scheduler_callable=lambda optimizer: StepLR(optimizer, step_size=10, gamma=0.5), num_warmup_steps=10, warmup_interval="epoch", ) @@ -50,18 +48,15 @@ def test_num_warmup_steps(self, fxt_optimizer, mocker: MockerFixture): schedulers = scheduler_callable(fxt_optimizer) assert len(schedulers) == 2 - assert schedulers[0] == mock_main_scheduler + assert isinstance(schedulers[0], StepLR) assert isinstance(schedulers[1], LinearWarmupScheduler) assert schedulers[1].num_warmup_steps == 10 assert schedulers[1].interval == "epoch" - def test_monitor(self, fxt_optimizer, mocker: MockerFixture): - mock_main_scheduler = mocker.MagicMock() - mock_main_scheduler.monitor = "not_my_metric" - + def test_monitor(self, fxt_optimizer): # If monitor None, do not override monitor. scheduler_callable = LinearWarmupSchedulerCallable( - main_scheduler_callable=lambda _: mock_main_scheduler, + main_scheduler_callable=lambda optimizer: ReduceLROnPlateau(optimizer, monitor="not_my_metric"), num_warmup_steps=10, monitor=None, ) @@ -73,7 +68,7 @@ def test_monitor(self, fxt_optimizer, mocker: MockerFixture): # Set monitor from "not_my_metric" to "my_metric" scheduler_callable = LinearWarmupSchedulerCallable( - main_scheduler_callable=lambda _: mock_main_scheduler, + main_scheduler_callable=lambda optimizer: ReduceLROnPlateau(optimizer, monitor="not_my_metric"), num_warmup_steps=10, monitor="my_metric", )