forked from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_module.py
217 lines (174 loc) · 9.1 KB
/
mnist_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from typing import Any
import torch
from lightning import LightningModule
from torchmetrics import MeanMetric, MetricCollection, MetricTracker
from lightning_hydra_template.utils import pad_keys
class MNISTLitModule(LightningModule):
"""Example of a `LightningModule` for MNIST classification.
A `LightningModule` implements 8 key methods:
```python
def __init__(self):
# Define initialization code here.
def setup(self, stage):
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
# This hook is called on every process when using DDP.
def training_step(self, batch, batch_idx):
# The complete training step.
def validation_step(self, batch, batch_idx):
# The complete validation step.
def test_step(self, batch, batch_idx):
# The complete test step.
def predict_step(self, batch, batch_idx):
# The complete predict step.
def configure_optimizers(self):
# Define and configure optimizers and LR schedulers.
```
Docs:
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
"""
def __init__(
self,
net: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
criterion: torch.nn.Module,
metrics: MetricCollection | None = None,
) -> None:
"""Initialize a `MNISTLitModule`.
Args:
net: The model to train.
optimizer: The optimizer to use for training.
scheduler: The learning rate scheduler to use for training.
criterion: The loss function to use for training.
metrics: A collection of metrics to use for evaluation.
"""
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
# it is a good practice to ignore nn.Module instances (i.e. `net`, `criterion`, `metrics`) from hyperparameters
# as they are already stored in during checkpointing in the model's state_dict
self.save_hyperparameters(logger=False, ignore=["net", "criterion", "metrics"])
self.net = net
# loss function
self.criterion = criterion
# for averaging loss across batches
self.train_loss = MeanMetric()
self.test_loss = MeanMetric()
# for the validation loop, we wrap the loss inside a tracker, to help keep track of the best value across epochs
# this is useful for callbacks/optimizers that might need to monitor the validation loss
self.val_loss_tracker = MetricTracker(MeanMetric(), maximize=False)
# metric objects for calculating and averaging accuracy across batches
self._base_metrics = metrics
if self._base_metrics:
# torchmetrics recommends to use different instances of the metrics for train, val, and test
# to avoid conflicts since the metrics are stateful
self.train_metrics = self._base_metrics.clone(prefix="train/")
self.test_metrics = self._base_metrics.clone(prefix="test/")
# just as for the loss, we wrap the metrics inside a tracker to help track the best values across epochs
# here, we explicitly set `maximize=None` to infer the best value from the underlying metric
self.val_metrics_tracker = MetricTracker(self._base_metrics.clone(prefix="val/"), maximize=None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform a forward pass through the model `self.net`.
Args:
x: A tensor of images.
Returns:
A tensor of logits.
"""
return self.net(x)
def model_step(self, batch: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform a single model step on a batch of data.
Args:
batch: A batch of data containing the input tensor of images and target labels.
Returns:
A tuple of tensors containing the loss, the (unnormalized) predictions (i.e. logits), and the target labels,
respectively.
"""
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
preds = torch.argmax(logits, dim=1)
return loss, preds, y
def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""Perform a single training step on a batch of data from the training set.
Args:
batch: A batch of data containing the input tensor of images and target labels.
batch_idx: The index of the current batch.
Returns:
A tensor of losses between model predictions and targets.
"""
loss, preds, targets = self.model_step(batch)
# update and log metrics
self.train_loss.update(loss)
self.train_metrics.update(preds, targets)
self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
if self._base_metrics:
self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True)
# return loss or backpropagation will fail
return loss
def on_validation_epoch_start(self) -> None:
"""Lightning hook that is called when a validation epoch starts."""
# Initialize new instances of the tracked loss/metrics for the new epoch
# Since by default Lightning executes validation step sanity checks before training starts,
# this also makes sure that loss/metrics logged during the sanity check (i.e. 1st val increment)
# are not used to compute loss/metrics in the 1st actual validation epoch (i.e. 2nd val increment)
# This is a workaround to ignore sanity checks values, since trackers do not support deleting previous metrics,
# and it is simpler than the alternative of reinitializing the val trackers in `on_train_start`
self.val_loss_tracker.increment()
if self._base_metrics:
self.val_metrics_tracker.increment()
def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
"""Perform a single validation step on a batch of data from the validation set.
Args:
batch: A batch of data containing the input tensor of images and target labels.
batch_idx: The index of the current batch.
"""
loss, preds, targets = self.model_step(batch)
# update loss and metrics (which will be logged at the end of the epoch)
self.val_loss_tracker.update(loss)
if self._base_metrics:
self.val_metrics_tracker.update(preds, targets)
def on_validation_epoch_end(self) -> None:
"""Lightning hook that is called when a validation epoch ends."""
epoch_loss = self.val_loss_tracker.compute() # get current val loss
best_loss = self.val_loss_tracker.best_metric() # get best so far val loss
self.log("val/loss", epoch_loss, prog_bar=True)
self.log("val/loss/best", best_loss, prog_bar=True)
if self._base_metrics:
epoch_metrics = self.val_metrics_tracker.compute() # get current val metrics
best_metrics = self.val_metrics_tracker.best_metric() # get best so far val metrics
self.log_dict(epoch_metrics, prog_bar=True)
self.log_dict(pad_keys(best_metrics, postfix="/best"), prog_bar=True)
def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
"""Perform a single test step on a batch of data from the test set.
Args:
batch: A batch of data containing the input tensor of images and target labels.
batch_idx: The index of the current batch.
"""
loss, preds, targets = self.model_step(batch)
# update and log metrics
self.test_loss.update(loss)
self.test_metrics.update(preds, targets)
self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
if self._base_metrics:
self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True)
def configure_optimizers(self) -> dict[str, Any]:
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you would only need one, but in the case of GANs or similar you might have multiple.
Examples:
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
Returns:
A dict containing the configured optimizers and learning-rate schedulers to be used for training.
"""
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
if self.hparams.scheduler is not None:
scheduler = self.hparams.scheduler(optimizer=optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val/loss",
"interval": "epoch",
"frequency": 1,
},
}
return {"optimizer": optimizer}