@@ -49,17 +49,22 @@ def inner_wrapper(request, *args, **kwargs):
49
49
def log_download_attempts (
50
50
urls_and_md5s = None ,
51
51
patch = True ,
52
- download_url_target = "torchvision.datasets. utils.download_url " ,
52
+ download_url_location = ". utils" ,
53
53
patch_auxiliaries = None ,
54
54
):
55
55
if urls_and_md5s is None :
56
56
urls_and_md5s = set ()
57
+ if download_url_location .startswith ("." ):
58
+ download_url_location = f"torchvision.datasets{ download_url_location } "
57
59
if patch_auxiliaries is None :
58
60
patch_auxiliaries = patch
59
61
60
62
with contextlib .ExitStack () as stack :
61
63
download_url_mock = stack .enter_context (
62
- unittest .mock .patch (download_url_target , wraps = None if patch else download_url )
64
+ unittest .mock .patch (
65
+ f"{ download_url_location } .download_url" ,
66
+ wraps = None if patch else download_url ,
67
+ )
63
68
)
64
69
if patch_auxiliaries :
65
70
# download_and_extract_archive
@@ -132,9 +137,17 @@ def make_download_configs(urls_and_md5s, name=None):
132
137
]
133
138
134
139
135
- def collect_download_configs (dataset_loader , name , ** kwargs ):
136
- with contextlib .suppress (Exception ), log_download_attempts (** kwargs ) as urls_and_md5s :
137
- dataset_loader ()
140
+ def collect_download_configs (dataset_loader , name = None , ** kwargs ):
141
+ urls_and_md5s = set ()
142
+ try :
143
+ with log_download_attempts (urls_and_md5s = urls_and_md5s , ** kwargs ):
144
+ dataset = dataset_loader ()
145
+ except Exception :
146
+ dataset = None
147
+
148
+ if name is None and dataset is not None :
149
+ name = type (dataset ).__name__
150
+
138
151
return make_download_configs (urls_and_md5s , name )
139
152
140
153
@@ -146,34 +159,40 @@ def places365():
146
159
147
160
datasets .Places365 (root , split = split , small = small , download = True )
148
161
149
- return make_download_configs (urls_and_md5s , "Places365" )
162
+ return make_download_configs (urls_and_md5s , name = "Places365" )
150
163
151
164
152
165
def caltech101 ():
153
- return collect_download_configs (lambda : datasets .Caltech101 ("." , download = True ), "Caltech101" )
166
+ return collect_download_configs (lambda : datasets .Caltech101 ("." , download = True ), name = "Caltech101" )
154
167
155
168
156
169
def caltech256 ():
157
- return collect_download_configs (lambda : datasets .Caltech256 ("." , download = True ), "Caltech256" )
170
+ return collect_download_configs (lambda : datasets .Caltech256 ("." , download = True ), name = "Caltech256" )
158
171
159
172
160
173
def cifar10 ():
161
- return collect_download_configs (lambda : datasets .CIFAR10 ("." , download = True ), "CIFAR10" )
174
+ return collect_download_configs (lambda : datasets .CIFAR10 ("." , download = True ), name = "CIFAR10" )
162
175
163
176
164
177
def cifar100 ():
165
- return collect_download_configs (lambda : datasets .CIFAR10 ("." , download = True ), "CIFAR100" )
178
+ return collect_download_configs (lambda : datasets .CIFAR10 ("." , download = True ), name = "CIFAR100" )
166
179
167
180
168
181
def voc ():
169
- download_configs = []
170
- for year in ("2007" , "2007-test" , "2008" , "2009" , "2010" , "2011" , "2012" ):
171
- with contextlib .suppress (Exception ), log_download_attempts (
172
- download_url_target = "torchvision.datasets.voc.download_url"
173
- ) as urls_and_md5s :
174
- datasets .VOCSegmentation ("." , year = year , download = True )
175
- download_configs .extend (make_download_configs (urls_and_md5s , f"VOC, { year } " ))
176
- return download_configs
182
+ return itertools .chain (
183
+ * [
184
+ collect_download_configs (
185
+ lambda : datasets .VOCSegmentation ("." , year = year , download = True ),
186
+ name = f"VOC, { year } " ,
187
+ download_url_location = ".voc" ,
188
+ )
189
+ for year in ("2007" , "2007-test" , "2008" , "2009" , "2010" , "2011" , "2012" )
190
+ ]
191
+ )
192
+
193
+
194
+ def mnist ():
195
+ return collect_download_configs (lambda : datasets .MNIST ("." , download = True ), name = "MNIST" )
177
196
178
197
179
198
def make_parametrize_kwargs (download_configs ):
@@ -196,6 +215,7 @@ def make_parametrize_kwargs(download_configs):
196
215
cifar100 (),
197
216
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
198
217
# voc(),
218
+ mnist (),
199
219
)
200
220
)
201
221
)
0 commit comments