From 97f2bf5339c86cf28a6295b38c6522f679d82d30 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 11 Apr 2022 10:53:11 +0200 Subject: [PATCH] Optionally Avoid recomputing features (#722) Co-authored-by: Jirka Borovec Co-authored-by: Nicki Skafte Detlefsen --- CHANGELOG.md | 3 +++ tests/image/test_fid.py | 25 +++++++++++++++++++++++++ tests/image/test_kid.py | 25 +++++++++++++++++++++++++ torchmetrics/image/fid.py | 23 ++++++++++++++++++++++- torchmetrics/image/kid.py | 22 +++++++++++++++++++++- 5 files changed, 96 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8139fff992..8fde1841881 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `adaptive_k` for the `RetrievalPrecision` metric ([#910](https://github.com/PyTorchLightning/metrics/pull/910)) +- Added `reset_real_features` argument image quality assesment metrics ([#722](https://github.com/PyTorchLightning/metrics/pull/722)) + + ### Changed - Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853), [#914](https://github.com/PyTorchLightning/metrics/pull/914)) diff --git a/tests/image/test_fid.py b/tests/image/test_fid.py index 2c6d7455cf9..dd44af83c50 100644 --- a/tests/image/test_fid.py +++ b/tests/image/test_fid.py @@ -153,3 +153,28 @@ def test_compare_fid(tmpdir, feature=2048): tm_res = metric.compute() assert torch.allclose(tm_res.cpu(), torch.tensor([torch_fid["frechet_inception_distance"]]), atol=1e-3) + + +@pytest.mark.parametrize("reset_real_features", [True, False]) +def test_reset_real_features_arg(reset_real_features): + metric = FrechetInceptionDistance(feature=64, reset_real_features=reset_real_features) + + metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=True) + metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=False) + + assert len(metric.real_features) == 1 + assert list(metric.real_features[0].shape) == [2, 64] + + assert len(metric.fake_features) == 1 + assert list(metric.fake_features[0].shape) == [2, 64] + + metric.reset() + + # fake features should always reset + assert len(metric.fake_features) == 0 + + if reset_real_features: + assert len(metric.real_features) == 0 + else: + assert len(metric.real_features) == 1 + assert list(metric.real_features[0].shape) == [2, 64] diff --git a/tests/image/test_kid.py b/tests/image/test_kid.py index c9459bc57c5..dca29cd1c97 100644 --- a/tests/image/test_kid.py +++ b/tests/image/test_kid.py @@ -163,3 +163,28 @@ def test_compare_kid(tmpdir, feature=2048): assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid["kernel_inception_distance_mean"]]), atol=1e-3) assert torch.allclose(tm_std.cpu(), torch.tensor([torch_fid["kernel_inception_distance_std"]]), atol=1e-3) + + +@pytest.mark.parametrize("reset_real_features", [True, False]) +def test_reset_real_features_arg(reset_real_features): + metric = KernelInceptionDistance(feature=64, reset_real_features=reset_real_features) + + metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=True) + metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=False) + + assert len(metric.real_features) == 1 + assert list(metric.real_features[0].shape) == [2, 64] + + assert len(metric.fake_features) == 1 + assert list(metric.fake_features[0].shape) == [2, 64] + + metric.reset() + + # fake features should always reset + assert len(metric.fake_features) == 0 + + if reset_real_features: + assert len(metric.real_features) == 0 + else: + assert len(metric.real_features) == 1 + assert list(metric.real_features[0].shape) == [2, 64] diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 2b7bdef7ee8..51986d3fee6 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -161,6 +161,10 @@ class FrechetInceptionDistance(Metric): - an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns an ``[N,d]`` matrix where ``N`` is the batch size and ``d`` is the feature size. + reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not + change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if + your dataset does not change. + compute_on_step: Forward only calls ``update()`` and returns None if this is set to False. @@ -186,6 +190,8 @@ class FrechetInceptionDistance(Metric): If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048] TypeError: If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module`` + ValueError: + If ``reset_real_features`` is not an ``bool`` Example: >>> import torch @@ -203,11 +209,13 @@ class FrechetInceptionDistance(Metric): """ real_features: List[Tensor] fake_features: List[Tensor] - higher_is_better = False + higher_is_better: bool = False + is_differentiable: bool = False def __init__( self, feature: Union[int, torch.nn.Module] = 2048, + reset_real_features: bool = True, compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: @@ -237,6 +245,10 @@ def __init__( else: raise TypeError("Got unknown input to argument `feature`") + if not isinstance(reset_real_features, bool): + raise ValueError("Arugment `reset_real_features` expected to be a bool") + self.reset_real_features = reset_real_features + self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) @@ -274,3 +286,12 @@ def compute(self) -> Tensor: # compute fid return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype) + + def reset(self) -> None: + if not self.reset_real_features: + # remove temporarily to avoid resetting + value = self._defaults.pop("real_features") + super().reset() + self._defaults["real_features"] = value + else: + super().reset() diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index 80683fae3f3..9bc70069f50 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -112,6 +112,9 @@ class KernelInceptionDistance(Metric): Scale-length of polynomial kernel. If set to ``None`` will be automatically set to the feature size coef: Bias term in the polynomial kernel. + reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not + change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if + your dataset does not change. compute_on_step: Forward only calls ``update()`` and returns None if this is set to False. @@ -145,6 +148,8 @@ class KernelInceptionDistance(Metric): If ``gamma`` is niether ``None`` or a float larger than 0 ValueError: If ``coef`` is not an float larger than 0 + ValueError: + If ``reset_real_features`` is not an ``bool`` Example: >>> import torch @@ -163,7 +168,8 @@ class KernelInceptionDistance(Metric): """ real_features: List[Tensor] fake_features: List[Tensor] - higher_is_better = False + higher_is_better: bool = False + is_differentiable: bool = False def __init__( self, @@ -173,6 +179,7 @@ def __init__( degree: int = 3, gamma: Optional[float] = None, # type: ignore coef: float = 1.0, + reset_real_features: bool = True, compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: @@ -222,6 +229,10 @@ def __init__( raise ValueError("Argument `coef` expected to be float larger than 0") self.coef = coef + if not isinstance(reset_real_features, bool): + raise ValueError("Arugment `reset_real_features` expected to be a bool") + self.reset_real_features = reset_real_features + # states for extracted features self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) @@ -267,3 +278,12 @@ def compute(self) -> Tuple[Tensor, Tensor]: kid_scores_.append(o) kid_scores = torch.stack(kid_scores_) return kid_scores.mean(), kid_scores.std(unbiased=False) + + def reset(self) -> None: + if not self.reset_real_features: + # remove temporarily to avoid resetting + value = self._defaults.pop("real_features") + super().reset() + self._defaults["real_features"] = value + else: + super().reset()