From 43d38dd57e6bfa44284159be7df0950625667e37 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 10 Apr 2024 13:48:15 +0200 Subject: [PATCH] compute precision recall on raw scores Signed-off-by: Ashwin Vaidya --- src/anomalib/metrics/optimal_f1.py | 3 ++- tests/unit/metrics/test_optimal_f1.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/anomalib/metrics/optimal_f1.py b/src/anomalib/metrics/optimal_f1.py index c7f983c696..d9d4537973 100644 --- a/src/anomalib/metrics/optimal_f1.py +++ b/src/anomalib/metrics/optimal_f1.py @@ -7,7 +7,8 @@ import torch from torchmetrics import Metric -from torchmetrics.classification import BinaryPrecisionRecallCurve + +from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve logger = logging.getLogger(__name__) diff --git a/tests/unit/metrics/test_optimal_f1.py b/tests/unit/metrics/test_optimal_f1.py index 4f6aaa0126..8dcece255d 100644 --- a/tests/unit/metrics/test_optimal_f1.py +++ b/tests/unit/metrics/test_optimal_f1.py @@ -42,4 +42,10 @@ def test_optimal_f1_raw() -> None: metric.update(preds, labels) assert metric.compute() == 1.0 - assert metric.threshold == 0.5 + assert metric.threshold == 0.0 + + metric.reset() + preds = torch.tensor([-0.5, 0.0, 1.0, 2.0, -0.1]) + metric.update(preds, labels) + assert metric.compute() == torch.tensor(1.0) + assert metric.threshold == -0.1