Skip to content

Commit

Permalink
Convert AnomalyMapGenerator to nn.Module (#497)
Browse files Browse the repository at this point in the history
Convert anomaly-map-generator to nn.Module
  • Loading branch information
samet-akcay authored Aug 12, 2022
1 parent 27a31cd commit 5a452ec
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 27 deletions.
7 changes: 4 additions & 3 deletions anomalib/models/cflow/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import Tensor
from torch import Tensor, nn


class AnomalyMapGenerator:
class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap."""

def __init__(
self,
image_size: Union[ListConfig, Tuple],
pool_layers: List[str],
):
super().__init__()
self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True)
self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size)
self.pool_layers: List[str] = pool_layers
Expand Down Expand Up @@ -60,7 +61,7 @@ def compute_anomaly_map(

return anomaly_map

def __call__(self, **kwargs: Union[List[Tensor], List[int], List[List]]) -> Tensor:
def forward(self, **kwargs: Union[List[Tensor], List[int], List[List]]) -> Tensor:
"""Returns anomaly_map.
Expects `distribution`, `height` and 'width' keywords to be passed explicitly
Expand Down
7 changes: 4 additions & 3 deletions anomalib/models/fastflow/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import Tensor
from torch import Tensor, nn


class AnomalyMapGenerator:
class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap."""

def __init__(self, input_size: Union[ListConfig, Tuple]):
super().__init__()
self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size)

def __call__(self, hidden_variables: List[Tensor]) -> Tensor:
def forward(self, hidden_variables: List[Tensor]) -> Tensor:
"""Generate Anomaly Heatmap.
This implementation generates the heatmap based on the flow maps
Expand Down
12 changes: 6 additions & 6 deletions anomalib/models/padim/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def compute_anomaly_map(self, embedding: Tensor, mean: Tensor, inv_covariance: T

return smoothed_anomaly_map

def forward(self, **kwds):
def forward(self, **kwargs):
"""Returns anomaly_map.
Expects `embedding`, `mean` and `covariance` keywords to be passed explicitly.
Expand All @@ -125,11 +125,11 @@ def forward(self, **kwds):
torch.Tensor: anomaly map
"""

if not ("embedding" in kwds and "mean" in kwds and "inv_covariance" in kwds):
raise ValueError(f"Expected keys `embedding`, `mean` and `covariance`. Found {kwds.keys()}")
if not ("embedding" in kwargs and "mean" in kwargs and "inv_covariance" in kwargs):
raise ValueError(f"Expected keys `embedding`, `mean` and `covariance`. Found {kwargs.keys()}")

embedding: Tensor = kwds["embedding"]
mean: Tensor = kwds["mean"]
inv_covariance: Tensor = kwds["inv_covariance"]
embedding: Tensor = kwargs["embedding"]
mean: Tensor = kwargs["mean"]
inv_covariance: Tensor = kwargs["inv_covariance"]

return self.compute_anomaly_map(embedding, mean, inv_covariance)
7 changes: 4 additions & 3 deletions anomalib/models/reverse_distillation/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import torch.nn.functional as F
from kornia.filters import gaussian_blur2d
from omegaconf import ListConfig
from torch import Tensor
from torch import Tensor, nn


class AnomalyMapGenerator:
class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap.
Args:
Expand All @@ -32,6 +32,7 @@ class AnomalyMapGenerator:
"""

def __init__(self, image_size: Union[ListConfig, Tuple], sigma: int = 4, mode: str = "multiply"):
super().__init__()
self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size)
self.sigma = sigma
self.kernel_size = 2 * int(4.0 * sigma + 0.5) + 1
Expand All @@ -40,7 +41,7 @@ def __init__(self, image_size: Union[ListConfig, Tuple], sigma: int = 4, mode: s
raise ValueError(f"Found mode {mode}. Only multiply and add are supported.")
self.mode = mode

def __call__(self, student_features: List[Tensor], teacher_features: List[Tensor]) -> Tensor:
def forward(self, student_features: List[Tensor], teacher_features: List[Tensor]) -> Tensor:
"""Computes anomaly map given encoder and decoder features.
Args:
Expand Down
15 changes: 8 additions & 7 deletions anomalib/models/stfpm/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import Tensor
from torch import Tensor, nn


class AnomalyMapGenerator:
class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap."""

def __init__(
self,
image_size: Union[ListConfig, Tuple],
):
super().__init__()
self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True)
self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size)

Expand Down Expand Up @@ -59,7 +60,7 @@ def compute_anomaly_map(

return anomaly_map

def __call__(self, **kwds: Dict[str, Tensor]) -> torch.Tensor:
def forward(self, **kwargs: Dict[str, Tensor]) -> torch.Tensor:
"""Returns anomaly map.
Expects `teach_features` and `student_features` keywords to be passed explicitly.
Expand All @@ -78,10 +79,10 @@ def __call__(self, **kwds: Dict[str, Tensor]) -> torch.Tensor:
torch.Tensor: anomaly map
"""

if not ("teacher_features" in kwds and "student_features" in kwds):
raise ValueError(f"Expected keys `teacher_features` and `student_features. Found {kwds.keys()}")
if not ("teacher_features" in kwargs and "student_features" in kwargs):
raise ValueError(f"Expected keys `teacher_features` and `student_features. Found {kwargs.keys()}")

teacher_features: Dict[str, Tensor] = kwds["teacher_features"]
student_features: Dict[str, Tensor] = kwds["student_features"]
teacher_features: Dict[str, Tensor] = kwargs["teacher_features"]
student_features: Dict[str, Tensor] = kwargs["student_features"]

return self.compute_anomaly_map(teacher_features, student_features)
8 changes: 4 additions & 4 deletions tests/helpers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def __init__(

def __call__(self, func):
@wraps(func)
def inner(*args, **kwds):
def inner(*args, **kwargs):
# If true, will use MVTech AD dataset for testing.
# Useful for nightly builds
if self.use_mvtec:
return func(*args, path=self.path, **kwds)
return func(*args, path=self.path, **kwargs)
else:
with GeneratedDummyDataset(
num_train=self.num_train,
Expand All @@ -145,8 +145,8 @@ def inner(*args, **kwds):
max_size=self.max_size,
seed=self.seed,
) as dataset_path:
kwds["category"] = "shapes"
return func(*args, path=dataset_path, **kwds)
kwargs["category"] = "shapes"
return func(*args, path=dataset_path, **kwargs)

return inner

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def test_dataloader(self) -> DataLoader:
return DataLoader(DummyDataset())


class DummyAnomalyMapGenerator:
class DummyAnomalyMapGenerator(nn.Module):
def __init__(self):
super().__init__()
self.input_size = (100, 100)
self.sigma = 4

Expand Down

0 comments on commit 5a452ec

Please sign in to comment.