-
Notifications
You must be signed in to change notification settings - Fork 705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/dick/anomaly score normalization #35
Merged
samet-akcay
merged 44 commits into
development
from
feature/dick/anomaly-score-normalization
Dec 21, 2021
Merged
Changes from 26 commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
1e2ac6e
implement anomaly score and anomaly map normalization
djdameln 95c15c1
switch back to callback design
djdameln cd1612b
fix docstrings and typing
djdameln 149e026
always cast thresholds to float
djdameln 0cdb25c
improve logic of training stats metric
djdameln 122bbb6
switch to torchvision feature extractor
djdameln 5d82316
update configs
djdameln b4246ab
merge development
djdameln 891c709
switch back to saving training stats in model
djdameln 95de8fc
Tensor -> tensor
djdameln d84ea11
loading checkpoint no longer needed
djdameln 1cb7a15
add normalization to inferencer
djdameln 21dfefb
subtract image mean from anomaly maps
djdameln ae70160
small refactor
djdameln 8416336
switch to torchmetrics design for threshold computation
djdameln a41d7c1
import training stats from init
djdameln 1494e5a
small fix
djdameln 9a196b2
switch to torchmetrics for persisting training stats
djdameln ac4573e
add test case for normalization callback
djdameln d05acb9
fix visualizer
djdameln 54cd83e
fix mypy issues
djdameln 4c1a756
fix compression tests
djdameln cecb23e
remove print statement
djdameln 721ea0b
revert checkpoint loading
djdameln 4a3b2b1
revert changing weight path
djdameln 416a4c8
Merge branch 'development' into feature/dick/anomaly-score-normalization
djdameln 95067fe
rename normalization callback
djdameln 8453260
rename anomaly score dsitrbution class
djdameln e30d870
change function ordering
djdameln 895bf1e
remove cuda version from torch and torchvision
djdameln 5d46732
add deprecation warning to feature extractor.
djdameln 918f422
training_stats -> training_distribution
djdameln a55e6d5
update requirements
djdameln 7a07fa4
revert to anomalib feature extractor
djdameln 9853828
workaround for torch 1.8.1 compatibility
djdameln aee4802
rename normalization callback
djdameln 497d3bc
merge development
djdameln c002f2a
Revert "add deprecation warning to feature extractor."
djdameln fc01635
Add batch size support to patchcore
samet-akcay fc8b3d9
Score Normalization doesnt work for Patchcore
samet-akcay 00dc01f
Merge branch 'development' of github.com:openvinotoolkit/anomalib int…
samet-akcay a56ba72
check to prevent using both normalization and nncf
djdameln e23f82d
use get_dataset_path
djdameln 1ffb0fb
Merge branch 'feature/dick/anomaly-score-normalization' of github.com…
djdameln File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
"""Anomaly Score Normalization Callback.""" | ||
import copy | ||
from typing import Any, Dict, Optional | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from pytorch_lightning import Callback, Trainer | ||
from pytorch_lightning.utilities.types import STEP_OUTPUT | ||
from torch.distributions import LogNormal, Normal | ||
|
||
|
||
class NormalizationCallback(Callback): | ||
"""Callback that standardizes the image-level and pixel-level anomaly scores.""" | ||
|
||
def __init__(self): | ||
self.image_dist: Optional[LogNormal] = None | ||
self.pixel_dist: Optional[LogNormal] = None | ||
|
||
def on_test_start(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: | ||
"""Called when the test begins.""" | ||
pl_module.image_metrics.F1.threshold = 0.5 | ||
pl_module.pixel_metrics.F1.threshold = 0.5 | ||
|
||
def on_train_epoch_end( | ||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, _unused: Optional[Any] = None | ||
) -> None: | ||
"""Called when the train epoch ends. | ||
|
||
Use the current model to compute the anomaly score distributions | ||
of the normal training data. This is needed after every epoch, because the statistics must be | ||
stored in the state dict of the checkpoint file. | ||
""" | ||
self._collect_stats(trainer, pl_module) | ||
|
||
def on_validation_batch_end( | ||
self, | ||
_trainer: pl.Trainer, | ||
pl_module: pl.LightningModule, | ||
outputs: Optional[STEP_OUTPUT], | ||
_batch: Any, | ||
_batch_idx: int, | ||
_dataloader_idx: int, | ||
) -> None: | ||
"""Called when the validation batch ends, standardizes the predicted scores and anomaly maps.""" | ||
self._standardize(outputs, pl_module) | ||
|
||
def on_test_batch_end( | ||
self, | ||
_trainer: pl.Trainer, | ||
pl_module: pl.LightningModule, | ||
outputs: Optional[STEP_OUTPUT], | ||
_batch: Any, | ||
_batch_idx: int, | ||
_dataloader_idx: int, | ||
) -> None: | ||
"""Called when the test batch ends, normalizes the predicted scores and anomaly maps.""" | ||
self._standardize(outputs, pl_module) | ||
self._normalize(outputs, pl_module) | ||
|
||
def on_predict_batch_end( | ||
self, | ||
_trainer: pl.Trainer, | ||
pl_module: pl.LightningModule, | ||
outputs: Dict, | ||
_batch: Any, | ||
_batch_idx: int, | ||
_dataloader_idx: int, | ||
) -> None: | ||
"""Called when the predict batch ends, normalizes the predicted scores and anomaly maps.""" | ||
self._standardize(outputs, pl_module) | ||
self._normalize(outputs, pl_module) | ||
outputs["pred_labels"] = outputs["pred_scores"] >= 0.5 | ||
|
||
def _collect_stats(self, trainer, pl_module): | ||
"""Collect the statistics of the normal training data. | ||
|
||
Create a trainer and use it to predict the anomaly maps and scores of the normal training data. Then | ||
estimate the distribution of anomaly scores for normal data at the image and pixel level by computing | ||
the mean and standard deviations. A dictionary containing the computed statistics is stored in self.stats. | ||
""" | ||
predictions = Trainer(gpus=trainer.gpus).predict( | ||
model=copy.deepcopy(pl_module), dataloaders=trainer.datamodule.train_dataloader() | ||
) | ||
pl_module.training_stats.reset() | ||
for batch in predictions: | ||
if "pred_scores" in batch.keys(): | ||
pl_module.training_stats.update(anomaly_scores=batch["pred_scores"]) | ||
if "anomaly_maps" in batch.keys(): | ||
pl_module.training_stats.update(anomaly_maps=batch["anomaly_maps"]) | ||
pl_module.training_stats.compute() | ||
|
||
def _standardize(self, outputs: STEP_OUTPUT, pl_module) -> None: | ||
"""Standardize the predicted scores and anomaly maps to the z-domain.""" | ||
stats = pl_module.training_stats.to(outputs["pred_scores"].device) | ||
|
||
outputs["pred_scores"] = torch.log(outputs["pred_scores"]) | ||
outputs["pred_scores"] = (outputs["pred_scores"] - stats.image_mean) / stats.image_std | ||
if "anomaly_maps" in outputs.keys(): | ||
outputs["anomaly_maps"] = (torch.log(outputs["anomaly_maps"]) - stats.pixel_mean) / stats.pixel_std | ||
outputs["anomaly_maps"] -= (stats.image_mean - stats.pixel_mean) / stats.pixel_std | ||
|
||
def _normalize(self, outputs: STEP_OUTPUT, pl_module: pl.LightningModule) -> None: | ||
"""Normalize the predicted scores and anomaly maps by first standardizing and then computing the CDF.""" | ||
device = outputs["pred_scores"].device | ||
image_threshold = pl_module.image_threshold.value.cpu() | ||
pixel_threshold = pl_module.pixel_threshold.value.cpu() | ||
|
||
norm = Normal(torch.Tensor([0]), torch.Tensor([1])) | ||
outputs["pred_scores"] = norm.cdf(outputs["pred_scores"].cpu() - image_threshold).to(device) | ||
if "anomaly_maps" in outputs.keys(): | ||
outputs["anomaly_maps"] = norm.cdf(outputs["anomaly_maps"].cpu() - pixel_threshold).to(device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
"""Custom anomaly evaluation metrics.""" | ||
from .adaptive_threshold import AdaptiveThreshold | ||
from .auroc import AUROC | ||
from .optimal_f1 import OptimalF1 | ||
from .training_stats import TrainingStats | ||
|
||
__all__ = ["AUROC", "OptimalF1"] | ||
__all__ = ["AUROC", "OptimalF1", "AdaptiveThreshold", "TrainingStats"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Implementation of Optimal F1 score based on TorchMetrics.""" | ||
import torch | ||
from torchmetrics import Metric, PrecisionRecallCurve | ||
|
||
|
||
class AdaptiveThreshold(Metric): | ||
"""Optimal F1 Metric. | ||
|
||
Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the | ||
predicted anomaly scores. | ||
""" | ||
|
||
def __init__(self, default_value: float, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self.precision_recall_curve = PrecisionRecallCurve(num_classes=1, compute_on_step=False) | ||
self.add_state("value", default=torch.tensor(default_value), persistent=True) # pylint: disable=not-callable | ||
self.value = torch.tensor(default_value) # pylint: disable=not-callable | ||
|
||
# pylint: disable=arguments-differ | ||
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore | ||
"""Update the precision-recall curve metric.""" | ||
self.precision_recall_curve.update(preds, target) | ||
|
||
def compute(self) -> torch.Tensor: | ||
"""Compute the threshold that yields the optimal F1 score. | ||
|
||
Compute the F1 scores while varying the threshold. Store the optimal | ||
threshold as attribute and return the maximum value of the F1 score. | ||
|
||
Returns: | ||
Value of the F1 score at the optimal threshold. | ||
""" | ||
precision: torch.Tensor | ||
recall: torch.Tensor | ||
thresholds: torch.Tensor | ||
|
||
precision, recall, thresholds = self.precision_recall_curve.compute() | ||
f1_score = (2 * precision * recall) / (precision + recall + 1e-10) | ||
self.value = thresholds[torch.argmax(f1_score)] | ||
return self.value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
"""Module that computes the parameters of the normal data distribution of the training set.""" | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
from torchmetrics import Metric | ||
|
||
|
||
class TrainingStats(Metric): | ||
djdameln marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Mean and standard deviation of the anomaly scores of normal training data.""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.anomaly_maps = [] | ||
self.anomaly_scores = [] | ||
|
||
self.add_state("image_mean", torch.empty(0), persistent=True) | ||
self.add_state("image_std", torch.empty(0), persistent=True) | ||
self.add_state("pixel_mean", torch.empty(0), persistent=True) | ||
self.add_state("pixel_std", torch.empty(0), persistent=True) | ||
|
||
self.image_mean = torch.empty(0) | ||
self.image_std = torch.empty(0) | ||
self.pixel_mean = torch.empty(0) | ||
self.pixel_std = torch.empty(0) | ||
|
||
# pylint: disable=arguments-differ | ||
def update( # type: ignore | ||
self, anomaly_scores: Optional[Tensor] = None, anomaly_maps: Optional[Tensor] = None | ||
) -> None: | ||
"""Update the precision-recall curve metric.""" | ||
if anomaly_maps is not None: | ||
self.anomaly_maps.append(anomaly_maps) | ||
if anomaly_scores is not None: | ||
self.anomaly_scores.append(anomaly_scores) | ||
|
||
def compute(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | ||
"""Compute stats.""" | ||
anomaly_scores = torch.hstack(self.anomaly_scores) | ||
anomaly_scores = torch.log(anomaly_scores) | ||
|
||
self.image_mean = anomaly_scores.mean() | ||
self.image_std = anomaly_scores.std() | ||
|
||
if self.anomaly_maps: | ||
anomaly_maps = torch.vstack(self.anomaly_maps) | ||
anomaly_maps = torch.log(anomaly_maps).cpu() | ||
|
||
self.pixel_mean = anomaly_maps.mean(dim=0).squeeze() | ||
self.pixel_std = anomaly_maps.std(dim=0).squeeze() | ||
|
||
return self.image_mean, self.image_std, self.pixel_mean, self.pixel_std |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
NormalizationCallback
is not sufficient to understand what the callback does. I think it should be something likeScoreNormalizationCallback
or something similar.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I tried
AnomalyScoreNormalizationCallback
but found it too verbose.ScoreNormalizationCallback
would be better, though maybe a bit vague. How aboutOutputNormalizationCallback
, because it normalizes the outputs of the model?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to
OutputNormalizationCallback
. Let me know what you think