Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable metrics #230

Merged
merged 11 commits into from
Apr 20, 2022
Merged
18 changes: 7 additions & 11 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks.base import Callback
from torch import Tensor, nn
from torchmetrics import F1, MetricCollection

from anomalib.utils.metrics import (
AUROC,
AdaptiveThreshold,
AnomalyScoreDistribution,
MinMax,
get_metrics,
)


Expand Down Expand Up @@ -58,12 +57,9 @@ def __init__(self, params: Union[DictConfig, ListConfig]):
self.model: nn.Module

# metrics
image_auroc = AUROC(num_classes=1, pos_label=1, compute_on_step=False)
image_f1 = F1(num_classes=1, compute_on_step=False, threshold=self.hparams.model.threshold.image_default)
pixel_auroc = AUROC(num_classes=1, pos_label=1, compute_on_step=False)
pixel_f1 = F1(num_classes=1, compute_on_step=False, threshold=self.hparams.model.threshold.pixel_default)
self.image_metrics = MetricCollection([image_auroc, image_f1], prefix="image_").cpu()
self.pixel_metrics = MetricCollection([pixel_auroc, pixel_f1], prefix="pixel_").cpu()
self.image_metrics, self.pixel_metrics = get_metrics(self.hparams)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there something planned for classification models which do not have pixel metrics? When I removed the pixel metric key from the config file it threw error for padim

Copy link
Contributor Author

@djdameln djdameln Apr 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work:

metrics:
  image:
    - F1
    - AUROC
  pixel: []

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works now 🙂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why my comments disappeared from here, but ideally

metrics:
  image:
    - F1
    - AUROC

should also work fine.

self.image_metrics.set_threshold(self.hparams.model.threshold.image_default)
self.image_metrics.set_threshold(self.hparams.model.threshold.pixel_default)

def forward(self, batch): # pylint: disable=arguments-differ
"""Forward-pass input tensor to the module.
Expand Down Expand Up @@ -154,8 +150,8 @@ def _compute_adaptive_threshold(self, outputs):
else:
self.pixel_threshold.value = self.image_threshold.value

self.image_metrics.F1.threshold = self.image_threshold.value.item()
self.pixel_metrics.F1.threshold = self.pixel_threshold.value.item()
self.image_metrics.set_threshold(self.image_threshold.value.item())
self.pixel_metrics.set_threshold(self.pixel_threshold.value.item())

def _collect_outputs(self, image_metric, pixel_metric, outputs):
for output in outputs:
Expand All @@ -181,5 +177,5 @@ def _outputs_to_cpu(self, output):
def _log_metrics(self):
"""Log computed performance metrics."""
self.log_dict(self.image_metrics)
if self.hparams.dataset.task == "segmentation":
if self.pixel_metrics.update_called:
self.log_dict(self.pixel_metrics)
9 changes: 8 additions & 1 deletion anomalib/models/padim/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,20 @@ model:
- layer1
- layer2
- layer3
metric: auc
normalization_method: min_max # options: [none, min_max, cdf]
threshold:
image_default: 3
pixel_default: 3
adaptive: true

metrics:
image:
- F1
- AUROC
pixel:
- F1
- AUROC

project:
seed: 42
path: ./results
Expand Down
4 changes: 2 additions & 2 deletions anomalib/utils/callbacks/min_max_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class MinMaxNormalizationCallback(Callback):

def on_test_start(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Called when the test begins."""
pl_module.image_metrics.F1.threshold = 0.5
pl_module.pixel_metrics.F1.threshold = 0.5
pl_module.image_metrics.set_threshold(0.5)
pl_module.pixel_metrics.set_threshold(0.5)

def on_validation_batch_end(
self,
Expand Down
2 changes: 1 addition & 1 deletion anomalib/utils/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def on_test_batch_end(
normalize = False # anomaly maps are already normalized
else:
normalize = True # raw anomaly maps. Still need to normalize
threshold = pl_module.pixel_metrics.F1.threshold
threshold = pl_module.pixel_metrics.threshold

for i, (filename, image, anomaly_map, pred_score, gt_label) in enumerate(
zip(
Expand Down
53 changes: 53 additions & 0 deletions anomalib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,61 @@
"""Custom anomaly evaluation metrics."""
import importlib
import warnings
from typing import List, Optional, Tuple, Union

import torchmetrics
from omegaconf import DictConfig, ListConfig

from .adaptive_threshold import AdaptiveThreshold
from .anomaly_score_distribution import AnomalyScoreDistribution
from .auroc import AUROC
from .collection import AnomalibMetricCollection
from .min_max import MinMax
from .optimal_f1 import OptimalF1

__all__ = ["AUROC", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"]


def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to modify this for LightningCLI

"""Create metric collections based on the config.

Args:
config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf

Returns:
AnomalibMetricCollection: Image-level metric collection
AnomalibMetricCollection: Pixel-level metric collection
"""
image_metrics = metric_collection_from_names(config.metrics.image, "image_")
pixel_metrics = metric_collection_from_names(config.metrics.pixel, "pixel_")
return image_metrics, pixel_metrics


def metric_collection_from_names(metric_names: List[str], prefix: Optional[str]) -> AnomalibMetricCollection:
"""Create a metric collection from a list of metric names.

The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module,
then in TorchMetrics package.

Args:
metric_names (List[str]): List of metric names to be included in the collection.
prefix (Optional[str]): prefix to assign to the metrics in the collection.

Returns:
AnomalibMetricCollection: Collection of metrics.
"""
metrics_module = importlib.import_module("anomalib.utils.metrics")
metrics = AnomalibMetricCollection([], prefix=prefix)
for metric_name in metric_names:
if hasattr(metrics_module, metric_name):
metric_cls = getattr(metrics_module, metric_name)
metrics.add_metrics(metric_cls(compute_on_step=False))
elif hasattr(torchmetrics, metric_name):
try:
metric_cls = getattr(torchmetrics, metric_name)
metrics.add_metrics(metric_cls(compute_on_step=False))
except TypeError:
warnings.warn(f"Incorrect constructor arguments for {metric_name} metric from TorchMetrics package.")
else:
warnings.warn(f"No metric with name {metric_name} found in Anomalib metrics or TorchMetrics.")
return metrics
48 changes: 48 additions & 0 deletions anomalib/utils/metrics/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Anomalib Metric Collection."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from torchmetrics import MetricCollection


class AnomalibMetricCollection(MetricCollection):
"""Extends the MetricCollection class for use in the Anomalib pipeline."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._update_called = False
self._threshold = 0.5

def set_threshold(self, threshold_value):
"""Update the threshold value for all metrics that have the threshold attribute."""
self._threshold = threshold_value
for metric in self.values():
if hasattr(metric, "threshold"):
metric.threshold = threshold_value

def update(self, *args, **kwargs) -> None:
"""Add data to the metrics."""
super().update(*args, **kwargs)
self._update_called = True

@property
def update_called(self) -> bool:
"""Returns a boolean indicating if the update method has been called at least once."""
return self._update_called

@property
def threshold(self) -> float:
"""Return the value of the anomaly threshold."""
return self._threshold