Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

padim arguments improvements #664

Merged
3 changes: 2 additions & 1 deletion anomalib/models/components/feature_extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
# SPDX-License-Identifier: Apache-2.0

from .feature_extractor import FeatureExtractor
from .utils import dryrun_find_featuremap_dims

__all__ = ["FeatureExtractor"]
__all__ = ["FeatureExtractor", "dryrun_find_featuremap_dims"]
29 changes: 29 additions & 0 deletions anomalib/models/components/feature_extractors/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Utility functions to manipulate feature extractors."""

from typing import Dict, List, Tuple, Union

import torch

from anomalib.models.components.feature_extractors.feature_extractor import (
FeatureExtractor,
)


def dryrun_find_featuremap_dims(
feature_extractor: FeatureExtractor,
input_size: Tuple[int, int],
layers: List[str],
) -> Dict[str, Dict[str, Union[int, Tuple[int, int]]]]:
"""Dry run an empty image of `input_size` size to get the featuremap tensors' dimensions (num_features, resolution).

Returns:
Tuple[int, int]: maping of `layer -> dimensions dict`
Each `dimension dict` has two keys: `num_features` (int) and `resolution`(Tuple[int, int]).
"""

dryrun_input = torch.empty(1, 3, *input_size)
dryrun_features = feature_extractor(dryrun_input)
return {
layer: {"num_features": dryrun_features[layer].shape[1], "resolution": dryrun_features[layer].shape[2:]}
for layer in layers
}
7 changes: 6 additions & 1 deletion anomalib/models/padim/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from omegaconf import DictConfig, ListConfig
Expand All @@ -31,6 +31,8 @@ class Padim(AnomalyModule):
input_size (Tuple[int, int]): Size of the model input.
backbone (str): Backbone CNN network
pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone.
n_features (int, optional): Number of features to retain in the dimension reduction step.
Default values from the paper are available for: resnet18 (100), wide_resnet50_2 (550).
"""

def __init__(
Expand All @@ -39,6 +41,7 @@ def __init__(
input_size: Tuple[int, int],
backbone: str,
pre_trained: bool = True,
n_features: Optional[int] = None,
):
super().__init__()

Expand All @@ -48,6 +51,7 @@ def __init__(
backbone=backbone,
pre_trained=pre_trained,
layers=layers,
n_features=n_features,
).eval()

self.stats: List[Tensor] = []
Expand Down Expand Up @@ -119,6 +123,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]):
layers=hparams.model.layers,
backbone=hparams.model.backbone,
pre_trained=hparams.model.pre_trained,
n_features=hparams.model.n_features if "n_features" in hparams.model else None,
)
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(hparams)
58 changes: 49 additions & 9 deletions anomalib/models/padim/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,40 @@
from torch import Tensor, nn

from anomalib.models.components import FeatureExtractor, MultiVariateGaussian
from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims
from anomalib.models.padim.anomaly_map import AnomalyMapGenerator
from anomalib.pre_processing import Tiler

DIMS = {
"resnet18": {"orig_dims": 448, "reduced_dims": 100, "emb_scale": 4},
"wide_resnet50_2": {"orig_dims": 1792, "reduced_dims": 550, "emb_scale": 4},
# defaults from the paper
_N_FEATURES_DEFAULTS = {
"resnet18": 100,
"wide_resnet50_2": 550,
}


def _deduce_dims(
jpcbertoldo marked this conversation as resolved.
Show resolved Hide resolved
feature_extractor: FeatureExtractor, input_size: Tuple[int, int], layers: List[str]
) -> Tuple[int, int]:
"""Run a dry run to deduce the dimensions of the extracted features.

Important: `layers` is assumed to be ordered and the first (layers[0])
is assumed to be the layer with largest resolution.

Returns:
Tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches)
"""
dimensions_mapping = dryrun_find_featuremap_dims(feature_extractor, input_size, layers)

# the first layer in `layers` has the largest resolution
first_layer_resolution = dimensions_mapping[layers[0]]["resolution"]
n_patches = torch.tensor(first_layer_resolution).prod().int().item()

# the original embedding size is the sum of the channels of all layers
n_features_original = sum(dimensions_mapping[layer]["num_features"] for layer in layers) # type: ignore

return n_features_original, n_patches


class PadimModel(nn.Module):
"""Padim Module.

Expand All @@ -28,6 +53,8 @@ class PadimModel(nn.Module):
layers (List[str]): Layers used for feature extraction
backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18".
pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone.
n_features (int, optional): Number of features to retain in the dimension reduction step.
Default values from the paper are available for: resnet18 (100), wide_resnet50_2 (550).
"""

def __init__(
Expand All @@ -36,28 +63,41 @@ def __init__(
layers: List[str],
backbone: str = "resnet18",
pre_trained: bool = True,
n_features: Optional[int] = None,
):
super().__init__()
self.tiler: Optional[Tiler] = None

self.backbone = backbone
self.layers = layers
self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained)
self.dims = DIMS[backbone]
self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers)

n_features = n_features or _N_FEATURES_DEFAULTS.get(self.backbone)

if n_features is None:
raise ValueError(
f"n_features must be specified for backbone {self.backbone}. "
f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}"
)

assert (
0 < n_features <= self.n_features_original
), f"for backbone {self.backbone}, 0 < n_features <= {self.n_features_original}, found {n_features}"

self.n_features = n_features

# pylint: disable=not-callable
# Since idx is randomly selected, save it with model to get same results
self.register_buffer(
"idx",
torch.tensor(sample(range(0, DIMS[backbone]["orig_dims"]), DIMS[backbone]["reduced_dims"])),
torch.tensor(sample(range(0, self.n_features_original), self.n_features)),
)
self.idx: Tensor
self.loss = None
self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size)

n_features = DIMS[backbone]["reduced_dims"]
patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"]
n_patches = patches_dims.ceil().prod().int().item()
self.gaussian = MultiVariateGaussian(n_features, n_patches)
self.gaussian = MultiVariateGaussian(self.n_features, self.n_patches)

def forward(self, input_tensor: Tensor) -> Tensor:
"""Forward-pass image-batch (N, C, H, W) into model to extract features.
Expand Down
30 changes: 29 additions & 1 deletion tests/pre_merge/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Tuple

import pytest
import torch

from anomalib.models.components.feature_extractors import FeatureExtractor
from anomalib.models.components.feature_extractors import (
FeatureExtractor,
dryrun_find_featuremap_dims,
)


class TestFeatureExtractor:
Expand Down Expand Up @@ -33,3 +38,26 @@ def test_feature_extraction(self, backbone, pretrained):
assert model.idx == [1, 2, 3]
else:
pass


@pytest.mark.parametrize(
"backbone",
["resnet18", "wide_resnet50_2"],
)
@pytest.mark.parametrize(
"input_size",
[(256, 256), (224, 224), (128, 128)],
)
def test_dryrun_find_featuremap_dims(backbone: str, input_size: Tuple[int, int]):
"""Use the function and check the expected output format."""
layers = ["layer1", "layer2", "layer3"]
model = FeatureExtractor(backbone=backbone, layers=layers, pre_trained=True)
mapping = dryrun_find_featuremap_dims(model, input_size, layers)
for lay in layers:
layer_mapping = mapping[lay]
num_features = layer_mapping["num_features"]
assert isinstance(num_features, int), f"{type(num_features)}"
resolution = layer_mapping["resolution"]
assert isinstance(resolution, tuple), f"{type(resolution)}"
assert len(resolution) == len(input_size), f"{len(resolution)}, {len(input_size)}"
assert all(isinstance(x, int) for x in resolution), f"{[type(x) for x in resolution]}"