diff --git a/pyproject.toml b/pyproject.toml index e913a9df994..c6aaef76da3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ xpu = [ base = [ "torch==2.2.2", - "lightning==2.2", + "lightning==2.3.3", "pytorchcv", "timm==1.0.3", "openvino==2024.3", diff --git a/src/otx/core/schedulers/__init__.py b/src/otx/core/schedulers/__init__.py index ed74aaf62cf..d4d928e3c61 100644 --- a/src/otx/core/schedulers/__init__.py +++ b/src/otx/core/schedulers/__init__.py @@ -7,19 +7,18 @@ from typing import Callable -from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER from lightning.pytorch.cli import ReduceLROnPlateau +from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer from otx.core.schedulers.callable import SchedulerCallableSupportHPO from otx.core.schedulers.warmup_schedulers import LinearWarmupScheduler, LinearWarmupSchedulerCallable +LRSchedulerListCallable = Callable[[Optimizer], list[LRScheduler | ReduceLROnPlateau]] + __all__ = [ "LRSchedulerListCallable", "LinearWarmupScheduler", "LinearWarmupSchedulerCallable", "SchedulerCallableSupportHPO", ] - - -LRSchedulerListCallable = Callable[[Optimizer], list[_TORCH_LRSCHEDULER | ReduceLROnPlateau]]