Skip to content

Commit

Permalink
Add SchedulerCallableSupportHPO (#3334)
Browse files Browse the repository at this point in the history
Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim authored Apr 18, 2024
1 parent 2e6d225 commit ac7eedd
Showing 8 changed files with 326 additions and 78 deletions.
15 changes: 11 additions & 4 deletions src/otx/core/model/base.py
Original file line number Diff line number Diff line change
@@ -38,8 +38,12 @@
from otx.core.exporter.native import OTXNativeModelExporter
from otx.core.metrics import MetricInput, NullMetricCallable
from otx.core.optimizer.callable import OptimizerCallableSupportHPO
from otx.core.schedulers import LRSchedulerListCallable, PicklableLRSchedulerCallable
from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler
from otx.core.schedulers import (
LinearWarmupScheduler,
LinearWarmupSchedulerCallable,
LRSchedulerListCallable,
SchedulerCallableSupportHPO,
)
from otx.core.types.export import OTXExportFormatType, TaskLevelExportParameters
from otx.core.types.label import LabelInfo, NullLabelInfo
from otx.core.types.precision import OTXPrecisionType
@@ -730,8 +734,11 @@ def patch_optimizer_and_scheduler_for_hpo(self) -> None:
if not isinstance(self.optimizer_callable, OptimizerCallableSupportHPO):
self.optimizer_callable = OptimizerCallableSupportHPO.from_callable(self.optimizer_callable)

if not isinstance(self.scheduler_callable, PicklableLRSchedulerCallable):
self.scheduler_callable = PicklableLRSchedulerCallable(self.scheduler_callable)
if not isinstance(self.scheduler_callable, SchedulerCallableSupportHPO) and not isinstance(
self.scheduler_callable,
LinearWarmupSchedulerCallable, # LinearWarmupSchedulerCallable natively supports HPO
):
self.scheduler_callable = SchedulerCallableSupportHPO.from_callable(self.scheduler_callable)

@property
def tile_config(self) -> TileConfig:
15 changes: 1 addition & 14 deletions src/otx/core/optimizer/callable.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
from __future__ import annotations

import importlib
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any

from torch import nn
from torch.optim.optimizer import Optimizer
@@ -23,7 +23,6 @@ class OptimizerCallableSupportHPO:
Args:
optimizer_cls: Optimizer class type or string class import path. See examples for details.
optimizer_kwargs: Keyword arguments used for the initialization of the given `optimizer_cls`.
search_hparams: Sequence of optimizer hyperparameter names which can be tuned by the OTX HPO algorithm.
Examples:
This is an example to create `MobileNetV3ForMulticlassCls` with a `SGD` optimizer and
@@ -69,7 +68,6 @@ def __init__(
self,
optimizer_cls: type[Optimizer] | str,
optimizer_kwargs: dict[str, int | float | bool],
search_hparams: Sequence[str] = ("lr",),
):
if isinstance(optimizer_cls, str):
splited = optimizer_cls.split(".")
@@ -84,15 +82,6 @@ def __init__(
else:
raise TypeError(optimizer_cls)

for search_hparam in search_hparams:
if search_hparam not in optimizer_kwargs:
msg = (
f"Search hyperparamter={search_hparam} should be existed in "
f"optimizer keyword arguments={optimizer_kwargs} as well."
)
raise ValueError(msg)

self.search_hparams = list(search_hparams)
self.optimizer_kwargs = optimizer_kwargs
self.__dict__.update(optimizer_kwargs)

@@ -137,14 +126,12 @@ def __init__(
OptimizerCallableSupportHPO,
optimizer_cls=self.optimizer_path,
optimizer_kwargs=self.optimizer_kwargs,
search_hparams=self.search_hparams,
)

def __reduce__(self) -> str | tuple[Any, ...]:
return self.__class__, (
self.optimizer_path,
self.optimizer_kwargs,
self.search_hparams,
)

@classmethod
22 changes: 3 additions & 19 deletions src/otx/core/schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -5,37 +5,21 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Callable
from typing import Callable

import dill
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch.cli import ReduceLROnPlateau
from torch.optim.optimizer import Optimizer

from otx.core.schedulers.callable import SchedulerCallableSupportHPO
from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler, LinearWarmupSchedulerCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable

__all__ = [
"LRSchedulerListCallable",
"LinearWarmupScheduler",
"LinearWarmupSchedulerCallable",
"SchedulerCallableSupportHPO",
]


LRSchedulerListCallable = Callable[[Optimizer], list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]]


class PicklableLRSchedulerCallable:
"""It converts unpicklable lr scheduler callable such as lambda function to picklable."""

def __init__(self, scheduler_callable: LRSchedulerCallable | LRSchedulerListCallable):
self.dumped_scheduler_callable = dill.dumps(scheduler_callable)

def __call__(
self,
optimizer: Optimizer,
) -> _TORCH_LRSCHEDULER | ReduceLROnPlateau | list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]:
scheduler_callable = dill.loads(self.dumped_scheduler_callable) # noqa: S301
return scheduler_callable(optimizer)
160 changes: 160 additions & 0 deletions src/otx/core/schedulers/callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Scheduler callable to support hyper-parameter optimization (HPO) algorithm."""

from __future__ import annotations

import importlib
import inspect
from typing import TYPE_CHECKING, Any

from lightning.pytorch.cli import ReduceLROnPlateau
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau as TorchReduceLROnPlateau

from otx.core.utils.jsonargparse import ClassType, lazy_instance

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable


class SchedulerCallableSupportHPO:
"""LR scheduler callable supports OTX hyper-parameter optimization (HPO) algorithm.
Args:
scheduler_cls: `LRScheduler` class type or string class import path. See examples for details.
scheduler_kwargs: Keyword arguments used for the initialization of the given `scheduler_cls`.
Examples:
This is an example to create `MobileNetV3ForMulticlassCls` with a `StepLR` lr scheduler and
custom configurations.
```python
from torch.optim.lr_scheduler import StepLR
from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls
model = MobileNetV3ForMulticlassCls(
num_classes=3,
scheduler=SchedulerCallableSupportHPO(
scheduler_cls=StepLR,
scheduler_kwargs={
"step_size": 10,
"gamma": 0.5,
},
),
)
```
It can be created from the string class import path such as
```python
from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls
model = MobileNetV3ForMulticlassCls(
num_classes=3,
optimizer=SchedulerCallableSupportHPO(
scheduler_cls="torch.optim.lr_scheduler.StepLR",
scheduler_kwargs={
"step_size": 10,
"gamma": 0.5,
},
),
)
```
"""

def __init__(
self,
scheduler_cls: type[LRScheduler] | str,
scheduler_kwargs: dict[str, int | float | bool | str],
):
if isinstance(scheduler_cls, str):
splited = scheduler_cls.split(".")
module_path, class_name = ".".join(splited[:-1]), splited[-1]
module = importlib.import_module(module_path)

self.scheduler_init: type[LRScheduler] = getattr(module, class_name)
self.scheduler_path = scheduler_cls
elif issubclass(scheduler_cls, LRScheduler | ReduceLROnPlateau):
self.scheduler_init = scheduler_cls
self.scheduler_path = scheduler_cls.__module__ + "." + scheduler_cls.__qualname__
else:
raise TypeError(scheduler_cls)

self.scheduler_kwargs = scheduler_kwargs
self.__dict__.update(scheduler_kwargs)

def __call__(self, optimizer: Optimizer) -> LRScheduler:
"""Create `torch.optim.LRScheduler` instance for the given parameters."""
return self.scheduler_init(optimizer, **self.scheduler_kwargs)

def to_lazy_instance(self) -> ClassType:
"""Return lazy instance of this class.
Because OTX is rely on jsonargparse library,
the default value of class initialization
argument should be the lazy instance.
Please refer to https://jsonargparse.readthedocs.io/en/stable/#default-values
for more details.
Examples:
This is an example to implement a new model with a `StepLR` scheduler and
custom configurations as a default.
```python
class MyAwesomeMulticlassClsModel(OTXMulticlassClsModel):
def __init__(
self,
num_classes: int,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = SchedulerCallableSupportHPO(
scheduler_cls=StepLR,
scheduler_kwargs={
"step_size": 10,
"gamma": 0.5,
},
).to_lazy_instance(),
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
...
```
"""
return lazy_instance(
SchedulerCallableSupportHPO,
scheduler_cls=self.scheduler_path,
scheduler_kwargs=self.scheduler_kwargs,
)

def __reduce__(self) -> str | tuple[Any, ...]:
return self.__class__, (
self.scheduler_path,
self.scheduler_kwargs,
)

@classmethod
def from_callable(cls, func: LRSchedulerCallable) -> SchedulerCallableSupportHPO:
"""Create this class instance from an existing optimizer callable."""
dummy_params = [nn.Parameter()]
optimizer = Optimizer(dummy_params, {"lr": 1.0})
scheduler = func(optimizer)

allow_names = set(inspect.signature(scheduler.__class__).parameters)

if isinstance(scheduler, ReduceLROnPlateau):
# NOTE: Other arguments except "monitor", such as "patience"
# are not included in the signature of ReduceLROnPlateau.__init__()
allow_names.update(key for key in inspect.signature(TorchReduceLROnPlateau).parameters)

block_names = {"optimizer", "last_epoch"}

scheduler_kwargs = {
key: value for key, value in scheduler.state_dict().items() if key in allow_names and key not in block_names
}

return SchedulerCallableSupportHPO(
scheduler_cls=scheduler.__class__,
scheduler_kwargs=scheduler_kwargs,
)
4 changes: 3 additions & 1 deletion src/otx/core/schedulers/warmup_schedulers.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@

from torch.optim.lr_scheduler import LambdaLR, LRScheduler

from otx.core.schedulers.callable import SchedulerCallableSupportHPO

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer
@@ -65,7 +67,7 @@ def __init__(
warmup_interval: Literal["step", "epoch"] = "step",
monitor: str | None = None,
):
self.main_scheduler_callable = main_scheduler_callable
self.main_scheduler_callable = SchedulerCallableSupportHPO.from_callable(main_scheduler_callable)
self.num_warmup_steps = num_warmup_steps
self.warmup_interval = warmup_interval
self.monitor = monitor
40 changes: 15 additions & 25 deletions tests/unit/core/optimizer/test_callable.py
Original file line number Diff line number Diff line change
@@ -48,30 +48,21 @@ def test_succeed(self, fxt_optimizer_cls, fxt_params):
assert all(param["momentum"] == 0.9 for param in fxt_params)
assert all(param["weight_decay"] == 1e-4 for param in fxt_params)

def test_failure(self, fxt_invaliid_optimizer_cls, fxt_optimizer_cls):
with pytest.raises(TypeError):
OptimizerCallableSupportHPO(
optimizer_cls=fxt_invaliid_optimizer_cls,
optimizer_kwargs={
"lr": 0.1,
"momentum": 0.9,
"weight_decay": 1e-4,
},
)

with pytest.raises(
ValueError,
match="Search hyperparamter=(.*) should be existed in optimizer keyword arguments",
):
OptimizerCallableSupportHPO(
optimizer_cls=fxt_optimizer_cls,
search_hparams=("lr", "non_momentum"),
optimizer_kwargs={
"lr": 0.1,
"momentum": 0.9,
"weight_decay": 1e-4,
},
)
def test_from_callable(self, fxt_params):
optimizer_callable = OptimizerCallableSupportHPO.from_callable(
func=lambda params: SGD(params, lr=0.1, momentum=0.9, weight_decay=1e-4),
)
optimizer = optimizer_callable(fxt_params)

assert isinstance(optimizer, SGD)

assert optimizer_callable.lr == 0.1
assert optimizer_callable.momentum == 0.9
assert optimizer_callable.weight_decay == 1e-4

assert all(param["lr"] == 0.1 for param in fxt_params)
assert all(param["momentum"] == 0.9 for param in fxt_params)
assert all(param["weight_decay"] == 1e-4 for param in fxt_params)

def test_picklable(self, fxt_optimizer_cls):
optimizer_callable = OptimizerCallableSupportHPO(
@@ -89,7 +80,6 @@ def test_picklable(self, fxt_optimizer_cls):
assert isinstance(unpickled, OptimizerCallableSupportHPO)
assert optimizer_callable.optimizer_path == unpickled.optimizer_path
assert optimizer_callable.optimizer_kwargs == unpickled.optimizer_kwargs
assert optimizer_callable.search_hparams == unpickled.search_hparams

def test_lazy_instance(self, fxt_optimizer_cls):
default_optimizer_callable = OptimizerCallableSupportHPO(
Loading

0 comments on commit ac7eedd

Please sign in to comment.