Skip to content

Commit

Permalink
fix synthetic anomaly datasets (#2497)
Browse files Browse the repository at this point in the history
* fix synthetic dataset

* remove print

* fix test

* remove try except

---------

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
alexriedel1 and samet-akcay authored Jan 20, 2025
1 parent 2801fc5 commit 9da426b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
12 changes: 8 additions & 4 deletions src/anomalib/data/utils/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,24 +158,26 @@ class SyntheticAnomalyDataset(AnomalibDataset):
augmentations (Transform | None): Transform object describing the input data augmentations.
source_samples: DataFrame containing normal samples used as source for
synthetic anomalies.
dataset_name: str dataset name for path of temporary anomalous samples
Example:
>>> transform = Compose([...])
>>> dataset = SyntheticAnomalyDataset(
... transform=transform,
... source_samples=normal_df
... source_samples=normal_df,
... dataset_name="synthetic"
... )
>>> len(dataset) # 50/50 normal/anomalous split
100
"""

def __init__(self, augmentations: Transform | None, source_samples: DataFrame) -> None:
def __init__(self, augmentations: Transform | None, source_samples: DataFrame, dataset_name: str) -> None:
super().__init__(augmentations=augmentations)

self.source_samples = source_samples

# Files will be written to a temporary directory in the workdir
root = Path(ROOT)
root = Path(ROOT) / dataset_name
root.mkdir(parents=True, exist_ok=True)

self.root = Path(mkdtemp(dir=root))
Expand All @@ -194,6 +196,8 @@ def __init__(self, augmentations: Transform | None, source_samples: DataFrame) -
0.5,
)

self.samples.attrs["task"] = "segmentation"

@classmethod
def from_dataset(
cls: type["SyntheticAnomalyDataset"],
Expand All @@ -212,7 +216,7 @@ def from_dataset(
>>> normal_dataset = Dataset(...)
>>> synthetic = SyntheticAnomalyDataset.from_dataset(normal_dataset)
"""
return cls(augmentations=dataset.augmentations, source_samples=dataset.samples)
return cls(augmentations=dataset.augmentations, source_samples=dataset.samples, dataset_name=dataset.name)

def __copy__(self) -> "SyntheticAnomalyDataset":
"""Return shallow copy and prevent cleanup of original.
Expand Down
16 changes: 7 additions & 9 deletions src/anomalib/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,10 @@ def generate_output_filename(
>>> generate_output_filename(path, "/output", "MyDataset", "class_A")
PosixPath('/output/image_001.jpg')
Dataset not found raises error:
Dataset not found returns the output path:
>>> generate_output_filename("/wrong/path/image.png", "/out", "Missing")
Traceback (most recent call last):
...
ValueError: Dataset name 'Missing' not found in the input path.
PosixPath('/out/wrong/path/image.png')
Note:
- Directory structure after ``dataset_name`` (or ``category`` if provided) is
Expand All @@ -315,12 +313,12 @@ def generate_output_filename(
input_path = Path(input_path)
output_path = Path(output_path)

# Find the position of the dataset name in the path
try:
# Check if the dataset name is in the input path. If not, just use the output path
if dataset_name.lower() not in [x.lower() for x in input_path.parts]:
dataset_index = len(input_path.parts)
else:
# Find the position of the dataset name in the path
dataset_index = next(i for i, part in enumerate(input_path.parts) if part.lower() == dataset_name.lower())
except ValueError:
msg = f"Dataset name '{dataset_name}' not found in the input path."
raise ValueError(msg) from None

# Determine the start index for preserving subdirectories
start_index = dataset_index + 1
Expand Down
1 change: 1 addition & 0 deletions tests/unit/data/utils/test_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def synthetic_dataset_from_samples(folder_dataset: FolderDataset) -> SyntheticAn
return SyntheticAnomalyDataset(
augmentations=folder_dataset.augmentations,
source_samples=folder_dataset.samples,
dataset_name=folder_dataset.name,
)


Expand Down

0 comments on commit 9da426b

Please sign in to comment.