Skip to content

Commit c5d6f1f

Browse files
pmeierdatumbox
andauthored
Add download tests for remaining datasets (#3338)
* kmnist * emnist * qmnist * omniglot * phototour * sbdataset * sbu * semeion * stl10 * svhn * usps * cifar100 * enable download logging for google drive * celeba * widerface * lint * add timeout logic * lint * debug CI connection to problematic server * set timeout for ping * [ci skip] remove ping * revert debugging * disable requests to problematic server * re-enable all other tests Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 8317295 commit c5d6f1f

File tree

1 file changed

+170
-21
lines changed

1 file changed

+170
-21
lines changed

test/test_datasets_download.py

Lines changed: 170 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import unittest.mock
55
from datetime import datetime
66
from os import path
7-
from urllib.error import HTTPError
7+
from urllib.error import HTTPError, URLError
88
from urllib.parse import urlparse
99
from urllib.request import urlopen, Request
1010

1111
import pytest
1212

1313
from torchvision import datasets
14-
from torchvision.datasets.utils import download_url, check_integrity
14+
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive
1515

1616
from common_utils import get_tmp_dir
1717
from fakedata_generation import places365_root
@@ -48,35 +48,47 @@ def inner_wrapper(request, *args, **kwargs):
4848
@contextlib.contextmanager
4949
def log_download_attempts(
5050
urls_and_md5s=None,
51+
file="utils",
5152
patch=True,
52-
download_url_location=".utils",
53-
patch_auxiliaries=None,
53+
mock_auxiliaries=None,
5454
):
55+
def add_mock(stack, name, file, **kwargs):
56+
try:
57+
return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
58+
except AttributeError as error:
59+
if file != "utils":
60+
return add_mock(stack, name, "utils", **kwargs)
61+
else:
62+
raise pytest.UsageError from error
63+
5564
if urls_and_md5s is None:
5665
urls_and_md5s = set()
57-
if download_url_location.startswith("."):
58-
download_url_location = f"torchvision.datasets{download_url_location}"
59-
if patch_auxiliaries is None:
60-
patch_auxiliaries = patch
66+
if mock_auxiliaries is None:
67+
mock_auxiliaries = patch
6168

6269
with contextlib.ExitStack() as stack:
63-
download_url_mock = stack.enter_context(
64-
unittest.mock.patch(
65-
f"{download_url_location}.download_url",
66-
wraps=None if patch else download_url,
67-
)
70+
url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
71+
google_drive_mock = add_mock(
72+
stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
6873
)
69-
if patch_auxiliaries:
70-
# download_and_extract_archive
71-
stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive"))
74+
75+
if mock_auxiliaries:
76+
add_mock(stack, "extract_archive", file)
77+
7278
try:
7379
yield urls_and_md5s
7480
finally:
75-
for args, kwargs in download_url_mock.call_args_list:
81+
for args, kwargs in url_mock.call_args_list:
7682
url = args[0]
7783
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
7884
urls_and_md5s.add((url, md5))
7985

86+
for args, kwargs in google_drive_mock.call_args_list:
87+
id = args[0]
88+
url = f"https://drive.google.com/file/d/{id}"
89+
md5 = args[3] if len(args) == 4 else kwargs.get("md5")
90+
urls_and_md5s.add((url, md5))
91+
8092

8193
def retry(fn, times=1, wait=5.0):
8294
msgs = []
@@ -101,21 +113,23 @@ def retry(fn, times=1, wait=5.0):
101113
def assert_server_response_ok():
102114
try:
103115
yield
116+
except URLError as error:
117+
raise AssertionError("The request timed out.") from error
104118
except HTTPError as error:
105119
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
106120

107121

108122
def assert_url_is_accessible(url):
109123
request = Request(url, headers=dict(method="HEAD"))
110124
with assert_server_response_ok():
111-
urlopen(request)
125+
urlopen(request, timeout=5.0)
112126

113127

114128
def assert_file_downloads_correctly(url, md5):
115129
with get_tmp_dir() as root:
116130
file = path.join(root, path.basename(url))
117131
with assert_server_response_ok():
118-
with urlopen(url) as response, open(file, "wb") as fh:
132+
with urlopen(url, timeout=5.0) as response, open(file, "wb") as fh:
119133
fh.write(response.read())
120134

121135
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
@@ -175,7 +189,7 @@ def cifar10():
175189

176190

177191
def cifar100():
178-
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100")
192+
return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100")
179193

180194

181195
def voc():
@@ -184,7 +198,7 @@ def voc():
184198
collect_download_configs(
185199
lambda: datasets.VOCSegmentation(".", year=year, download=True),
186200
name=f"VOC, {year}",
187-
download_url_location=".voc",
201+
file="voc",
188202
)
189203
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
190204
]
@@ -199,6 +213,128 @@ def fashion_mnist():
199213
return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST")
200214

201215

216+
def kmnist():
217+
return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST")
218+
219+
220+
def emnist():
221+
# the 'split' argument can be any valid one, since everything is downloaded anyway
222+
return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST")
223+
224+
225+
def qmnist():
226+
return itertools.chain(
227+
*[
228+
collect_download_configs(
229+
lambda: datasets.QMNIST(".", what=what, download=True),
230+
name=f"QMNIST, {what}",
231+
file="mnist",
232+
)
233+
for what in ("train", "test", "nist")
234+
]
235+
)
236+
237+
238+
def omniglot():
239+
return itertools.chain(
240+
*[
241+
collect_download_configs(
242+
lambda: datasets.Omniglot(".", background=background, download=True),
243+
name=f"Omniglot, {'background' if background else 'evaluation'}",
244+
)
245+
for background in (True, False)
246+
]
247+
)
248+
249+
250+
def phototour():
251+
return itertools.chain(
252+
*[
253+
collect_download_configs(
254+
lambda: datasets.PhotoTour(".", name=name, download=True),
255+
name=f"PhotoTour, {name}",
256+
file="phototour",
257+
)
258+
# The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
259+
# requests timeout from within CI. They are disabled until this is resolved.
260+
for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris"
261+
]
262+
)
263+
264+
265+
def sbdataset():
266+
return collect_download_configs(
267+
lambda: datasets.SBDataset(".", download=True),
268+
name="SBDataset",
269+
file="voc",
270+
)
271+
272+
273+
def sbu():
274+
return collect_download_configs(
275+
lambda: datasets.SBU(".", download=True),
276+
name="SBU",
277+
file="sbu",
278+
)
279+
280+
281+
def semeion():
282+
return collect_download_configs(
283+
lambda: datasets.SEMEION(".", download=True),
284+
name="SEMEION",
285+
file="semeion",
286+
)
287+
288+
289+
def stl10():
290+
return collect_download_configs(
291+
lambda: datasets.STL10(".", download=True),
292+
name="STL10",
293+
)
294+
295+
296+
def svhn():
297+
return itertools.chain(
298+
*[
299+
collect_download_configs(
300+
lambda: datasets.SVHN(".", split=split, download=True),
301+
name=f"SVHN, {split}",
302+
file="svhn",
303+
)
304+
for split in ("train", "test", "extra")
305+
]
306+
)
307+
308+
309+
def usps():
310+
return itertools.chain(
311+
*[
312+
collect_download_configs(
313+
lambda: datasets.USPS(".", train=train, download=True),
314+
name=f"USPS, {'train' if train else 'test'}",
315+
file="usps",
316+
)
317+
for train in (True, False)
318+
]
319+
)
320+
321+
322+
def celeba():
323+
return collect_download_configs(
324+
lambda: datasets.CelebA(".", download=True),
325+
name="CelebA",
326+
file="celeba",
327+
)
328+
329+
330+
def widerface():
331+
return collect_download_configs(
332+
lambda: datasets.WIDERFace(".", download=True),
333+
name="WIDERFace",
334+
file="widerface",
335+
)
336+
337+
202338
def make_parametrize_kwargs(download_configs):
203339
argvalues = []
204340
ids = []
@@ -221,6 +357,19 @@ def make_parametrize_kwargs(download_configs):
221357
# voc(),
222358
mnist(),
223359
fashion_mnist(),
360+
kmnist(),
361+
emnist(),
362+
qmnist(),
363+
omniglot(),
364+
phototour(),
365+
sbdataset(),
366+
sbu(),
367+
semeion(),
368+
stl10(),
369+
svhn(),
370+
usps(),
371+
celeba(),
372+
widerface(),
224373
)
225374
)
226375
)

0 commit comments

Comments
 (0)