Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix efficient ad #2015

Merged
merged 5 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions configs/model/efficient_ad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ model:
weight_decay: 1.0e-05
padding: false
pad_maps: true
batch_size: 1

metrics:
pixel:
- AUROC

trainer:
max_epochs: 200
max_epochs: 1000
max_steps: 70000
2 changes: 1 addition & 1 deletion src/anomalib/models/image/efficient_ad/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Anomalies are detected as the difference in output feature maps between the teac

## Usage

`python tools/train.py --model efficient_ad`
`anomalib train --model EfficientAd --data anomalib.data.MVTec --data.train_batch_size 1`

## Benchmark

Expand Down
21 changes: 13 additions & 8 deletions src/anomalib/models/image/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ class EfficientAd(AnomalyModule):
pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the
output anomaly maps so that their size matches the size in the padding = True case.
Defaults to ``True``.
batch_size (int): batch size for imagenet dataloader
Defaults to ``1``.
"""

def __init__(
Expand All @@ -71,7 +69,6 @@ def __init__(
weight_decay: float = 0.00001,
padding: bool = False,
pad_maps: bool = True,
batch_size: int = 1,
) -> None:
super().__init__()

Expand All @@ -83,7 +80,7 @@ def __init__(
padding=padding,
pad_maps=pad_maps,
)
self.batch_size = batch_size
self.batch_size = 1 # imagenet dataloader batch_size is 1 according to the paper
self.lr = lr
self.weight_decay = weight_decay

Expand Down Expand Up @@ -237,9 +234,18 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
def on_train_start(self) -> None:
"""Called before the first training epoch.

First sets up the pretrained teacher model, then prepares the imagenette data, and finally calculates or
loads the channel-wise mean and std of the training dataset and push to the model.
First check if EfficientAd-specific parameters are set correctly (train_batch_size of 1
and no Imagenet normalization in transforms), then sets up the pretrained teacher model,
then prepares the imagenette data, and finally calculates or loads
the channel-wise mean and std of the training dataset and push to the model.
"""
if self.trainer.datamodule.train_batch_size != 1:
msg = "train_batch_size for EfficientAd should be 1."
raise ValueError(msg)
if self._transform and any(isinstance(transform, Normalize) for transform in self._transform.transforms):
msg = "Transforms for EfficientAd should not contain Normalize."
raise ValueError(msg)

sample = next(iter(self.trainer.train_dataloader))
image_size = sample["image"].shape[-2:]
self.prepare_pretrained_model()
Expand Down Expand Up @@ -314,11 +320,10 @@ def learning_type(self) -> LearningType:
return LearningType.ONE_CLASS

def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform:
"""Default transform for Padim."""
"""Default transform for EfficientAd. Imagenet normalization applied in forward."""
image_size = image_size or (256, 256)
return Compose(
[
Resize(image_size, antialias=True),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
],
)
3 changes: 2 additions & 1 deletion tests/integration/model/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def _get_objects(
root=dataset_path / "mvtec",
category="dummy",
task=task_type,
train_batch_size=2,
# EfficientAd requires train batch size 1
train_batch_size=1 if model_name == "efficient_ad" else 2,
)

model = get_model(model_name, **extra_args)
Expand Down
Loading