-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add download tests for remaining datasets #3338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
6ee89aa
kmnist
pmeier ecf18a1
emnist
pmeier 8dc1cae
qmnist
pmeier cc02fe9
omniglot
pmeier 8296c84
phototour
pmeier 3593f27
sbdataset
pmeier 761abb4
sbu
pmeier 16c6990
semeion
pmeier ef26a6a
stl10
pmeier 925c5b7
svhn
pmeier 80d9df2
usps
pmeier fe32dd0
cifar100
pmeier f666ef6
enable download logging for google drive
pmeier 44997b0
celeba
pmeier c8f54c4
widerface
pmeier 4bbacc3
lint
pmeier 9087060
Merge branch 'master' into download-tests-kmnist
datumbox 3bff918
add timeout logic
pmeier 95a1400
lint
pmeier 7bc3173
Merge branch 'master' into download-tests-kmnist
datumbox 69339b9
debug CI connection to problematic server
pmeier 8c8a055
Merge remote-tracking branch 'pmeier/download-tests-kmnist' into down…
pmeier 4482ec9
set timeout for ping
pmeier 0f0ed08
[ci skip] remove ping
pmeier 6fee41f
revert debugging
pmeier f6c9b1d
disable requests to problematic server
pmeier 8323abb
re-enable all other tests
pmeier File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,14 @@ | |
import unittest.mock | ||
from datetime import datetime | ||
from os import path | ||
from urllib.error import HTTPError | ||
from urllib.error import HTTPError, URLError | ||
from urllib.parse import urlparse | ||
from urllib.request import urlopen, Request | ||
|
||
import pytest | ||
|
||
from torchvision import datasets | ||
from torchvision.datasets.utils import download_url, check_integrity | ||
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive | ||
|
||
from common_utils import get_tmp_dir | ||
from fakedata_generation import places365_root | ||
|
@@ -48,35 +48,47 @@ def inner_wrapper(request, *args, **kwargs): | |
@contextlib.contextmanager | ||
def log_download_attempts( | ||
urls_and_md5s=None, | ||
file="utils", | ||
patch=True, | ||
download_url_location=".utils", | ||
patch_auxiliaries=None, | ||
mock_auxiliaries=None, | ||
): | ||
def add_mock(stack, name, file, **kwargs): | ||
try: | ||
return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs)) | ||
except AttributeError as error: | ||
if file != "utils": | ||
return add_mock(stack, name, "utils", **kwargs) | ||
else: | ||
raise pytest.UsageError from error | ||
|
||
if urls_and_md5s is None: | ||
urls_and_md5s = set() | ||
if download_url_location.startswith("."): | ||
download_url_location = f"torchvision.datasets{download_url_location}" | ||
if patch_auxiliaries is None: | ||
patch_auxiliaries = patch | ||
if mock_auxiliaries is None: | ||
mock_auxiliaries = patch | ||
|
||
with contextlib.ExitStack() as stack: | ||
download_url_mock = stack.enter_context( | ||
unittest.mock.patch( | ||
f"{download_url_location}.download_url", | ||
wraps=None if patch else download_url, | ||
) | ||
url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url) | ||
google_drive_mock = add_mock( | ||
stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive | ||
) | ||
if patch_auxiliaries: | ||
# download_and_extract_archive | ||
stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive")) | ||
|
||
if mock_auxiliaries: | ||
add_mock(stack, "extract_archive", file) | ||
|
||
try: | ||
yield urls_and_md5s | ||
finally: | ||
for args, kwargs in download_url_mock.call_args_list: | ||
for args, kwargs in url_mock.call_args_list: | ||
url = args[0] | ||
md5 = args[-1] if len(args) == 4 else kwargs.get("md5") | ||
urls_and_md5s.add((url, md5)) | ||
|
||
for args, kwargs in google_drive_mock.call_args_list: | ||
id = args[0] | ||
url = f"https://drive.google.com/file/d/{id}" | ||
md5 = args[3] if len(args) == 4 else kwargs.get("md5") | ||
urls_and_md5s.add((url, md5)) | ||
|
||
|
||
def retry(fn, times=1, wait=5.0): | ||
msgs = [] | ||
|
@@ -101,21 +113,23 @@ def retry(fn, times=1, wait=5.0): | |
def assert_server_response_ok(): | ||
try: | ||
yield | ||
except URLError as error: | ||
raise AssertionError("The request timed out.") from error | ||
except HTTPError as error: | ||
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error | ||
|
||
|
||
def assert_url_is_accessible(url): | ||
request = Request(url, headers=dict(method="HEAD")) | ||
with assert_server_response_ok(): | ||
urlopen(request) | ||
urlopen(request, timeout=5.0) | ||
|
||
|
||
def assert_file_downloads_correctly(url, md5): | ||
with get_tmp_dir() as root: | ||
file = path.join(root, path.basename(url)) | ||
with assert_server_response_ok(): | ||
with urlopen(url) as response, open(file, "wb") as fh: | ||
with urlopen(url, timeout=5.0) as response, open(file, "wb") as fh: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||
fh.write(response.read()) | ||
|
||
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" | ||
|
@@ -175,7 +189,7 @@ def cifar10(): | |
|
||
|
||
def cifar100(): | ||
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100") | ||
return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100") | ||
|
||
|
||
def voc(): | ||
|
@@ -184,7 +198,7 @@ def voc(): | |
collect_download_configs( | ||
lambda: datasets.VOCSegmentation(".", year=year, download=True), | ||
name=f"VOC, {year}", | ||
download_url_location=".voc", | ||
file="voc", | ||
) | ||
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012") | ||
] | ||
|
@@ -199,6 +213,128 @@ def fashion_mnist(): | |
return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST") | ||
|
||
|
||
def kmnist(): | ||
return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST") | ||
|
||
|
||
def emnist(): | ||
# the 'split' argument can be any valid one, since everything is downloaded anyway | ||
return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST") | ||
|
||
|
||
def qmnist(): | ||
return itertools.chain( | ||
*[ | ||
collect_download_configs( | ||
lambda: datasets.QMNIST(".", what=what, download=True), | ||
name=f"QMNIST, {what}", | ||
file="mnist", | ||
) | ||
for what in ("train", "test", "nist") | ||
] | ||
) | ||
|
||
|
||
def omniglot(): | ||
return itertools.chain( | ||
*[ | ||
collect_download_configs( | ||
lambda: datasets.Omniglot(".", background=background, download=True), | ||
name=f"Omniglot, {'background' if background else 'evaluation'}", | ||
) | ||
for background in (True, False) | ||
] | ||
) | ||
|
||
|
||
def phototour(): | ||
return itertools.chain( | ||
*[ | ||
collect_download_configs( | ||
lambda: datasets.PhotoTour(".", name=name, download=True), | ||
name=f"PhotoTour, {name}", | ||
file="phototour", | ||
) | ||
# The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all | ||
# requests timeout from within CI. They are disabled until this is resolved. | ||
for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris" | ||
] | ||
) | ||
|
||
|
||
def sbdataset(): | ||
return collect_download_configs( | ||
lambda: datasets.SBDataset(".", download=True), | ||
name="SBDataset", | ||
file="voc", | ||
) | ||
|
||
|
||
def sbu(): | ||
return collect_download_configs( | ||
lambda: datasets.SBU(".", download=True), | ||
name="SBU", | ||
file="sbu", | ||
) | ||
|
||
|
||
def semeion(): | ||
return collect_download_configs( | ||
lambda: datasets.SEMEION(".", download=True), | ||
name="SEMEION", | ||
file="semeion", | ||
) | ||
|
||
|
||
def stl10(): | ||
return collect_download_configs( | ||
lambda: datasets.STL10(".", download=True), | ||
name="STL10", | ||
) | ||
|
||
|
||
def svhn(): | ||
return itertools.chain( | ||
*[ | ||
collect_download_configs( | ||
lambda: datasets.SVHN(".", split=split, download=True), | ||
name=f"SVHN, {split}", | ||
file="svhn", | ||
) | ||
for split in ("train", "test", "extra") | ||
] | ||
) | ||
|
||
|
||
def usps(): | ||
return itertools.chain( | ||
*[ | ||
collect_download_configs( | ||
lambda: datasets.USPS(".", train=train, download=True), | ||
name=f"USPS, {'train' if train else 'test'}", | ||
file="usps", | ||
) | ||
for train in (True, False) | ||
] | ||
) | ||
|
||
|
||
def celeba(): | ||
return collect_download_configs( | ||
lambda: datasets.CelebA(".", download=True), | ||
name="CelebA", | ||
file="celeba", | ||
) | ||
|
||
|
||
def widerface(): | ||
return collect_download_configs( | ||
lambda: datasets.WIDERFace(".", download=True), | ||
name="WIDERFace", | ||
file="widerface", | ||
) | ||
|
||
|
||
def make_parametrize_kwargs(download_configs): | ||
argvalues = [] | ||
ids = [] | ||
|
@@ -221,6 +357,19 @@ def make_parametrize_kwargs(download_configs): | |
# voc(), | ||
mnist(), | ||
fashion_mnist(), | ||
kmnist(), | ||
emnist(), | ||
qmnist(), | ||
omniglot(), | ||
phototour(), | ||
sbdataset(), | ||
sbu(), | ||
semeion(), | ||
stl10(), | ||
svhn(), | ||
usps(), | ||
celeba(), | ||
widerface(), | ||
) | ||
) | ||
) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Replace hardcode numeric values with a constant.