Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to implant optimizer and scheduler into model code #3258

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ exclude = [
# Ruff complains it but don't know how to fix since it literally showed no useful logs.
# https://github.com/openvinotoolkit/training_extensions/actions/runs/7176557723/job/19541622452?pr=2718#step:5:170
"tests/regression/*.py",

# Mostly borrowed from jsonargparse codebase
"src/otx/core/utils/jsonargparse.py"
]

# Same as Black.
Expand Down
45 changes: 38 additions & 7 deletions src/otx/algo/classification/efficientnet_b0.py
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

from typing import TYPE_CHECKING

import torch
from lightning.pytorch.cli import ReduceLROnPlateau

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import (
MMPretrainHlabelClsModel,
MMPretrainMulticlassClsModel,
Expand All @@ -30,8 +32,17 @@ class EfficientNetB0ForHLabelCls(MMPretrainHlabelClsModel):
def __init__(
self,
hlabel_info: HLabelInfo,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = lambda params: torch.optim.SGD(
params=params,
lr=0.0049,
),
scheduler: LRSchedulerCallable | LRSchedulerListCallable = lambda optimizer: ReduceLROnPlateau(
optimizer,
mode="max",
factor=0.1,
patience=1,
monitor="val/accuracy",
),
metric: MetricCallable = HLabelClsMetricCallble,
torch_compile: bool = False,
) -> None:
Expand All @@ -58,8 +69,19 @@ def __init__(
self,
num_classes: int,
light: bool = False,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = lambda params: torch.optim.SGD(
params=params,
lr=0.0049,
momentum=0.9,
weight_decay=0.0001,
),
scheduler: LRSchedulerCallable | LRSchedulerListCallable = lambda optimizer: ReduceLROnPlateau(
optimizer,
mode="max",
factor=0.1,
patience=1,
monitor="val/accuracy",
),
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand All @@ -85,8 +107,17 @@ class EfficientNetB0ForMultilabelCls(MMPretrainMultilabelClsModel):
def __init__(
self,
num_classes: int,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = lambda params: torch.optim.SGD(
params=params,
lr=0.0049,
),
scheduler: LRSchedulerCallable | LRSchedulerListCallable = lambda optimizer: ReduceLROnPlateau(
optimizer,
mode="max",
factor=0.1,
patience=1,
monitor="val/accuracy",
),
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand Down
13 changes: 13 additions & 0 deletions src/otx/cli/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def otx_install(option: str | None = None, verbose: bool = False, do_not_install
status_code = create_command("install").main(install_args)
if status_code == 0:
console.log(f"Installation Complete: {install_args}")
else:
msg = "Cannot complete installation"
raise RuntimeError(msg)

# https://github.com/Madoshakalaka/pipenv-setup/issues/101
os.environ["SETUPTOOLS_USE_DISTUTILS"] = "stdlib"
Expand All @@ -132,6 +135,16 @@ def otx_install(option: str | None = None, verbose: bool = False, do_not_install
status_code = mim_installation(mmcv_install_args)
if status_code == 0:
console.log(f"MMLab Installation Complete: {mmcv_install_args}")
else:
msg = "Cannot complete installation"
raise RuntimeError(msg)

# TODO(harimkang): Remove this reinstalling after resolving conflict with anomalib==1.0.1
# https://github.com/openvinotoolkit/training_extensions/actions/runs/8531851027/job/23372146228?pr=3258#step:5:2587
status_code = create_command("install").main(["jsonargparse==4.27.7"])
if status_code != 0:
msg = "Cannot install jsonargparse==4.27.7"
raise RuntimeError(msg)
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved

# Patch MMAction2 with src/otx/cli/patches/mmaction2.patch
patch_mmaction2()
Expand Down
16 changes: 15 additions & 1 deletion src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from otx.core.data.entity.tile import OTXTileBatchDataEntity, T_OTXTileBatchDataEntity
from otx.core.exporter.base import OTXModelExporter
from otx.core.metrics import MetricInput, NullMetricCallable
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.optimizer.callable import OptimizerCallableSupportHPO
from otx.core.schedulers import LRSchedulerListCallable, PicklableLRSchedulerCallable
from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler
from otx.core.types.export import OTXExportFormatType
from otx.core.types.label import LabelInfo, NullLabelInfo
Expand Down Expand Up @@ -666,6 +667,19 @@ def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Tensor) ->

return super().lr_scheduler_step(scheduler=scheduler, metric=metric)

def patch_optimizer_and_scheduler_for_hpo(self) -> None:
"""Patch optimizer and scheduler for hyperparameter optimization.

This is inplace function changing inner states (`optimizer_callable` and `scheduler_callable`).
Both will be changed to be picklable. In addition, `optimizer_callable` is changed
to make its hyperparameters gettable.
"""
if not isinstance(self.optimizer_callable, OptimizerCallableSupportHPO):
self.optimizer_callable = OptimizerCallableSupportHPO.from_callable(self.optimizer_callable)

if not isinstance(self.scheduler_callable, PicklableLRSchedulerCallable):
self.scheduler_callable = PicklableLRSchedulerCallable(self.scheduler_callable)


class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]):
"""Base class for the OpenVINO model.
Expand Down
7 changes: 7 additions & 0 deletions src/otx/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Modules related to an optimizer."""

from otx.core.optimizer.callable import OptimizerCallableSupportHPO

__all__ = ["OptimizerCallableSupportHPO"]
161 changes: 161 additions & 0 deletions src/otx/core/optimizer/callable.py
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Optimizer callable to support hyper-parameter optimization (HPO) algorithm."""

from __future__ import annotations

import importlib
from typing import TYPE_CHECKING, Any, Sequence

from torch import nn
from torch.optim.optimizer import Optimizer

from otx.core.utils.jsonargparse import ClassType, lazy_instance

if TYPE_CHECKING:
from lightning.pytorch.cli import OptimizerCallable
from torch.optim.optimizer import params_t


class OptimizerCallableSupportHPO:
"""Optimizer callable supports OTX hyper-parameter optimization (HPO) algorithm.

Args:
optimizer_cls: Optimizer class type or string class import path. See examples for details.
optimizer_kwargs: Keyword arguments used for the initialization of the given `optimizer_cls`.
search_hparams: Sequence of optimizer hyperparameter names which can be tuned by the OTX HPO algorithm.

Examples:
This is an example to create `MobileNetV3ForMulticlassCls` with a `SGD` optimizer and
custom configurations.

```python
from torch.optim import SGD
from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls

model = MobileNetV3ForMulticlassCls(
num_classes=3,
optimizer=OptimizerCallableSupportHPO(
optimizer_cls=SGD,
optimizer_kwargs={
"lr": 0.1,
"momentum": 0.9,
"weight_decay": 1e-4,
},
),
)
```

It can be created from the string class import path such as

```python
from otx.algo.classification.mobilenet_v3_large import MobileNetV3ForMulticlassCls

model = MobileNetV3ForMulticlassCls(
num_classes=3,
optimizer=OptimizerCallableSupportHPO(
optimizer_cls="torch.optim.SGD",
optimizer_kwargs={
"lr": 0.1,
"momentum": 0.9,
"weight_decay": 1e-4,
},
),
)
```
"""

def __init__(
self,
optimizer_cls: type[Optimizer] | str,
optimizer_kwargs: dict[str, int | float | bool],
search_hparams: Sequence[str] = ("lr",),
):
if isinstance(optimizer_cls, str):
splited = optimizer_cls.split(".")
module_path, class_name = ".".join(splited[:-1]), splited[-1]
module = importlib.import_module(module_path)

self.optimizer_init: type[Optimizer] = getattr(module, class_name)
self.optimizer_path = optimizer_cls
elif issubclass(optimizer_cls, Optimizer):
self.optimizer_init = optimizer_cls
self.optimizer_path = optimizer_cls.__module__ + "." + optimizer_cls.__qualname__
else:
raise TypeError(optimizer_cls)

for search_hparam in search_hparams:
if search_hparam not in optimizer_kwargs:
msg = (
f"Search hyperparamter={search_hparam} should be existed in "
f"optimizer keyword arguments={optimizer_kwargs} as well."
)
raise ValueError(msg)

self.search_hparams = list(search_hparams)
self.optimizer_kwargs = optimizer_kwargs
self.__dict__.update(optimizer_kwargs)

def __call__(self, params: params_t) -> Optimizer:
"""Create `torch.optim.Optimizer` instance for the given parameters."""
return self.optimizer_init(params, **self.optimizer_kwargs)

def to_lazy_instance(self) -> ClassType:
"""Return lazy instance of this class.

Because OTX is rely on jsonargparse library,
the default value of class initialization
argument should be the lazy instance.
Please refer to https://jsonargparse.readthedocs.io/en/stable/#default-values
for more details.

Examples:
This is an example to implement a new model with a `SGD` optimizer and
custom configurations as a default.

```python
class MyAwesomeMulticlassClsModel(OTXMulticlassClsModel):
def __init__(
self,
num_classes: int,
optimizer: OptimizerCallable = OptimizerCallableSupportHPO(
optimizer_cls=SGD,
optimizer_kwargs={
"lr": 0.1,
"momentum": 0.9,
"weight_decay": 1e-4,
},
).to_lazy_instance(),
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
...
```
"""
return lazy_instance(
OptimizerCallableSupportHPO,
optimizer_cls=self.optimizer_path,
optimizer_kwargs=self.optimizer_kwargs,
search_hparams=self.search_hparams,
)

def __reduce__(self) -> str | tuple[Any, ...]:
return self.__class__, (
self.optimizer_path,
self.optimizer_kwargs,
self.search_hparams,
)

@classmethod
def from_callable(cls, func: OptimizerCallable) -> OptimizerCallableSupportHPO:
"""Create this class instance from an existing optimizer callable."""
dummy_params = [nn.Parameter()]
optimizer = func(dummy_params)

param_group = next(iter(optimizer.param_groups))

return OptimizerCallableSupportHPO(
optimizer_cls=optimizer.__class__,
optimizer_kwargs={key: value for key, value in param_group.items() if key != "params"},
)
20 changes: 19 additions & 1 deletion src/otx/core/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@

from __future__ import annotations

from typing import Callable
from typing import TYPE_CHECKING, Callable

import dill
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch.cli import ReduceLROnPlateau
from torch.optim.optimizer import Optimizer

from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler, LinearWarmupSchedulerCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable

__all__ = [
"LRSchedulerListCallable",
"LinearWarmupScheduler",
Expand All @@ -21,3 +25,17 @@


LRSchedulerListCallable = Callable[[Optimizer], list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]]


class PicklableLRSchedulerCallable:
"""It converts unpicklable lr scheduler callable such as lambda function to picklable."""

def __init__(self, scheduler_callable: LRSchedulerCallable | LRSchedulerListCallable):
self.dumped_scheduler_callable = dill.dumps(scheduler_callable)

def __call__(
self,
optimizer: Optimizer,
) -> _TORCH_LRSCHEDULER | ReduceLROnPlateau | list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]:
scheduler_callable = dill.loads(self.dumped_scheduler_callable) # noqa: S301
return scheduler_callable(optimizer)
Loading
Loading