Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] Detach bugfix #4313

Merged
merged 17 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed that metrics do not store computational graph for all seen data ([#4313](https://github.com/PyTorchLightning/pytorch-lightning/pull/4313))


## [1.0.2] - 2020-10-15

Expand Down Expand Up @@ -97,7 +99,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed `current_epoch` property update to reflect true epoch number inside `LightningDataModule`, when `reload_dataloaders_every_epoch=True`. ([#3974](https://github.com/PyTorchLightning/pytorch-lightning/pull/3974))
- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053))
- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053))
- Fixed mismatch between docstring and code regarding when `on_load_checkpoint` hook is called ([#3996](https://github.com/PyTorchLightning/pytorch-lightning/pull/3996))


Expand Down Expand Up @@ -442,7 +444,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986))
- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997))
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))

Expand Down
28 changes: 20 additions & 8 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ The example below shows how to use a metric in your ``LightningModule``:
def __init__(self):
...
self.accuracy = pl.metrics.Accuracy()

def training_step(self, batch, batch_idx):
logits = self(x)
...
# log step metric
self.log('train_acc_step', self.accuracy(logits, y))
...

def training_epoch_end(self, outs):
# log epoch metric
self.log('train_acc_epoch', self.accuracy.compute())
Expand All @@ -57,15 +57,15 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
This however is only true for metrics that inherit the base class ``Metric``,
and thus the functional metric API provides no support for in-built distributed synchronization
or reduction functions.


.. code-block:: python

def __init__(self):
...
self.train_acc = pl.metrics.Accuracy()
self.valid_acc = pl.metrics.Accuracy()

def training_step(self, batch, batch_idx):
logits = self(x)
...
Expand All @@ -91,17 +91,17 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)

# training step accuracy
batch_acc = train_accuracy(y_hat, y)

for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)

# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()

# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()

Expand Down Expand Up @@ -144,6 +144,18 @@ Example implementation:
def compute(self):
return self.correct.float() / self.total

Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. However, note that the cashed state is detached from the computational
graph and cannot be backpropagated. Not doing this would mean storing the computational
graph for each update call, which can lead to out-of-memory errors.
In practise this means that:

.. code-block:: python
metric = MyMetric()
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated


**********
Metric API
**********
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def forward(self, *args, **kwargs):
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.
"""
# add current step
self.update(*args, **kwargs)
with torch.no_grad():
self.update(*args, **kwargs)
self._forward_cache = None

if self.compute_on_step:
Expand Down