diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index c927b553942..ac189f4f635 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -4,14 +4,14 @@ import unittest.mock from datetime import datetime from os import path -from urllib.error import HTTPError +from urllib.error import HTTPError, URLError from urllib.parse import urlparse from urllib.request import urlopen, Request import pytest from torchvision import datasets -from torchvision.datasets.utils import download_url, check_integrity +from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive from common_utils import get_tmp_dir from fakedata_generation import places365_root @@ -48,35 +48,47 @@ def inner_wrapper(request, *args, **kwargs): @contextlib.contextmanager def log_download_attempts( urls_and_md5s=None, + file="utils", patch=True, - download_url_location=".utils", - patch_auxiliaries=None, + mock_auxiliaries=None, ): + def add_mock(stack, name, file, **kwargs): + try: + return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs)) + except AttributeError as error: + if file != "utils": + return add_mock(stack, name, "utils", **kwargs) + else: + raise pytest.UsageError from error + if urls_and_md5s is None: urls_and_md5s = set() - if download_url_location.startswith("."): - download_url_location = f"torchvision.datasets{download_url_location}" - if patch_auxiliaries is None: - patch_auxiliaries = patch + if mock_auxiliaries is None: + mock_auxiliaries = patch with contextlib.ExitStack() as stack: - download_url_mock = stack.enter_context( - unittest.mock.patch( - f"{download_url_location}.download_url", - wraps=None if patch else download_url, - ) + url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url) + google_drive_mock = add_mock( + stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive ) - if patch_auxiliaries: - # download_and_extract_archive - stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive")) + + if mock_auxiliaries: + add_mock(stack, "extract_archive", file) + try: yield urls_and_md5s finally: - for args, kwargs in download_url_mock.call_args_list: + for args, kwargs in url_mock.call_args_list: url = args[0] md5 = args[-1] if len(args) == 4 else kwargs.get("md5") urls_and_md5s.add((url, md5)) + for args, kwargs in google_drive_mock.call_args_list: + id = args[0] + url = f"https://drive.google.com/file/d/{id}" + md5 = args[3] if len(args) == 4 else kwargs.get("md5") + urls_and_md5s.add((url, md5)) + def retry(fn, times=1, wait=5.0): msgs = [] @@ -101,6 +113,8 @@ def retry(fn, times=1, wait=5.0): def assert_server_response_ok(): try: yield + except URLError as error: + raise AssertionError("The request timed out.") from error except HTTPError as error: raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error @@ -108,14 +122,14 @@ def assert_server_response_ok(): def assert_url_is_accessible(url): request = Request(url, headers=dict(method="HEAD")) with assert_server_response_ok(): - urlopen(request) + urlopen(request, timeout=5.0) def assert_file_downloads_correctly(url, md5): with get_tmp_dir() as root: file = path.join(root, path.basename(url)) with assert_server_response_ok(): - with urlopen(url) as response, open(file, "wb") as fh: + with urlopen(url, timeout=5.0) as response, open(file, "wb") as fh: fh.write(response.read()) assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" @@ -175,7 +189,7 @@ def cifar10(): def cifar100(): - return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100") + return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100") def voc(): @@ -184,7 +198,7 @@ def voc(): collect_download_configs( lambda: datasets.VOCSegmentation(".", year=year, download=True), name=f"VOC, {year}", - download_url_location=".voc", + file="voc", ) for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012") ] @@ -199,6 +213,128 @@ def fashion_mnist(): return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST") +def kmnist(): + return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST") + + +def emnist(): + # the 'split' argument can be any valid one, since everything is downloaded anyway + return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST") + + +def qmnist(): + return itertools.chain( + *[ + collect_download_configs( + lambda: datasets.QMNIST(".", what=what, download=True), + name=f"QMNIST, {what}", + file="mnist", + ) + for what in ("train", "test", "nist") + ] + ) + + +def omniglot(): + return itertools.chain( + *[ + collect_download_configs( + lambda: datasets.Omniglot(".", background=background, download=True), + name=f"Omniglot, {'background' if background else 'evaluation'}", + ) + for background in (True, False) + ] + ) + + +def phototour(): + return itertools.chain( + *[ + collect_download_configs( + lambda: datasets.PhotoTour(".", name=name, download=True), + name=f"PhotoTour, {name}", + file="phototour", + ) + # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all + # requests timeout from within CI. They are disabled until this is resolved. + for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris" + ] + ) + + +def sbdataset(): + return collect_download_configs( + lambda: datasets.SBDataset(".", download=True), + name="SBDataset", + file="voc", + ) + + +def sbu(): + return collect_download_configs( + lambda: datasets.SBU(".", download=True), + name="SBU", + file="sbu", + ) + + +def semeion(): + return collect_download_configs( + lambda: datasets.SEMEION(".", download=True), + name="SEMEION", + file="semeion", + ) + + +def stl10(): + return collect_download_configs( + lambda: datasets.STL10(".", download=True), + name="STL10", + ) + + +def svhn(): + return itertools.chain( + *[ + collect_download_configs( + lambda: datasets.SVHN(".", split=split, download=True), + name=f"SVHN, {split}", + file="svhn", + ) + for split in ("train", "test", "extra") + ] + ) + + +def usps(): + return itertools.chain( + *[ + collect_download_configs( + lambda: datasets.USPS(".", train=train, download=True), + name=f"USPS, {'train' if train else 'test'}", + file="usps", + ) + for train in (True, False) + ] + ) + + +def celeba(): + return collect_download_configs( + lambda: datasets.CelebA(".", download=True), + name="CelebA", + file="celeba", + ) + + +def widerface(): + return collect_download_configs( + lambda: datasets.WIDERFace(".", download=True), + name="WIDERFace", + file="widerface", + ) + + def make_parametrize_kwargs(download_configs): argvalues = [] ids = [] @@ -221,6 +357,19 @@ def make_parametrize_kwargs(download_configs): # voc(), mnist(), fashion_mnist(), + kmnist(), + emnist(), + qmnist(), + omniglot(), + phototour(), + sbdataset(), + sbu(), + semeion(), + stl10(), + svhn(), + usps(), + celeba(), + widerface(), ) ) )