Skip to content

Add download tests for remaining datasets #3338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Feb 3, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 170 additions & 21 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import unittest.mock
from datetime import datetime
from os import path
from urllib.error import HTTPError
from urllib.error import HTTPError, URLError
from urllib.parse import urlparse
from urllib.request import urlopen, Request

import pytest

from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive

from common_utils import get_tmp_dir
from fakedata_generation import places365_root
Expand Down Expand Up @@ -48,35 +48,47 @@ def inner_wrapper(request, *args, **kwargs):
@contextlib.contextmanager
def log_download_attempts(
urls_and_md5s=None,
file="utils",
patch=True,
download_url_location=".utils",
patch_auxiliaries=None,
mock_auxiliaries=None,
):
def add_mock(stack, name, file, **kwargs):
try:
return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
except AttributeError as error:
if file != "utils":
return add_mock(stack, name, "utils", **kwargs)
else:
raise pytest.UsageError from error

if urls_and_md5s is None:
urls_and_md5s = set()
if download_url_location.startswith("."):
download_url_location = f"torchvision.datasets{download_url_location}"
if patch_auxiliaries is None:
patch_auxiliaries = patch
if mock_auxiliaries is None:
mock_auxiliaries = patch

with contextlib.ExitStack() as stack:
download_url_mock = stack.enter_context(
unittest.mock.patch(
f"{download_url_location}.download_url",
wraps=None if patch else download_url,
)
url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
google_drive_mock = add_mock(
stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
)
if patch_auxiliaries:
# download_and_extract_archive
stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive"))

if mock_auxiliaries:
add_mock(stack, "extract_archive", file)

try:
yield urls_and_md5s
finally:
for args, kwargs in download_url_mock.call_args_list:
for args, kwargs in url_mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))

for args, kwargs in google_drive_mock.call_args_list:
id = args[0]
url = f"https://drive.google.com/file/d/{id}"
md5 = args[3] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))


def retry(fn, times=1, wait=5.0):
msgs = []
Expand All @@ -101,21 +113,23 @@ def retry(fn, times=1, wait=5.0):
def assert_server_response_ok():
try:
yield
except URLError as error:
raise AssertionError("The request timed out.") from error
except HTTPError as error:
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error


def assert_url_is_accessible(url):
request = Request(url, headers=dict(method="HEAD"))
with assert_server_response_ok():
urlopen(request)
urlopen(request, timeout=5.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Replace hardcode numeric values with a constant.



def assert_file_downloads_correctly(url, md5):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with assert_server_response_ok():
with urlopen(url) as response, open(file, "wb") as fh:
with urlopen(url, timeout=5.0) as response, open(file, "wb") as fh:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

fh.write(response.read())

assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
Expand Down Expand Up @@ -175,7 +189,7 @@ def cifar10():


def cifar100():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100")
return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100")


def voc():
Expand All @@ -184,7 +198,7 @@ def voc():
collect_download_configs(
lambda: datasets.VOCSegmentation(".", year=year, download=True),
name=f"VOC, {year}",
download_url_location=".voc",
file="voc",
)
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
]
Expand All @@ -199,6 +213,128 @@ def fashion_mnist():
return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST")


def kmnist():
return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST")


def emnist():
# the 'split' argument can be any valid one, since everything is downloaded anyway
return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST")


def qmnist():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.QMNIST(".", what=what, download=True),
name=f"QMNIST, {what}",
file="mnist",
)
for what in ("train", "test", "nist")
]
)


def omniglot():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Omniglot(".", background=background, download=True),
name=f"Omniglot, {'background' if background else 'evaluation'}",
)
for background in (True, False)
]
)


def phototour():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.PhotoTour(".", name=name, download=True),
name=f"PhotoTour, {name}",
file="phototour",
)
# The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
# requests timeout from within CI. They are disabled until this is resolved.
for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris"
]
)


def sbdataset():
return collect_download_configs(
lambda: datasets.SBDataset(".", download=True),
name="SBDataset",
file="voc",
)


def sbu():
return collect_download_configs(
lambda: datasets.SBU(".", download=True),
name="SBU",
file="sbu",
)


def semeion():
return collect_download_configs(
lambda: datasets.SEMEION(".", download=True),
name="SEMEION",
file="semeion",
)


def stl10():
return collect_download_configs(
lambda: datasets.STL10(".", download=True),
name="STL10",
)


def svhn():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.SVHN(".", split=split, download=True),
name=f"SVHN, {split}",
file="svhn",
)
for split in ("train", "test", "extra")
]
)


def usps():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.USPS(".", train=train, download=True),
name=f"USPS, {'train' if train else 'test'}",
file="usps",
)
for train in (True, False)
]
)


def celeba():
return collect_download_configs(
lambda: datasets.CelebA(".", download=True),
name="CelebA",
file="celeba",
)


def widerface():
return collect_download_configs(
lambda: datasets.WIDERFace(".", download=True),
name="WIDERFace",
file="widerface",
)


def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
Expand All @@ -221,6 +357,19 @@ def make_parametrize_kwargs(download_configs):
# voc(),
mnist(),
fashion_mnist(),
kmnist(),
emnist(),
qmnist(),
omniglot(),
phototour(),
sbdataset(),
sbu(),
semeion(),
stl10(),
svhn(),
usps(),
celeba(),
widerface(),
)
)
)
Expand Down