Skip to content

Commit

Permalink
docstring, annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Aug 5, 2020
1 parent 31e35ec commit a0afec2
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions pl_bolts/optimizers/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import warnings
from typing import List

import torch.nn as nn
from torch.optim import Optimizer, Adam
Expand Down Expand Up @@ -52,13 +53,13 @@ class LinearWarmupCosineAnnealingLR(_LRScheduler):

def __init__(
self,
optimizer,
warmup_epochs,
max_epochs,
warmup_start_lr=0.0,
eta_min=0.0,
last_epoch=-1,
):
optimizer: Optimizer,
warmup_epochs: int,
max_epochs: int,
warmup_start_lr: float = 0.0,
eta_min: float = 0.0,
last_epoch: int = -1,
) -> None:

self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
Expand All @@ -67,7 +68,10 @@ def __init__(

super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)

def get_lr(self):
def get_lr(self) -> List[float]:
"""
Compute learning rate using chainable form of the scheduler
"""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
Expand Down Expand Up @@ -118,7 +122,10 @@ def get_lr(self):
for group in self.optimizer.param_groups
]

def _get_closed_form_lr(self):
def _get_closed_form_lr(self) -> List[float]:
"""
Called when epoch is passed as a param to the `step` function of the scheduler.
"""
if self.last_epoch < self.warmup_epochs:
return [
self.warmup_start_lr
Expand Down

0 comments on commit a0afec2

Please sign in to comment.