4
4
import unittest .mock
5
5
from datetime import datetime
6
6
from os import path
7
- from urllib .error import HTTPError
7
+ from urllib .error import HTTPError , URLError
8
8
from urllib .parse import urlparse
9
9
from urllib .request import urlopen , Request
10
10
11
11
import pytest
12
12
13
13
from torchvision import datasets
14
- from torchvision .datasets .utils import download_url , check_integrity
14
+ from torchvision .datasets .utils import download_url , check_integrity , download_file_from_google_drive
15
15
16
16
from common_utils import get_tmp_dir
17
17
from fakedata_generation import places365_root
@@ -48,35 +48,47 @@ def inner_wrapper(request, *args, **kwargs):
48
48
@contextlib .contextmanager
49
49
def log_download_attempts (
50
50
urls_and_md5s = None ,
51
+ file = "utils" ,
51
52
patch = True ,
52
- download_url_location = ".utils" ,
53
- patch_auxiliaries = None ,
53
+ mock_auxiliaries = None ,
54
54
):
55
+ def add_mock (stack , name , file , ** kwargs ):
56
+ try :
57
+ return stack .enter_context (unittest .mock .patch (f"torchvision.datasets.{ file } .{ name } " , ** kwargs ))
58
+ except AttributeError as error :
59
+ if file != "utils" :
60
+ return add_mock (stack , name , "utils" , ** kwargs )
61
+ else :
62
+ raise pytest .UsageError from error
63
+
55
64
if urls_and_md5s is None :
56
65
urls_and_md5s = set ()
57
- if download_url_location .startswith ("." ):
58
- download_url_location = f"torchvision.datasets{ download_url_location } "
59
- if patch_auxiliaries is None :
60
- patch_auxiliaries = patch
66
+ if mock_auxiliaries is None :
67
+ mock_auxiliaries = patch
61
68
62
69
with contextlib .ExitStack () as stack :
63
- download_url_mock = stack .enter_context (
64
- unittest .mock .patch (
65
- f"{ download_url_location } .download_url" ,
66
- wraps = None if patch else download_url ,
67
- )
70
+ url_mock = add_mock (stack , "download_url" , file , wraps = None if patch else download_url )
71
+ google_drive_mock = add_mock (
72
+ stack , "download_file_from_google_drive" , file , wraps = None if patch else download_file_from_google_drive
68
73
)
69
- if patch_auxiliaries :
70
- # download_and_extract_archive
71
- stack .enter_context (unittest .mock .patch ("torchvision.datasets.utils.extract_archive" ))
74
+
75
+ if mock_auxiliaries :
76
+ add_mock (stack , "extract_archive" , file )
77
+
72
78
try :
73
79
yield urls_and_md5s
74
80
finally :
75
- for args , kwargs in download_url_mock .call_args_list :
81
+ for args , kwargs in url_mock .call_args_list :
76
82
url = args [0 ]
77
83
md5 = args [- 1 ] if len (args ) == 4 else kwargs .get ("md5" )
78
84
urls_and_md5s .add ((url , md5 ))
79
85
86
+ for args , kwargs in google_drive_mock .call_args_list :
87
+ id = args [0 ]
88
+ url = f"https://drive.google.com/file/d/{ id } "
89
+ md5 = args [3 ] if len (args ) == 4 else kwargs .get ("md5" )
90
+ urls_and_md5s .add ((url , md5 ))
91
+
80
92
81
93
def retry (fn , times = 1 , wait = 5.0 ):
82
94
msgs = []
@@ -101,21 +113,23 @@ def retry(fn, times=1, wait=5.0):
101
113
def assert_server_response_ok ():
102
114
try :
103
115
yield
116
+ except URLError as error :
117
+ raise AssertionError ("The request timed out." ) from error
104
118
except HTTPError as error :
105
119
raise AssertionError (f"The server returned { error .code } : { error .reason } ." ) from error
106
120
107
121
108
122
def assert_url_is_accessible (url ):
109
123
request = Request (url , headers = dict (method = "HEAD" ))
110
124
with assert_server_response_ok ():
111
- urlopen (request )
125
+ urlopen (request , timeout = 5.0 )
112
126
113
127
114
128
def assert_file_downloads_correctly (url , md5 ):
115
129
with get_tmp_dir () as root :
116
130
file = path .join (root , path .basename (url ))
117
131
with assert_server_response_ok ():
118
- with urlopen (url ) as response , open (file , "wb" ) as fh :
132
+ with urlopen (url , timeout = 5.0 ) as response , open (file , "wb" ) as fh :
119
133
fh .write (response .read ())
120
134
121
135
assert check_integrity (file , md5 = md5 ), "The MD5 checksums mismatch"
@@ -175,7 +189,7 @@ def cifar10():
175
189
176
190
177
191
def cifar100 ():
178
- return collect_download_configs (lambda : datasets .CIFAR10 ("." , download = True ), name = "CIFAR100" )
192
+ return collect_download_configs (lambda : datasets .CIFAR100 ("." , download = True ), name = "CIFAR100" )
179
193
180
194
181
195
def voc ():
@@ -184,7 +198,7 @@ def voc():
184
198
collect_download_configs (
185
199
lambda : datasets .VOCSegmentation ("." , year = year , download = True ),
186
200
name = f"VOC, { year } " ,
187
- download_url_location = ". voc" ,
201
+ file = " voc" ,
188
202
)
189
203
for year in ("2007" , "2007-test" , "2008" , "2009" , "2010" , "2011" , "2012" )
190
204
]
@@ -199,6 +213,128 @@ def fashion_mnist():
199
213
return collect_download_configs (lambda : datasets .FashionMNIST ("." , download = True ), name = "FashionMNIST" )
200
214
201
215
216
+ def kmnist ():
217
+ return collect_download_configs (lambda : datasets .KMNIST ("." , download = True ), name = "KMNIST" )
218
+
219
+
220
+ def emnist ():
221
+ # the 'split' argument can be any valid one, since everything is downloaded anyway
222
+ return collect_download_configs (lambda : datasets .EMNIST ("." , split = "byclass" , download = True ), name = "EMNIST" )
223
+
224
+
225
+ def qmnist ():
226
+ return itertools .chain (
227
+ * [
228
+ collect_download_configs (
229
+ lambda : datasets .QMNIST ("." , what = what , download = True ),
230
+ name = f"QMNIST, { what } " ,
231
+ file = "mnist" ,
232
+ )
233
+ for what in ("train" , "test" , "nist" )
234
+ ]
235
+ )
236
+
237
+
238
+ def omniglot ():
239
+ return itertools .chain (
240
+ * [
241
+ collect_download_configs (
242
+ lambda : datasets .Omniglot ("." , background = background , download = True ),
243
+ name = f"Omniglot, { 'background' if background else 'evaluation' } " ,
244
+ )
245
+ for background in (True , False )
246
+ ]
247
+ )
248
+
249
+
250
+ def phototour ():
251
+ return itertools .chain (
252
+ * [
253
+ collect_download_configs (
254
+ lambda : datasets .PhotoTour ("." , name = name , download = True ),
255
+ name = f"PhotoTour, { name } " ,
256
+ file = "phototour" ,
257
+ )
258
+ # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
259
+ # requests timeout from within CI. They are disabled until this is resolved.
260
+ for name in ("notredame" , "yosemite" , "liberty" ) # "notredame_harris", "yosemite_harris", "liberty_harris"
261
+ ]
262
+ )
263
+
264
+
265
+ def sbdataset ():
266
+ return collect_download_configs (
267
+ lambda : datasets .SBDataset ("." , download = True ),
268
+ name = "SBDataset" ,
269
+ file = "voc" ,
270
+ )
271
+
272
+
273
+ def sbu ():
274
+ return collect_download_configs (
275
+ lambda : datasets .SBU ("." , download = True ),
276
+ name = "SBU" ,
277
+ file = "sbu" ,
278
+ )
279
+
280
+
281
+ def semeion ():
282
+ return collect_download_configs (
283
+ lambda : datasets .SEMEION ("." , download = True ),
284
+ name = "SEMEION" ,
285
+ file = "semeion" ,
286
+ )
287
+
288
+
289
+ def stl10 ():
290
+ return collect_download_configs (
291
+ lambda : datasets .STL10 ("." , download = True ),
292
+ name = "STL10" ,
293
+ )
294
+
295
+
296
+ def svhn ():
297
+ return itertools .chain (
298
+ * [
299
+ collect_download_configs (
300
+ lambda : datasets .SVHN ("." , split = split , download = True ),
301
+ name = f"SVHN, { split } " ,
302
+ file = "svhn" ,
303
+ )
304
+ for split in ("train" , "test" , "extra" )
305
+ ]
306
+ )
307
+
308
+
309
+ def usps ():
310
+ return itertools .chain (
311
+ * [
312
+ collect_download_configs (
313
+ lambda : datasets .USPS ("." , train = train , download = True ),
314
+ name = f"USPS, { 'train' if train else 'test' } " ,
315
+ file = "usps" ,
316
+ )
317
+ for train in (True , False )
318
+ ]
319
+ )
320
+
321
+
322
+ def celeba ():
323
+ return collect_download_configs (
324
+ lambda : datasets .CelebA ("." , download = True ),
325
+ name = "CelebA" ,
326
+ file = "celeba" ,
327
+ )
328
+
329
+
330
+ def widerface ():
331
+ return collect_download_configs (
332
+ lambda : datasets .WIDERFace ("." , download = True ),
333
+ name = "WIDERFace" ,
334
+ file = "widerface" ,
335
+ )
336
+
337
+
202
338
def make_parametrize_kwargs (download_configs ):
203
339
argvalues = []
204
340
ids = []
@@ -221,6 +357,19 @@ def make_parametrize_kwargs(download_configs):
221
357
# voc(),
222
358
mnist (),
223
359
fashion_mnist (),
360
+ kmnist (),
361
+ emnist (),
362
+ qmnist (),
363
+ omniglot (),
364
+ phototour (),
365
+ sbdataset (),
366
+ sbu (),
367
+ semeion (),
368
+ stl10 (),
369
+ svhn (),
370
+ usps (),
371
+ celeba (),
372
+ widerface (),
224
373
)
225
374
)
226
375
)
0 commit comments