Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add types in bolts/callbacks #444

Merged
merged 26 commits into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci_test-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ jobs:
run: |
python -m pip install --upgrade --user pip
pip install --requirement ./requirements.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade
pip install --requirement ./requirements/loggers.txt --quiet --upgrade-strategy only-if-needed
pip install --requirement ./requirements/test.txt --quiet --upgrade-strategy only-if-needed
# pip install tox coverage
python --version
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/install-pkg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
pip install virtualenv
virtualenv vEnv ; source vEnv/bin/activate
# pip install -r requirements.txt
pip install torchvision
pip install torchvision matplotlib
pip install dist/*
cd .. & python -c "import pytorch_lightning as pl ; print(pl.__version__)"
cd .. & python -c "import pl_bolts ; print(pl_bolts.__version__)"
Expand Down
24 changes: 18 additions & 6 deletions pl_bolts/callbacks/byol_updates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import math
from typing import Sequence, Union

from pytorch_lightning import Callback
from pytorch_lightning import Callback, LightningModule, Trainer
from torch import Tensor
from torch.nn import Module


class BYOLMAWeightUpdate(Callback):
Expand Down Expand Up @@ -36,7 +39,15 @@ def __init__(self, initial_tau: float = 0.996):
self.initial_tau = initial_tau
self.current_tau = initial_tau

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Sequence,
batch: Sequence,
batch_idx: int,
dataloader_idx: int,
) -> None:
# get networks
online_net = pl_module.online_network
target_net = pl_module.target_network
Expand All @@ -47,13 +58,14 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
# update tau after
self.current_tau = self.update_tau(pl_module, trainer)

def update_tau(self, pl_module, trainer):
max_steps = len(trainer.train_dataloader) * trainer.max_epochs
def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float:
max_steps = len(trainer.train_dataloader) * trainer.max_epochs # type: ignore[attr-defined]
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2
return tau

def update_weights(self, online_net, target_net):
def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None:
# apply MA weight update
for (name, online_p), (_, target_p) in zip(online_net.named_parameters(), target_net.named_parameters()):
for (name, online_p), (_, target_p) in zip(
online_net.named_parameters(), target_net.named_parameters()): # type: ignore[union-attr]
if 'weight' in name:
target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data
42 changes: 25 additions & 17 deletions pl_bolts/callbacks/data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import torch
import torch.nn as nn
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger, WandbLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle

from pl_bolts.utils import _WANDB_AVAILABLE
Expand Down Expand Up @@ -38,22 +39,22 @@ def __init__(self, log_every_n_steps: int = None):
interval defined in the Trainer. Use this to override the Trainer default.
"""
super().__init__()
self._log_every_n_steps = log_every_n_steps
self._log_every_n_steps: Optional[int] = log_every_n_steps
self._log = False
self._trainer = None
self._train_batch_idx = None
self._trainer: Trainer
self._train_batch_idx: int

def on_train_start(self, trainer, pl_module):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._log = self._is_logger_available(trainer.logger)
self._log_every_n_steps = self._log_every_n_steps or trainer.log_every_n_steps
self._log_every_n_steps = self._log_every_n_steps or trainer.log_every_n_steps # type: ignore[attr-defined]
self._trainer = trainer

def on_train_batch_start(
self, trainer, pl_module, batch, batch_idx, dataloader_idx
):
self, trainer: Trainer, pl_module: LightningModule, batch: Sequence, batch_idx: int, dataloader_idx: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for batch it is too strict, this could be almost anything (Any)

) -> None:
self._train_batch_idx = batch_idx

def log_histograms(self, batch, group="") -> None:
def log_histograms(self, batch: Sequence, group: str = "") -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

"""
Logs the histograms at the interval defined by `row_log_interval`, given a logger is available.

Expand All @@ -64,11 +65,11 @@ def log_histograms(self, batch, group="") -> None:
Each label also has the tensors's shape as suffix.
group: Name under which the histograms will be grouped.
"""
if not self._log or (self._train_batch_idx + 1) % self._log_every_n_steps != 0:
if not self._log or (self._train_batch_idx + 1) % self._log_every_n_steps != 0: # type: ignore[operator]
return

batch = apply_to_collection(batch, dtype=np.ndarray, function=torch.from_numpy)
named_tensors = {}
named_tensors: Dict[str, Tensor] = {}
collect_and_name_tensors(batch, output=named_tensors, parent_name=group)

for name, tensor in named_tensors.items():
Expand Down Expand Up @@ -100,7 +101,7 @@ def log_histogram(self, tensor: Tensor, name: str) -> None:
data={name: wandb.Histogram(tensor)}, commit=False,
)

def _is_logger_available(self, logger) -> bool:
def _is_logger_available(self, logger: LightningLoggerBase) -> bool:
available = True
if not logger:
rank_zero_warn("Cannot log histograms because Trainer has no logger.")
Expand Down Expand Up @@ -154,9 +155,9 @@ def __init__(
"""
super().__init__(log_every_n_steps=log_every_n_steps)
self._submodule_names = submodules
self._hook_handles = []
self._hook_handles: List = []

def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_train_start(trainer, pl_module)
submodule_dict = dict(pl_module.named_modules())
self._hook_handles = []
Expand All @@ -170,7 +171,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
handle = self._register_hook(name, submodule_dict[name])
self._hook_handles.append(handle)

def on_train_end(self, trainer, pl_module):
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
for handle in self._hook_handles:
handle.remove()

Expand Down Expand Up @@ -198,7 +199,7 @@ def _register_hook(self, module_name: str, module: nn.Module) -> RemovableHandle
else self.GROUP_NAME_OUTPUT
)

def hook(_, inp, out):
def hook(_: Module, inp: Sequence, out: Sequence) -> None:
inp = inp[0] if len(inp) == 1 else inp
self.log_histograms(inp, group=input_group_name)
self.log_histograms(out, group=output_group_name)
Expand Down Expand Up @@ -228,7 +229,14 @@ def __init__(self, log_every_n_steps: int = None):
"""
super().__init__(log_every_n_steps=log_every_n_steps)

def on_train_batch_start(self, trainer, pl_module, batch, *args, **kwargs):
def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Sequence,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

*args: int,
**kwargs: int,
) -> None:
super().on_train_batch_start(trainer, pl_module, batch, *args, **kwargs)
self.log_histograms(batch, group=self.GROUP_NAME)

Expand Down
27 changes: 14 additions & 13 deletions pl_bolts/callbacks/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_info

Expand Down Expand Up @@ -33,10 +34,10 @@ class PrintTableMetricsCallback(Callback):

"""

def __init__(self):
self.metrics = []
def __init__(self) -> None:
self.metrics: List = []

def on_epoch_end(self, trainer, pl_module):
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
metrics_dict = copy.copy(trainer.callback_metrics)
self.metrics.append(metrics_dict)
rank_zero_info(dicts_to_table(self.metrics))
Expand All @@ -49,7 +50,7 @@ def dicts_to_table(dicts: List[Dict],
convert_headers: Optional[Dict[str, Callable]] = None,
header_names: Optional[List[str]] = None,
skip_none_lines: bool = False,
replace_values: Optional[Dict[str, Any]] = None):
replace_values: Optional[Dict[str, Any]] = None) -> str:
"""
Generate ascii table from dictionary
Taken from (https://stackoverflow.com/questions/40056747/print-a-list-of-dictionaries-in-table-form)
Expand Down Expand Up @@ -79,30 +80,30 @@ def dicts_to_table(dicts: List[Dict],
# optional arg prelude
if keys is None:
if len(dicts) > 0:
keys = dicts[0].keys()
keys = dicts[0].keys() # type: ignore[assignment]
elif header_names is not None:
keys = header_names
else:
raise ValueError('keys or header_names mandatory on empty input list')
if pads is None:
pads = [''] * len(keys)
elif len(pads) != len(keys):
raise ValueError(f'bad pad length {len(pads)}, expected: {len(keys)}')
pads = [''] * len(keys) # type: ignore[arg-type]
elif len(pads) != len(keys): # type: ignore[arg-type]
raise ValueError(f'bad pad length {len(pads)}, expected: {len(keys)}') # type: ignore[arg-type]
if fcodes is None:
fcodes = [''] * len(keys)
fcodes = [''] * len(keys) # type: ignore[arg-type]
elif len(fcodes) != len(fcodes):
raise ValueError(f'bad fcodes length {len(fcodes)}, expected: {len(keys)}')
raise ValueError(f'bad fcodes length {len(fcodes)}, expected: {len(keys)}') # type: ignore[arg-type]
if convert_headers is None:
convert_headers = {}
if header_names is None:
header_names = keys
if replace_values is None:
replace_values = {}
# build header
headline = '│'.join(f"{v:{pad}}" for v, pad in zip_longest(header_names, pads))
headline = '│'.join(f"{v:{pad}}" for v, pad in zip_longest(header_names, pads)) # type: ignore[arg-type]
underline = '─' * len(headline)
# suffix special keys to apply converters to later on
marked_keys = [h + '____' if h in convert_headers else h for h in keys]
marked_keys = [h + '____' if h in convert_headers else h for h in keys] # type: ignore[union-attr]
marked_values = {}
s = '│'.join(f"{{{h}:{pad}{fcode}}}" for h, pad, fcode in zip_longest(marked_keys, pads, fcodes))
lines = [headline, underline, ]
Expand All @@ -119,7 +120,7 @@ def dicts_to_table(dicts: List[Dict],
elif none_keys:
raise ValueError(f'keys {none_keys} are None in {d}. Do skip or use replace mapping.')
for h in convert_headers:
if h in keys:
if h in keys: # type: ignore[operator]
converter = convert_headers[h]
marked_values[h + '____'] = converter(d)
line = s.format(**d, **marked_values)
Expand Down
25 changes: 15 additions & 10 deletions pl_bolts/callbacks/ssl_online.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional
from typing import Optional, Sequence, Tuple, Union

import torch
from pytorch_lightning import Callback
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.metrics.functional import accuracy
from torch import Tensor, device
from torch.nn import functional as F
from torch.optim import Optimizer


class SSLOnlineEvaluator(Callback): # pragma: no cover
Expand All @@ -24,6 +26,7 @@ class SSLOnlineEvaluator(Callback): # pragma: no cover
)

"""

def __init__(
self,
dataset: str,
Expand All @@ -44,13 +47,13 @@ def __init__(

self.hidden_dim = hidden_dim
self.drop_p = drop_p
self.optimizer = None
self.optimizer: Optimizer

self.z_dim = z_dim
self.num_classes = num_classes
self.dataset = dataset

def on_pretrain_routine_start(self, trainer, pl_module):
def on_pretrain_routine_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

pl_module.non_linear_evaluator = SSLEvaluator(
Expand All @@ -64,12 +67,12 @@ def on_pretrain_routine_start(self, trainer, pl_module):
pl_module.non_linear_evaluator.parameters(), lr=1e-4
)

def get_representations(self, pl_module, x):
def get_representations(self, pl_module: LightningModule, x: Tensor) -> Tensor:
representations = pl_module(x)
representations = representations.reshape(representations.size(0), -1)
return representations

def to_device(self, batch, device):
def to_device(self, batch: Sequence, device: Union[str, device]) -> Tuple[Tensor, Tensor]:
# get the labeled batch
if self.dataset == 'stl10':
labeled_batch = batch[1]
Expand All @@ -84,7 +87,8 @@ def to_device(self, batch, device):

return x, y

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Sequence,
batch: Sequence, batch_idx: int, dataloader_idx: int) -> None:
x, y = self.to_device(batch, pl_module.device)

with torch.no_grad():
Expand All @@ -93,7 +97,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
representations = representations.detach()

# forward pass
mlp_preds = pl_module.non_linear_evaluator(representations)
mlp_preds = pl_module.non_linear_evaluator(representations) # type: ignore[operator]
mlp_loss = F.cross_entropy(mlp_preds, y)

# update finetune weights
Expand All @@ -106,7 +110,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
pl_module.log('online_train_acc', train_acc, on_step=True, on_epoch=False)
pl_module.log('online_train_loss', mlp_loss, on_step=True, on_epoch=False)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_validation_batch_end(self, trainer: Trainer, pl_module: LightningModule,
outputs: Sequence, batch: Sequence, batch_idx: int, dataloader_idx: int) -> None:
x, y = self.to_device(batch, pl_module.device)

with torch.no_grad():
Expand All @@ -115,7 +120,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
representations = representations.detach()

# forward pass
mlp_preds = pl_module.non_linear_evaluator(representations)
mlp_preds = pl_module.non_linear_evaluator(representations) # type: ignore[operator]
mlp_loss = F.cross_entropy(mlp_preds, y)

# log metrics
Expand Down
12 changes: 8 additions & 4 deletions pl_bolts/callbacks/variational.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import math
from typing import List

import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from torch import Tensor

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down Expand Up @@ -58,18 +61,19 @@ def __init__(
self.normalize = normalize
self.steps = steps

def on_epoch_end(self, trainer, pl_module):
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0:
images = self.interpolate_latent_space(pl_module, latent_dim=pl_module.hparams.latent_dim)
images = torch.cat(images, dim=0)
images = self.interpolate_latent_space(
pl_module, latent_dim=pl_module.hparams.latent_dim) # type: ignore[union-attr]
images = torch.cat(images, dim=0) # type: ignore[assignment]

num_images = (self.range_end - self.range_start) ** 2
num_rows = int(math.sqrt(num_images))
grid = torchvision.utils.make_grid(images, nrow=num_rows, normalize=self.normalize)
str_title = f'{pl_module.__class__.__name__}_latent_space'
trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step)

def interpolate_latent_space(self, pl_module, latent_dim):
def interpolate_latent_space(self, pl_module: LightningModule, latent_dim: int) -> List[Tensor]:
images = []
with torch.no_grad():
pl_module.eval()
Expand Down
Loading