diff --git a/src/otx/algo/callbacks/adaptive_train_scheduling.py b/src/otx/algo/callbacks/adaptive_train_scheduling.py index afeeaa0bb35..3d47f50fa63 100644 --- a/src/otx/algo/callbacks/adaptive_train_scheduling.py +++ b/src/otx/algo/callbacks/adaptive_train_scheduling.py @@ -1,6 +1,6 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# + """Callback to reschedule the validation interval adaptively.""" from __future__ import annotations @@ -14,7 +14,7 @@ if TYPE_CHECKING: from lightning import LightningModule, Trainer - from lightning.pytorch.utilities.types import LRSchedulerConfig + from lightning.pytorch.utilities.types import LRSchedulerConfig, LRSchedulerTypeUnion class AdaptiveTrainScheduling(Callback): @@ -32,10 +32,13 @@ class AdaptiveTrainScheduling(Callback): def __init__(self, max_interval: int = 5, decay: float = -0.025): self.max_interval = max_interval self.decay = decay + self.min_earlystop_interval = 3 + self.min_lrschedule_patience = 2 self._saved_check_val_every_n_epoch: int | None = None self._saved_log_every_n_steps: int | None = None - self._revert_frequency: list = [] - self._revert_patience: list = [] + self._revert_lr_frequency: list = [] + self._revert_lr_patience: list = [] + self._revert_es_patience: list = [] def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Execute this function at starting the train stage.""" @@ -81,13 +84,14 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: trainer.log_every_n_steps = self._saved_log_every_n_steps self._saved_log_every_n_steps = None - if len(self._revert_frequency) > 0: - for revert in self._revert_frequency: - revert() + if len(self._revert_lr_frequency) > 0 and len(self._revert_lr_patience) > 0: + for revert_f, revert_p in zip(self._revert_lr_frequency, self._revert_lr_patience): + revert_f() + revert_p() - if len(self._revert_patience) > 0: - for revert in self._revert_patience: - revert() + if len(self._revert_es_patience) > 0: + for revert_es in self._revert_es_patience: + revert_es() def _get_adaptive_interval(self, iter_per_epoch: int, max_interval: int) -> int: """Get adaptive interval.""" @@ -100,21 +104,36 @@ def _change_lr_scheduler_frequency(self, lr_configs: list[LRSchedulerConfig], ad should be changed according to the adaptive interval. """ - def _revert_func(config: LRSchedulerConfig, saved_frequency: int) -> None: + def _revert_frequency(config: LRSchedulerConfig, saved_frequency: int) -> None: config.frequency = saved_frequency + def _revert_patience(lr_scheduler: LRSchedulerTypeUnion, saved_patience: int) -> None: + lr_scheduler.patience = saved_patience + for config in lr_configs: if hasattr(config, "frequency") and hasattr(config, "interval") and config.interval == "epoch": + saved_frequency = config.frequency + config.frequency = adaptive_interval msg = ( "The frequency of LRscheduler will be changed due to the effect of adaptive interval: " - f"{config.frequency} --> {adaptive_interval}." + f"{saved_frequency} --> {adaptive_interval}." ) log.warning(msg) - - saved_frequency = config.frequency - config.frequency = adaptive_interval - - self._revert_frequency += [partial(_revert_func, config, saved_frequency)] + self._revert_lr_frequency += [partial(_revert_frequency, config, saved_frequency)] + + if hasattr(config, "scheduler") and hasattr(config.scheduler, "patience"): + saved_patience = config.scheduler.patience + adjusted_patience = ( + max(int(config.scheduler.patience / adaptive_interval), self.min_lrschedule_patience) - 1 + ) + config.scheduler.patience = adjusted_patience + + msg = ( + "The patience of LRscheduler will be changed due to the effect of adaptive interval: " + f"{saved_patience} --> {adjusted_patience}." + ) + log.warning(msg) + self._revert_lr_patience += [partial(_revert_patience, config, saved_patience)] def _change_early_stopping_patience(self, callbacks: list[Callback], adaptive_interval: int) -> None: """Change the EarlyStopping patience to change the patience. @@ -130,7 +149,7 @@ def _revert_func(callback: Callback, saved_patience: int) -> None: for callback in callbacks: if isinstance(callback, EarlyStopping): - adjusted_patience = int(callback.patience / adaptive_interval) + adjusted_patience = max(int(callback.patience / adaptive_interval), self.min_earlystop_interval) msg = ( "The patience of early stopping will be changed due to the effect of adaptive interval: " f"{callback.patience} --> {adjusted_patience}." @@ -140,4 +159,4 @@ def _revert_func(callback: Callback, saved_patience: int) -> None: saved_patience = callback.patience callback.patience = adjusted_patience - self._revert_patience += [partial(_revert_func, callback, saved_patience)] + self._revert_es_patience += [partial(_revert_func, callback, saved_patience)] diff --git a/src/otx/algo/schedulers/warmup_schedulers.py b/src/otx/algo/schedulers/warmup_schedulers.py index ff72a2d44bc..dbfe26b8212 100644 --- a/src/otx/algo/schedulers/warmup_schedulers.py +++ b/src/otx/algo/schedulers/warmup_schedulers.py @@ -22,3 +22,8 @@ def __init__( self.num_warmup_steps = num_warmup_steps self.interval = interval super().__init__(optimizer, lambda step: min(step / num_warmup_steps, 1.0)) + + def step(self, epoch: int | None = None) -> None: + """Overriding the step to disable the warmup scheduler after n_steps.""" + if self._step_count < self.num_warmup_steps: + super().step(epoch) diff --git a/tests/unit/algo/callbacks/test_adaptive_train_scheduling.py b/tests/unit/algo/callbacks/test_adaptive_train_scheduling.py index 9d019be0f0a..553e93e640c 100644 --- a/tests/unit/algo/callbacks/test_adaptive_train_scheduling.py +++ b/tests/unit/algo/callbacks/test_adaptive_train_scheduling.py @@ -6,6 +6,7 @@ from lightning import LightningModule, Trainer from lightning.pytorch.callbacks.early_stopping import EarlyStopping +from lightning.pytorch.cli import ReduceLROnPlateau from lightning.pytorch.utilities.types import LRSchedulerConfig from otx.algo.callbacks.adaptive_train_scheduling import AdaptiveTrainScheduling from torch.utils.data import DataLoader @@ -31,6 +32,8 @@ def test_callback(self, caplog) -> None: mock_trainer.callbacks = [mock_callback] mock_lr_scheduler_config = MagicMock(spec=LRSchedulerConfig) + mock_lr_scheduler_config.scheduler = MagicMock(spec=ReduceLROnPlateau) + mock_lr_scheduler_config.scheduler.patience = 5 mock_lr_scheduler_config.frequency = 1 mock_lr_scheduler_config.interval = "epoch" mock_trainer.lr_scheduler_configs = [mock_lr_scheduler_config] @@ -40,8 +43,9 @@ def test_callback(self, caplog) -> None: assert mock_trainer.check_val_every_n_epoch != 1 # Adaptively updated assert mock_trainer.callbacks[0].patience != 5 assert mock_trainer.lr_scheduler_configs[0].frequency != 1 + assert mock_trainer.lr_scheduler_configs[0].scheduler.patience != 5 assert mock_trainer.log_every_n_steps == 10 # Equal to len(train_dataloader) - assert len(caplog.records) == 4 # Warning two times + assert len(caplog.records) == 5 # Warning two times callback.on_train_end(trainer=mock_trainer, pl_module=mock_pl_module) # Restore temporarily updated values @@ -49,3 +53,4 @@ def test_callback(self, caplog) -> None: assert mock_trainer.log_every_n_steps == 50 assert mock_trainer.callbacks[0].patience == 5 assert mock_trainer.lr_scheduler_configs[0].frequency == 1 + assert mock_trainer.lr_scheduler_configs[0].scheduler.patience == 1