Skip to content

Commit

Permalink
Refactor to implant optimizer and scheduler into model code (#3258)
Browse files Browse the repository at this point in the history
* Refactor to implant optimizer and scheduler into model code

 - Add OptimizerCallableSupportHPO and PicklableLRSchedulerCallable for
   HPO
 - Implant optimizer and scheduler into efficientnet_b0 code as a testbed

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>

* Upgrade jsonargparse to 4.27.7

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>

* Fix test installation error

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>

* Fix test errors

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>

---------

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim authored Apr 9, 2024
1 parent d31cb33 commit bc887ed
Show file tree
Hide file tree
Showing 14 changed files with 584 additions and 77 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ exclude = [
# Ruff complains it but don't know how to fix since it literally showed no useful logs.
# https://github.com/openvinotoolkit/training_extensions/actions/runs/7176557723/job/19541622452?pr=2718#step:5:170
"tests/regression/*.py",

# Mostly borrowed from jsonargparse codebase
"src/otx/core/utils/jsonargparse.py"
]

# Same as Black.
Expand Down
45 changes: 38 additions & 7 deletions src/otx/algo/classification/efficientnet_b0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

from typing import TYPE_CHECKING

import torch
from lightning.pytorch.cli import ReduceLROnPlateau

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import (
MMPretrainHlabelClsModel,
MMPretrainMulticlassClsModel,
Expand All @@ -31,8 +33,17 @@ class EfficientNetB0ForHLabelCls(ExplainableMixInMMPretrainModel, MMPretrainHlab
def __init__(
self,
hlabel_info: HLabelInfo,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = lambda params: torch.optim.SGD(
params=params,
lr=0.0049,
),
scheduler: LRSchedulerCallable | LRSchedulerListCallable = lambda optimizer: ReduceLROnPlateau(
optimizer,
mode="max",
factor=0.1,
patience=1,
monitor="val/accuracy",
),
metric: MetricCallable = HLabelClsMetricCallble,
torch_compile: bool = False,
) -> None:
Expand All @@ -59,8 +70,19 @@ def __init__(
self,
num_classes: int,
light: bool = False,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = lambda params: torch.optim.SGD(
params=params,
lr=0.0049,
momentum=0.9,
weight_decay=0.0001,
),
scheduler: LRSchedulerCallable | LRSchedulerListCallable = lambda optimizer: ReduceLROnPlateau(
optimizer,
mode="max",
factor=0.1,
patience=1,
monitor="val/accuracy",
),
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand All @@ -86,8 +108,17 @@ class EfficientNetB0ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrain
def __init__(
self,
num_classes: int,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = lambda params: torch.optim.SGD(
params=params,
lr=0.0049,
),
scheduler: LRSchedulerCallable | LRSchedulerListCallable = lambda optimizer: ReduceLROnPlateau(
optimizer,
mode="max",
factor=0.1,
patience=1,
monitor="val/accuracy",
),
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand Down
13 changes: 13 additions & 0 deletions src/otx/cli/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def otx_install(option: str | None = None, verbose: bool = False, do_not_install
status_code = create_command("install").main(install_args)
if status_code == 0:
console.log(f"Installation Complete: {install_args}")
else:
msg = "Cannot complete installation"
raise RuntimeError(msg)

# https://github.com/Madoshakalaka/pipenv-setup/issues/101
os.environ["SETUPTOOLS_USE_DISTUTILS"] = "stdlib"
Expand All @@ -132,6 +135,16 @@ def otx_install(option: str | None = None, verbose: bool = False, do_not_install
status_code = mim_installation(mmcv_install_args)
if status_code == 0:
console.log(f"MMLab Installation Complete: {mmcv_install_args}")
else:
msg = "Cannot complete installation"
raise RuntimeError(msg)

# TODO(harimkang): Remove this reinstalling after resolving conflict with anomalib==1.0.1
# https://github.com/openvinotoolkit/training_extensions/actions/runs/8531851027/job/23372146228?pr=3258#step:5:2587
status_code = create_command("install").main(["jsonargparse==4.27.7"])
if status_code != 0:
msg = "Cannot install jsonargparse==4.27.7"
raise RuntimeError(msg)

# Patch MMAction2 with src/otx/cli/patches/mmaction2.patch
patch_mmaction2()
Expand Down
16 changes: 15 additions & 1 deletion src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from otx.core.data.entity.tile import OTXTileBatchDataEntity, T_OTXTileBatchDataEntity
from otx.core.exporter.base import OTXModelExporter
from otx.core.metrics import MetricInput, NullMetricCallable
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.optimizer.callable import OptimizerCallableSupportHPO
from otx.core.schedulers import LRSchedulerListCallable, PicklableLRSchedulerCallable
from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler
from otx.core.types.export import OTXExportFormatType
from otx.core.types.label import LabelInfo, NullLabelInfo
Expand Down Expand Up @@ -675,6 +676,19 @@ def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Tensor) ->

return super().lr_scheduler_step(scheduler=scheduler, metric=metric)

def patch_optimizer_and_scheduler_for_hpo(self) -> None:
"""Patch optimizer and scheduler for hyperparameter optimization.
This is inplace function changing inner states (`optimizer_callable` and `scheduler_callable`).
Both will be changed to be picklable. In addition, `optimizer_callable` is changed
to make its hyperparameters gettable.
"""
if not isinstance(self.optimizer_callable, OptimizerCallableSupportHPO):
self.optimizer_callable = OptimizerCallableSupportHPO.from_callable(self.optimizer_callable)

if not isinstance(self.scheduler_callable, PicklableLRSchedulerCallable):
self.scheduler_callable = PicklableLRSchedulerCallable(self.scheduler_callable)


class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]):
"""Base class for the OpenVINO model.
Expand Down
7 changes: 7 additions & 0 deletions src/otx/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Modules related to an optimizer."""

from otx.core.optimizer.callable import OptimizerCallableSupportHPO

__all__ = ["OptimizerCallableSupportHPO"]
161 changes: 161 additions & 0 deletions src/otx/core/optimizer/callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Optimizer callable to support hyper-parameter optimization (HPO) algorithm."""

from __future__ import annotations

import importlib
from typing import TYPE_CHECKING, Any, Sequence

from torch import nn
from torch.optim.optimizer import Optimizer

from otx.core.utils.jsonargparse import ClassType, lazy_instance

if TYPE_CHECKING:
from lightning.pytorch.cli import OptimizerCallable
from torch.optim.optimizer import params_t


class OptimizerCallableSupportHPO:
"""Optimizer callable supports OTX hyper-parameter optimization (HPO) algorithm.
Args:
optimizer_cls: Optimizer class type or string class import path. See examples for details.
optimizer_kwargs: Keyword arguments used for the initialization of the given `optimizer_cls`.
search_hparams: Sequence of optimizer hyperparameter names which can be tuned by the OTX HPO algorithm.
Examples:
This is an example to create `MobileNetV3ForMulticlassCls` with a `SGD` optimizer and
custom configurations.
```python
from torch.optim import SGD
from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls
model = MobileNetV3ForMulticlassCls(
num_classes=3,
optimizer=OptimizerCallableSupportHPO(
optimizer_cls=SGD,
optimizer_kwargs={
"lr": 0.1,
"momentum": 0.9,
"weight_decay": 1e-4,
},
),
)
```
It can be created from the string class import path such as
```python
from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls
model = MobileNetV3ForMulticlassCls(
num_classes=3,
optimizer=OptimizerCallableSupportHPO(
optimizer_cls="torch.optim.SGD",
optimizer_kwargs={
"lr": 0.1,
"momentum": 0.9,
"weight_decay": 1e-4,
},
),
)
```
"""

def __init__(
self,
optimizer_cls: type[Optimizer] | str,
optimizer_kwargs: dict[str, int | float | bool],
search_hparams: Sequence[str] = ("lr",),
):
if isinstance(optimizer_cls, str):
splited = optimizer_cls.split(".")
module_path, class_name = ".".join(splited[:-1]), splited[-1]
module = importlib.import_module(module_path)

self.optimizer_init: type[Optimizer] = getattr(module, class_name)
self.optimizer_path = optimizer_cls
elif issubclass(optimizer_cls, Optimizer):
self.optimizer_init = optimizer_cls
self.optimizer_path = optimizer_cls.__module__ + "." + optimizer_cls.__qualname__
else:
raise TypeError(optimizer_cls)

for search_hparam in search_hparams:
if search_hparam not in optimizer_kwargs:
msg = (
f"Search hyperparamter={search_hparam} should be existed in "
f"optimizer keyword arguments={optimizer_kwargs} as well."
)
raise ValueError(msg)

self.search_hparams = list(search_hparams)
self.optimizer_kwargs = optimizer_kwargs
self.__dict__.update(optimizer_kwargs)

def __call__(self, params: params_t) -> Optimizer:
"""Create `torch.optim.Optimizer` instance for the given parameters."""
return self.optimizer_init(params, **self.optimizer_kwargs)

def to_lazy_instance(self) -> ClassType:
"""Return lazy instance of this class.
Because OTX is rely on jsonargparse library,
the default value of class initialization
argument should be the lazy instance.
Please refer to https://jsonargparse.readthedocs.io/en/stable/#default-values
for more details.
Examples:
This is an example to implement a new model with a `SGD` optimizer and
custom configurations as a default.
```python
class MyAwesomeMulticlassClsModel(OTXMulticlassClsModel):
def __init__(
self,
num_classes: int,
optimizer: OptimizerCallable = OptimizerCallableSupportHPO(
optimizer_cls=SGD,
optimizer_kwargs={
"lr": 0.1,
"momentum": 0.9,
"weight_decay": 1e-4,
},
).to_lazy_instance(),
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
...
```
"""
return lazy_instance(
OptimizerCallableSupportHPO,
optimizer_cls=self.optimizer_path,
optimizer_kwargs=self.optimizer_kwargs,
search_hparams=self.search_hparams,
)

def __reduce__(self) -> str | tuple[Any, ...]:
return self.__class__, (
self.optimizer_path,
self.optimizer_kwargs,
self.search_hparams,
)

@classmethod
def from_callable(cls, func: OptimizerCallable) -> OptimizerCallableSupportHPO:
"""Create this class instance from an existing optimizer callable."""
dummy_params = [nn.Parameter()]
optimizer = func(dummy_params)

param_group = next(iter(optimizer.param_groups))

return OptimizerCallableSupportHPO(
optimizer_cls=optimizer.__class__,
optimizer_kwargs={key: value for key, value in param_group.items() if key != "params"},
)
20 changes: 19 additions & 1 deletion src/otx/core/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@

from __future__ import annotations

from typing import Callable
from typing import TYPE_CHECKING, Callable

import dill
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch.cli import ReduceLROnPlateau
from torch.optim.optimizer import Optimizer

from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler, LinearWarmupSchedulerCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable

__all__ = [
"LRSchedulerListCallable",
"LinearWarmupScheduler",
Expand All @@ -21,3 +25,17 @@


LRSchedulerListCallable = Callable[[Optimizer], list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]]


class PicklableLRSchedulerCallable:
"""It converts unpicklable lr scheduler callable such as lambda function to picklable."""

def __init__(self, scheduler_callable: LRSchedulerCallable | LRSchedulerListCallable):
self.dumped_scheduler_callable = dill.dumps(scheduler_callable)

def __call__(
self,
optimizer: Optimizer,
) -> _TORCH_LRSCHEDULER | ReduceLROnPlateau | list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]:
scheduler_callable = dill.loads(self.dumped_scheduler_callable) # noqa: S301
return scheduler_callable(optimizer)
Loading

0 comments on commit bc887ed

Please sign in to comment.