From 6b420a7fcf1287254012e46203ebfff9fa8dcc4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Mon, 29 Nov 2021 14:49:23 +0100 Subject: [PATCH 01/13] Fix CompositionalMetric --- torchmetrics/metric.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 1bc8911b494..9a3b0df9908 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -756,6 +756,33 @@ def compute(self) -> Any: return self.op(val_a, val_b) + @torch.jit.unused + def forward(self, *args: Any, **kwargs: Any) -> Any: + + if isinstance(self.metric_a, Metric): + val_a = self.metric_a(*args, **self.metric_a._filter_kwargs(**kwargs)) + else: + val_a = self.metric_a + + if isinstance(self.metric_b, Metric): + val_b = self.metric_b(*args, **self.metric_b._filter_kwargs(**kwargs)) + else: + val_b = self.metric_b + + if val_a is None: + # compute_on_step of metric_a is False + return None + elif val_b is None: + if isinstance(self.metric_b, Metric): + # compute_on_step of metric_b is False + return None + else: + # Unary op + return self.op(val_a) + else: + # Binary op + return self.op(val_a, val_b) + def reset(self) -> None: if isinstance(self.metric_a, Metric): self.metric_a.reset() @@ -774,3 +801,6 @@ def __repr__(self) -> str: repr_str = self.__class__.__name__ + _op_metrics return repr_str + + def _wrap_compute(self, compute: Callable) -> Callable: + return compute From 22c18bef7ec00df1bd73cd0939df5a40fa67f3e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Mon, 29 Nov 2021 15:40:33 +0100 Subject: [PATCH 02/13] Added test case for CompositionalMetric.forward() --- tests/bases/test_composition.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/bases/test_composition.py b/tests/bases/test_composition.py index df6a90b7171..fcae3d92524 100644 --- a/tests/bases/test_composition.py +++ b/tests/bases/test_composition.py @@ -24,7 +24,7 @@ class DummyMetric(Metric): def __init__(self, val_to_return): super().__init__() - self._num_updates = 0 + self.add_state("_num_updates", tensor(0), dist_reduce_fx="sum") self._val_to_return = val_to_return self._update_called = True @@ -34,10 +34,6 @@ def update(self, *args, **kwargs) -> None: def compute(self): return tensor(self._val_to_return) - def reset(self): - self._num_updates = 0 - return super().reset() - @pytest.mark.parametrize( ["second_operand", "expected_result"], @@ -557,3 +553,19 @@ def test_compositional_metrics_update(): assert compos.metric_a._num_updates == 3 assert compos.metric_b._num_updates == 3 + + +def test_compositional_metrics_forward(): + + compos = DummyMetric(5) + DummyMetric(4) + + assert isinstance(compos, CompositionalMetric) + compos() + compos() + compos() + + assert isinstance(compos.metric_a, DummyMetric) + assert isinstance(compos.metric_b, DummyMetric) + + assert compos.metric_a._num_updates == 3 + assert compos.metric_b._num_updates == 3 From a92fa7d1ab6f9b7b84e6a871cd9a9de5bf436b37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Mon, 29 Nov 2021 15:50:42 +0100 Subject: [PATCH 03/13] Improved code style compliance --- torchmetrics/metric.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 9a3b0df9908..f1bb4c874f7 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -772,16 +772,17 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: if val_a is None: # compute_on_step of metric_a is False return None - elif val_b is None: + + if val_b is None: if isinstance(self.metric_b, Metric): # compute_on_step of metric_b is False return None - else: - # Unary op - return self.op(val_a) - else: - # Binary op - return self.op(val_a, val_b) + + # Unary op + return self.op(val_a) + + # Binary op + return self.op(val_a, val_b) def reset(self) -> None: if isinstance(self.metric_a, Metric): From fdf85c334f5521781aa4787fe8d52bb1ba693e63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Tue, 30 Nov 2021 13:53:47 +0100 Subject: [PATCH 04/13] Extended unit test as suggested by SkafteNicki Co-authored-by: Nicki Skafte Detlefsen --- tests/bases/test_composition.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/bases/test_composition.py b/tests/bases/test_composition.py index fcae3d92524..da81a99e8c1 100644 --- a/tests/bases/test_composition.py +++ b/tests/bases/test_composition.py @@ -555,14 +555,15 @@ def test_compositional_metrics_update(): assert compos.metric_b._num_updates == 3 -def test_compositional_metrics_forward(): - - compos = DummyMetric(5) + DummyMetric(4) +@pytest.mark.parametrize("metric_b", [4, DummyMetric(4)]) +def test_compositional_metrics_forward(metric_b): + """ test forward method of compositional metric """ + compos = DummyMetric(5) + metric_b assert isinstance(compos, CompositionalMetric) - compos() - compos() - compos() + assert compos() == 9 + assert compos() == 9 + assert compos() == 9 assert isinstance(compos.metric_a, DummyMetric) assert isinstance(compos.metric_b, DummyMetric) From 00d375af9c0af3434f6df1951c703b06aa8e33c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 12:54:18 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_composition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bases/test_composition.py b/tests/bases/test_composition.py index da81a99e8c1..d6ef1e45c08 100644 --- a/tests/bases/test_composition.py +++ b/tests/bases/test_composition.py @@ -557,7 +557,7 @@ def test_compositional_metrics_update(): @pytest.mark.parametrize("metric_b", [4, DummyMetric(4)]) def test_compositional_metrics_forward(metric_b): - """ test forward method of compositional metric """ + """test forward method of compositional metric.""" compos = DummyMetric(5) + metric_b assert isinstance(compos, CompositionalMetric) From ce79a865ea9948d14631a09d00b7276a82fa790a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 30 Nov 2021 15:03:54 +0100 Subject: [PATCH 06/13] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 88d7db99dce..182f8dd0e24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix edge case of AUROC with `average=weighted` on GPU ([#606](https://github.com/PyTorchLightning/metrics/pull/606)) +- Fixed `forward` in compositional metrics ([#645](https://github.com/PyTorchLightning/metrics/pull/645)) + ## [0.6.0] - 2021-10-28 From 9ab9e5b65cc9dddb6b8a3ae3e247eafb071b43d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Tue, 30 Nov 2021 15:05:53 +0100 Subject: [PATCH 07/13] Fixed test_compositional_metrics_forward --- tests/bases/test_composition.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/bases/test_composition.py b/tests/bases/test_composition.py index d6ef1e45c08..acba57d89e2 100644 --- a/tests/bases/test_composition.py +++ b/tests/bases/test_composition.py @@ -566,7 +566,8 @@ def test_compositional_metrics_forward(metric_b): assert compos() == 9 assert isinstance(compos.metric_a, DummyMetric) - assert isinstance(compos.metric_b, DummyMetric) - assert compos.metric_a._num_updates == 3 - assert compos.metric_b._num_updates == 3 + + if isinstance(metric_b, Metric): + assert isinstance(compos.metric_b, DummyMetric) + assert compos.metric_b._num_updates == 3 From 88af0b387fc9a6865b9325b3f3bc6ef54a48a586 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 30 Nov 2021 15:15:24 +0100 Subject: [PATCH 08/13] fix + improve tests --- tests/bases/test_composition.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/bases/test_composition.py b/tests/bases/test_composition.py index d6ef1e45c08..a44042aa5c5 100644 --- a/tests/bases/test_composition.py +++ b/tests/bases/test_composition.py @@ -540,7 +540,7 @@ def test_metrics_getitem(value, idx, expected_result): def test_compositional_metrics_update(): - + """ test update method for compositional metrics.""" compos = DummyMetric(5) + DummyMetric(4) assert isinstance(compos, CompositionalMetric) @@ -555,18 +555,25 @@ def test_compositional_metrics_update(): assert compos.metric_b._num_updates == 3 +@pytest.mark.parametrize("compute_on_step", [True, False]) @pytest.mark.parametrize("metric_b", [4, DummyMetric(4)]) -def test_compositional_metrics_forward(metric_b): - """test forward method of compositional metric.""" - compos = DummyMetric(5) + metric_b +def test_compositional_metrics_forward(compute_on_step, metric_b): + """test forward method of compositional metrics.""" + metric_a = DummyMetric(5) + metric_a.compute_on_step = compute_on_step + compos = metric_a + metric_b assert isinstance(compos, CompositionalMetric) - assert compos() == 9 - assert compos() == 9 - assert compos() == 9 + for _ in range(3): + val = compos() + assert val == 9 if compute_on_step else val == None assert isinstance(compos.metric_a, DummyMetric) - assert isinstance(compos.metric_b, DummyMetric) - assert compos.metric_a._num_updates == 3 - assert compos.metric_b._num_updates == 3 + + if isinstance(metric_b, DummyMetric): + assert isinstance(compos.metric_b, DummyMetric) + assert compos.metric_b._num_updates == 3 + + compos.reset() + From 87386aceb29ea297c617846cc9ecc5a765b0f6b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 14:18:38 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_composition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bases/test_composition.py b/tests/bases/test_composition.py index dc81d821cfa..a4f6391146c 100644 --- a/tests/bases/test_composition.py +++ b/tests/bases/test_composition.py @@ -540,7 +540,7 @@ def test_metrics_getitem(value, idx, expected_result): def test_compositional_metrics_update(): - """ test update method for compositional metrics.""" + """test update method for compositional metrics.""" compos = DummyMetric(5) + DummyMetric(4) assert isinstance(compos, CompositionalMetric) From 7aa166431612e51ed2fc0870924ba4790e1befa9 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 30 Nov 2021 15:19:24 +0100 Subject: [PATCH 10/13] flake8 --- tests/bases/test_composition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bases/test_composition.py b/tests/bases/test_composition.py index dc81d821cfa..2110f7f613f 100644 --- a/tests/bases/test_composition.py +++ b/tests/bases/test_composition.py @@ -566,7 +566,7 @@ def test_compositional_metrics_forward(compute_on_step, metric_b): assert isinstance(compos, CompositionalMetric) for _ in range(3): val = compos() - assert val == 9 if compute_on_step else val == None + assert val == 9 if compute_on_step else val is None assert isinstance(compos.metric_a, DummyMetric) assert compos.metric_a._num_updates == 3 From be20ddb1549325a2630ffa24cf9f4ec4605952df Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 30 Nov 2021 19:34:24 +0100 Subject: [PATCH 11/13] Apply suggestions from code review --- torchmetrics/metric.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index f1bb4c874f7..5dcf8d9d99d 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -759,15 +759,9 @@ def compute(self) -> Any: @torch.jit.unused def forward(self, *args: Any, **kwargs: Any) -> Any: - if isinstance(self.metric_a, Metric): - val_a = self.metric_a(*args, **self.metric_a._filter_kwargs(**kwargs)) - else: - val_a = self.metric_a + val_a = self.metric_a(*args, **self.metric_a._filter_kwargs(**kwargs)) if isinstance(self.metric_a, Metric) else self.metric_a - if isinstance(self.metric_b, Metric): - val_b = self.metric_b(*args, **self.metric_b._filter_kwargs(**kwargs)) - else: - val_b = self.metric_b + val_b = self.metric_b(*args, **self.metric_b._filter_kwargs(**kwargs)) if isinstance(self.metric_b, Metric) else self.metric_b if val_a is None: # compute_on_step of metric_a is False From e57142ff7130efe5bbfb877206ec4e6de7bac3b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 18:34:56 +0000 Subject: [PATCH 12/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/metric.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 5dcf8d9d99d..06f28be7eb9 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -759,9 +759,17 @@ def compute(self) -> Any: @torch.jit.unused def forward(self, *args: Any, **kwargs: Any) -> Any: - val_a = self.metric_a(*args, **self.metric_a._filter_kwargs(**kwargs)) if isinstance(self.metric_a, Metric) else self.metric_a + val_a = ( + self.metric_a(*args, **self.metric_a._filter_kwargs(**kwargs)) + if isinstance(self.metric_a, Metric) + else self.metric_a + ) - val_b = self.metric_b(*args, **self.metric_b._filter_kwargs(**kwargs)) if isinstance(self.metric_b, Metric) else self.metric_b + val_b = ( + self.metric_b(*args, **self.metric_b._filter_kwargs(**kwargs)) + if isinstance(self.metric_b, Metric) + else self.metric_b + ) if val_a is None: # compute_on_step of metric_a is False From 5c26c932000090b3b5a107eda6ee28f69205306d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 30 Nov 2021 19:35:32 +0100 Subject: [PATCH 13/13] Apply suggestions from code review --- torchmetrics/metric.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 06f28be7eb9..e32aace6cfa 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -764,7 +764,6 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: if isinstance(self.metric_a, Metric) else self.metric_a ) - val_b = ( self.metric_b(*args, **self.metric_b._filter_kwargs(**kwargs)) if isinstance(self.metric_b, Metric)