From db29ba98a0f6ff65ac6333806b3603494b6d01c5 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Tue, 1 Nov 2022 02:06:54 +0100 Subject: [PATCH 1/7] n_features an arg and deduce other dimensions dinamically if necessary --- anomalib/models/padim/lightning_model.py | 5 +- anomalib/models/padim/torch_model.py | 68 ++++++++++++++++++++---- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/anomalib/models/padim/lightning_model.py b/anomalib/models/padim/lightning_model.py index 2fa029c488..d0832b6022 100644 --- a/anomalib/models/padim/lightning_model.py +++ b/anomalib/models/padim/lightning_model.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from omegaconf import DictConfig, ListConfig @@ -39,6 +39,7 @@ def __init__( input_size: Tuple[int, int], backbone: str, pre_trained: bool = True, + n_features: Optional[int] = None, ): super().__init__() @@ -48,6 +49,7 @@ def __init__( backbone=backbone, pre_trained=pre_trained, layers=layers, + n_features=n_features, ).eval() self.stats: List[Tensor] = [] @@ -119,6 +121,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]): layers=hparams.model.layers, backbone=hparams.model.backbone, pre_trained=hparams.model.pre_trained, + n_features=hparams.model.n_features if "n_features" in hparams.model else None, ) self.hparams: Union[DictConfig, ListConfig] # type: ignore self.save_hyperparameters(hparams) diff --git a/anomalib/models/padim/torch_model.py b/anomalib/models/padim/torch_model.py index 17fa51f453..72d8c73c66 100644 --- a/anomalib/models/padim/torch_model.py +++ b/anomalib/models/padim/torch_model.py @@ -14,11 +14,39 @@ from anomalib.models.padim.anomaly_map import AnomalyMapGenerator from anomalib.pre_processing import Tiler -DIMS = { - "resnet18": {"orig_dims": 448, "reduced_dims": 100, "emb_scale": 4}, - "wide_resnet50_2": {"orig_dims": 1792, "reduced_dims": 550, "emb_scale": 4}, +_DIMS = { + "resnet18": {"orig_dims": 448, "emb_scale": 4}, + "wide_resnet50_2": {"orig_dims": 1792, "emb_scale": 4}, } +# defaults from the paper +_N_FEATURES_DEFAULTS = { + "resnet18": 100, + "wide_resnet50_2": 550, +} + + +def _deduce_dims( + feature_extractor: FeatureExtractor, input_size: Tuple[int, int], layers: List[str] +) -> Tuple[int, int]: + """Run a dry run to deduce the dimensions of the extracted features. + + Returns: + Tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches) + """ + + dryrun_input = torch.empty(1, 3, *input_size) + dryrun_features = feature_extractor(dryrun_input) + + # the first layer in `layers` is the largest spatial size + dryrun_emb_first_layer = dryrun_features[layers[0]] + n_patches = torch.tensor(dryrun_emb_first_layer.shape[-2:]).prod().int().item() + + # the original embedding size is the sum of the channels of all layers + n_features_original = sum(dryrun_features[layer].shape[1] for layer in layers) + + return n_features_original, n_patches + class PadimModel(nn.Module): """Padim Module. @@ -36,6 +64,7 @@ def __init__( layers: List[str], backbone: str = "resnet18", pre_trained: bool = True, + n_features: Optional[int] = None, ): super().__init__() self.tiler: Optional[Tiler] = None @@ -43,21 +72,42 @@ def __init__( self.backbone = backbone self.layers = layers self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained) - self.dims = DIMS[backbone] + + if backbone in _DIMS: + backbone_dims = _DIMS[backbone] + self.n_features_original = backbone_dims["orig_dims"] + emb_scale = backbone_dims["emb_scale"] + patches_dims = torch.tensor(input_size) / emb_scale + self.n_patches = patches_dims.ceil().prod().int().item() + + else: + self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers) + + if n_features is None: + + if self.backbone in _N_FEATURES_DEFAULTS: + n_features = _N_FEATURES_DEFAULTS[self.backbone] + + else: + raise ValueError( + f"{self.__class__.__name__}.n_features must be specified for backbone {self.backbone}. " + f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}" + ) + assert ( + n_features <= self.n_features_original + ), f"n_features ({n_features}) must be <= n_features_original ({self.n_features_original})" + self.n_features = n_features # pylint: disable=not-callable # Since idx is randomly selected, save it with model to get same results self.register_buffer( "idx", - torch.tensor(sample(range(0, DIMS[backbone]["orig_dims"]), DIMS[backbone]["reduced_dims"])), + torch.tensor(sample(range(0, self.n_features_original), self.n_features)), ) self.idx: Tensor self.loss = None self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size) - n_features = DIMS[backbone]["reduced_dims"] - patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"] - n_patches = patches_dims.ceil().prod().int().item() - self.gaussian = MultiVariateGaussian(n_features, n_patches) + self.gaussian = MultiVariateGaussian(self.n_features, self.n_patches) def forward(self, input_tensor: Tensor) -> Tensor: """Forward-pass image-batch (N, C, H, W) into model to extract features. From ffb2190a963f3a1eb5ccee99f1da5f507aba506d Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:38:22 +0100 Subject: [PATCH 2/7] only use deduced values --- anomalib/models/padim/torch_model.py | 37 ++++++++++------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/anomalib/models/padim/torch_model.py b/anomalib/models/padim/torch_model.py index 72d8c73c66..ace09f9a7d 100644 --- a/anomalib/models/padim/torch_model.py +++ b/anomalib/models/padim/torch_model.py @@ -14,11 +14,6 @@ from anomalib.models.padim.anomaly_map import AnomalyMapGenerator from anomalib.pre_processing import Tiler -_DIMS = { - "resnet18": {"orig_dims": 448, "emb_scale": 4}, - "wide_resnet50_2": {"orig_dims": 1792, "emb_scale": 4}, -} - # defaults from the paper _N_FEATURES_DEFAULTS = { "resnet18": 100, @@ -31,6 +26,9 @@ def _deduce_dims( ) -> Tuple[int, int]: """Run a dry run to deduce the dimensions of the extracted features. + Important: `layers` is assumed to be ordered and the first (layers[0]) + is assumed to be the layer with largest resolution. + Returns: Tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches) """ @@ -72,31 +70,22 @@ def __init__( self.backbone = backbone self.layers = layers self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained) + self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers) - if backbone in _DIMS: - backbone_dims = _DIMS[backbone] - self.n_features_original = backbone_dims["orig_dims"] - emb_scale = backbone_dims["emb_scale"] - patches_dims = torch.tensor(input_size) / emb_scale - self.n_patches = patches_dims.ceil().prod().int().item() - - else: - self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers) + n_features = n_features or _N_FEATURES_DEFAULTS.get(self.backbone) if n_features is None: + raise ValueError( + f"n_features must be specified for backbone {self.backbone}. " + f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}" + ) - if self.backbone in _N_FEATURES_DEFAULTS: - n_features = _N_FEATURES_DEFAULTS[self.backbone] - - else: - raise ValueError( - f"{self.__class__.__name__}.n_features must be specified for backbone {self.backbone}. " - f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}" - ) assert ( - n_features <= self.n_features_original - ), f"n_features ({n_features}) must be <= n_features_original ({self.n_features_original})" + 0 < n_features <= self.n_features_original + ), f"for backbone {self.backbone}, 0 < n_features <= {self.n_features_original}, found {n_features}" + self.n_features = n_features + # pylint: disable=not-callable # Since idx is randomly selected, save it with model to get same results self.register_buffer( From 256e5871eac80d3c52c366539d040e3ade7eff9f Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Tue, 1 Nov 2022 02:06:54 +0100 Subject: [PATCH 3/7] n_features an arg and deduce other dimensions dinamically if necessary --- anomalib/models/padim/lightning_model.py | 5 +- anomalib/models/padim/torch_model.py | 68 ++++++++++++++++++++---- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/anomalib/models/padim/lightning_model.py b/anomalib/models/padim/lightning_model.py index 2fa029c488..d0832b6022 100644 --- a/anomalib/models/padim/lightning_model.py +++ b/anomalib/models/padim/lightning_model.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from omegaconf import DictConfig, ListConfig @@ -39,6 +39,7 @@ def __init__( input_size: Tuple[int, int], backbone: str, pre_trained: bool = True, + n_features: Optional[int] = None, ): super().__init__() @@ -48,6 +49,7 @@ def __init__( backbone=backbone, pre_trained=pre_trained, layers=layers, + n_features=n_features, ).eval() self.stats: List[Tensor] = [] @@ -119,6 +121,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]): layers=hparams.model.layers, backbone=hparams.model.backbone, pre_trained=hparams.model.pre_trained, + n_features=hparams.model.n_features if "n_features" in hparams.model else None, ) self.hparams: Union[DictConfig, ListConfig] # type: ignore self.save_hyperparameters(hparams) diff --git a/anomalib/models/padim/torch_model.py b/anomalib/models/padim/torch_model.py index 17fa51f453..72d8c73c66 100644 --- a/anomalib/models/padim/torch_model.py +++ b/anomalib/models/padim/torch_model.py @@ -14,11 +14,39 @@ from anomalib.models.padim.anomaly_map import AnomalyMapGenerator from anomalib.pre_processing import Tiler -DIMS = { - "resnet18": {"orig_dims": 448, "reduced_dims": 100, "emb_scale": 4}, - "wide_resnet50_2": {"orig_dims": 1792, "reduced_dims": 550, "emb_scale": 4}, +_DIMS = { + "resnet18": {"orig_dims": 448, "emb_scale": 4}, + "wide_resnet50_2": {"orig_dims": 1792, "emb_scale": 4}, } +# defaults from the paper +_N_FEATURES_DEFAULTS = { + "resnet18": 100, + "wide_resnet50_2": 550, +} + + +def _deduce_dims( + feature_extractor: FeatureExtractor, input_size: Tuple[int, int], layers: List[str] +) -> Tuple[int, int]: + """Run a dry run to deduce the dimensions of the extracted features. + + Returns: + Tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches) + """ + + dryrun_input = torch.empty(1, 3, *input_size) + dryrun_features = feature_extractor(dryrun_input) + + # the first layer in `layers` is the largest spatial size + dryrun_emb_first_layer = dryrun_features[layers[0]] + n_patches = torch.tensor(dryrun_emb_first_layer.shape[-2:]).prod().int().item() + + # the original embedding size is the sum of the channels of all layers + n_features_original = sum(dryrun_features[layer].shape[1] for layer in layers) + + return n_features_original, n_patches + class PadimModel(nn.Module): """Padim Module. @@ -36,6 +64,7 @@ def __init__( layers: List[str], backbone: str = "resnet18", pre_trained: bool = True, + n_features: Optional[int] = None, ): super().__init__() self.tiler: Optional[Tiler] = None @@ -43,21 +72,42 @@ def __init__( self.backbone = backbone self.layers = layers self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained) - self.dims = DIMS[backbone] + + if backbone in _DIMS: + backbone_dims = _DIMS[backbone] + self.n_features_original = backbone_dims["orig_dims"] + emb_scale = backbone_dims["emb_scale"] + patches_dims = torch.tensor(input_size) / emb_scale + self.n_patches = patches_dims.ceil().prod().int().item() + + else: + self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers) + + if n_features is None: + + if self.backbone in _N_FEATURES_DEFAULTS: + n_features = _N_FEATURES_DEFAULTS[self.backbone] + + else: + raise ValueError( + f"{self.__class__.__name__}.n_features must be specified for backbone {self.backbone}. " + f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}" + ) + assert ( + n_features <= self.n_features_original + ), f"n_features ({n_features}) must be <= n_features_original ({self.n_features_original})" + self.n_features = n_features # pylint: disable=not-callable # Since idx is randomly selected, save it with model to get same results self.register_buffer( "idx", - torch.tensor(sample(range(0, DIMS[backbone]["orig_dims"]), DIMS[backbone]["reduced_dims"])), + torch.tensor(sample(range(0, self.n_features_original), self.n_features)), ) self.idx: Tensor self.loss = None self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size) - n_features = DIMS[backbone]["reduced_dims"] - patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"] - n_patches = patches_dims.ceil().prod().int().item() - self.gaussian = MultiVariateGaussian(n_features, n_patches) + self.gaussian = MultiVariateGaussian(self.n_features, self.n_patches) def forward(self, input_tensor: Tensor) -> Tensor: """Forward-pass image-batch (N, C, H, W) into model to extract features. From 9735cd808da2edd8eefd2b8c1f6a036583b10ad9 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:38:22 +0100 Subject: [PATCH 4/7] only use deduced values --- anomalib/models/padim/torch_model.py | 37 ++++++++++------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/anomalib/models/padim/torch_model.py b/anomalib/models/padim/torch_model.py index 72d8c73c66..ace09f9a7d 100644 --- a/anomalib/models/padim/torch_model.py +++ b/anomalib/models/padim/torch_model.py @@ -14,11 +14,6 @@ from anomalib.models.padim.anomaly_map import AnomalyMapGenerator from anomalib.pre_processing import Tiler -_DIMS = { - "resnet18": {"orig_dims": 448, "emb_scale": 4}, - "wide_resnet50_2": {"orig_dims": 1792, "emb_scale": 4}, -} - # defaults from the paper _N_FEATURES_DEFAULTS = { "resnet18": 100, @@ -31,6 +26,9 @@ def _deduce_dims( ) -> Tuple[int, int]: """Run a dry run to deduce the dimensions of the extracted features. + Important: `layers` is assumed to be ordered and the first (layers[0]) + is assumed to be the layer with largest resolution. + Returns: Tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches) """ @@ -72,31 +70,22 @@ def __init__( self.backbone = backbone self.layers = layers self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained) + self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers) - if backbone in _DIMS: - backbone_dims = _DIMS[backbone] - self.n_features_original = backbone_dims["orig_dims"] - emb_scale = backbone_dims["emb_scale"] - patches_dims = torch.tensor(input_size) / emb_scale - self.n_patches = patches_dims.ceil().prod().int().item() - - else: - self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers) + n_features = n_features or _N_FEATURES_DEFAULTS.get(self.backbone) if n_features is None: + raise ValueError( + f"n_features must be specified for backbone {self.backbone}. " + f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}" + ) - if self.backbone in _N_FEATURES_DEFAULTS: - n_features = _N_FEATURES_DEFAULTS[self.backbone] - - else: - raise ValueError( - f"{self.__class__.__name__}.n_features must be specified for backbone {self.backbone}. " - f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}" - ) assert ( - n_features <= self.n_features_original - ), f"n_features ({n_features}) must be <= n_features_original ({self.n_features_original})" + 0 < n_features <= self.n_features_original + ), f"for backbone {self.backbone}, 0 < n_features <= {self.n_features_original}, found {n_features}" + self.n_features = n_features + # pylint: disable=not-callable # Since idx is randomly selected, save it with model to get same results self.register_buffer( From d6e14b72f7db4fed4c559653c10b41433d744936 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:44:46 +0100 Subject: [PATCH 5/7] update docs --- anomalib/models/padim/lightning_model.py | 2 ++ anomalib/models/padim/torch_model.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/anomalib/models/padim/lightning_model.py b/anomalib/models/padim/lightning_model.py index d0832b6022..e9b1e70278 100644 --- a/anomalib/models/padim/lightning_model.py +++ b/anomalib/models/padim/lightning_model.py @@ -31,6 +31,8 @@ class Padim(AnomalyModule): input_size (Tuple[int, int]): Size of the model input. backbone (str): Backbone CNN network pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + n_features (int, optional): Number of features to retain in the dimension reduction step. + Default values from the paper are available for: resnet18 (100), wide_resnet50_2 (550). """ def __init__( diff --git a/anomalib/models/padim/torch_model.py b/anomalib/models/padim/torch_model.py index ace09f9a7d..9b35901bea 100644 --- a/anomalib/models/padim/torch_model.py +++ b/anomalib/models/padim/torch_model.py @@ -54,6 +54,8 @@ class PadimModel(nn.Module): layers (List[str]): Layers used for feature extraction backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + n_features (int, optional): Number of features to retain in the dimension reduction step. + Default values from the paper are available for: resnet18 (100), wide_resnet50_2 (550). """ def __init__( From e4e230b10c144e29765af40e0a4206bd13022558 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Mon, 7 Nov 2022 20:20:22 +0100 Subject: [PATCH 6/7] encapsulate dryrun --- .../components/feature_extractors/__init__.py | 3 +- .../components/feature_extractors/utils.py | 29 +++++++++++++++++++ anomalib/models/padim/torch_model.py | 13 ++++----- 3 files changed, 37 insertions(+), 8 deletions(-) create mode 100644 anomalib/models/components/feature_extractors/utils.py diff --git a/anomalib/models/components/feature_extractors/__init__.py b/anomalib/models/components/feature_extractors/__init__.py index 0922fd3701..100e5e234d 100644 --- a/anomalib/models/components/feature_extractors/__init__.py +++ b/anomalib/models/components/feature_extractors/__init__.py @@ -4,5 +4,6 @@ # SPDX-License-Identifier: Apache-2.0 from .feature_extractor import FeatureExtractor +from .utils import dryrun_find_featuremap_dims -__all__ = ["FeatureExtractor"] +__all__ = ["FeatureExtractor", "dryrun_find_featuremap_dims"] diff --git a/anomalib/models/components/feature_extractors/utils.py b/anomalib/models/components/feature_extractors/utils.py new file mode 100644 index 0000000000..0efd1011fd --- /dev/null +++ b/anomalib/models/components/feature_extractors/utils.py @@ -0,0 +1,29 @@ +"""Utility functions to manipulate feature extractors.""" + +from typing import Dict, List, Tuple, Union + +import torch + +from anomalib.models.components.feature_extractors.feature_extractor import ( + FeatureExtractor, +) + + +def dryrun_find_featuremap_dims( + feature_extractor: FeatureExtractor, + input_size: Tuple[int, int], + layers: List[str], +) -> Dict[str, Dict[str, Union[int, Tuple[int, int]]]]: + """Dry run an empty image of `input_size` size to get the featuremap tensors' dimensions (num_features, resolution). + + Returns: + Tuple[int, int]: maping of `layer -> dimensions dict` + Each `dimension dict` has two keys: `num_features` (int) and `resolution`(Tuple[int, int]). + """ + + dryrun_input = torch.empty(1, 3, *input_size) + dryrun_features = feature_extractor(dryrun_input) + return { + layer: {"num_features": dryrun_features[layer].shape[1], "resolution": dryrun_features[layer].shape[2:]} + for layer in layers + } diff --git a/anomalib/models/padim/torch_model.py b/anomalib/models/padim/torch_model.py index 9b35901bea..98cdada3c2 100644 --- a/anomalib/models/padim/torch_model.py +++ b/anomalib/models/padim/torch_model.py @@ -11,6 +11,7 @@ from torch import Tensor, nn from anomalib.models.components import FeatureExtractor, MultiVariateGaussian +from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims from anomalib.models.padim.anomaly_map import AnomalyMapGenerator from anomalib.pre_processing import Tiler @@ -32,16 +33,14 @@ def _deduce_dims( Returns: Tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches) """ + dimensions_mapping = dryrun_find_featuremap_dims(feature_extractor, input_size, layers) - dryrun_input = torch.empty(1, 3, *input_size) - dryrun_features = feature_extractor(dryrun_input) - - # the first layer in `layers` is the largest spatial size - dryrun_emb_first_layer = dryrun_features[layers[0]] - n_patches = torch.tensor(dryrun_emb_first_layer.shape[-2:]).prod().int().item() + # the first layer in `layers` has the largest resolution + first_layer_resolution = dimensions_mapping[layers[0]]["resolution"] + n_patches = torch.tensor(first_layer_resolution).prod().int().item() # the original embedding size is the sum of the channels of all layers - n_features_original = sum(dryrun_features[layer].shape[1] for layer in layers) + n_features_original = sum(dimensions_mapping[layer]["num_features"] for layer in layers) # type: ignore return n_features_original, n_patches From dbd885a9b8df599bf88148e586e8b0467f9c92b0 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Mon, 7 Nov 2022 20:44:35 +0100 Subject: [PATCH 7/7] add test for feature map size deduction --- .../models/test_feature_extractor.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/pre_merge/models/test_feature_extractor.py b/tests/pre_merge/models/test_feature_extractor.py index 7355e8fb2f..cfe456a6ca 100644 --- a/tests/pre_merge/models/test_feature_extractor.py +++ b/tests/pre_merge/models/test_feature_extractor.py @@ -1,7 +1,12 @@ +from typing import Tuple + import pytest import torch -from anomalib.models.components.feature_extractors import FeatureExtractor +from anomalib.models.components.feature_extractors import ( + FeatureExtractor, + dryrun_find_featuremap_dims, +) class TestFeatureExtractor: @@ -33,3 +38,26 @@ def test_feature_extraction(self, backbone, pretrained): assert model.idx == [1, 2, 3] else: pass + + +@pytest.mark.parametrize( + "backbone", + ["resnet18", "wide_resnet50_2"], +) +@pytest.mark.parametrize( + "input_size", + [(256, 256), (224, 224), (128, 128)], +) +def test_dryrun_find_featuremap_dims(backbone: str, input_size: Tuple[int, int]): + """Use the function and check the expected output format.""" + layers = ["layer1", "layer2", "layer3"] + model = FeatureExtractor(backbone=backbone, layers=layers, pre_trained=True) + mapping = dryrun_find_featuremap_dims(model, input_size, layers) + for lay in layers: + layer_mapping = mapping[lay] + num_features = layer_mapping["num_features"] + assert isinstance(num_features, int), f"{type(num_features)}" + resolution = layer_mapping["resolution"] + assert isinstance(resolution, tuple), f"{type(resolution)}" + assert len(resolution) == len(input_size), f"{len(resolution)}, {len(input_size)}" + assert all(isinstance(x, int) for x in resolution), f"{[type(x) for x in resolution]}"