From f396dd4ab6a40b7b2f95004278d328fff18c999e Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Wed, 24 Apr 2024 11:20:01 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20Update=20lightning=20inference?= =?UTF-8?q?=20(#2018)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update lightning inference * update changelog --- CHANGELOG.md | 2 ++ src/anomalib/cli/cli.py | 29 +++------------------ src/anomalib/engine/engine.py | 49 +++++++++++++++++++++++++++-------- 3 files changed, 43 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c2ed8c1a68..1de6d36ad2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- Add data_path argument to predict entrypoint and add properties for retrieving model path by @djdameln in https://github.com/openvinotoolkit/anomalib/pull/2018 + ### Changed - 🔨Rename OptimalF1 to F1Max for consistency with the literature, by @samet-akcay in https://github.com/openvinotoolkit/anomalib/pull/1980 diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index 9de4a1d9c1..1accf96f57 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -6,7 +6,6 @@ import logging from collections.abc import Callable, Sequence from functools import partial -from inspect import signature from pathlib import Path from types import MethodType from typing import Any @@ -29,7 +28,6 @@ from torch.utils.data import DataLoader, Dataset from anomalib.data import AnomalibDataModule - from anomalib.data.predict import PredictDataset from anomalib.engine import Engine from anomalib.metrics.threshold import BaseThreshold from anomalib.models import AnomalyModule @@ -216,7 +214,7 @@ def add_predict_arguments(self, parser: ArgumentParser) -> None: added = parser.add_method_arguments( Engine, "predict", - skip={"model", "dataloaders", "datamodule", "dataset"}, + skip={"model", "dataloaders", "datamodule", "dataset", "data_path"}, ) self.subcommand_method_arguments["predict"] = added self.add_arguments_to_parser(parser) @@ -267,8 +265,6 @@ def before_instantiate_classes(self) -> None: """Modify the configuration to properly instantiate classes and sets up tiler.""" subcommand = self.config["subcommand"] if subcommand in (*self.subcommands(), "train", "predict"): - if self.config["subcommand"] == "predict" and isinstance(self.config["predict"]["data"], str | Path): - self.config["predict"]["data"] = self._set_predict_dataloader_namespace(self.config["predict"]["data"]) self.config[subcommand] = update_config(self.config[subcommand]) def instantiate_classes(self) -> None: @@ -415,27 +411,6 @@ def _add_trainer_arguments_to_parser( **scheduler_kwargs, ) - def _set_predict_dataloader_namespace(self, data_path: str | Path | Namespace) -> Namespace: - """Set the predict dataloader namespace. - - If the argument is of type str or Path, then it is assumed to be the path to the prediction data and is - assigned to PredictDataset. - - Args: - data_path (str | Path | Namespace): Path to the data. - - Returns: - Namespace: Namespace containing the predict dataloader. - """ - if isinstance(data_path, str | Path): - init_args = {key: value.default for key, value in signature(PredictDataset).parameters.items()} - init_args["path"] = data_path - data_path = Namespace( - class_path="anomalib.data.predict.PredictDataset", - init_args=Namespace(init_args), - ) - return data_path - def _add_default_arguments_to_parser(self, parser: ArgumentParser) -> None: """Adds default arguments to the parser.""" parser.add_argument( @@ -463,6 +438,8 @@ def _prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]: fn_kwargs["datamodule"] = self.datamodule elif isinstance(self.datamodule, DataLoader): fn_kwargs["dataloaders"] = self.datamodule + elif isinstance(self.datamodule, Path | str): + fn_kwargs["data_path"] = self.datamodule return fn_kwargs def _parser(self, subcommand: str | None) -> ArgumentParser: diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index a08b78e529..43e9e2d213 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -225,6 +225,28 @@ def threshold_callback(self) -> _ThresholdCallback | None: raise ValueError(msg) return callbacks[0] if len(callbacks) > 0 else None + @property + def checkpoint_callback(self) -> ModelCheckpoint | None: + """The ``ModelCheckpoint`` callback in the trainer.callbacks list, or ``None`` if it doesn't exist. + + Returns: + ModelCheckpoint | None: ModelCheckpoint callback, if available. + """ + if self._trainer is None: + return None + return self.trainer.checkpoint_callback + + @property + def best_model_path(self) -> str | None: + """The path to the best model checkpoint. + + Returns: + str: Path to the best model checkpoint. + """ + if self.checkpoint_callback is None: + return None + return self.checkpoint_callback.best_model_path + def _setup_workspace( self, model: AnomalyModule, @@ -672,6 +694,7 @@ def predict( dataset: Dataset | PredictDataset | None = None, return_predictions: bool | None = None, ckpt_path: str | Path | None = None, + data_path: str | Path | None = None, ) -> _PREDICT_OUTPUT | None: """Predict using the model using the trainer. @@ -703,6 +726,9 @@ def predict( Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. Defaults to None. + data_path (str | Path | None): + Path to the image or folder containing images to generate predictions for. + Defaults to None. Returns: _PREDICT_OUTPUT | None: Predictions. @@ -743,18 +769,19 @@ def predict( if not ckpt_path: logger.warning("ckpt_path is not provided. Model weights will not be loaded.") - # Handle the instance when a dataset is passed to the predict method + # Collect dataloaders + if dataloaders is None: + dataloaders = [] + elif isinstance(dataloaders, DataLoader): + dataloaders = [dataloaders] + elif not isinstance(dataloaders, list): + msg = f"Unknown type for dataloaders {type(dataloaders)}" + raise TypeError(msg) if dataset is not None: - dataloader = DataLoader(dataset) - if dataloaders is None: - dataloaders = dataloader - elif isinstance(dataloaders, DataLoader): - dataloaders = [dataloaders, dataloader] - elif isinstance(dataloaders, list): # dataloader is a list - dataloaders.append(dataloader) - else: - msg = f"Unknown type for dataloaders {type(dataloaders)}" - raise TypeError(msg) + dataloaders.append(DataLoader(dataset)) + if data_path is not None: + dataloaders.append(DataLoader(PredictDataset(data_path))) + dataloaders = dataloaders or None self._setup_dataset_task(dataloaders, datamodule) self._setup_transform(model or self.model, datamodule=datamodule, dataloaders=dataloaders, ckpt_path=ckpt_path)