diff --git a/CHANGELOG.md b/CHANGELOG.md index 747deb5185..c2ed8c1a68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Changed +- 🔨Rename OptimalF1 to F1Max for consistency with the literature, by @samet-akcay in https://github.com/openvinotoolkit/anomalib/pull/1980 - 🐞Update OptimalF1 score to use BinaryPrecisionRecallCurve and remove num_classes by @ashwinvaidya17 in https://github.com/openvinotoolkit/anomalib/pull/1972 ### Deprecated diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 2eefb4882d..4c3eafa811 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -3,7 +3,6 @@ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - import importlib import logging from collections.abc import Callable @@ -17,6 +16,7 @@ from .aupro import AUPRO from .auroc import AUROC from .collection import AnomalibMetricCollection +from .f1_max import F1Max from .f1_score import F1Score from .min_max import MinMax from .precision_recall_curve import BinaryPrecisionRecallCurve @@ -30,6 +30,7 @@ "AnomalyScoreDistribution", "BinaryPrecisionRecallCurve", "F1AdaptiveThreshold", + "F1Max", "F1Score", "ManualThreshold", "MinMax", diff --git a/src/anomalib/metrics/f1_max.py b/src/anomalib/metrics/f1_max.py new file mode 100644 index 0000000000..8b9b42f305 --- /dev/null +++ b/src/anomalib/metrics/f1_max.py @@ -0,0 +1,100 @@ +"""Implementation of F1Max score based on TorchMetrics.""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import torch +from torchmetrics import Metric + +from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve + +logger = logging.getLogger(__name__) + + +class F1Max(Metric): + """F1Max Metric for Computing the Maximum F1 Score. + + This class is designed to calculate the maximum F1 score from the precision- + recall curve for binary classification tasks. The F1 score is a harmonic + mean of precision and recall, offering a balance between these two metrics. + The maximum F1 score (F1-Max) is particularly useful in scenarios where an + optimal balance between precision and recall is desired, such as in + imbalanced datasets or when both false positives and false negatives carry + significant costs. + + After computing the F1Max score, the class also identifies and stores the + threshold that yields this maximum F1 score, which providing insight into + the optimal point for the classification decision. + + Args: + **kwargs: Variable keyword arguments that can be passed to the parent class. + + Attributes: + full_state_update (bool): Indicates whether the metric requires updating + the entire state. Set to False for this metric as it calculates the + F1 score based on the current state without needing historical data. + precision_recall_curve (BinaryPrecisionRecallCurve): Utility to compute + precision and recall values across different thresholds. + threshold (torch.Tensor): Stores the threshold value that results in the + maximum F1 score. + + Examples: + >>> from anomalib.metrics import F1Max + >>> import torch + + >>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + >>> target = torch.tensor([0, 0, 1, 1]) + + >>> f1_max = F1Max() + >>> f1_max.update(preds, target) + + >>> optimal_f1_score = f1_max.compute() + >>> print(f"Optimal F1 Score: {f1_max_score}") + >>> print(f"Optimal Threshold: {f1_max.threshold}") + + Note: + - Use `update` method to input predictions and target labels. + - Use `compute` method to calculate the maximum F1 score after all + updates. + - Use `reset` method to clear the current state and prepare for a new + set of calculations. + """ + + full_state_update: bool = False + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + self.precision_recall_curve = BinaryPrecisionRecallCurve() + + self.threshold: torch.Tensor + + def update(self, preds: torch.Tensor, target: torch.Tensor, *args, **kwargs) -> None: + """Update the precision-recall curve metric.""" + del args, kwargs # These variables are not used. + + self.precision_recall_curve.update(preds, target) + + def compute(self) -> torch.Tensor: + """Compute the value of the optimal F1 score. + + Compute the F1 scores while varying the threshold. Store the optimal + threshold as attribute and return the maximum value of the F1 score. + + Returns: + Value of the F1 score at the optimal threshold. + """ + precision: torch.Tensor + recall: torch.Tensor + thresholds: torch.Tensor + + precision, recall, thresholds = self.precision_recall_curve.compute() + f1_score = (2 * precision * recall) / (precision + recall + 1e-10) + self.threshold = thresholds[torch.argmax(f1_score)] + return torch.max(f1_score) + + def reset(self) -> None: + """Reset the metric.""" + self.precision_recall_curve.reset() diff --git a/src/anomalib/metrics/optimal_f1.py b/src/anomalib/metrics/optimal_f1.py deleted file mode 100644 index d9d4537973..0000000000 --- a/src/anomalib/metrics/optimal_f1.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Implementation of Optimal F1 score based on TorchMetrics.""" - -# Copyright (C) 2022-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import logging - -import torch -from torchmetrics import Metric - -from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve - -logger = logging.getLogger(__name__) - - -class OptimalF1(Metric): - """Optimal F1 Metric. - - Compute the optimal F1 score at the adaptive threshold, based on the F1 - metric of the true labels and the predicted anomaly scores. - - Args: - kwargs: Any keyword arguments. - - .. deprecated:: 1.0.0 - OptimalF1 metric is deprecated and will be removed in a future release. - The optimal F1 score for Anomalib predictions can be obtained by - computing the adaptive threshold with the AnomalyScoreThreshold metric - and setting the computed threshold value in TorchMetrics F1Score metric. - """ - - full_state_update: bool = False - - def __init__(self, **kwargs) -> None: - msg = ( - "OptimalF1 metric is deprecated and will be removed in a future release. The optimal F1 score for " - "Anomalib predictions can be obtained by computing the adaptive threshold with the " - "AnomalyScoreThreshold metric and setting the computed threshold value in TorchMetrics F1Score metric." - ) - logger.warning(msg) - super().__init__(**kwargs) - - self.precision_recall_curve = BinaryPrecisionRecallCurve() - - self.threshold: torch.Tensor - - def update(self, preds: torch.Tensor, target: torch.Tensor, *args, **kwargs) -> None: - """Update the precision-recall curve metric.""" - del args, kwargs # These variables are not used. - - self.precision_recall_curve.update(preds, target) - - def compute(self) -> torch.Tensor: - """Compute the value of the optimal F1 score. - - Compute the F1 scores while varying the threshold. Store the optimal - threshold as attribute and return the maximum value of the F1 score. - - Returns: - Value of the F1 score at the optimal threshold. - """ - precision: torch.Tensor - recall: torch.Tensor - thresholds: torch.Tensor - - precision, recall, thresholds = self.precision_recall_curve.compute() - f1_score = (2 * precision * recall) / (precision + recall + 1e-10) - self.threshold = thresholds[torch.argmax(f1_score)] - return torch.max(f1_score) - - def reset(self) -> None: - """Reset the metric.""" - self.precision_recall_curve.reset() diff --git a/tests/unit/metrics/test_optimal_f1.py b/tests/unit/metrics/test_f1_max.py similarity index 76% rename from tests/unit/metrics/test_optimal_f1.py rename to tests/unit/metrics/test_f1_max.py index 8dcece255d..7ce60e9996 100644 --- a/tests/unit/metrics/test_optimal_f1.py +++ b/tests/unit/metrics/test_f1_max.py @@ -1,19 +1,19 @@ -"""Test OptimalF1 metric.""" +"""Test F1Max metric.""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import torch -from anomalib.metrics.optimal_f1 import OptimalF1 +from anomalib.metrics.f1_max import F1Max -def test_optimal_f1_logits() -> None: - """Checks if OptimalF1 metric computes the optimal F1 score. +def test_f1_max_logits() -> None: + """Checks if F1Max metric computes the optimal F1 score. Test when the preds are in [0, 1] """ - metric = OptimalF1() + metric = F1Max() preds = torch.tensor([0.1, 0.5, 0.9, 1.0]) labels = torch.tensor([0, 1, 1, 1]) @@ -30,12 +30,12 @@ def test_optimal_f1_logits() -> None: assert metric.threshold == 0.1 -def test_optimal_f1_raw() -> None: - """Checks if OptimalF1 metric computes the optimal F1 score. +def test_f1_max_raw() -> None: + """Checks if F1Max metric computes the optimal F1 score. Test when the preds are outside [0, 1]. BinaryPrecisionRecall automatically applies sigmoid. """ - metric = OptimalF1() + metric = F1Max() preds = torch.tensor([-0.5, 0, 0.5, 1.0, 2]) labels = torch.tensor([0, 1, 1, 1, 1])