Skip to content

Commit

Permalink
Refactored-assert-statements-with-explicit-error-handling (#1825)
Browse files Browse the repository at this point in the history
* Refactored-assert-statements-with-explicit-error-handling

Signed-off-by: sahusiddharth <siddharth.sahu@plaksha.edu.in>

* fixed ruff

Signed-off-by: sahusiddharth <siddharth.sahu@plaksha.edu.in>

* fixed logic in mvtec

Signed-off-by: sahusiddharth <siddharth.sahu@plaksha.edu.in>

---------

Signed-off-by: sahusiddharth <siddharth.sahu@plaksha.edu.in>
  • Loading branch information
sahusiddharth authored Mar 18, 2024
1 parent 2c87edf commit 4c7d7c3
Show file tree
Hide file tree
Showing 31 changed files with 259 additions and 117 deletions.
34 changes: 24 additions & 10 deletions src/anomalib/callbacks/nncf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,26 @@ def compose_nncf_config(nncf_config: dict, enabled_options: list[str]) -> dict:
# So, user can define `order_of_parts` in the optimisation_config
# to specify the order of applying the parts.
order_of_parts = optimisation_parts["order_of_parts"]
assert isinstance(order_of_parts, list), 'The field "order_of_parts" in optimisation config should be a list'
if not isinstance(order_of_parts, list):
msg = 'The field "order_of_parts" in optimization config should be a list'
raise TypeError(msg)

for part in enabled_options:
assert (
part in order_of_parts
), f"The part {part} is selected, but it is absent in order_of_parts={order_of_parts}"
if part not in order_of_parts:
msg = f"The part {part} is selected, but it is absent in order_of_parts={order_of_parts}"
raise ValueError(msg)

optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options]

assert "base" in optimisation_parts, 'Error: the optimisation config does not contain the "base" part'
if "base" not in optimisation_parts:
msg = 'Error: the optimisation config does not contain the "base" part'
raise KeyError(msg)
nncf_config_part = optimisation_parts["base"]

for part in optimisation_parts_to_choose:
assert part in optimisation_parts, f'Error: the optimisation config does not contain the part "{part}"'
if part not in optimisation_parts:
msg = f'Error: the optimisation config does not contain the part "{part}"'
raise KeyError(msg)
optimisation_part_dict = optimisation_parts[part]
try:
nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict)
Expand Down Expand Up @@ -205,9 +211,16 @@ def _err_str(_a: dict | list, _b: dict | list, _key: int | str | None = None) ->
f" type(b) = {type(_b)}"
)

assert isinstance(a, dict | list), f"Can merge only dicts and lists, whereas type(a)={type(a)}"
assert isinstance(b, dict | list), _err_str(a, b, cur_key)
assert isinstance(a, list) == isinstance(b, list), _err_str(a, b, cur_key)
if not (isinstance(a, dict | list)):
msg = f"Can merge only dicts and lists, whereas type(a)={type(a)}"
raise TypeError(msg)

if not (isinstance(b, dict | list)):
raise TypeError(_err_str(a, b, cur_key))

if (isinstance(a, list) and not isinstance(b, list)) or (isinstance(b, list) and not isinstance(a, list)):
raise TypeError(_err_str(a, b, cur_key))

if isinstance(a, list) and isinstance(b, list):
# the main diff w.r.t. mmcf.Config -- merging of lists
return a + b
Expand All @@ -222,7 +235,8 @@ def _err_str(_a: dict | list, _b: dict | list, _key: int | str | None = None) ->
a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key)
continue

assert not isinstance(b[k], dict | list), _err_str(a[k], b[k], new_cur_key)
if any(isinstance(b[k], t) for t in [dict, list]):
raise TypeError(_err_str(a[k], b[k], new_cur_key))

# suppose here that a[k] and b[k] are scalars, just overwrite
a[k] = b[k]
Expand Down
24 changes: 17 additions & 7 deletions src/anomalib/data/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def subsample(self, indices: Sequence[int], inplace: bool = False) -> "AnomalibD
inplace (bool): When true, the subsampling will be performed on the instance itself.
Defaults to ``False``.
"""
assert len(set(indices)) == len(indices), "No duplicates allowed in indices."
if len(set(indices)) != len(indices):
msg = "No duplicates allowed in indices."
raise ValueError(msg)
dataset = self if inplace else copy.deepcopy(self)
dataset.samples = self.samples.iloc[indices].reset_index(drop=True)
return dataset
Expand All @@ -116,12 +118,18 @@ def samples(self, samples: DataFrame) -> None:
samples (DataFrame): DataFrame with new samples.
"""
# validate the passed samples by checking the
assert isinstance(samples, DataFrame), f"samples must be a pandas.DataFrame, found {type(samples)}"
if not isinstance(samples, DataFrame):
msg = f"samples must be a pandas.DataFrame, found {type(samples)}"
raise TypeError(msg)

expected_columns = _EXPECTED_COLUMNS_PERTASK[self.task]
assert all(
col in samples.columns for col in expected_columns
), f"samples must have (at least) columns {expected_columns}, found {samples.columns}"
assert samples["image_path"].apply(lambda p: Path(p).exists()).all(), "missing file path(s) in samples"
if not all(col in samples.columns for col in expected_columns):
msg = f"samples must have (at least) columns {expected_columns}, found {samples.columns}"
raise ValueError(msg)

if not samples["image_path"].apply(lambda p: Path(p).exists()).all():
msg = "missing file path(s) in samples"
raise FileNotFoundError(msg)

self._samples = samples.sort_values(by="image_path", ignore_index=True)

Expand Down Expand Up @@ -193,7 +201,9 @@ def __add__(self, other_dataset: "AnomalibDataset") -> "AnomalibDataset":
Returns:
AnomalibDataset: Concatenated dataset.
"""
assert isinstance(other_dataset, self.__class__), "Cannot concatenate datasets that are not of the same type."
if not isinstance(other_dataset, self.__class__):
msg = "Cannot concatenate datasets that are not of the same type."
raise TypeError(msg)
dataset = copy.deepcopy(self)
dataset.samples = pd.concat([self.samples, other_dataset.samples], ignore_index=True)
return dataset
22 changes: 16 additions & 6 deletions src/anomalib/data/base/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def __init__(

def __len__(self) -> int:
"""Get length of the dataset."""
assert isinstance(self.indexer, ClipsIndexer)
if not isinstance(self.indexer, ClipsIndexer):
msg = "self.indexer must be an instance of ClipsIndexer."
raise TypeError(msg)
return self.indexer.num_clips()

@property
Expand All @@ -94,7 +96,9 @@ def _setup_clips(self) -> None:
Should be called after each change to self._samples
"""
assert callable(self.indexer_cls)
if not callable(self.indexer_cls):
msg = "self.indexer_cls must be callable."
raise TypeError(msg)
self.indexer = self.indexer_cls( # pylint: disable=not-callable
video_paths=list(self.samples.image_path),
mask_paths=list(self.samples.mask_path),
Expand Down Expand Up @@ -145,8 +149,9 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
Returns:
dict[str, str | torch.Tensor]: Dictionary containing the mask, clip and file system information.
"""
assert isinstance(self.indexer, ClipsIndexer)

if not isinstance(self.indexer, ClipsIndexer):
msg = "self.indexer must be an instance of ClipsIndexer."
raise TypeError(msg)
item = self.indexer.get_item(index)
# include the untransformed image for visualization
item["original_image"] = item["image"].to(torch.uint8)
Expand Down Expand Up @@ -185,8 +190,13 @@ def _setup(self, _stage: str | None = None) -> None:
Video datamodules are not compatible with synthetic anomaly generation.
"""
assert self.train_data is not None
assert self.test_data is not None
if self.train_data is None:
msg = "self.train_data cannot be None."
raise ValueError(msg)

if self.test_data is None:
msg = "self.test_data cannot be None."
raise ValueError(msg)

self.train_data.setup()
self.test_data.setup()
Expand Down
31 changes: 21 additions & 10 deletions src/anomalib/data/depth/folder_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from anomalib import TaskType
from anomalib.data.base import AnomalibDataModule, AnomalibDepthDataset
from anomalib.data.errors import MisMatchError
from anomalib.data.utils import (
DirType,
LabelName,
Expand All @@ -24,7 +25,7 @@
from anomalib.data.utils.path import _prepare_files_labels, validate_and_resolve_path


def make_folder3d_dataset(
def make_folder3d_dataset( # noqa: C901
normal_dir: str | Path,
root: str | Path | None = None,
abnormal_dir: str | Path | None = None,
Expand Down Expand Up @@ -74,7 +75,9 @@ def make_folder3d_dataset(
abnormal_depth_dir = validate_and_resolve_path(abnormal_depth_dir, root) if abnormal_depth_dir else None
normal_test_depth_dir = validate_and_resolve_path(normal_test_depth_dir, root) if normal_test_depth_dir else None

assert normal_dir.is_dir(), "A folder location must be provided in normal_dir."
if not normal_dir.is_dir():
msg = "A folder location must be provided in normal_dir."
raise ValueError(msg)

filenames = []
labels = []
Expand Down Expand Up @@ -129,17 +132,23 @@ def make_folder3d_dataset(
].image_path.to_numpy()

# make sure every rgb image has a corresponding depth image and that the file exists
assert (
mismatch = (
samples.loc[samples.label_index == LabelName.ABNORMAL]
.apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and depth images. Make sure the mask files in 'xyz' \
folder follow the same naming convention as the anomalous images in the dataset \
(e.g. image: '000.png', depth: '000.tiff')."
)
if not mismatch:
msg = """Mismatch between anomalous images and depth images. Make sure the mask files
in 'xyz' folder follow the same naming convention as the anomalous images in the dataset
(e.g. image: '000.png', depth: '000.tiff')."""
raise MisMatchError(msg)

assert samples.depth_path.apply(
missing_depth_files = samples.depth_path.apply(
lambda x: Path(x).exists() if not isna(x) else True,
).all(), "missing depth image files"
).all()
if not missing_depth_files:
msg = "Missing depth image files."
raise FileNotFoundError(msg)

samples = samples.astype({"depth_path": "str"})

Expand All @@ -152,9 +161,11 @@ def make_folder3d_dataset(
samples = samples.astype({"mask_path": "str"})

# make sure all the files exist
assert samples.mask_path.apply(
if not samples.mask_path.apply(
lambda x: Path(x).exists() if x != "" else True,
).all(), f"missing mask files, mask_dir={mask_dir}"
).all():
msg = f"Missing mask files. mask_dir={mask_dir}"
raise FileNotFoundError(msg)
else:
samples["mask_path"] = ""

Expand Down
24 changes: 15 additions & 9 deletions src/anomalib/data/depth/mvtec_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from anomalib import TaskType
from anomalib.data.base import AnomalibDataModule, AnomalibDepthDataset
from anomalib.data.errors import MisMatchError
from anomalib.data.utils import (
DownloadInfo,
LabelName,
Expand Down Expand Up @@ -146,22 +147,27 @@ def make_mvtec_3d_dataset(
samples = samples.astype({"image_path": "str", "mask_path": "str", "depth_path": "str"})

# assert that the right mask files are associated with the right test images
assert (
mismatch_masks = (
samples.loc[samples.label_index == LabelName.ABNORMAL]
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and ground truth masks. Make sure the mask files in 'ground_truth' \
folder follow the same naming convention as the anomalous images in the dataset (e.g. image: '000.png', \
mask: '000.png' or '000_mask.png')."
)
if not mismatch_masks:
msg = """Mismatch between anomalous images and ground truth masks. Make sure the mask files
in 'ground_truth' folder follow the same naming convention as the anomalous images in
the dataset (e.g. image: '000.png', mask: '000.png' or '000_mask.png')."""
raise MisMatchError(msg)

# assert that the right depth image files are associated with the right test images
assert (
mismatch_depth = (
samples.loc[samples.label_index == LabelName.ABNORMAL]
.apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and depth images. Make sure the mask files in 'xyz' \
folder follow the same naming convention as the anomalous images in the dataset (e.g. image: '000.png', \
depth: '000.tiff')."
)
if not mismatch_depth:
msg = """Mismatch between anomalous images and depth images. Make sure the mask files in
'xyz' folder follow the same naming convention as the anomalous images in the dataset
(e.g. image: '000.png', depth: '000.tiff')."""
raise MisMatchError(msg)

if split:
samples = samples[samples.split == split].reset_index(drop=True)
Expand Down
19 changes: 19 additions & 0 deletions src/anomalib/data/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Custom Exception Class for Mismatch Detection (MisMatchError)."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


class MisMatchError(Exception):
"""Exception raised when a mismatch is detected.
Attributes:
message (str): Explanation of the error.
"""

def __init__(self, message: str = "") -> None:
if message:
self.message = message
else:
self.message = "Mismatch detected."
super().__init__(self.message)
16 changes: 11 additions & 5 deletions src/anomalib/data/image/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from anomalib import TaskType
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
from anomalib.data.errors import MisMatchError
from anomalib.data.utils import (
DirType,
LabelName,
Expand Down Expand Up @@ -102,7 +103,9 @@ def _resolve_path_and_convert_to_list(path: str | Path | Sequence[str | Path] |
abnormal_dir = _resolve_path_and_convert_to_list(abnormal_dir)
normal_test_dir = _resolve_path_and_convert_to_list(normal_test_dir)
mask_dir = _resolve_path_and_convert_to_list(mask_dir)
assert len(normal_dir) > 0, "A folder location must be provided in normal_dir."
if len(normal_dir) == 0:
msg = "A folder location must be provided in normal_dir."
raise ValueError(msg)

filenames = []
labels = []
Expand Down Expand Up @@ -144,13 +147,16 @@ def _resolve_path_and_convert_to_list(path: str | Path | Sequence[str | Path] |
samples = samples.astype({"mask_path": "str"})

# make sure all every rgb image has a corresponding mask image.
assert (
if not (
samples.loc[samples.label_index == LabelName.ABNORMAL]
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and mask images. Make sure the mask files \
folder follow the same naming convention as the anomalous images in the dataset \
(e.g. image: '000.png', mask: '000.png')."
):
msg = """Mismatch between anomalous images and mask images. Make sure the mask files "
"folder follow the same naming convention as the anomalous images in the dataset "
"(e.g. image: '000.png', mask: '000.png')."""
raise MisMatchError(msg)

else:
samples["mask_path"] = ""

Expand Down
11 changes: 7 additions & 4 deletions src/anomalib/data/image/kolektor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from anomalib import TaskType
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
from anomalib.data.errors import MisMatchError
from anomalib.data.utils import (
DownloadInfo,
Split,
Expand Down Expand Up @@ -165,13 +166,15 @@ def make_kolektor_dataset(
samples = samples[["path", "item", "split", "label", "image_path", "mask_path", "label_index"]]

# assert that the right mask files are associated with the right test images
assert (
if not (
samples.loc[samples.label_index == 1]
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and ground truth masks. Make sure the mask files \
follow the same naming convention as the anomalous images in the dataset (e.g. image: 'Part0.jpg', \
mask: 'Part0_label.bmp')."
):
msg = """Mismatch between anomalous images and ground truth masks. Make sure the mask files
follow the same naming convention as the anomalous images in the dataset
(e.g. image: 'Part0.jpg', mask: 'Part0_label.bmp')."""
raise MisMatchError(msg)

# Get the dataframe for the required split
if split:
Expand Down
18 changes: 10 additions & 8 deletions src/anomalib/data/image/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from anomalib import TaskType
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
from anomalib.data.errors import MisMatchError
from anomalib.data.utils import (
DownloadInfo,
LabelName,
Expand Down Expand Up @@ -154,14 +155,15 @@ def make_mvtec_dataset(
] = mask_samples.image_path.to_numpy()

# assert that the right mask files are associated with the right test images
if len(samples.loc[samples.label_index == LabelName.ABNORMAL]):
assert (
samples.loc[samples.label_index == LabelName.ABNORMAL]
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and ground truth masks. Make sure the mask files in 'ground_truth' \
folder follow the same naming convention as the anomalous images in the dataset (e.g. image: \
'000.png', mask: '000.png' or '000_mask.png')."
abnormal_samples = samples.loc[samples.label_index == LabelName.ABNORMAL]
if (
len(abnormal_samples)
and not abnormal_samples.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1).all()
):
msg = """Mismatch between anomalous images and ground truth masks. Make sure t
he mask files in 'ground_truth' folder follow the same naming convention as the
anomalous images in the dataset (e.g. image: '000.png', mask: '000.png' or '000_mask.png')."""
raise MisMatchError(msg)

if split:
samples = samples[samples.split == split].reset_index(drop=True)
Expand Down
Loading

0 comments on commit 4c7d7c3

Please sign in to comment.