Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to allow Scheduler to be configured as a list and fix a bug in the AdaptiveTrainingCallback #2910

Merged
merged 11 commits into from
Feb 14, 2024
2 changes: 1 addition & 1 deletion src/otx/algo/callbacks/adaptive_train_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _revert_func(config: LRSchedulerConfig, saved_frequency: int) -> None:
config.frequency = saved_frequency

for config in lr_configs:
if hasattr(config, "frequency"):
if hasattr(config, "frequency") and hasattr(config, "interval") and config.interval == "epoch":
msg = (
"The frequency of LRscheduler will be changed due to the effect of adaptive interval: "
f"{config.frequency} --> {adaptive_interval}."
Expand Down
23 changes: 16 additions & 7 deletions src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,19 @@ def engine_subcommand_parser(**kwargs) -> ArgumentParser:
sub_configs=True,
)
# Optimizer & Scheduler Settings
from lightning.pytorch.cli import LRSchedulerTypeTuple
from lightning.pytorch.cli import ReduceLROnPlateau
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

optim_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
scheduler_kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
parser.add_subclass_arguments(
baseclass=(Optimizer,),
baseclass=(Optimizer, list),
nested_key="optimizer",
**optim_kwargs,
)
parser.add_subclass_arguments(
baseclass=LRSchedulerTypeTuple,
baseclass=(LRScheduler, ReduceLROnPlateau, list),
nested_key="scheduler",
**scheduler_kwargs,
)
Expand Down Expand Up @@ -341,11 +342,17 @@ def instantiate_model(self, model_config: Namespace) -> tuple:
# Update self.config with model
self.config[self.subcommand].update(Namespace(model=model_config))

optimizer_kwargs = namespace_to_dict(self.get_config_value(self.config_init, "optimizer", Namespace()))
scheduler_kwargs = namespace_to_dict(self.get_config_value(self.config_init, "scheduler", Namespace()))
from otx.core.utils.instantiators import partial_instantiate_class

return model, partial_instantiate_class(optimizer_kwargs), partial_instantiate_class(scheduler_kwargs)
optimizer_kwargs = self.get_config_value(self.config_init, "optimizer", {})
optimizer_kwargs = optimizer_kwargs if isinstance(optimizer_kwargs, list) else [optimizer_kwargs]
optimizers = partial_instantiate_class([_opt for _opt in optimizer_kwargs if _opt])

scheduler_kwargs = self.get_config_value(self.config_init, "scheduler", {})
scheduler_kwargs = scheduler_kwargs if isinstance(scheduler_kwargs, list) else [scheduler_kwargs]
schedulers = partial_instantiate_class([_sch for _sch in scheduler_kwargs if _sch])

return model, optimizers, schedulers

def get_config_value(self, config: Namespace, key: str, default: Any = None) -> Any: # noqa: ANN401
"""Retrieves the value of a configuration key from the given config object.
Expand All @@ -357,8 +364,10 @@ def get_config_value(self, config: Namespace, key: str, default: Any = None) ->

Returns:
Any: The value of the configuration key, or the default value if the key is not found.
if the value is a Namespace, it is converted to a dictionary.
"""
return config.get(str(self.subcommand), config).get(key, default)
result = config.get(str(self.subcommand), config).get(key, default)
return namespace_to_dict(result) if isinstance(result, Namespace) else result

def get_subcommand_parser(self, subcommand: str | None) -> ArgumentParser:
"""Returns the argument parser for the specified subcommand.
Expand Down
2 changes: 1 addition & 1 deletion src/otx/cli/utils/jsonargparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def list_override(configs: Namespace, key: str, overrides: list) -> None:
... ...
... ]
"""
if key not in configs:
if key not in configs or configs[key] is None:
return
for target in overrides:
class_path = target.get("class_path", None)
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/model/module/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __init__(
self,
otx_model: OTXActionClsModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__(
otx_model=otx_model,
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/model/module/action_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
self,
otx_model: OTXActionDetModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__(
otx_model=otx_model,
Expand Down
57 changes: 25 additions & 32 deletions src/otx/core/model/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from lightning import LightningModule
from torch import Tensor

from otx.algo.schedulers.warmup_schedulers import BaseWarmupScheduler
from otx.core.data.entity.base import (
OTXBatchDataEntity,
OTXBatchLossEntity,
Expand All @@ -34,11 +33,13 @@ def __init__(
self,
optimizer: torch.optim.Optimizer,
num_warmup_steps: int = 1000,
interval: str = "step",
):
if num_warmup_steps > 0:
if not num_warmup_steps > 0:
msg = f"num_warmup_steps should be > 0, got {num_warmup_steps}"
ValueError(msg)
raise ValueError(msg)
self.num_warmup_steps = num_warmup_steps
self.interval = interval
super().__init__(optimizer, lambda step: min(step / num_warmup_steps, 1.0))


Expand All @@ -50,8 +51,8 @@ def __init__(
*,
otx_model: OTXModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__()

Expand Down Expand Up @@ -110,7 +111,7 @@ def setup(self, stage: str) -> None:
if self.torch_compile and stage == "fit":
self.model = torch.compile(self.model)

def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[torch.optim.Optimizer]]:
def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]]:
"""Choose what optimizers and learning-rate schedulers to use in your optimization.

Normally you'd need one. But in the case of GANs or similar you might have multiple.
Expand All @@ -120,34 +121,26 @@ def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[torch.

:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
"""
optimizer = (
self.hparams.optimizer(params=self.parameters())
if callable(self.hparams.optimizer)
else self.hparams.optimizer
)

scheduler = (
self.hparams.scheduler(optimizer=optimizer) if callable(self.hparams.scheduler) else self.hparams.scheduler
)

lr_scheduler_configs = []
if isinstance(scheduler, BaseWarmupScheduler) and scheduler.warmup_steps > 0:
lr_scheduler_configs += [
{
"scheduler": LinearWarmupScheduler(optimizer, num_warmup_steps=scheduler.warmup_steps),
"interval": "step",
},
]
lr_scheduler_configs += [
{
"scheduler": scheduler,
"monitor": self.lr_scheduler_monitor_key,
"interval": "epoch",
"frequency": self.trainer.check_val_every_n_epoch,
},

def ensure_list(item: Any) -> list: # noqa: ANN401
return item if isinstance(item, list) else [item]

optimizers = [
optimizer(params=self.parameters()) if callable(optimizer) else optimizer
for optimizer in ensure_list(self.hparams.optimizer)
]

return [optimizer], lr_scheduler_configs
lr_schedulers = []
for scheduler_config in ensure_list(self.hparams.scheduler):
scheduler = scheduler_config(optimizers[0]) if callable(scheduler_config) else scheduler_config
lr_scheduler_config = {"scheduler": scheduler}
if hasattr(scheduler, "interval"):
lr_scheduler_config["interval"] = scheduler.interval
if hasattr(scheduler, "monitor"):
lr_scheduler_config["monitor"] = scheduler.monitor
lr_schedulers.append(lr_scheduler_config)

return optimizers, lr_schedulers

def register_load_state_dict_pre_hook(self, model_classes: list[str], ckpt_classes: list[str]) -> None:
"""Register self.model's load_state_dict_pre_hook.
Expand Down
12 changes: 6 additions & 6 deletions src/otx/core/model/module/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
self,
otx_model: OTXMulticlassClsModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__(
otx_model=otx_model,
Expand Down Expand Up @@ -130,8 +130,8 @@ def __init__(
self,
otx_model: OTXMultilabelClsModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__(
otx_model=otx_model,
Expand Down Expand Up @@ -218,8 +218,8 @@ def __init__(
self,
otx_model: OTXHlabelClsModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__(
otx_model=otx_model,
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/model/module/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
self,
otx_model: ExplainableOTXDetModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__(
otx_model=otx_model,
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/model/module/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(
self,
otx_model: ExplainableOTXInstanceSegModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__(
otx_model=otx_model,
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/model/module/rotated_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def __init__(
self,
otx_model: OTXRotatedDetModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__(
otx_model=otx_model,
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/model/module/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
self,
otx_model: OTXSegmentationModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__(
otx_model=otx_model,
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/model/module/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(
self,
otx_model: OTXVisualPromptingModel,
torch_compile: bool,
optimizer: OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
optimizer: list[OptimizerCallable] | OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01),
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__(
otx_model=otx_model,
Expand Down
23 changes: 14 additions & 9 deletions src/otx/core/utils/instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,29 @@ def instantiate_loggers(logger_cfg: list | None) -> list[Logger]:
return logger


def partial_instantiate_class(init: dict | None) -> partial | None:
def partial_instantiate_class(init: list | dict | None) -> list[partial] | None:
"""Partially instantiates a class with the given initialization arguments.

Copy from lightning.pytorch.cli.instantiate_class and modify it to use partial.

Args:
init (dict): A dictionary containing the initialization arguments.
It should have the following keys:
init (list | dict | None): A dictionary containing the initialization arguments.
It should have the following each keys:
- "init_args" (dict): A dictionary of keyword arguments to be passed to the class constructor.
- "class_path" (str): The fully qualified path of the class to be instantiated.

Returns:
partial: A partial object representing the partially instantiated class.
list[partial] | None: A partial object representing the partially instantiated class.
"""
if not init:
return None
kwargs = init.get("init_args", {})
class_module, class_name = init["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
return partial(args_class, **kwargs)
if not isinstance(init, list):
init = [init]
items: list[partial] = []
for item in init:
kwargs = item.get("init_args", {})
class_module, class_name = item["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
items.append(partial(args_class, **kwargs))
return items
21 changes: 11 additions & 10 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(
work_dir: PathLike = "./otx-workspace",
datamodule: OTXDataModule | None = None,
model: OTXModel | str | None = None,
optimizer: OptimizerCallable | None = None,
scheduler: LRSchedulerCallable | None = None,
optimizer: list[OptimizerCallable] | OptimizerCallable | None = None,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable | None = None,
checkpoint: PathLike | None = None,
device: DeviceType = DeviceType.auto,
**kwargs,
Expand All @@ -97,9 +97,10 @@ def __init__(
work_dir (PathLike, optional): Working directory for the engine. Defaults to "./otx-workspace".
datamodule (OTXDataModule | None, optional): The data module for the engine. Defaults to None.
model (OTXModel | str | None, optional): The model for the engine. Defaults to None.
optimizer (OptimizerCallable | None, optional): The optimizer for the engine. Defaults to None.
scheduler (LRSchedulerCallable | None, optional): The learning rate scheduler for the engine.
optimizer (list[OptimizerCallable] | OptimizerCallable | None, optional): The optimizer for the engine.
Defaults to None.
scheduler (list[LRSchedulerCallable] | LRSchedulerCallable | None, optional):
The learning rate scheduler for the engine. Defaults to None.
checkpoint (PathLike | None, optional): Path to the checkpoint file. Defaults to None.
device (DeviceType, optional): The device type to use. Defaults to DeviceType.auto.
**kwargs: Additional keyword arguments for pl.Trainer.
Expand Down Expand Up @@ -132,10 +133,10 @@ def __init__(
meta_info=self._datamodule.meta_info if self._datamodule is not None else None,
)
)
self.optimizer: OptimizerCallable | None = (
self.optimizer: list[OptimizerCallable] | OptimizerCallable | None = (
optimizer if optimizer is not None else self._auto_configurator.get_optimizer()
)
self.scheduler: LRSchedulerCallable | None = (
self.scheduler: list[LRSchedulerCallable] | LRSchedulerCallable | None = (
scheduler if scheduler is not None else self._auto_configurator.get_scheduler()
)

Expand Down Expand Up @@ -667,15 +668,15 @@ def datamodule(self) -> OTXDataModule:
def _build_lightning_module(
self,
model: OTXModel,
optimizer: OptimizerCallable,
scheduler: LRSchedulerCallable,
optimizer: list[OptimizerCallable] | OptimizerCallable | None,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable | None,
) -> OTXLitModule:
"""Builds a LightningModule for engine workflow.

Args:
model (OTXModel): The OTXModel instance.
optimizer (OptimizerCallable): The optimizer callable.
scheduler (LRSchedulerCallable): The learning rate scheduler callable.
optimizer (list[OptimizerCallable] | OptimizerCallable | None): The optimizer callable.
scheduler (list[LRSchedulerCallable] | LRSchedulerCallable | None): The learning rate scheduler callable.

Returns:
OTXLitModule: The built LightningModule instance.
Expand Down
Loading
Loading