-
Notifications
You must be signed in to change notification settings - Fork 447
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SchedulerCallableSupportHPO (#3334)
Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
- Loading branch information
Showing
8 changed files
with
326 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Scheduler callable to support hyper-parameter optimization (HPO) algorithm.""" | ||
|
||
from __future__ import annotations | ||
|
||
import importlib | ||
import inspect | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from lightning.pytorch.cli import ReduceLROnPlateau | ||
from torch import nn | ||
from torch.optim import Optimizer | ||
from torch.optim.lr_scheduler import LRScheduler | ||
from torch.optim.lr_scheduler import ReduceLROnPlateau as TorchReduceLROnPlateau | ||
|
||
from otx.core.utils.jsonargparse import ClassType, lazy_instance | ||
|
||
if TYPE_CHECKING: | ||
from lightning.pytorch.cli import LRSchedulerCallable | ||
|
||
|
||
class SchedulerCallableSupportHPO: | ||
"""LR scheduler callable supports OTX hyper-parameter optimization (HPO) algorithm. | ||
Args: | ||
scheduler_cls: `LRScheduler` class type or string class import path. See examples for details. | ||
scheduler_kwargs: Keyword arguments used for the initialization of the given `scheduler_cls`. | ||
Examples: | ||
This is an example to create `MobileNetV3ForMulticlassCls` with a `StepLR` lr scheduler and | ||
custom configurations. | ||
```python | ||
from torch.optim.lr_scheduler import StepLR | ||
from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls | ||
model = MobileNetV3ForMulticlassCls( | ||
num_classes=3, | ||
scheduler=SchedulerCallableSupportHPO( | ||
scheduler_cls=StepLR, | ||
scheduler_kwargs={ | ||
"step_size": 10, | ||
"gamma": 0.5, | ||
}, | ||
), | ||
) | ||
``` | ||
It can be created from the string class import path such as | ||
```python | ||
from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls | ||
model = MobileNetV3ForMulticlassCls( | ||
num_classes=3, | ||
optimizer=SchedulerCallableSupportHPO( | ||
scheduler_cls="torch.optim.lr_scheduler.StepLR", | ||
scheduler_kwargs={ | ||
"step_size": 10, | ||
"gamma": 0.5, | ||
}, | ||
), | ||
) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
scheduler_cls: type[LRScheduler] | str, | ||
scheduler_kwargs: dict[str, int | float | bool | str], | ||
): | ||
if isinstance(scheduler_cls, str): | ||
splited = scheduler_cls.split(".") | ||
module_path, class_name = ".".join(splited[:-1]), splited[-1] | ||
module = importlib.import_module(module_path) | ||
|
||
self.scheduler_init: type[LRScheduler] = getattr(module, class_name) | ||
self.scheduler_path = scheduler_cls | ||
elif issubclass(scheduler_cls, LRScheduler | ReduceLROnPlateau): | ||
self.scheduler_init = scheduler_cls | ||
self.scheduler_path = scheduler_cls.__module__ + "." + scheduler_cls.__qualname__ | ||
else: | ||
raise TypeError(scheduler_cls) | ||
|
||
self.scheduler_kwargs = scheduler_kwargs | ||
self.__dict__.update(scheduler_kwargs) | ||
|
||
def __call__(self, optimizer: Optimizer) -> LRScheduler: | ||
"""Create `torch.optim.LRScheduler` instance for the given parameters.""" | ||
return self.scheduler_init(optimizer, **self.scheduler_kwargs) | ||
|
||
def to_lazy_instance(self) -> ClassType: | ||
"""Return lazy instance of this class. | ||
Because OTX is rely on jsonargparse library, | ||
the default value of class initialization | ||
argument should be the lazy instance. | ||
Please refer to https://jsonargparse.readthedocs.io/en/stable/#default-values | ||
for more details. | ||
Examples: | ||
This is an example to implement a new model with a `StepLR` scheduler and | ||
custom configurations as a default. | ||
```python | ||
class MyAwesomeMulticlassClsModel(OTXMulticlassClsModel): | ||
def __init__( | ||
self, | ||
num_classes: int, | ||
optimizer: OptimizerCallable = DefaultOptimizerCallable, | ||
scheduler: LRSchedulerCallable | LRSchedulerListCallable = SchedulerCallableSupportHPO( | ||
scheduler_cls=StepLR, | ||
scheduler_kwargs={ | ||
"step_size": 10, | ||
"gamma": 0.5, | ||
}, | ||
).to_lazy_instance(), | ||
metric: MetricCallable = MultiClassClsMetricCallable, | ||
torch_compile: bool = False, | ||
) -> None: | ||
... | ||
``` | ||
""" | ||
return lazy_instance( | ||
SchedulerCallableSupportHPO, | ||
scheduler_cls=self.scheduler_path, | ||
scheduler_kwargs=self.scheduler_kwargs, | ||
) | ||
|
||
def __reduce__(self) -> str | tuple[Any, ...]: | ||
return self.__class__, ( | ||
self.scheduler_path, | ||
self.scheduler_kwargs, | ||
) | ||
|
||
@classmethod | ||
def from_callable(cls, func: LRSchedulerCallable) -> SchedulerCallableSupportHPO: | ||
"""Create this class instance from an existing optimizer callable.""" | ||
dummy_params = [nn.Parameter()] | ||
optimizer = Optimizer(dummy_params, {"lr": 1.0}) | ||
scheduler = func(optimizer) | ||
|
||
allow_names = set(inspect.signature(scheduler.__class__).parameters) | ||
|
||
if isinstance(scheduler, ReduceLROnPlateau): | ||
# NOTE: Other arguments except "monitor", such as "patience" | ||
# are not included in the signature of ReduceLROnPlateau.__init__() | ||
allow_names.update(key for key in inspect.signature(TorchReduceLROnPlateau).parameters) | ||
|
||
block_names = {"optimizer", "last_epoch"} | ||
|
||
scheduler_kwargs = { | ||
key: value for key, value in scheduler.state_dict().items() if key in allow_names and key not in block_names | ||
} | ||
|
||
return SchedulerCallableSupportHPO( | ||
scheduler_cls=scheduler.__class__, | ||
scheduler_kwargs=scheduler_kwargs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.