diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index f0fd4330ae5..e99017d8b55 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -14,13 +14,7 @@ import pytest from torchvision import datasets -from torchvision.datasets.utils import ( - _get_redirect_url, - check_integrity, - download_file_from_google_drive, - download_url, - USER_AGENT, -) +from torchvision.datasets.utils import _get_redirect_url, USER_AGENT def limit_requests_per_time(min_secs_between_requests=2.0): @@ -84,47 +78,45 @@ def inner_wrapper(request, *args, **kwargs): @contextlib.contextmanager def log_download_attempts( - urls_and_md5s=None, - file="utils", - patch=True, - mock_auxiliaries=None, + urls, + *, + dataset_module, ): - def add_mock(stack, name, file, **kwargs): + def maybe_add_mock(*, module, name, stack, lst=None): + patcher = unittest.mock.patch(f"torchvision.datasets.{module}.{name}") + try: - return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs)) - except AttributeError as error: - if file != "utils": - return add_mock(stack, name, "utils", **kwargs) - else: - raise pytest.UsageError from error - - if urls_and_md5s is None: - urls_and_md5s = set() - if mock_auxiliaries is None: - mock_auxiliaries = patch + mock = stack.enter_context(patcher) + except AttributeError: + return - with contextlib.ExitStack() as stack: - url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url) - google_drive_mock = add_mock( - stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive - ) + if lst is not None: + lst.append(mock) - if mock_auxiliaries: - add_mock(stack, "extract_archive", file) + with contextlib.ExitStack() as stack: + download_url_mocks = [] + download_file_from_google_drive_mocks = [] + for module in [dataset_module, "utils"]: + maybe_add_mock(module=module, name="download_url", stack=stack, lst=download_url_mocks) + maybe_add_mock( + module=module, + name="download_file_from_google_drive", + stack=stack, + lst=download_file_from_google_drive_mocks, + ) + maybe_add_mock(module=module, name="extract_archive", stack=stack) try: - yield urls_and_md5s + yield finally: - for args, kwargs in url_mock.call_args_list: - url = args[0] - md5 = args[-1] if len(args) == 4 else kwargs.get("md5") - urls_and_md5s.add((url, md5)) + for download_url_mock in download_url_mocks: + for args, kwargs in download_url_mock.call_args_list: + urls.append(args[0] if args else kwargs["url"]) - for args, kwargs in google_drive_mock.call_args_list: - id = args[0] - url = f"https://drive.google.com/file/d/{id}" - md5 = args[3] if len(args) == 4 else kwargs.get("md5") - urls_and_md5s.add((url, md5)) + for download_file_from_google_drive_mock in download_file_from_google_drive_mocks: + for args, kwargs in download_file_from_google_drive_mock.call_args_list: + file_id = args[0] if args else kwargs["file_id"] + urls.append(f"https://drive.google.com/file/d/{file_id}") def retry(fn, times=1, wait=5.0): @@ -170,45 +162,14 @@ def assert_url_is_accessible(url, timeout=5.0): urlopen(request, timeout=timeout) -def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0): - file = path.join(tmpdir, path.basename(url)) - with assert_server_response_ok(): - with open(file, "wb") as fh: - request = Request(url, headers={"User-Agent": USER_AGENT}) - response = urlopen(request, timeout=timeout) - fh.write(response.read()) - - assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" - - -class DownloadConfig: - def __init__(self, url, md5=None, id=None): - self.url = url - self.md5 = md5 - self.id = id or url - - def __repr__(self) -> str: - return self.id - +def collect_urls(dataset_cls, *args, **kwargs): + urls = [] + with contextlib.suppress(Exception), log_download_attempts( + urls, dataset_module=dataset_cls.__module__.split(".")[-1] + ): + dataset_cls(*args, **kwargs) -def make_download_configs(urls_and_md5s, name=None): - return [ - DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s - ] - - -def collect_download_configs(dataset_loader, name=None, **kwargs): - urls_and_md5s = set() - try: - with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs): - dataset = dataset_loader() - except Exception: - dataset = None - - if name is None and dataset is not None: - name = type(dataset).__name__ - - return make_download_configs(urls_and_md5s, name) + return [(url, f"{dataset_cls.__name__}, {url}") for url in urls] # This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a @@ -223,12 +184,14 @@ def root(): def places365(): - return itertools.chain( - *[ - collect_download_configs( - lambda: datasets.Places365(ROOT, split=split, small=small, download=True), - name=f"Places365, {split}, {'small' if small else 'large'}", - file="places365", + return itertools.chain.from_iterable( + [ + collect_urls( + datasets.Places365, + ROOT, + split=split, + small=small, + download=True, ) for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)) ] @@ -236,30 +199,26 @@ def places365(): def caltech101(): - return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101") + return collect_urls(datasets.Caltech101, ROOT, download=True) def caltech256(): - return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256") + return collect_urls(datasets.Caltech256, ROOT, download=True) def cifar10(): - return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10") + return collect_urls(datasets.CIFAR10, ROOT, download=True) def cifar100(): - return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100") + return collect_urls(datasets.CIFAR100, ROOT, download=True) def voc(): # TODO: Also test the "2007-test" key - return itertools.chain( - *[ - collect_download_configs( - lambda: datasets.VOCSegmentation(ROOT, year=year, download=True), - name=f"VOC, {year}", - file="voc", - ) + return itertools.chain.from_iterable( + [ + collect_urls(datasets.VOCSegmentation, ROOT, year=year, download=True) for year in ("2007", "2008", "2009", "2010", "2011", "2012") ] ) @@ -267,59 +226,42 @@ def voc(): def mnist(): with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]): - return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST") + return collect_urls(datasets.MNIST, ROOT, download=True) def fashion_mnist(): - return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST") + return collect_urls(datasets.FashionMNIST, ROOT, download=True) def kmnist(): - return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST") + return collect_urls(datasets.KMNIST, ROOT, download=True) def emnist(): # the 'split' argument can be any valid one, since everything is downloaded anyway - return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST") + return collect_urls(datasets.EMNIST, ROOT, split="byclass", download=True) def qmnist(): - return itertools.chain( - *[ - collect_download_configs( - lambda: datasets.QMNIST(ROOT, what=what, download=True), - name=f"QMNIST, {what}", - file="mnist", - ) - for what in ("train", "test", "nist") - ] + return itertools.chain.from_iterable( + [collect_urls(datasets.QMNIST, ROOT, what=what, download=True) for what in ("train", "test", "nist")] ) def moving_mnist(): - return collect_download_configs(lambda: datasets.MovingMNIST(ROOT, download=True), name="MovingMNIST") + return collect_urls(datasets.MovingMNIST, ROOT, download=True) def omniglot(): - return itertools.chain( - *[ - collect_download_configs( - lambda: datasets.Omniglot(ROOT, background=background, download=True), - name=f"Omniglot, {'background' if background else 'evaluation'}", - ) - for background in (True, False) - ] + return itertools.chain.from_iterable( + [collect_urls(datasets.Omniglot, ROOT, background=background, download=True) for background in (True, False)] ) def phototour(): - return itertools.chain( - *[ - collect_download_configs( - lambda: datasets.PhotoTour(ROOT, name=name, download=True), - name=f"PhotoTour, {name}", - file="phototour", - ) + return itertools.chain.from_iterable( + [ + collect_urls(datasets.PhotoTour, ROOT, name=name, download=True) # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all # requests timeout from within CI. They are disabled until this is resolved. for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris" @@ -328,91 +270,51 @@ def phototour(): def sbdataset(): - return collect_download_configs( - lambda: datasets.SBDataset(ROOT, download=True), - name="SBDataset", - file="voc", - ) + return collect_urls(datasets.SBDataset, ROOT, download=True) def sbu(): - return collect_download_configs( - lambda: datasets.SBU(ROOT, download=True), - name="SBU", - file="sbu", - ) + return collect_urls(datasets.SBU, ROOT, download=True) def semeion(): - return collect_download_configs( - lambda: datasets.SEMEION(ROOT, download=True), - name="SEMEION", - file="semeion", - ) + return collect_urls(datasets.SEMEION, ROOT, download=True) def stl10(): - return collect_download_configs( - lambda: datasets.STL10(ROOT, download=True), - name="STL10", - ) + return collect_urls(datasets.STL10, ROOT, download=True) def svhn(): - return itertools.chain( - *[ - collect_download_configs( - lambda: datasets.SVHN(ROOT, split=split, download=True), - name=f"SVHN, {split}", - file="svhn", - ) - for split in ("train", "test", "extra") - ] + return itertools.chain.from_iterable( + [collect_urls(datasets.SVHN, ROOT, split=split, download=True) for split in ("train", "test", "extra")] ) def usps(): - return itertools.chain( - *[ - collect_download_configs( - lambda: datasets.USPS(ROOT, train=train, download=True), - name=f"USPS, {'train' if train else 'test'}", - file="usps", - ) - for train in (True, False) - ] + return itertools.chain.from_iterable( + [collect_urls(datasets.USPS, ROOT, train=train, download=True) for train in (True, False)] ) def celeba(): - return collect_download_configs( - lambda: datasets.CelebA(ROOT, download=True), - name="CelebA", - file="celeba", - ) + return collect_urls(datasets.CelebA, ROOT, download=True) def widerface(): - return collect_download_configs( - lambda: datasets.WIDERFace(ROOT, download=True), - name="WIDERFace", - file="widerface", - ) + return collect_urls(datasets.WIDERFace, ROOT, download=True) def kinetics(): - return itertools.chain( - *[ - collect_download_configs( - lambda: datasets.Kinetics( - path.join(ROOT, f"Kinetics{num_classes}"), - frames_per_clip=1, - num_classes=num_classes, - split=split, - download=True, - ), - name=f"Kinetics, {num_classes}, {split}", - file="kinetics", + return itertools.chain.from_iterable( + [ + collect_urls( + datasets.Kinetics, + path.join(ROOT, f"Kinetics{num_classes}"), + frames_per_clip=1, + num_classes=num_classes, + split=split, + download=True, ) for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val")) ] @@ -420,58 +322,55 @@ def kinetics(): def kitti(): - return itertools.chain( - *[ - collect_download_configs( - lambda train=train: datasets.Kitti(ROOT, train=train, download=True), - name=f"Kitti, {'train' if train else 'test'}", - file="kitti", - ) - for train in (True, False) - ] + return itertools.chain.from_iterable( + [collect_urls(datasets.Kitti, ROOT, train=train, download=True) for train in (True, False)] ) -def make_parametrize_kwargs(download_configs): - argvalues = [] - ids = [] - for config in download_configs: - argvalues.append((config.url, config.md5)) - ids.append(config.id) - - return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids) - - -@pytest.mark.parametrize( - **make_parametrize_kwargs( - itertools.chain( - caltech101(), - caltech256(), - cifar10(), - cifar100(), - # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details. - # voc(), - mnist(), - fashion_mnist(), - kmnist(), - emnist(), - qmnist(), - omniglot(), - phototour(), - sbdataset(), - semeion(), - stl10(), - svhn(), - usps(), - celeba(), - widerface(), - kinetics(), - kitti(), - places365(), - ) +def stanford_cars(): + return itertools.chain.from_iterable( + [collect_urls(datasets.StanfordCars, ROOT, split=split, download=True) for split in ["train", "test"]] + ) + + +def url_parametrization(*dataset_urls_and_ids_fns): + return pytest.mark.parametrize( + "url", + [ + pytest.param(url, id=id) + for dataset_urls_and_ids_fn in dataset_urls_and_ids_fns + for url, id in sorted(set(dataset_urls_and_ids_fn())) + ], ) + + +@url_parametrization( + caltech101, + caltech256, + cifar10, + cifar100, + # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details. + # voc, + mnist, + fashion_mnist, + kmnist, + emnist, + qmnist, + omniglot, + phototour, + sbdataset, + semeion, + stl10, + svhn, + usps, + celeba, + widerface, + kinetics, + kitti, + places365, + sbu, ) -def test_url_is_accessible(url, md5): +def test_url_is_accessible(url): """ If you see this test failing, find the offending dataset in the parametrization and move it to ``test_url_is_not_accessible`` and link an issue detailing the problem. @@ -479,15 +378,11 @@ def test_url_is_accessible(url, md5): retry(lambda: assert_url_is_accessible(url)) -@pytest.mark.parametrize( - **make_parametrize_kwargs( - itertools.chain( - sbu(), # https://github.com/pytorch/vision/issues/7005 - ) - ) +@url_parametrization( + stanford_cars, # https://github.com/pytorch/vision/issues/7545 ) @pytest.mark.xfail -def test_url_is_not_accessible(url, md5): +def test_url_is_not_accessible(url): """ As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they @@ -497,8 +392,3 @@ def test_url_is_not_accessible(url, md5): ``test_url_is_accessible``. """ retry(lambda: assert_url_is_accessible(url)) - - -@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain())) -def test_file_downloads_correctly(url, md5): - retry(lambda: assert_file_downloads_correctly(url, md5))