Skip to content

Commit

Permalink
Fix warmup scheduler and add patience update for adaptive interval (#…
Browse files Browse the repository at this point in the history
…3056)

* Add warmup logic and update adaptive interval

* Fix precommit
  • Loading branch information
sungmanc authored Mar 8, 2024
1 parent 854747a commit f36a6ae
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 20 deletions.
57 changes: 38 additions & 19 deletions src/otx/algo/callbacks/adaptive_train_scheduling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

"""Callback to reschedule the validation interval adaptively."""

from __future__ import annotations
Expand All @@ -14,7 +14,7 @@

if TYPE_CHECKING:
from lightning import LightningModule, Trainer
from lightning.pytorch.utilities.types import LRSchedulerConfig
from lightning.pytorch.utilities.types import LRSchedulerConfig, LRSchedulerTypeUnion


class AdaptiveTrainScheduling(Callback):
Expand All @@ -32,10 +32,13 @@ class AdaptiveTrainScheduling(Callback):
def __init__(self, max_interval: int = 5, decay: float = -0.025):
self.max_interval = max_interval
self.decay = decay
self.min_earlystop_interval = 3
self.min_lrschedule_patience = 2
self._saved_check_val_every_n_epoch: int | None = None
self._saved_log_every_n_steps: int | None = None
self._revert_frequency: list = []
self._revert_patience: list = []
self._revert_lr_frequency: list = []
self._revert_lr_patience: list = []
self._revert_es_patience: list = []

def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Execute this function at starting the train stage."""
Expand Down Expand Up @@ -81,13 +84,14 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
trainer.log_every_n_steps = self._saved_log_every_n_steps
self._saved_log_every_n_steps = None

if len(self._revert_frequency) > 0:
for revert in self._revert_frequency:
revert()
if len(self._revert_lr_frequency) > 0 and len(self._revert_lr_patience) > 0:
for revert_f, revert_p in zip(self._revert_lr_frequency, self._revert_lr_patience):
revert_f()
revert_p()

if len(self._revert_patience) > 0:
for revert in self._revert_patience:
revert()
if len(self._revert_es_patience) > 0:
for revert_es in self._revert_es_patience:
revert_es()

def _get_adaptive_interval(self, iter_per_epoch: int, max_interval: int) -> int:
"""Get adaptive interval."""
Expand All @@ -100,21 +104,36 @@ def _change_lr_scheduler_frequency(self, lr_configs: list[LRSchedulerConfig], ad
should be changed according to the adaptive interval.
"""

def _revert_func(config: LRSchedulerConfig, saved_frequency: int) -> None:
def _revert_frequency(config: LRSchedulerConfig, saved_frequency: int) -> None:
config.frequency = saved_frequency

def _revert_patience(lr_scheduler: LRSchedulerTypeUnion, saved_patience: int) -> None:
lr_scheduler.patience = saved_patience

for config in lr_configs:
if hasattr(config, "frequency") and hasattr(config, "interval") and config.interval == "epoch":
saved_frequency = config.frequency
config.frequency = adaptive_interval
msg = (
"The frequency of LRscheduler will be changed due to the effect of adaptive interval: "
f"{config.frequency} --> {adaptive_interval}."
f"{saved_frequency} --> {adaptive_interval}."
)
log.warning(msg)

saved_frequency = config.frequency
config.frequency = adaptive_interval

self._revert_frequency += [partial(_revert_func, config, saved_frequency)]
self._revert_lr_frequency += [partial(_revert_frequency, config, saved_frequency)]

if hasattr(config, "scheduler") and hasattr(config.scheduler, "patience"):
saved_patience = config.scheduler.patience
adjusted_patience = (
max(int(config.scheduler.patience / adaptive_interval), self.min_lrschedule_patience) - 1
)
config.scheduler.patience = adjusted_patience

msg = (
"The patience of LRscheduler will be changed due to the effect of adaptive interval: "
f"{saved_patience} --> {adjusted_patience}."
)
log.warning(msg)
self._revert_lr_patience += [partial(_revert_patience, config, saved_patience)]

def _change_early_stopping_patience(self, callbacks: list[Callback], adaptive_interval: int) -> None:
"""Change the EarlyStopping patience to change the patience.
Expand All @@ -130,7 +149,7 @@ def _revert_func(callback: Callback, saved_patience: int) -> None:

for callback in callbacks:
if isinstance(callback, EarlyStopping):
adjusted_patience = int(callback.patience / adaptive_interval)
adjusted_patience = max(int(callback.patience / adaptive_interval), self.min_earlystop_interval)
msg = (
"The patience of early stopping will be changed due to the effect of adaptive interval: "
f"{callback.patience} --> {adjusted_patience}."
Expand All @@ -140,4 +159,4 @@ def _revert_func(callback: Callback, saved_patience: int) -> None:
saved_patience = callback.patience
callback.patience = adjusted_patience

self._revert_patience += [partial(_revert_func, callback, saved_patience)]
self._revert_es_patience += [partial(_revert_func, callback, saved_patience)]
5 changes: 5 additions & 0 deletions src/otx/algo/schedulers/warmup_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ def __init__(
self.num_warmup_steps = num_warmup_steps
self.interval = interval
super().__init__(optimizer, lambda step: min(step / num_warmup_steps, 1.0))

def step(self, epoch: int | None = None) -> None:
"""Overriding the step to disable the warmup scheduler after n_steps."""
if self._step_count < self.num_warmup_steps:
super().step(epoch)
7 changes: 6 additions & 1 deletion tests/unit/algo/callbacks/test_adaptive_train_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.cli import ReduceLROnPlateau
from lightning.pytorch.utilities.types import LRSchedulerConfig
from otx.algo.callbacks.adaptive_train_scheduling import AdaptiveTrainScheduling
from torch.utils.data import DataLoader
Expand All @@ -31,6 +32,8 @@ def test_callback(self, caplog) -> None:
mock_trainer.callbacks = [mock_callback]

mock_lr_scheduler_config = MagicMock(spec=LRSchedulerConfig)
mock_lr_scheduler_config.scheduler = MagicMock(spec=ReduceLROnPlateau)
mock_lr_scheduler_config.scheduler.patience = 5
mock_lr_scheduler_config.frequency = 1
mock_lr_scheduler_config.interval = "epoch"
mock_trainer.lr_scheduler_configs = [mock_lr_scheduler_config]
Expand All @@ -40,12 +43,14 @@ def test_callback(self, caplog) -> None:
assert mock_trainer.check_val_every_n_epoch != 1 # Adaptively updated
assert mock_trainer.callbacks[0].patience != 5
assert mock_trainer.lr_scheduler_configs[0].frequency != 1
assert mock_trainer.lr_scheduler_configs[0].scheduler.patience != 5
assert mock_trainer.log_every_n_steps == 10 # Equal to len(train_dataloader)
assert len(caplog.records) == 4 # Warning two times
assert len(caplog.records) == 5 # Warning two times

callback.on_train_end(trainer=mock_trainer, pl_module=mock_pl_module)
# Restore temporarily updated values
assert mock_trainer.check_val_every_n_epoch == 1
assert mock_trainer.log_every_n_steps == 50
assert mock_trainer.callbacks[0].patience == 5
assert mock_trainer.lr_scheduler_configs[0].frequency == 1
assert mock_trainer.lr_scheduler_configs[0].scheduler.patience == 1

0 comments on commit f36a6ae

Please sign in to comment.