Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ”¨ Rename OptimalF1 to F1Max #1980

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Changed

- πŸ”¨Rename OptimalF1 to F1Max for consistency with the literature, by @samet-akcay in https://github.com/openvinotoolkit/anomalib/pull/1980
- 🐞Update OptimalF1 score to use BinaryPrecisionRecallCurve and remove num_classes by @ashwinvaidya17 in https://github.com/openvinotoolkit/anomalib/pull/1972

### Deprecated
Expand Down
3 changes: 2 additions & 1 deletion src/anomalib/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import importlib
import logging
from collections.abc import Callable
Expand All @@ -17,6 +16,7 @@
from .aupro import AUPRO
from .auroc import AUROC
from .collection import AnomalibMetricCollection
from .f1_max import F1Max
from .f1_score import F1Score
from .min_max import MinMax
from .precision_recall_curve import BinaryPrecisionRecallCurve
Expand All @@ -30,6 +30,7 @@
"AnomalyScoreDistribution",
"BinaryPrecisionRecallCurve",
"F1AdaptiveThreshold",
"F1Max",
"F1Score",
"ManualThreshold",
"MinMax",
Expand Down
100 changes: 100 additions & 0 deletions src/anomalib/metrics/f1_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Implementation of F1Max score based on TorchMetrics."""

# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging

import torch
from torchmetrics import Metric

from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve

logger = logging.getLogger(__name__)


class F1Max(Metric):
"""F1Max Metric for Computing the Maximum F1 Score.

This class is designed to calculate the maximum F1 score from the precision-
recall curve for binary classification tasks. The F1 score is a harmonic
mean of precision and recall, offering a balance between these two metrics.
The maximum F1 score (F1-Max) is particularly useful in scenarios where an
optimal balance between precision and recall is desired, such as in
imbalanced datasets or when both false positives and false negatives carry
significant costs.

After computing the F1Max score, the class also identifies and stores the
threshold that yields this maximum F1 score, which providing insight into
the optimal point for the classification decision.

Args:
**kwargs: Variable keyword arguments that can be passed to the parent class.

Attributes:
full_state_update (bool): Indicates whether the metric requires updating
the entire state. Set to False for this metric as it calculates the
F1 score based on the current state without needing historical data.
precision_recall_curve (BinaryPrecisionRecallCurve): Utility to compute
precision and recall values across different thresholds.
threshold (torch.Tensor): Stores the threshold value that results in the
maximum F1 score.

Examples:
>>> from anomalib.metrics import F1Max
>>> import torch

>>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8])
>>> target = torch.tensor([0, 0, 1, 1])

>>> f1_max = F1Max()
>>> f1_max.update(preds, target)

>>> optimal_f1_score = f1_max.compute()
>>> print(f"Optimal F1 Score: {f1_max_score}")
>>> print(f"Optimal Threshold: {f1_max.threshold}")

Note:
- Use `update` method to input predictions and target labels.
- Use `compute` method to calculate the maximum F1 score after all
updates.
- Use `reset` method to clear the current state and prepare for a new
set of calculations.
"""

full_state_update: bool = False

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

self.precision_recall_curve = BinaryPrecisionRecallCurve()

self.threshold: torch.Tensor

def update(self, preds: torch.Tensor, target: torch.Tensor, *args, **kwargs) -> None:
"""Update the precision-recall curve metric."""
del args, kwargs # These variables are not used.

self.precision_recall_curve.update(preds, target)

def compute(self) -> torch.Tensor:
"""Compute the value of the optimal F1 score.

Compute the F1 scores while varying the threshold. Store the optimal
threshold as attribute and return the maximum value of the F1 score.

Returns:
Value of the F1 score at the optimal threshold.
"""
precision: torch.Tensor
recall: torch.Tensor
thresholds: torch.Tensor

precision, recall, thresholds = self.precision_recall_curve.compute()
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
self.threshold = thresholds[torch.argmax(f1_score)]
return torch.max(f1_score)

def reset(self) -> None:
"""Reset the metric."""
self.precision_recall_curve.reset()
73 changes: 0 additions & 73 deletions src/anomalib/metrics/optimal_f1.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
"""Test OptimalF1 metric."""
"""Test F1Max metric."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch

from anomalib.metrics.optimal_f1 import OptimalF1
from anomalib.metrics.f1_max import F1Max


def test_optimal_f1_logits() -> None:
"""Checks if OptimalF1 metric computes the optimal F1 score.
def test_f1_max_logits() -> None:
"""Checks if F1Max metric computes the optimal F1 score.

Test when the preds are in [0, 1]
"""
metric = OptimalF1()
metric = F1Max()

preds = torch.tensor([0.1, 0.5, 0.9, 1.0])
labels = torch.tensor([0, 1, 1, 1])
Expand All @@ -30,12 +30,12 @@ def test_optimal_f1_logits() -> None:
assert metric.threshold == 0.1


def test_optimal_f1_raw() -> None:
"""Checks if OptimalF1 metric computes the optimal F1 score.
def test_f1_max_raw() -> None:
"""Checks if F1Max metric computes the optimal F1 score.

Test when the preds are outside [0, 1]. BinaryPrecisionRecall automatically applies sigmoid.
"""
metric = OptimalF1()
metric = F1Max()

preds = torch.tensor([-0.5, 0, 0.5, 1.0, 2])
labels = torch.tensor([0, 1, 1, 1, 1])
Expand Down
Loading