diff --git a/src/anomalib/data/validators/torch/video.py b/src/anomalib/data/validators/torch/video.py index b7ca50c943..bcfca62451 100644 --- a/src/anomalib/data/validators/torch/video.py +++ b/src/anomalib/data/validators/torch/video.py @@ -588,10 +588,10 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None: Examples: >>> import torch >>> from anomalib.data.validators import VideoBatchValidator - >>> gt_masks = torch.rand(2, 10, 224, 224) > 0.5 # 2 videos, 10 frames each + >>> gt_masks = torch.rand(10, 224, 224) > 0.5 # 10 frames each >>> validated_masks = VideoBatchValidator.validate_gt_mask(gt_masks) >>> print(validated_masks.shape) - torch.Size([2, 10, 224, 224]) + torch.Size([10, 224, 224]) >>> single_frame_masks = torch.rand(4, 456, 256) > 0.5 # 4 single-frame images >>> validated_single_frame = VideoBatchValidator.validate_gt_mask(single_frame_masks) >>> print(validated_single_frame.shape) @@ -600,17 +600,18 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None: if mask is None: return None if not isinstance(mask, torch.Tensor): - msg = f"Masks must be a torch.Tensor, got {type(mask)}." + msg = f"Ground truth mask must be a torch.Tensor, got {type(mask)}." raise TypeError(msg) - if mask.ndim not in {3, 4, 5}: - msg = f"Masks must have shape [B, H, W], [B, T, H, W] or [B, T, 1, H, W], got shape {mask.shape}." + if mask.ndim not in {2, 3, 4}: + msg = f"Ground truth mask must have shape [H, W] or [N, H, W] or [N, 1, H, W] got shape {mask.shape}." raise ValueError(msg) - if mask.ndim == 5: - if mask.shape[2] != 1: - msg = f"Masks must have 1 channel, got {mask.shape[2]}." + if mask.ndim == 2: + mask = mask.unsqueeze(0) + if mask.ndim == 4: + if mask.shape[1] != 1: + msg = f"Ground truth mask must have 1 channel, got {mask.shape[1]}." raise ValueError(msg) - mask = mask.squeeze(2) - + mask = mask.squeeze(1) return Mask(mask, dtype=torch.bool) @staticmethod diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index ecd0a4f062..a548dd23e4 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -259,9 +259,6 @@ def _setup_trainer(self, model: AnomalibModule) -> None: # Setup anomalib callbacks to be used with the trainer self._setup_anomalib_callbacks() - # Temporarily set devices to 1 to avoid issues with multiple processes - self._cache.args["devices"] = 1 - # Instantiate the trainer if it is not already instantiated if self._trainer is None: self._trainer = Trainer(**self._cache.args) diff --git a/src/anomalib/metrics/evaluator.py b/src/anomalib/metrics/evaluator.py index 53f05af3b2..460a2a4b0b 100644 --- a/src/anomalib/metrics/evaluator.py +++ b/src/anomalib/metrics/evaluator.py @@ -3,6 +3,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import logging from collections.abc import Sequence from typing import Any @@ -14,6 +15,8 @@ from anomalib.metrics import AnomalibMetric +logger = logging.getLogger(__name__) + class Evaluator(nn.Module, Callback): """Evaluator module for LightningModule. @@ -53,8 +56,15 @@ def __init__( super().__init__() self.val_metrics = ModuleList(self.validate_metrics(val_metrics)) self.test_metrics = ModuleList(self.validate_metrics(test_metrics)) - - if compute_on_cpu: + self.compute_on_cpu = compute_on_cpu + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + """Move metrics to cpu if ``num_devices == 1`` and ``compute_on_cpu`` is set to ``True``.""" + del pl_module, stage # Unused arguments. + if trainer.num_devices > 1: + if self.compute_on_cpu: + logger.warning("Number of devices is greater than 1, setting compute_on_cpu to False.") + elif self.compute_on_cpu: self.metrics_to_cpu(self.val_metrics) self.metrics_to_cpu(self.test_metrics) diff --git a/src/anomalib/models/components/base/memory_bank_module.py b/src/anomalib/models/components/base/memory_bank_module.py index 738dff6185..501e8dc11a 100644 --- a/src/anomalib/models/components/base/memory_bank_module.py +++ b/src/anomalib/models/components/base/memory_bank_module.py @@ -19,6 +19,7 @@ class MemoryBankMixin(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.register_buffer("_is_fitted", torch.tensor([False])) + self.device: torch.device # defined in lightning module self._is_fitted: torch.Tensor @abstractmethod @@ -34,10 +35,10 @@ def on_validation_start(self) -> None: """Ensure that the model is fitted before validation starts.""" if not self._is_fitted: self.fit() - self._is_fitted = torch.tensor([True]) + self._is_fitted = torch.tensor([True], device=self.device) def on_train_epoch_end(self) -> None: """Ensure that the model is fitted before validation starts.""" if not self._is_fitted: self.fit() - self._is_fitted = torch.tensor([True]) + self._is_fitted = torch.tensor([True], device=self.device) diff --git a/src/anomalib/models/components/classification/kde_classifier.py b/src/anomalib/models/components/classification/kde_classifier.py index 88362ff3de..d50e5cca31 100644 --- a/src/anomalib/models/components/classification/kde_classifier.py +++ b/src/anomalib/models/components/classification/kde_classifier.py @@ -93,7 +93,10 @@ def fit(self, embeddings: torch.Tensor) -> bool: # if max training points is non-zero and smaller than number of staged features, select random subset if embeddings.shape[0] > self.max_training_points: - selected_idx = torch.tensor(random.sample(range(embeddings.shape[0]), self.max_training_points)) + selected_idx = torch.tensor( + random.sample(range(embeddings.shape[0]), self.max_training_points), + device=embeddings.device, + ) selected_features = embeddings[selected_idx] else: selected_features = embeddings diff --git a/src/anomalib/models/components/dimensionality_reduction/pca.py b/src/anomalib/models/components/dimensionality_reduction/pca.py index 93c60b2b56..3e9bd4bb65 100644 --- a/src/anomalib/models/components/dimensionality_reduction/pca.py +++ b/src/anomalib/models/components/dimensionality_reduction/pca.py @@ -74,7 +74,7 @@ def fit(self, dataset: torch.Tensor) -> None: else: num_components = int(self.n_components) - self.num_components = torch.Tensor([num_components]) + self.num_components = torch.tensor([num_components], device=dataset.device) self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components].float() self.singular_values = sig[:num_components].float() @@ -98,7 +98,7 @@ def fit_transform(self, dataset: torch.Tensor) -> torch.Tensor: mean = dataset.mean(dim=0) dataset -= mean num_components = int(self.n_components) - self.num_components = torch.Tensor([num_components]) + self.num_components = torch.tensor([num_components], device=dataset.device) v_h = torch.linalg.svd(dataset)[-1] self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components] diff --git a/src/anomalib/models/image/dfkde/lightning_model.py b/src/anomalib/models/image/dfkde/lightning_model.py index 16ccac6403..666fb5507d 100644 --- a/src/anomalib/models/image/dfkde/lightning_model.py +++ b/src/anomalib/models/image/dfkde/lightning_model.py @@ -94,6 +94,9 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: embedding = self.model(batch.image) self.embeddings.append(embedding) + # Return a dummy loss tensor + return torch.tensor(0.0, requires_grad=True, device=self.device) + def fit(self) -> None: """Fit a KDE Model to the embedding collected from the training set.""" embeddings = torch.vstack(self.embeddings) diff --git a/src/anomalib/models/image/dfm/lightning_model.py b/src/anomalib/models/image/dfm/lightning_model.py index 96a4388835..1bdad50e1e 100644 --- a/src/anomalib/models/image/dfm/lightning_model.py +++ b/src/anomalib/models/image/dfm/lightning_model.py @@ -100,6 +100,9 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: embedding = self.model.get_features(batch.image).squeeze() self.embeddings.append(embedding) + # Return a dummy loss tensor + return torch.tensor(0.0, requires_grad=True, device=self.device) + def fit(self) -> None: """Fit a PCA transformation and a Gaussian model to dataset.""" logger.info("Aggregating the embedding extracted from the training set.") diff --git a/src/anomalib/models/image/dfm/torch_model.py b/src/anomalib/models/image/dfm/torch_model.py index ab133d045f..520cbf8196 100644 --- a/src/anomalib/models/image/dfm/torch_model.py +++ b/src/anomalib/models/image/dfm/torch_model.py @@ -41,7 +41,7 @@ def fit(self, dataset: torch.Tensor) -> None: dataset (torch.Tensor): Input dataset to fit the model. """ num_samples = dataset.shape[1] - self.mean_vec = torch.mean(dataset, dim=1) + self.mean_vec = torch.mean(dataset, dim=1, device=dataset.device) data_centered = (dataset - self.mean_vec.reshape(-1, 1)) / math.sqrt(num_samples) self.u_mat, self.sigma_mat, _ = torch.linalg.svd(data_centered, full_matrices=False) diff --git a/src/anomalib/models/image/dsr/anomaly_generator.py b/src/anomalib/models/image/dsr/anomaly_generator.py index 9bb262500c..2d1d5c4a75 100644 --- a/src/anomalib/models/image/dsr/anomaly_generator.py +++ b/src/anomalib/models/image/dsr/anomaly_generator.py @@ -73,7 +73,7 @@ def augment_batch(self, batch: Tensor) -> Tensor: masks_list: list[Tensor] = [] for _ in range(batch_size): if torch.rand(1) > self.p_anomalous: # include normal samples - masks_list.append(torch.zeros((1, height, width))) + masks_list.append(torch.zeros((1, height, width), device=batch.device)) else: mask = self.generate_anomaly(height, width) masks_list.append(mask) diff --git a/src/anomalib/models/image/padim/lightning_model.py b/src/anomalib/models/image/padim/lightning_model.py index 4a223c9e62..78f17861c0 100644 --- a/src/anomalib/models/image/padim/lightning_model.py +++ b/src/anomalib/models/image/padim/lightning_model.py @@ -91,7 +91,10 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: del args, kwargs # These variables are not used. embedding = self.model(batch.image) - self.embeddings.append(embedding.cpu()) + self.embeddings.append(embedding) + + # Return a dummy loss tensor + return torch.tensor(0.0, requires_grad=True, device=self.device) def fit(self) -> None: """Fit a Gaussian to the embedding collected from the training set.""" diff --git a/src/anomalib/models/image/patchcore/lightning_model.py b/src/anomalib/models/image/patchcore/lightning_model.py index 689b7ac81f..e58185e50e 100644 --- a/src/anomalib/models/image/patchcore/lightning_model.py +++ b/src/anomalib/models/image/patchcore/lightning_model.py @@ -118,6 +118,8 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: embedding = self.model(batch.image) self.embeddings.append(embedding) + # Return a dummy loss tensor + return torch.tensor(0.0, requires_grad=True, device=self.device) def fit(self) -> None: """Apply subsampling to the embedding collected from the training set.""" diff --git a/src/anomalib/models/video/ai_vad/lightning_model.py b/src/anomalib/models/video/ai_vad/lightning_model.py index bbebbe5edf..3afd674673 100644 --- a/src/anomalib/models/video/ai_vad/lightning_model.py +++ b/src/anomalib/models/video/ai_vad/lightning_model.py @@ -7,9 +7,9 @@ # SPDX-License-Identifier: Apache-2.0 import logging -from dataclasses import replace from typing import Any +import torch from lightning.pytorch.utilities.types import STEP_OUTPUT from anomalib import LearningType @@ -124,6 +124,9 @@ def training_step(self, batch: VideoBatch) -> None: self.model.density_estimator.update(features, video_path) self.total_detections += len(next(iter(features.values()))) + # Return a dummy loss tensor + return torch.tensor(0.0, requires_grad=True, device=self.device) + def fit(self) -> None: """Fit the density estimators to the extracted features from the training set.""" if self.total_detections == 0: @@ -147,13 +150,7 @@ def validation_step(self, batch: VideoBatch, *args, **kwargs) -> STEP_OUTPUT: del args, kwargs # Unused arguments. predictions = self.model(batch.image) - - return replace( - batch, - pred_score=predictions.pred_score, - anomaly_map=predictions.anomaly_map, - pred_mask=predictions.pred_mask, - ) + return batch.update(pred_score=predictions.pred_score, anomaly_map=predictions.anomaly_map) @property def trainer_arguments(self) -> dict[str, Any]: diff --git a/src/anomalib/utils/config.py b/src/anomalib/utils/config.py index f41617f355..aadaa6a42b 100644 --- a/src/anomalib/utils/config.py +++ b/src/anomalib/utils/config.py @@ -254,10 +254,3 @@ def _show_warnings(config: DictConfig | ListConfig | Namespace) -> None: "Anomalib's models and visualizer are currently not compatible with video datasets with a clip length > 1. " "Custom changes to these modules will be needed to prevent errors and/or unpredictable behaviour.", ) - if ( - "devices" in config.trainer - and (config.trainer.devices is None or config.trainer.devices != 1) - and config.trainer.accelerator != "cpu" - ): - logger.warning("Anomalib currently does not support multi-gpu training. Setting devices to 1.") - config.trainer.devices = 1 diff --git a/tests/unit/data/validators/torch/test_video.py b/tests/unit/data/validators/torch/test_video.py index 2933ddb7f4..04e3373a5a 100644 --- a/tests/unit/data/validators/torch/test_video.py +++ b/tests/unit/data/validators/torch/test_video.py @@ -174,10 +174,10 @@ def test_validate_gt_label_invalid_type(self) -> None: def test_validate_gt_mask_valid(self) -> None: """Test validation of valid ground truth masks.""" - masks = torch.randint(0, 2, (2, 10, 224, 224)) + masks = torch.randint(0, 2, (10, 1, 224, 224)) validated_masks = self.validator.validate_gt_mask(masks) assert isinstance(validated_masks, Mask) - assert validated_masks.shape == (2, 10, 224, 224) + assert validated_masks.shape == (10, 224, 224) assert validated_masks.dtype == torch.bool def test_validate_gt_mask_none(self) -> None: @@ -186,13 +186,13 @@ def test_validate_gt_mask_none(self) -> None: def test_validate_gt_mask_invalid_type(self) -> None: """Test validation of ground truth masks with invalid type.""" - with pytest.raises(TypeError, match="Masks must be a torch.Tensor"): + with pytest.raises(TypeError, match="Ground truth mask must be a torch.Tensor"): self.validator.validate_gt_mask([torch.zeros(10, 224, 224)]) def test_validate_gt_mask_invalid_shape(self) -> None: """Test validation of ground truth masks with invalid shape.""" - with pytest.raises(ValueError, match="Masks must have 1 channel, got 2."): - self.validator.validate_gt_mask(torch.zeros(2, 10, 2, 224, 224)) + with pytest.raises(ValueError, match="Ground truth mask must have 1 channel, got 2."): + self.validator.validate_gt_mask(torch.zeros(10, 2, 224, 224)) def test_validate_anomaly_map_valid(self) -> None: """Test validation of a valid anomaly map batch."""