Skip to content

Commit

Permalink
Merge branch 'release/v2.0.0' of github.com:openvinotoolkit/anomalib …
Browse files Browse the repository at this point in the history
…into add-missing-aux-components
  • Loading branch information
samet-akcay committed Dec 10, 2024
2 parents e2472db + 8bd06a9 commit ec49a5a
Show file tree
Hide file tree
Showing 15 changed files with 56 additions and 43 deletions.
21 changes: 11 additions & 10 deletions src/anomalib/data/validators/torch/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,10 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None:
Examples:
>>> import torch
>>> from anomalib.data.validators import VideoBatchValidator
>>> gt_masks = torch.rand(2, 10, 224, 224) > 0.5 # 2 videos, 10 frames each
>>> gt_masks = torch.rand(10, 224, 224) > 0.5 # 10 frames each
>>> validated_masks = VideoBatchValidator.validate_gt_mask(gt_masks)
>>> print(validated_masks.shape)
torch.Size([2, 10, 224, 224])
torch.Size([10, 224, 224])
>>> single_frame_masks = torch.rand(4, 456, 256) > 0.5 # 4 single-frame images
>>> validated_single_frame = VideoBatchValidator.validate_gt_mask(single_frame_masks)
>>> print(validated_single_frame.shape)
Expand All @@ -600,17 +600,18 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None:
if mask is None:
return None
if not isinstance(mask, torch.Tensor):
msg = f"Masks must be a torch.Tensor, got {type(mask)}."
msg = f"Ground truth mask must be a torch.Tensor, got {type(mask)}."
raise TypeError(msg)
if mask.ndim not in {3, 4, 5}:
msg = f"Masks must have shape [B, H, W], [B, T, H, W] or [B, T, 1, H, W], got shape {mask.shape}."
if mask.ndim not in {2, 3, 4}:
msg = f"Ground truth mask must have shape [H, W] or [N, H, W] or [N, 1, H, W] got shape {mask.shape}."
raise ValueError(msg)
if mask.ndim == 5:
if mask.shape[2] != 1:
msg = f"Masks must have 1 channel, got {mask.shape[2]}."
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if mask.ndim == 4:
if mask.shape[1] != 1:
msg = f"Ground truth mask must have 1 channel, got {mask.shape[1]}."
raise ValueError(msg)
mask = mask.squeeze(2)

mask = mask.squeeze(1)
return Mask(mask, dtype=torch.bool)

@staticmethod
Expand Down
3 changes: 0 additions & 3 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,6 @@ def _setup_trainer(self, model: AnomalibModule) -> None:
# Setup anomalib callbacks to be used with the trainer
self._setup_anomalib_callbacks()

# Temporarily set devices to 1 to avoid issues with multiple processes
self._cache.args["devices"] = 1

# Instantiate the trainer if it is not already instantiated
if self._trainer is None:
self._trainer = Trainer(**self._cache.args)
Expand Down
14 changes: 12 additions & 2 deletions src/anomalib/metrics/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from collections.abc import Sequence
from typing import Any

Expand All @@ -14,6 +15,8 @@

from anomalib.metrics import AnomalibMetric

logger = logging.getLogger(__name__)


class Evaluator(nn.Module, Callback):
"""Evaluator module for LightningModule.
Expand Down Expand Up @@ -53,8 +56,15 @@ def __init__(
super().__init__()
self.val_metrics = ModuleList(self.validate_metrics(val_metrics))
self.test_metrics = ModuleList(self.validate_metrics(test_metrics))

if compute_on_cpu:
self.compute_on_cpu = compute_on_cpu

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
"""Move metrics to cpu if ``num_devices == 1`` and ``compute_on_cpu`` is set to ``True``."""
del pl_module, stage # Unused arguments.
if trainer.num_devices > 1:
if self.compute_on_cpu:
logger.warning("Number of devices is greater than 1, setting compute_on_cpu to False.")
elif self.compute_on_cpu:
self.metrics_to_cpu(self.val_metrics)
self.metrics_to_cpu(self.test_metrics)

Expand Down
5 changes: 3 additions & 2 deletions src/anomalib/models/components/base/memory_bank_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class MemoryBankMixin(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.register_buffer("_is_fitted", torch.tensor([False]))
self.device: torch.device # defined in lightning module
self._is_fitted: torch.Tensor

@abstractmethod
Expand All @@ -34,10 +35,10 @@ def on_validation_start(self) -> None:
"""Ensure that the model is fitted before validation starts."""
if not self._is_fitted:
self.fit()
self._is_fitted = torch.tensor([True])
self._is_fitted = torch.tensor([True], device=self.device)

def on_train_epoch_end(self) -> None:
"""Ensure that the model is fitted before validation starts."""
if not self._is_fitted:
self.fit()
self._is_fitted = torch.tensor([True])
self._is_fitted = torch.tensor([True], device=self.device)
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def fit(self, embeddings: torch.Tensor) -> bool:

# if max training points is non-zero and smaller than number of staged features, select random subset
if embeddings.shape[0] > self.max_training_points:
selected_idx = torch.tensor(random.sample(range(embeddings.shape[0]), self.max_training_points))
selected_idx = torch.tensor(
random.sample(range(embeddings.shape[0]), self.max_training_points),
device=embeddings.device,
)
selected_features = embeddings[selected_idx]
else:
selected_features = embeddings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def fit(self, dataset: torch.Tensor) -> None:
else:
num_components = int(self.n_components)

self.num_components = torch.Tensor([num_components])
self.num_components = torch.tensor([num_components], device=dataset.device)

self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components].float()
self.singular_values = sig[:num_components].float()
Expand All @@ -98,7 +98,7 @@ def fit_transform(self, dataset: torch.Tensor) -> torch.Tensor:
mean = dataset.mean(dim=0)
dataset -= mean
num_components = int(self.n_components)
self.num_components = torch.Tensor([num_components])
self.num_components = torch.tensor([num_components], device=dataset.device)

v_h = torch.linalg.svd(dataset)[-1]
self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components]
Expand Down
3 changes: 3 additions & 0 deletions src/anomalib/models/image/dfkde/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
embedding = self.model(batch.image)
self.embeddings.append(embedding)

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Fit a KDE Model to the embedding collected from the training set."""
embeddings = torch.vstack(self.embeddings)
Expand Down
3 changes: 3 additions & 0 deletions src/anomalib/models/image/dfm/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
embedding = self.model.get_features(batch.image).squeeze()
self.embeddings.append(embedding)

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Fit a PCA transformation and a Gaussian model to dataset."""
logger.info("Aggregating the embedding extracted from the training set.")
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/image/dfm/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def fit(self, dataset: torch.Tensor) -> None:
dataset (torch.Tensor): Input dataset to fit the model.
"""
num_samples = dataset.shape[1]
self.mean_vec = torch.mean(dataset, dim=1)
self.mean_vec = torch.mean(dataset, dim=1, device=dataset.device)
data_centered = (dataset - self.mean_vec.reshape(-1, 1)) / math.sqrt(num_samples)
self.u_mat, self.sigma_mat, _ = torch.linalg.svd(data_centered, full_matrices=False)

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/image/dsr/anomaly_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def augment_batch(self, batch: Tensor) -> Tensor:
masks_list: list[Tensor] = []
for _ in range(batch_size):
if torch.rand(1) > self.p_anomalous: # include normal samples
masks_list.append(torch.zeros((1, height, width)))
masks_list.append(torch.zeros((1, height, width), device=batch.device))
else:
mask = self.generate_anomaly(height, width)
masks_list.append(mask)
Expand Down
5 changes: 4 additions & 1 deletion src/anomalib/models/image/padim/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
del args, kwargs # These variables are not used.

embedding = self.model(batch.image)
self.embeddings.append(embedding.cpu())
self.embeddings.append(embedding)

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Fit a Gaussian to the embedding collected from the training set."""
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/models/image/patchcore/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:

embedding = self.model(batch.image)
self.embeddings.append(embedding)
# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Apply subsampling to the embedding collected from the training set."""
Expand Down
13 changes: 5 additions & 8 deletions src/anomalib/models/video/ai_vad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from dataclasses import replace
from typing import Any

import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib import LearningType
Expand Down Expand Up @@ -124,6 +124,9 @@ def training_step(self, batch: VideoBatch) -> None:
self.model.density_estimator.update(features, video_path)
self.total_detections += len(next(iter(features.values())))

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Fit the density estimators to the extracted features from the training set."""
if self.total_detections == 0:
Expand All @@ -147,13 +150,7 @@ def validation_step(self, batch: VideoBatch, *args, **kwargs) -> STEP_OUTPUT:
del args, kwargs # Unused arguments.

predictions = self.model(batch.image)

return replace(
batch,
pred_score=predictions.pred_score,
anomaly_map=predictions.anomaly_map,
pred_mask=predictions.pred_mask,
)
return batch.update(pred_score=predictions.pred_score, anomaly_map=predictions.anomaly_map)

@property
def trainer_arguments(self) -> dict[str, Any]:
Expand Down
7 changes: 0 additions & 7 deletions src/anomalib/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,3 @@ def _show_warnings(config: DictConfig | ListConfig | Namespace) -> None:
"Anomalib's models and visualizer are currently not compatible with video datasets with a clip length > 1. "
"Custom changes to these modules will be needed to prevent errors and/or unpredictable behaviour.",
)
if (
"devices" in config.trainer
and (config.trainer.devices is None or config.trainer.devices != 1)
and config.trainer.accelerator != "cpu"
):
logger.warning("Anomalib currently does not support multi-gpu training. Setting devices to 1.")
config.trainer.devices = 1
10 changes: 5 additions & 5 deletions tests/unit/data/validators/torch/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ def test_validate_gt_label_invalid_type(self) -> None:

def test_validate_gt_mask_valid(self) -> None:
"""Test validation of valid ground truth masks."""
masks = torch.randint(0, 2, (2, 10, 224, 224))
masks = torch.randint(0, 2, (10, 1, 224, 224))
validated_masks = self.validator.validate_gt_mask(masks)
assert isinstance(validated_masks, Mask)
assert validated_masks.shape == (2, 10, 224, 224)
assert validated_masks.shape == (10, 224, 224)
assert validated_masks.dtype == torch.bool

def test_validate_gt_mask_none(self) -> None:
Expand All @@ -186,13 +186,13 @@ def test_validate_gt_mask_none(self) -> None:

def test_validate_gt_mask_invalid_type(self) -> None:
"""Test validation of ground truth masks with invalid type."""
with pytest.raises(TypeError, match="Masks must be a torch.Tensor"):
with pytest.raises(TypeError, match="Ground truth mask must be a torch.Tensor"):
self.validator.validate_gt_mask([torch.zeros(10, 224, 224)])

def test_validate_gt_mask_invalid_shape(self) -> None:
"""Test validation of ground truth masks with invalid shape."""
with pytest.raises(ValueError, match="Masks must have 1 channel, got 2."):
self.validator.validate_gt_mask(torch.zeros(2, 10, 2, 224, 224))
with pytest.raises(ValueError, match="Ground truth mask must have 1 channel, got 2."):
self.validator.validate_gt_mask(torch.zeros(10, 2, 224, 224))

def test_validate_anomaly_map_valid(self) -> None:
"""Test validation of a valid anomaly map batch."""
Expand Down

0 comments on commit ec49a5a

Please sign in to comment.