-
Notifications
You must be signed in to change notification settings - Fork 705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimized anomaly score calculation for PatchCore for both num_neighb… #633
Conversation
@VdLMV can you please provide a little bit more description? |
@samet-akcay It looks like the original title of PR was too long and is broken in two parts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting this! In addition to speeding up the computation, it turns out that your suggestion also fixes a bug in the patchcore model which caused the forward pass to produce an error when both test_batch_size
and num_neighbors
were 1.
I do feel that the implementation could be made a bit cleaner to improve readability of the code. I would suggest to move the logic related to n_neighbors == 1
out of the forward
method and into self.nearest_neighbors
and self.compute_anomaly_score
.
in self.nearest_neighbors
:
distances = torch.cdist(embedding, self.memory_bank, p=2.0)
if n_neighbors == 1:
# when n_neighbors is 1, speed up computation by using min instead of topk
patch_scores, locations = distances.min(1)
else:
patch_scores, locations = distances.topk(k=n_neighbors, largest=False, dim=1)
return patch_scores, locations
and at the top of self.compute_anomaly_score
:
# Don't need to compute weights if num_neighbors is 1
if self.num_neighbors == 1:
return patch_scores.amax(1)
The forward
method would then remain unchanged which improves the understandability of the data flow in the model.
I have addressed the comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I'm almost happy with the changes. Just one minor comment left
# reshape to w, h | ||
patch_scores = patch_scores.reshape((batch_size, 1, width, height)) | ||
# get anomaly map | ||
anomaly_map = self.anomaly_map_generator(patch_scores) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these lines are duplicate and can be removed
@djdameln sorry I overlooked this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@VdLMV, the CI unfortunately fails because of a formatting issue. Since it failed, the tests didn't run neither. Will you be able to run |
Description
Changes
Checklist