Skip to content

Commit

Permalink
Add option to load metrics with kwargs (#688)
Browse files Browse the repository at this point in the history
* add option to load metrics with kwargs

* make subfunctions private

* add tests

* address pr comments

* fix issu #685

* fix import

* address check issues in tests

* address pr requests

* doc it

* Apply suggestions from code review

* Apply suggestions from code review

* fix test import

* indent example

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
jpcbertoldo and samet-akcay authored Nov 8, 2022
1 parent 034953f commit b430573
Show file tree
Hide file tree
Showing 12 changed files with 335 additions and 50 deletions.
6 changes: 2 additions & 4 deletions anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,10 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
callbacks.append(post_processing_callback)

# Add metric configuration to the model via MetricsConfigurationCallback
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else None
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else None
metrics_callback = MetricsConfigurationCallback(
config.dataset.task,
image_metric_names,
pixel_metric_names,
config.metrics.get("image", None),
config.metrics.get("pixel", None),
)
callbacks.append(metrics_callback)

Expand Down
6 changes: 3 additions & 3 deletions anomalib/utils/callbacks/metrics_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY

from anomalib.models.components.base.anomaly_module import AnomalyModule
from anomalib.utils.metrics import metric_collection_from_names
from anomalib.utils.metrics import create_metric_collection

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,8 +74,8 @@ def setup(
pixel_metric_names = self.pixel_metric_names

if isinstance(pl_module, AnomalyModule):
pl_module.image_metrics = metric_collection_from_names(image_metric_names, "image_")
pl_module.pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_")
pl_module.image_metrics = create_metric_collection(image_metric_names, "image_")
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")

pl_module.image_metrics.set_threshold(pl_module.image_threshold.value)
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value)
141 changes: 123 additions & 18 deletions anomalib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import importlib
import warnings
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union

import torchmetrics
from omegaconf import DictConfig, ListConfig
Expand All @@ -23,23 +23,6 @@
__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AnomalyScoreThreshold", "AnomalyScoreDistribution", "MinMax", "PRO"]


def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
"""Create metric collections based on the config.
Args:
config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf
Returns:
AnomalibMetricCollection: Image-level metric collection
AnomalibMetricCollection: Pixel-level metric collection
"""
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else []
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else []
image_metrics = metric_collection_from_names(image_metric_names, "image_")
pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_")
return image_metrics, pixel_metrics


def metric_collection_from_names(metric_names: List[str], prefix: Optional[str]) -> AnomalibMetricCollection:
"""Create a metric collection from a list of metric names.
Expand Down Expand Up @@ -68,3 +51,125 @@ def metric_collection_from_names(metric_names: List[str], prefix: Optional[str])
else:
warnings.warn(f"No metric with name {metric_name} found in Anomalib metrics or TorchMetrics.")
return metrics


def _validate_metrics_dict(metrics: Dict[str, Dict[str, Any]]) -> None:
"""Check the assumptions about metrics config dict.
- Keys are metric names
- Values are dictionaries.
- Internal dictionaries:
- have key "class_path" and its value is of type str
- have key init_args" and its value is of type dict).
"""
assert all(
isinstance(metric, str) for metric in metrics.keys()
), f"All keys (metric names) must be strings, found {sorted(metrics.keys())}"
assert all(
isinstance(metric, (dict, DictConfig)) for metric in metrics.values()
), f"All values must be dictionaries, found {list(metrics.values())}"
assert all("class_path" in metric and isinstance(metric["class_path"], str) for metric in metrics.values()), (
"All internal dictionaries must have a 'class_path' key whose value is of type str, "
f"found {list(metrics.values())}"
)
assert all(
"init_args" in metric and isinstance(metric["init_args"], (dict, DictConfig)) for metric in metrics.values()
), (
"All internal dictionaries must have a 'init_args' key whose value is of type dict, "
f"found {list(metrics.values())}"
)


def _get_class_from_path(class_path: str) -> Any:
"""Get a class from a module assuming the string format is `package.subpackage.module.ClassName`."""
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
assert hasattr(module, class_name), f"Class {class_name} not found in module {module_name}"
cls = getattr(module, class_name)
return cls


def metric_collection_from_dicts(metrics: Dict[str, Dict[str, Any]], prefix: Optional[str]) -> AnomalibMetricCollection:
"""Create a metric collection from a dict of "metric name" -> "metric specifications".
Example:
metrics = {
"PixelWiseF1Score": {
"class_path": "torchmetrics.F1Score",
"init_args": {},
},
"PixelWiseAUROC": {
"class_path": "anomalib.utils.metrics.AUROC",
"init_args": {
"compute_on_cpu": True,
},
},
}
In the config file, the same specifications (for pixel-wise metrics) look like:
```yaml
metrics:
pixel:
PixelWiseF1Score:
class_path: torchmetrics.F1Score
init_args: {}
PixelWiseAUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true
```
Args:
metrics (Dict[str, Dict[str, Any]]): keys are metric names, values are dictionaries.
Internal Dict[str, Any] keys are "class_path" (value is string) and "init_args" (value is dict),
following the convention in Pytorch Lightning CLI.
prefix (Optional[str]): prefix to assign to the metrics in the collection.
Returns:
AnomalibMetricCollection: Collection of metrics.
"""
_validate_metrics_dict(metrics)
metrics_collection = {}
for name, dict_ in metrics.items():
class_path = dict_["class_path"]
kwargs = dict_["init_args"]
cls = _get_class_from_path(class_path)
metrics_collection[name] = cls(**kwargs)
return AnomalibMetricCollection(metrics_collection, prefix=prefix)


def create_metric_collection(
metrics: Union[List[str], Dict[str, Dict[str, Any]]], prefix: Optional[str]
) -> AnomalibMetricCollection:
"""Create a metric collection from a list of metric names or dictionaries.
This function will dispatch the actual creation to the appropriate function depending on the input type:
- if List[str] (names of metrics): see `metric_collection_from_names`
- if Dict[str, Dict[str, Any]] (path and init args of a class): see `metric_collection_from_dicts`
The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module,
then in TorchMetrics package.
Args:
metrics (Union[List[str], Dict[str, Dict[str, Any]]]).
prefix (Optional[str]): prefix to assign to the metrics in the collection.
Returns:
AnomalibMetricCollection: Collection of metrics.
"""
# fallback is using the names

if isinstance(metrics, (ListConfig, list)):
assert all(isinstance(metric, str) for metric in metrics), f"All metrics must be strings, found {metrics}"
return metric_collection_from_names(metrics, prefix)

if isinstance(metrics, (DictConfig, dict)):
_validate_metrics_dict(metrics)
return metric_collection_from_dicts(metrics, prefix)

raise ValueError(f"metrics must be a list or a dict, found {type(metrics)}")
40 changes: 40 additions & 0 deletions docs/source/reference_guide/api/metrics.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,46 @@
Metrics
=======

There are two ways of configuring metrics in the config file:

1. a list of metric names, or
2. a mapping of metric names to class path and init args.

Each subsection in the section ``metrics`` of the config file can have a different style but inside each one it must be the same style.

.. code-block:: yaml
:caption: Example of metrics configuration section in the config file.
metrics:
# imagewise metrics using the list of metric names style
image:
- F1Score
- AUROC
# pixelwise metrics using the mapping style
pixel:
F1Score:
class_path: torchmetrics.F1Score
init_args:
compute_on_cpu: true
AUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true
List of metric names
--------------------

A list of strings that match the name of a class in ``anomalib.utils.metrics`` or ``torchmetrics`` (in this order of priority), which will be instantiated with default arguments.

Mapping of metric names to class path and init args
---------------------------------------------------

A mapping of metric names (str) to a dictionary with two keys: "class_path" and "init_args".

"class_path" is a string with the full path to a metric (from root package down to the class name, e.g.: "anomalib.utils.metrics.AUROC").

"init_args" is a dictionary of arguments to be passed to the class constructor.

.. automodule:: anomalib.utils.metrics
:members:
:undoc-members:
Expand Down
43 changes: 43 additions & 0 deletions tests/helpers/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import shutil
import tempfile
from pathlib import Path

import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

from anomalib.utils.loggers.tensorboard import AnomalibTensorBoardLogger


class DummyDataset(Dataset):
def __len__(self):
return 1

def __getitem__(self, idx):
return torch.ones(1)


class DummyDataModule(pl.LightningDataModule):
def train_dataloader(self) -> DataLoader:
return DataLoader(DummyDataset())

def val_dataloader(self) -> DataLoader:
return DataLoader(DummyDataset())

def test_dataloader(self) -> DataLoader:
return DataLoader(DummyDataset())


class DummyModel(nn.Module):
pass


class DummyLogger(AnomalibTensorBoardLogger):
def __init__(self):
self.tempdir = Path(tempfile.mkdtemp())
super().__init__(name="tensorboard_logs", save_dir=self.tempdir)

def __del__(self):
if self.tempdir.exists():
shutil.rmtree(self.tempdir)
27 changes: 27 additions & 0 deletions tests/helpers/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Helpers for metrics tests."""

from typing import Tuple, Union

from omegaconf import DictConfig, ListConfig

from anomalib.utils.metrics import (
AnomalibMetricCollection,
metric_collection_from_names,
)


def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
"""Create metric collections based on the config.
Args:
config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf
Returns:
AnomalibMetricCollection: Image-level metric collection
AnomalibMetricCollection: Pixel-level metric collection
"""
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else []
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else []
image_metrics = metric_collection_from_names(image_metric_names, "image_")
pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_")
return image_metrics, pixel_metrics
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
metrics:
pixel:
F1Score:
class_path: torchmetrics.F1Score
init_args:
compute_on_cpu: true
AUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true
image:
- F1Score
- AUROC
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
metrics:
pixel:
- F1Score
- AUROC
image:
F1Score:
class_path: torchmetrics.F1Score
init_args:
compute_on_cpu: true
AUROC:
class_path: anomalib.utils.metrics.AUROC
init_args:
compute_on_cpu: true
Loading

0 comments on commit b430573

Please sign in to comment.