Skip to content

Commit d5bc2b2

Browse files
pmeiervfdev-5
andauthored
Add download tests for MNIST (#3336)
* cleanup * mnist * lint Co-authored-by: vfdev <[email protected]>
1 parent d5096a7 commit d5bc2b2

File tree

2 files changed

+41
-18
lines changed

2 files changed

+41
-18
lines changed

.github/workflows/tests-schedule.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ jobs:
3131
pip install numpy
3232
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
3333
34+
- name: Install all optional dataset requirements
35+
run: pip install scipy pandas pycocotools lmdb requests
36+
3437
- name: Install tests requirements
3538
run: pip install pytest
3639

test/test_datasets_download.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,22 @@ def inner_wrapper(request, *args, **kwargs):
4949
def log_download_attempts(
5050
urls_and_md5s=None,
5151
patch=True,
52-
download_url_target="torchvision.datasets.utils.download_url",
52+
download_url_location=".utils",
5353
patch_auxiliaries=None,
5454
):
5555
if urls_and_md5s is None:
5656
urls_and_md5s = set()
57+
if download_url_location.startswith("."):
58+
download_url_location = f"torchvision.datasets{download_url_location}"
5759
if patch_auxiliaries is None:
5860
patch_auxiliaries = patch
5961

6062
with contextlib.ExitStack() as stack:
6163
download_url_mock = stack.enter_context(
62-
unittest.mock.patch(download_url_target, wraps=None if patch else download_url)
64+
unittest.mock.patch(
65+
f"{download_url_location}.download_url",
66+
wraps=None if patch else download_url,
67+
)
6368
)
6469
if patch_auxiliaries:
6570
# download_and_extract_archive
@@ -132,9 +137,17 @@ def make_download_configs(urls_and_md5s, name=None):
132137
]
133138

134139

135-
def collect_download_configs(dataset_loader, name, **kwargs):
136-
with contextlib.suppress(Exception), log_download_attempts(**kwargs) as urls_and_md5s:
137-
dataset_loader()
140+
def collect_download_configs(dataset_loader, name=None, **kwargs):
141+
urls_and_md5s = set()
142+
try:
143+
with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
144+
dataset = dataset_loader()
145+
except Exception:
146+
dataset = None
147+
148+
if name is None and dataset is not None:
149+
name = type(dataset).__name__
150+
138151
return make_download_configs(urls_and_md5s, name)
139152

140153

@@ -146,34 +159,40 @@ def places365():
146159

147160
datasets.Places365(root, split=split, small=small, download=True)
148161

149-
return make_download_configs(urls_and_md5s, "Places365")
162+
return make_download_configs(urls_and_md5s, name="Places365")
150163

151164

152165
def caltech101():
153-
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), "Caltech101")
166+
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101")
154167

155168

156169
def caltech256():
157-
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), "Caltech256")
170+
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256")
158171

159172

160173
def cifar10():
161-
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR10")
174+
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10")
162175

163176

164177
def cifar100():
165-
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100")
178+
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100")
166179

167180

168181
def voc():
169-
download_configs = []
170-
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012"):
171-
with contextlib.suppress(Exception), log_download_attempts(
172-
download_url_target="torchvision.datasets.voc.download_url"
173-
) as urls_and_md5s:
174-
datasets.VOCSegmentation(".", year=year, download=True)
175-
download_configs.extend(make_download_configs(urls_and_md5s, f"VOC, {year}"))
176-
return download_configs
182+
return itertools.chain(
183+
*[
184+
collect_download_configs(
185+
lambda: datasets.VOCSegmentation(".", year=year, download=True),
186+
name=f"VOC, {year}",
187+
download_url_location=".voc",
188+
)
189+
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
190+
]
191+
)
192+
193+
194+
def mnist():
195+
return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST")
177196

178197

179198
def make_parametrize_kwargs(download_configs):
@@ -196,6 +215,7 @@ def make_parametrize_kwargs(download_configs):
196215
cifar100(),
197216
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
198217
# voc(),
218+
mnist(),
199219
)
200220
)
201221
)

0 commit comments

Comments
 (0)