From b4305730cf5ddcae28df404dc1108dd519428b7c Mon Sep 17 00:00:00 2001 From: Joao P C Bertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Tue, 8 Nov 2022 13:29:24 +0100 Subject: [PATCH] Add option to load metrics with kwargs (#688) * add option to load metrics with kwargs * make subfunctions private * add tests * address pr comments * fix issu #685 * fix import * address check issues in tests * address pr requests * doc it * Apply suggestions from code review * Apply suggestions from code review * fix test import * indent example Co-authored-by: Samet Akcay --- anomalib/utils/callbacks/__init__.py | 6 +- .../utils/callbacks/metrics_configuration.py | 6 +- anomalib/utils/metrics/__init__.py | 141 +++++++++++++++--- docs/source/reference_guide/api/metrics.rst | 40 +++++ tests/helpers/dummy.py | 43 ++++++ tests/helpers/metrics.py | 27 ++++ .../__init__.py | 0 .../data/config-good-00.yaml | 13 ++ .../data/config-good-01.yaml | 13 ++ .../test_metrics_configuration_callback.py | 63 ++++++++ .../dummy_lightning_model.py | 30 +--- .../visualizer_callback/test_visualizer.py | 3 +- 12 files changed, 335 insertions(+), 50 deletions(-) create mode 100644 tests/helpers/dummy.py create mode 100644 tests/helpers/metrics.py create mode 100644 tests/pre_merge/utils/callbacks/metrics_configuration_callback/__init__.py create mode 100644 tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-00.yaml create mode 100644 tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-01.yaml create mode 100644 tests/pre_merge/utils/callbacks/metrics_configuration_callback/test_metrics_configuration_callback.py diff --git a/anomalib/utils/callbacks/__init__.py b/anomalib/utils/callbacks/__init__.py index 59cc25db0c..01f0675cc3 100644 --- a/anomalib/utils/callbacks/__init__.py +++ b/anomalib/utils/callbacks/__init__.py @@ -84,12 +84,10 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]: callbacks.append(post_processing_callback) # Add metric configuration to the model via MetricsConfigurationCallback - image_metric_names = config.metrics.image if "image" in config.metrics.keys() else None - pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else None metrics_callback = MetricsConfigurationCallback( config.dataset.task, - image_metric_names, - pixel_metric_names, + config.metrics.get("image", None), + config.metrics.get("pixel", None), ) callbacks.append(metrics_callback) diff --git a/anomalib/utils/callbacks/metrics_configuration.py b/anomalib/utils/callbacks/metrics_configuration.py index 2abc47f848..ef80b619e5 100644 --- a/anomalib/utils/callbacks/metrics_configuration.py +++ b/anomalib/utils/callbacks/metrics_configuration.py @@ -12,7 +12,7 @@ from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY from anomalib.models.components.base.anomaly_module import AnomalyModule -from anomalib.utils.metrics import metric_collection_from_names +from anomalib.utils.metrics import create_metric_collection logger = logging.getLogger(__name__) @@ -74,8 +74,8 @@ def setup( pixel_metric_names = self.pixel_metric_names if isinstance(pl_module, AnomalyModule): - pl_module.image_metrics = metric_collection_from_names(image_metric_names, "image_") - pl_module.pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_") + pl_module.image_metrics = create_metric_collection(image_metric_names, "image_") + pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") pl_module.image_metrics.set_threshold(pl_module.image_threshold.value) pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value) diff --git a/anomalib/utils/metrics/__init__.py b/anomalib/utils/metrics/__init__.py index a31a94a559..e1fd7541fb 100644 --- a/anomalib/utils/metrics/__init__.py +++ b/anomalib/utils/metrics/__init__.py @@ -5,7 +5,7 @@ import importlib import warnings -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torchmetrics from omegaconf import DictConfig, ListConfig @@ -23,23 +23,6 @@ __all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AnomalyScoreThreshold", "AnomalyScoreDistribution", "MinMax", "PRO"] -def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]: - """Create metric collections based on the config. - - Args: - config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf - - Returns: - AnomalibMetricCollection: Image-level metric collection - AnomalibMetricCollection: Pixel-level metric collection - """ - image_metric_names = config.metrics.image if "image" in config.metrics.keys() else [] - pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else [] - image_metrics = metric_collection_from_names(image_metric_names, "image_") - pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_") - return image_metrics, pixel_metrics - - def metric_collection_from_names(metric_names: List[str], prefix: Optional[str]) -> AnomalibMetricCollection: """Create a metric collection from a list of metric names. @@ -68,3 +51,125 @@ def metric_collection_from_names(metric_names: List[str], prefix: Optional[str]) else: warnings.warn(f"No metric with name {metric_name} found in Anomalib metrics or TorchMetrics.") return metrics + + +def _validate_metrics_dict(metrics: Dict[str, Dict[str, Any]]) -> None: + """Check the assumptions about metrics config dict. + + - Keys are metric names + - Values are dictionaries. + - Internal dictionaries: + - have key "class_path" and its value is of type str + - have key init_args" and its value is of type dict). + + """ + assert all( + isinstance(metric, str) for metric in metrics.keys() + ), f"All keys (metric names) must be strings, found {sorted(metrics.keys())}" + assert all( + isinstance(metric, (dict, DictConfig)) for metric in metrics.values() + ), f"All values must be dictionaries, found {list(metrics.values())}" + assert all("class_path" in metric and isinstance(metric["class_path"], str) for metric in metrics.values()), ( + "All internal dictionaries must have a 'class_path' key whose value is of type str, " + f"found {list(metrics.values())}" + ) + assert all( + "init_args" in metric and isinstance(metric["init_args"], (dict, DictConfig)) for metric in metrics.values() + ), ( + "All internal dictionaries must have a 'init_args' key whose value is of type dict, " + f"found {list(metrics.values())}" + ) + + +def _get_class_from_path(class_path: str) -> Any: + """Get a class from a module assuming the string format is `package.subpackage.module.ClassName`.""" + module_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name) + assert hasattr(module, class_name), f"Class {class_name} not found in module {module_name}" + cls = getattr(module, class_name) + return cls + + +def metric_collection_from_dicts(metrics: Dict[str, Dict[str, Any]], prefix: Optional[str]) -> AnomalibMetricCollection: + """Create a metric collection from a dict of "metric name" -> "metric specifications". + + Example: + + metrics = { + "PixelWiseF1Score": { + "class_path": "torchmetrics.F1Score", + "init_args": {}, + }, + "PixelWiseAUROC": { + "class_path": "anomalib.utils.metrics.AUROC", + "init_args": { + "compute_on_cpu": True, + }, + }, + } + + In the config file, the same specifications (for pixel-wise metrics) look like: + + ```yaml + metrics: + pixel: + PixelWiseF1Score: + class_path: torchmetrics.F1Score + init_args: {} + PixelWiseAUROC: + class_path: anomalib.utils.metrics.AUROC + init_args: + compute_on_cpu: true + ``` + + Args: + metrics (Dict[str, Dict[str, Any]]): keys are metric names, values are dictionaries. + Internal Dict[str, Any] keys are "class_path" (value is string) and "init_args" (value is dict), + following the convention in Pytorch Lightning CLI. + + prefix (Optional[str]): prefix to assign to the metrics in the collection. + + Returns: + AnomalibMetricCollection: Collection of metrics. + """ + _validate_metrics_dict(metrics) + metrics_collection = {} + for name, dict_ in metrics.items(): + class_path = dict_["class_path"] + kwargs = dict_["init_args"] + cls = _get_class_from_path(class_path) + metrics_collection[name] = cls(**kwargs) + return AnomalibMetricCollection(metrics_collection, prefix=prefix) + + +def create_metric_collection( + metrics: Union[List[str], Dict[str, Dict[str, Any]]], prefix: Optional[str] +) -> AnomalibMetricCollection: + """Create a metric collection from a list of metric names or dictionaries. + + This function will dispatch the actual creation to the appropriate function depending on the input type: + + - if List[str] (names of metrics): see `metric_collection_from_names` + - if Dict[str, Dict[str, Any]] (path and init args of a class): see `metric_collection_from_dicts` + + The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module, + then in TorchMetrics package. + + Args: + metrics (Union[List[str], Dict[str, Dict[str, Any]]]). + prefix (Optional[str]): prefix to assign to the metrics in the collection. + + Returns: + AnomalibMetricCollection: Collection of metrics. + """ + # fallback is using the names + + if isinstance(metrics, (ListConfig, list)): + assert all(isinstance(metric, str) for metric in metrics), f"All metrics must be strings, found {metrics}" + return metric_collection_from_names(metrics, prefix) + + if isinstance(metrics, (DictConfig, dict)): + _validate_metrics_dict(metrics) + return metric_collection_from_dicts(metrics, prefix) + + raise ValueError(f"metrics must be a list or a dict, found {type(metrics)}") diff --git a/docs/source/reference_guide/api/metrics.rst b/docs/source/reference_guide/api/metrics.rst index 24a8d12ec9..a10fd46a72 100644 --- a/docs/source/reference_guide/api/metrics.rst +++ b/docs/source/reference_guide/api/metrics.rst @@ -1,6 +1,46 @@ Metrics ======= +There are two ways of configuring metrics in the config file: + +1. a list of metric names, or +2. a mapping of metric names to class path and init args. + +Each subsection in the section ``metrics`` of the config file can have a different style but inside each one it must be the same style. + +.. code-block:: yaml + :caption: Example of metrics configuration section in the config file. + + metrics: + # imagewise metrics using the list of metric names style + image: + - F1Score + - AUROC + # pixelwise metrics using the mapping style + pixel: + F1Score: + class_path: torchmetrics.F1Score + init_args: + compute_on_cpu: true + AUROC: + class_path: anomalib.utils.metrics.AUROC + init_args: + compute_on_cpu: true + +List of metric names +-------------------- + +A list of strings that match the name of a class in ``anomalib.utils.metrics`` or ``torchmetrics`` (in this order of priority), which will be instantiated with default arguments. + +Mapping of metric names to class path and init args +--------------------------------------------------- + +A mapping of metric names (str) to a dictionary with two keys: "class_path" and "init_args". + +"class_path" is a string with the full path to a metric (from root package down to the class name, e.g.: "anomalib.utils.metrics.AUROC"). + +"init_args" is a dictionary of arguments to be passed to the class constructor. + .. automodule:: anomalib.utils.metrics :members: :undoc-members: diff --git a/tests/helpers/dummy.py b/tests/helpers/dummy.py new file mode 100644 index 0000000000..a98875046e --- /dev/null +++ b/tests/helpers/dummy.py @@ -0,0 +1,43 @@ +import shutil +import tempfile +from pathlib import Path + +import pytorch_lightning as pl +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset + +from anomalib.utils.loggers.tensorboard import AnomalibTensorBoardLogger + + +class DummyDataset(Dataset): + def __len__(self): + return 1 + + def __getitem__(self, idx): + return torch.ones(1) + + +class DummyDataModule(pl.LightningDataModule): + def train_dataloader(self) -> DataLoader: + return DataLoader(DummyDataset()) + + def val_dataloader(self) -> DataLoader: + return DataLoader(DummyDataset()) + + def test_dataloader(self) -> DataLoader: + return DataLoader(DummyDataset()) + + +class DummyModel(nn.Module): + pass + + +class DummyLogger(AnomalibTensorBoardLogger): + def __init__(self): + self.tempdir = Path(tempfile.mkdtemp()) + super().__init__(name="tensorboard_logs", save_dir=self.tempdir) + + def __del__(self): + if self.tempdir.exists(): + shutil.rmtree(self.tempdir) diff --git a/tests/helpers/metrics.py b/tests/helpers/metrics.py new file mode 100644 index 0000000000..8758de9354 --- /dev/null +++ b/tests/helpers/metrics.py @@ -0,0 +1,27 @@ +"""Helpers for metrics tests.""" + +from typing import Tuple, Union + +from omegaconf import DictConfig, ListConfig + +from anomalib.utils.metrics import ( + AnomalibMetricCollection, + metric_collection_from_names, +) + + +def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]: + """Create metric collections based on the config. + + Args: + config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf + + Returns: + AnomalibMetricCollection: Image-level metric collection + AnomalibMetricCollection: Pixel-level metric collection + """ + image_metric_names = config.metrics.image if "image" in config.metrics.keys() else [] + pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else [] + image_metrics = metric_collection_from_names(image_metric_names, "image_") + pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_") + return image_metrics, pixel_metrics diff --git a/tests/pre_merge/utils/callbacks/metrics_configuration_callback/__init__.py b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-00.yaml b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-00.yaml new file mode 100644 index 0000000000..67fc6d4f8c --- /dev/null +++ b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-00.yaml @@ -0,0 +1,13 @@ +metrics: + pixel: + F1Score: + class_path: torchmetrics.F1Score + init_args: + compute_on_cpu: true + AUROC: + class_path: anomalib.utils.metrics.AUROC + init_args: + compute_on_cpu: true + image: + - F1Score + - AUROC diff --git a/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-01.yaml b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-01.yaml new file mode 100644 index 0000000000..0eec22351d --- /dev/null +++ b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-01.yaml @@ -0,0 +1,13 @@ +metrics: + pixel: + - F1Score + - AUROC + image: + F1Score: + class_path: torchmetrics.F1Score + init_args: + compute_on_cpu: true + AUROC: + class_path: anomalib.utils.metrics.AUROC + init_args: + compute_on_cpu: true diff --git a/tests/pre_merge/utils/callbacks/metrics_configuration_callback/test_metrics_configuration_callback.py b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/test_metrics_configuration_callback.py new file mode 100644 index 0000000000..997eb53ad3 --- /dev/null +++ b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/test_metrics_configuration_callback.py @@ -0,0 +1,63 @@ +from pathlib import Path + +import pytest +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from anomalib.models.components import AnomalyModule +from anomalib.utils.callbacks.metrics_configuration import MetricsConfigurationCallback +from anomalib.utils.metrics.collection import AnomalibMetricCollection +from tests.helpers.dummy import DummyDataModule, DummyLogger, DummyModel + + +class _DummyAnomalyModule(AnomalyModule): + def __init__(self): + super().__init__() + self.model = DummyModel() + self.task = "segmentation" + self.mode = "full" + self.callbacks = [] + + def test_step(self, batch, _): + return None + + def validation_epoch_end(self, outputs): + return None + + def test_epoch_end(self, outputs): + return None + + def configure_optimizers(self): + return None + + +@pytest.fixture +def config_from_yaml(request): + return OmegaConf.load(Path(__file__).parent / request.param) + + +@pytest.mark.parametrize( + ["config_from_yaml"], + [("data/config-good-00.yaml",), ("data/config-good-01.yaml",)], + indirect=["config_from_yaml"], +) +def test_metric_collection_configuration_callback(config_from_yaml): + """Test if metrics are properly instantiated.""" + + callback = MetricsConfigurationCallback( + task="segmentation", image_metrics=config_from_yaml.metrics.image, pixel_metrics=config_from_yaml.metrics.pixel + ) + + dummy_logger = DummyLogger() + dummy_anomaly_module = _DummyAnomalyModule() + trainer = pl.Trainer( + callbacks=[callback], logger=dummy_logger, checkpoint_callback=False, default_root_dir=dummy_logger.tempdir + ) + callback.setup(trainer, dummy_anomaly_module, DummyDataModule()) + + assert isinstance( + dummy_anomaly_module.image_metrics, AnomalibMetricCollection + ), f"{dummy_anomaly_module.image_metrics}" + assert isinstance( + dummy_anomaly_module.pixel_metrics, AnomalibMetricCollection + ), f"{dummy_anomaly_module.pixel_metrics}" diff --git a/tests/pre_merge/utils/callbacks/visualizer_callback/dummy_lightning_model.py b/tests/pre_merge/utils/callbacks/visualizer_callback/dummy_lightning_model.py index 8644871661..064e7f8936 100644 --- a/tests/pre_merge/utils/callbacks/visualizer_callback/dummy_lightning_model.py +++ b/tests/pre_merge/utils/callbacks/visualizer_callback/dummy_lightning_model.py @@ -1,46 +1,28 @@ from pathlib import Path from typing import Union -import pytorch_lightning as pl import torch from omegaconf.dictconfig import DictConfig from omegaconf.listconfig import ListConfig from torch import nn -from torch.utils.data import DataLoader, Dataset from anomalib.models.components import AnomalyModule from anomalib.utils.callbacks import ImageVisualizerCallback -from anomalib.utils.metrics import get_metrics from tests.helpers.dataset import get_dataset_path +from tests.helpers.metrics import get_metrics -class DummyDataset(Dataset): - def __init__(self): - super().__init__() - - def __len__(self): - return 1 - - def __getitem__(self, idx): - return torch.ones(1) - - -class DummyDataModule(pl.LightningDataModule): - def test_dataloader(self) -> DataLoader: - return DataLoader(DummyDataset()) - - -class DummyAnomalyMapGenerator(nn.Module): +class _DummyAnomalyMapGenerator(nn.Module): def __init__(self): super().__init__() self.input_size = (100, 100) self.sigma = 4 -class DummyModel(nn.Module): +class _DummyModel(nn.Module): def __init__(self): super().__init__() - self.anomaly_map_generator = DummyAnomalyMapGenerator() + self.anomaly_map_generator = _DummyAnomalyMapGenerator() class DummyModule(AnomalyModule): @@ -48,7 +30,7 @@ class DummyModule(AnomalyModule): def __init__(self, hparams: Union[DictConfig, ListConfig]): super().__init__() - self.model = DummyModel() + self.model = _DummyModel() self.task = "segmentation" self.mode = "full" self.callbacks = [ @@ -79,7 +61,7 @@ def test_step(self, batch, _): ) return outputs - def validation_epoch_end(self, output): + def validation_epoch_end(self, outputs): return None def test_epoch_end(self, outputs): diff --git a/tests/pre_merge/utils/callbacks/visualizer_callback/test_visualizer.py b/tests/pre_merge/utils/callbacks/visualizer_callback/test_visualizer.py index 35fdb2e129..df068d0b11 100644 --- a/tests/pre_merge/utils/callbacks/visualizer_callback/test_visualizer.py +++ b/tests/pre_merge/utils/callbacks/visualizer_callback/test_visualizer.py @@ -8,8 +8,9 @@ from omegaconf.omegaconf import OmegaConf from anomalib.utils.loggers import AnomalibTensorBoardLogger +from tests.helpers.dummy import DummyDataModule -from .dummy_lightning_model import DummyDataModule, DummyModule +from .dummy_lightning_model import DummyModule def get_dummy_module(config):