Skip to content

Commit

Permalink
✨ Add hash check to data download (#284)
Browse files Browse the repository at this point in the history
* Add hash check to datadownload

* filename variable

* Fix assert string

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
  • Loading branch information
ashwinvaidya17 and Ashwin Vaidya authored Apr 29, 2022
1 parent 1108571 commit 9076a95
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 14 deletions.
5 changes: 3 additions & 2 deletions anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tqdm import tqdm

from anomalib.data.inference import InferenceDataset
from anomalib.data.utils import DownloadProgressBar, read_image
from anomalib.data.utils import DownloadProgressBar, hash_check, read_image
from anomalib.data.utils.split import (
create_validation_set_from_test_set,
split_normal_images_in_train_set,
Expand Down Expand Up @@ -359,7 +359,8 @@ def prepare_data(self) -> None:
filename=zip_filename,
reporthook=progress_bar.update_to,
) # nosec

logger.info("Checking hash")
hash_check(zip_filename, "c1fa4d56ac50dd50908ce04e81037a8e")
logger.info("Extracting the dataset.")
with zipfile.ZipFile(zip_filename, "r") as zip_file:
zip_file.extractall(self.root.parent)
Expand Down
11 changes: 7 additions & 4 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from torchvision.datasets.folder import VisionDataset

from anomalib.data.inference import InferenceDataset
from anomalib.data.utils import DownloadProgressBar, read_image
from anomalib.data.utils import DownloadProgressBar, hash_check, read_image
from anomalib.data.utils.split import (
create_validation_set_from_test_set,
split_normal_images_in_train_set,
Expand Down Expand Up @@ -378,19 +378,22 @@ def prepare_data(self) -> None:
logger.info("Downloading the Mvtec AD dataset.")
url = "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094"
dataset_name = "mvtec_anomaly_detection.tar.xz"
zip_filename = self.root / dataset_name
with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc="MVTec AD") as progress_bar:
urlretrieve(
url=f"{url}/{dataset_name}",
filename=self.root / dataset_name,
filename=zip_filename,
reporthook=progress_bar.update_to,
)
logger.info("Checking hash")
hash_check(zip_filename, "eefca59f2cede9c3fc5b6befbfec275e")

logger.info("Extracting the dataset.")
with tarfile.open(self.root / dataset_name) as tar_file:
with tarfile.open(zip_filename) as tar_file:
tar_file.extractall(self.root)

logger.info("Cleaning the tar file")
(self.root / dataset_name).unlink()
(zip_filename).unlink()

def setup(self, stage: Optional[str] = None) -> None:
"""Setup train, validation and test data.
Expand Down
4 changes: 2 additions & 2 deletions anomalib/data/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .download import DownloadProgressBar
from .download import DownloadProgressBar, hash_check
from .image import get_image_filenames, read_image

__all__ = ["get_image_filenames", "read_image", "DownloadProgressBar"]
__all__ = ["get_image_filenames", "hash_check", "read_image", "DownloadProgressBar"]
25 changes: 19 additions & 6 deletions anomalib/data/utils/download.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Helper to show progress bars with `urlretrieve`.
Based on https://stackoverflow.com/a/53877507
"""
"""Helper to show progress bars with `urlretrieve`, check hash of file."""

# Copyright (C) 2020 Intel Corporation
#
Expand All @@ -17,7 +14,9 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import hashlib
import io
from pathlib import Path
from typing import Dict, Iterable, Optional, Union

from tqdm import tqdm
Expand Down Expand Up @@ -146,7 +145,7 @@ def __init__(
colour: Optional[str] = None,
delay: Optional[float] = 0,
gui: Optional[bool] = False,
**kwargs
**kwargs,
):
super().__init__(
iterable=iterable,
Expand Down Expand Up @@ -175,13 +174,14 @@ def __init__(
colour=colour,
delay=delay,
gui=gui,
**kwargs
**kwargs,
)
self.total: Optional[Union[int, float]]

def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size=None):
"""Progress bar hook for tqdm.
Based on https://stackoverflow.com/a/53877507
The implementor does not have to bother about passing parameters to this as it gets them from urlretrieve.
However the context needs a few parameters. Refer to the example.
Expand All @@ -193,3 +193,16 @@ def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size=N
if total_size is not None:
self.total = total_size
self.update(chunk_number * max_chunk_size - self.n)


def hash_check(file_path: Path, expected_hash: str):
"""Raise assert error if hash does not match the calculated hash of the file.
Args:
file_path (Path): Path to file.
expected_hash (str): Expected hash of the file.
"""
with open(file_path, "rb") as hash_file:
assert (
hashlib.md5(hash_file.read()).hexdigest() == expected_hash
), f"Downloaded file {file_path} does not match the required hash."

0 comments on commit 9076a95

Please sign in to comment.