Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized anomaly score calculation for PatchCore for both num_neighb… #633

Merged
merged 7 commits into from
Nov 2, 2022
8 changes: 4 additions & 4 deletions anomalib/data/utils/generators/perlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def generate_perlin_noise_2d(shape, res):
"""Fractal perlin noise."""

def f(t):
return 6 * t**5 - 15 * t**4 + 10 * t**3
return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3

delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
Expand All @@ -68,7 +68,7 @@ def f(t):
def random_2d_perlin(
shape: Tuple,
res: Tuple[Union[int, Tensor], Union[int, Tensor]],
fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3,
fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3,
) -> Union[np.ndarray, Tensor]:
"""Returns a random 2d perlin noise array.

Expand All @@ -90,7 +90,7 @@ def random_2d_perlin(
return result


def _rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
def _rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
"""Generate a random image containing Perlin noise. Numpy version."""
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
Expand All @@ -116,7 +116,7 @@ def dot(grad, shift):
return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1])


def _rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
def _rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
"""Generate a random image containing Perlin noise. PyTorch version."""
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/cflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_logp(dim_feature_vector: int, p_u: torch.Tensor, logdet_j: torch.Tensor)
torch.Tensor: Log probability
"""
ln_sqrt_2pi = -np.log(np.sqrt(2 * np.pi)) # ln(sqrt(2*pi))
logp = dim_feature_vector * ln_sqrt_2pi - 0.5 * torch.sum(p_u**2, 1) + logdet_j
logp = dim_feature_vector * ln_sqrt_2pi - 0.5 * torch.sum(p_u ** 2, 1) + logdet_j
return logp


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def johnson_lindenstrauss_min_dim(self, n_samples: int, eps: float = 0.1):
eps (float, optional): Minimum distortion rate. Defaults to 0.1.
"""

denominator = (eps**2 / 2) - (eps**3 / 3)
denominator = (eps ** 2 / 2) - (eps ** 3 / 3)
return (4 * np.log(n_samples) / denominator).astype(np.int64)

def fit(self, embedding: Tensor) -> "SparseRandomProjection":
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/components/stats/kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def fit(self, dataset: Tensor) -> None:

cov_mat = self.cov(dataset.T)
inv_cov_mat = torch.linalg.inv(cov_mat)
inv_cov = inv_cov_mat / factor**2
inv_cov = inv_cov_mat / factor ** 2

# transform data to account for bandwidth
bw_transform = torch.linalg.cholesky(inv_cov)
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/fastflow/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward(self, hidden_variables: List[Tensor]) -> Tensor:
"""
flow_maps: List[Tensor] = []
for hidden_variable in hidden_variables:
log_prob = -torch.mean(hidden_variable**2, dim=1, keepdim=True) * 0.5
log_prob = -torch.mean(hidden_variable ** 2, dim=1, keepdim=True) * 0.5
prob = torch.exp(log_prob)
flow_map = F.interpolate(
input=-prob,
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/fastflow/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ def forward(self, hidden_variables: List[Tensor], jacobians: List[Tensor]) -> Te
"""
loss = torch.tensor(0.0, device=hidden_variables[0].device) # pylint: disable=not-callable
for (hidden_variable, jacobian) in zip(hidden_variables, jacobians):
loss += torch.mean(0.5 * torch.sum(hidden_variable**2, dim=(1, 2, 3)) - jacobian)
loss += torch.mean(0.5 * torch.sum(hidden_variable ** 2, dim=(1, 2, 3)) - jacobian)
return loss
2 changes: 1 addition & 1 deletion anomalib/models/ganomaly/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(

# Calculate input channel size to recreate inverse pyramid
exp_factor = math.ceil(math.log(min(input_size) // 2, 2)) - 2
n_input_features = n_features * (2**exp_factor)
n_input_features = n_features * (2 ** exp_factor)

# CNN layer for latent vector input
self.latent_input.add_module(
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dataset:
category: bottle
image_size: 224
train_batch_size: 32
test_batch_size: 1
test_batch_size: 32
num_workers: 8
transform_config:
train: null
Expand Down
11 changes: 9 additions & 2 deletions anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def nearest_neighbors(self, embedding: Tensor, n_neighbors: int) -> Tuple[Tensor
Tensor: Locations of the nearest neighbor(s).
"""
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
patch_scores, locations = distances.topk(k=n_neighbors, largest=False, dim=1)
if n_neighbors == 1:
# when n_neighbors is 1, speed up computation by using min instead of topk
patch_scores, locations = distances.min(1)
else:
patch_scores, locations = distances.topk(k=n_neighbors, largest=False, dim=1)
return patch_scores, locations

def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embedding: Tensor) -> Tensor:
Expand All @@ -168,6 +172,9 @@ def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embeddi
Tensor: Image-level anomaly scores
"""

# Don't need to compute weights if num_neighbors is 1
if self.num_neighbors == 1:
return patch_scores.amax(1)
# 1. Find the patch with the largest distance to it's nearest neighbor in each image
max_patches = torch.argmax(patch_scores, dim=1) # (m^test,* in the paper)
# 2. Find the distance of the patch to it's nearest neighbor, and the location of the nn in the membank
Expand All @@ -179,7 +186,7 @@ def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embeddi
# 4. Find the distance of the patch features to each of the support samples
distances = torch.cdist(embedding[max_patches].unsqueeze(1), self.memory_bank[support_samples], p=2.0)
# 5. Apply softmax to find the weights
weights = (1 - F.softmax(distances.squeeze()))[..., 0]
weights = (1 - F.softmax(distances.squeeze(), 1))[..., 0]
# 6. Apply the weight factor to the score
score = weights * score # S^* in the paper
return score