diff --git a/anomalib/models/patchcore/anomaly_map.py b/anomalib/models/patchcore/anomaly_map.py index 086f44ff38..7b29f83ef3 100644 --- a/anomalib/models/patchcore/anomaly_map.py +++ b/anomalib/models/patchcore/anomaly_map.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F from omegaconf import ListConfig -from torch import nn +from torch import Tensor, nn from anomalib.models.components import GaussianBlur2d @@ -26,67 +26,32 @@ def __init__( kernel_size = 2 * int(4.0 * sigma + 0.5) + 1 self.blur = GaussianBlur2d(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma), channels=1) - def compute_anomaly_map(self, patch_scores: torch.Tensor, feature_map_shape: torch.Size) -> torch.Tensor: + def compute_anomaly_map(self, patch_scores: Tensor) -> torch.Tensor: """Pixel Level Anomaly Heatmap. Args: - patch_scores (torch.Tensor): Patch-level anomaly scores - feature_map_shape (torch.Size): 2-D feature map shape (width, height) + patch_scores (Tensor): Patch-level anomaly scores Returns: torch.Tensor: Map of the pixel-level anomaly scores """ - width, height = feature_map_shape - batch_size = len(patch_scores) // (width * height) - - anomaly_map = patch_scores[:, 0].reshape((batch_size, 1, width, height)) - anomaly_map = F.interpolate(anomaly_map, size=(self.input_size[0], self.input_size[1])) - + anomaly_map = F.interpolate(patch_scores, size=(self.input_size[0], self.input_size[1])) anomaly_map = self.blur(anomaly_map) return anomaly_map - @staticmethod - def compute_anomaly_score(patch_scores: torch.Tensor) -> torch.Tensor: - """Compute Image-Level Anomaly Score. - - Args: - patch_scores (torch.Tensor): Patch-level anomaly scores - Returns: - torch.Tensor: Image-level anomaly scores - """ - max_scores = torch.argmax(patch_scores[:, 0]) - confidence = torch.index_select(patch_scores, 0, max_scores) - weights = 1 - torch.max(F.softmax(confidence, dim=-1)) - score = weights * torch.max(patch_scores[:, 0]) - return score - - def forward(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, patch_scores: Tensor) -> Tensor: """Returns anomaly_map and anomaly_score. - Expects `patch_scores` keyword to be passed explicitly - Expects `feature_map_shape` keyword to be passed explicitly + Args: + patch_scores (Tensor): Patch-level anomaly scores Example >>> anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) - >>> map, score = anomaly_map_generator(patch_scores=numpy_array, feature_map_shape=feature_map_shape) - - Raises: - ValueError: If `patch_scores` key is not found + >>> map = anomaly_map_generator(patch_scores=patch_scores) Returns: - Tuple[torch.Tensor, torch.Tensor]: anomaly_map, anomaly_score + Tensor: anomaly_map """ - - if "patch_scores" not in kwargs: - raise ValueError(f"Expected key `patch_scores`. Found {kwargs.keys()}") - - if "feature_map_shape" not in kwargs: - raise ValueError(f"Expected key `feature_map_shape`. Found {kwargs.keys()}") - - patch_scores = kwargs["patch_scores"] - feature_map_shape = kwargs["feature_map_shape"] - - anomaly_map = self.compute_anomaly_map(patch_scores, feature_map_shape) - anomaly_score = self.compute_anomaly_score(patch_scores) - return anomaly_map, anomaly_score + anomaly_map = self.compute_anomaly_map(patch_scores) + return anomaly_map diff --git a/anomalib/models/patchcore/lightning_model.py b/anomalib/models/patchcore/lightning_model.py index 631ea10146..5a5583d42b 100644 --- a/anomalib/models/patchcore/lightning_model.py +++ b/anomalib/models/patchcore/lightning_model.py @@ -107,7 +107,7 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ anomaly_maps, anomaly_score = self.model(batch["image"]) batch["anomaly_maps"] = anomaly_maps - batch["pred_scores"] = anomaly_score.unsqueeze(0) + batch["pred_scores"] = anomaly_score return batch diff --git a/anomalib/models/patchcore/torch_model.py b/anomalib/models/patchcore/torch_model.py index 7c931dff47..21ca16b944 100644 --- a/anomalib/models/patchcore/torch_model.py +++ b/anomalib/models/patchcore/torch_model.py @@ -41,10 +41,10 @@ def __init__( self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) - self.register_buffer("memory_bank", torch.Tensor()) - self.memory_bank: torch.Tensor + self.register_buffer("memory_bank", Tensor()) + self.memory_bank: Tensor - def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + def forward(self, input_tensor: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Return Embedding during training, or a tuple of anomaly map and anomaly score during testing. Steps performed: @@ -56,7 +56,7 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso input_tensor (Tensor): Input tensor Returns: - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Embedding for training, + Union[Tensor, Tuple[Tensor, Tensor]]: Embedding for training, anomaly map and anomaly score for testing. """ if self.tiler: @@ -71,21 +71,29 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso if self.tiler: embedding = self.tiler.untile(embedding) - feature_map_shape = embedding.shape[-2:] + batch_size, _, width, height = embedding.shape embedding = self.reshape_embedding(embedding) if self.training: output = embedding else: - patch_scores = self.nearest_neighbors(embedding=embedding, n_neighbors=self.num_neighbors) - anomaly_map, anomaly_score = self.anomaly_map_generator( - patch_scores=patch_scores, feature_map_shape=feature_map_shape - ) + # apply nearest neighbor search + patch_scores, locations = self.nearest_neighbors(embedding=embedding, n_neighbors=1) + # reshape to batch dimension + patch_scores = patch_scores.reshape((batch_size, -1)) + locations = locations.reshape((batch_size, -1)) + # compute anomaly score + anomaly_score = self.compute_anomaly_score(patch_scores, locations, embedding) + # reshape to w, h + patch_scores = patch_scores.reshape((batch_size, 1, width, height)) + # get anomaly map + anomaly_map = self.anomaly_map_generator(patch_scores) + output = (anomaly_map, anomaly_score) return output - def generate_embedding(self, features: Dict[str, Tensor]) -> torch.Tensor: + def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor: """Generate embedding from hierarchical feature map. Args: @@ -121,7 +129,7 @@ def reshape_embedding(embedding: Tensor) -> Tensor: embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size) return embedding - def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> None: + def subsample_embedding(self, embedding: Tensor, sampling_ratio: float) -> None: """Subsample embedding based on coreset sampling and store to memory. Args: @@ -134,7 +142,7 @@ def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> coreset = sampler.sample_coreset() self.memory_bank = coreset - def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor: + def nearest_neighbors(self, embedding: Tensor, n_neighbors: int) -> Tuple[Tensor, Tensor]: """Nearest Neighbours using brute force method and euclidean norm. Args: @@ -143,7 +151,35 @@ def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor: Returns: Tensor: Patch scores. + Tensor: Locations of the nearest neighbor(s). """ distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm - patch_scores, _ = distances.topk(k=n_neighbors, largest=False, dim=1) - return patch_scores + patch_scores, locations = distances.topk(k=n_neighbors, largest=False, dim=1) + return patch_scores, locations + + def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embedding: Tensor) -> Tensor: + """Compute Image-Level Anomaly Score. + + Args: + patch_scores (Tensor): Patch-level anomaly scores + locations: Memory bank locations of the nearest neighbor for each patch location + embedding: The feature embeddings that generated the patch scores + Returns: + Tensor: Image-level anomaly scores + """ + + # 1. Find the patch with the largest distance to it's nearest neighbor in each image + max_patches = torch.argmax(patch_scores, dim=1) # (m^test,* in the paper) + # 2. Find the distance of the patch to it's nearest neighbor, and the location of the nn in the membank + score = patch_scores[torch.arange(len(patch_scores)), max_patches] # s in the paper + nn_index = locations[torch.arange(len(patch_scores)), max_patches] # m^* in the paper + # 3. Find the support samples of the nearest neighbor in the membank + nn_sample = self.memory_bank[nn_index, :] + _, support_samples = self.nearest_neighbors(nn_sample, n_neighbors=self.num_neighbors) # N_b(m^*) in the paper + # 4. Find the distance of the patch features to each of the support samples + distances = torch.cdist(embedding[max_patches].unsqueeze(1), self.memory_bank[support_samples], p=2.0) + # 5. Apply softmax to find the weights + weights = (1 - F.softmax(distances.squeeze()))[..., 0] + # 6. Apply the weight factor to the score + score = weights * score # S^* in the paper + return score