Skip to content

Commit 66bc73a

Browse files
authored
Merge branch 'main' into bugfix/zoomout
2 parents 1f5cfb9 + 067dc30 commit 66bc73a

File tree

2 files changed

+101
-136
lines changed

2 files changed

+101
-136
lines changed

test/builtin_dataset_mocks.py

Lines changed: 55 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import collections.abc
2-
import contextlib
32
import csv
43
import functools
54
import gzip
@@ -9,8 +8,6 @@
98
import pathlib
109
import pickle
1110
import random
12-
import tempfile
13-
import unittest.mock
1411
import xml.etree.ElementTree as ET
1512
from collections import defaultdict, Counter
1613

@@ -21,15 +18,12 @@
2118
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
2219
from torch.nn.functional import one_hot
2320
from torch.testing import make_tensor as _make_tensor
24-
from torchvision.prototype import datasets
2521
from torchvision.prototype.datasets._api import find
2622
from torchvision.prototype.utils._internal import sequence_to_str
2723

2824
make_tensor = functools.partial(_make_tensor, device="cpu")
2925
make_scalar = functools.partial(make_tensor, ())
3026

31-
TEST_HOME = pathlib.Path(tempfile.mkdtemp())
32-
3327

3428
__all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"]
3529

@@ -40,76 +34,48 @@ def __init__(self, name, mock_data_fn):
4034
self.info = self.dataset.info
4135
self.name = self.info.name
4236

43-
self.root = TEST_HOME / self.dataset.name
4437
self.mock_data_fn = mock_data_fn
4538
self.configs = self.info._configs
46-
self._cache = {}
4739

48-
def _parse_mock_data(self, config, mock_infos):
49-
if mock_infos is None:
40+
def _parse_mock_info(self, mock_info):
41+
if mock_info is None:
5042
raise pytest.UsageError(
5143
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
5244
f"integer indicating the number of samples for the current `config`."
5345
)
54-
55-
key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
56-
if datasets.utils.DatasetConfig not in key_types:
57-
mock_infos = {config: mock_infos}
58-
elif len(key_types) > 1:
46+
elif isinstance(mock_info, int):
47+
mock_info = dict(num_samples=mock_info)
48+
elif not isinstance(mock_info, dict):
5949
raise pytest.UsageError(
60-
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
61-
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
50+
f"The mock data function for dataset '{self.name}' returned a {type(mock_info)}. The returned object "
51+
f"should be a dictionary containing at least the number of samples for the key `'num_samples'`. If no "
52+
f"additional information is required for specific tests, the number of samples can also be returned as "
53+
f"an integer."
54+
)
55+
elif "num_samples" not in mock_info:
56+
raise pytest.UsageError(
57+
f"The dictionary returned by the mock data function for dataset '{self.name}' has to contain a "
58+
f"`'num_samples'` entry indicating the number of samples."
6259
)
6360

64-
for config_, mock_info in mock_infos.items():
65-
if config_ in self._cache:
66-
raise pytest.UsageError(
67-
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
68-
f"already exists in the cache."
69-
)
70-
if isinstance(mock_info, int):
71-
mock_infos[config_] = dict(num_samples=mock_info)
72-
elif not isinstance(mock_info, dict):
73-
raise pytest.UsageError(
74-
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
75-
f"{config_}. The returned object should be a dictionary containing at least the number of "
76-
f"samples for the key `'num_samples'`. If no additional information is required for specific "
77-
f"tests, the number of samples can also be returned as an integer."
78-
)
79-
elif "num_samples" not in mock_info:
80-
raise pytest.UsageError(
81-
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
82-
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
83-
)
84-
85-
return mock_infos
86-
87-
def _prepare_resources(self, config):
88-
if config in self._cache:
89-
return self._cache[config]
90-
91-
self.root.mkdir(exist_ok=True)
92-
mock_infos = self._parse_mock_data(config, self.mock_data_fn(self.info, self.root, config))
61+
return mock_info
9362

94-
available_file_names = {path.name for path in self.root.glob("*")}
95-
for config_, mock_info in mock_infos.items():
96-
required_file_names = {resource.file_name for resource in self.dataset.resources(config_)}
97-
missing_file_names = required_file_names - available_file_names
98-
if missing_file_names:
99-
raise pytest.UsageError(
100-
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
101-
f"for {config_}, but they were not created by the mock data function."
102-
)
63+
def prepare(self, home, config):
64+
root = home / self.name
65+
root.mkdir(exist_ok=True)
10366

104-
self._cache[config_] = mock_info
67+
mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config))
10568

106-
return self._cache[config]
69+
available_file_names = {path.name for path in root.glob("*")}
70+
required_file_names = {resource.file_name for resource in self.dataset.resources(config)}
71+
missing_file_names = required_file_names - available_file_names
72+
if missing_file_names:
73+
raise pytest.UsageError(
74+
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
75+
f"for {config}, but they were not created by the mock data function."
76+
)
10777

108-
@contextlib.contextmanager
109-
def prepare(self, config):
110-
mock_info = self._prepare_resources(config)
111-
with unittest.mock.patch("torchvision.prototype.datasets._api.home", return_value=str(TEST_HOME)):
112-
yield mock_info
78+
return mock_info
11379

11480

11581
def config_id(name, config):
@@ -254,32 +220,30 @@ def mnist(info, root, config):
254220

255221

256222
@register_mock
257-
def emnist(info, root, _):
223+
def emnist(info, root, config):
258224
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
259225
# labels in the data files. Thus, num_categories != len(categories) there.
260226
num_categories = defaultdict(
261227
lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")}
262228
)
263229

264-
mock_infos = {}
230+
num_samples_map = {}
265231
file_names = set()
266-
for config in info._configs:
267-
prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}"
232+
for config_ in info._configs:
233+
prefix = f"emnist-{config_.image_set.replace('_', '').lower()}-{config_.split}"
268234
images_file = f"{prefix}-images-idx3-ubyte.gz"
269235
labels_file = f"{prefix}-labels-idx1-ubyte.gz"
270236
file_names.update({images_file, labels_file})
271-
mock_infos[config] = dict(
272-
num_samples=MNISTMockData.generate(
273-
root,
274-
num_categories=num_categories[config.image_set],
275-
images_file=images_file,
276-
labels_file=labels_file,
277-
)
237+
num_samples_map[config_] = MNISTMockData.generate(
238+
root,
239+
num_categories=num_categories[config_.image_set],
240+
images_file=images_file,
241+
labels_file=labels_file,
278242
)
279243

280244
make_zip(root, "emnist-gzip.zip", *file_names)
281245

282-
return mock_infos
246+
return num_samples_map[config]
283247

284248

285249
@register_mock
@@ -290,25 +254,23 @@ def qmnist(info, root, config):
290254
prefix = "qmnist-train"
291255
suffix = ".gz"
292256
compressor = gzip.open
293-
mock_infos = num_samples
294257
elif config.split.startswith("test"):
295258
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
296259
# more than 10000 images for the dataset to not be empty.
297260
num_samples_gen = 10001
261+
num_samples = {
262+
"test": num_samples_gen,
263+
"test10k": min(num_samples_gen, 10_000),
264+
"test50k": num_samples_gen - 10_000,
265+
}[config.split]
298266
prefix = "qmnist-test"
299267
suffix = ".gz"
300268
compressor = gzip.open
301-
mock_infos = {
302-
info.make_config(split="test"): num_samples_gen,
303-
info.make_config(split="test10k"): min(num_samples_gen, 10_000),
304-
info.make_config(split="test50k"): num_samples_gen - 10_000,
305-
}
306269
else: # config.split == "nist"
307270
num_samples = num_samples_gen = num_categories + 3
308271
prefix = "xnist"
309272
suffix = ".xz"
310273
compressor = lzma.open
311-
mock_infos = num_samples
312274

313275
MNISTMockData.generate(
314276
root,
@@ -320,7 +282,7 @@ def qmnist(info, root, config):
320282
label_dtype=torch.int32,
321283
compressor=compressor,
322284
)
323-
return mock_infos
285+
return num_samples
324286

325287

326288
class CIFARMockData:
@@ -624,12 +586,7 @@ def generate(
624586

625587
@register_mock
626588
def coco(info, root, config):
627-
return dict(
628-
zip(
629-
[config_ for config_ in info._configs if config_.year == config.year],
630-
itertools.repeat(CocoMockData.generate(root, year=config.year, num_samples=5)),
631-
)
632-
)
589+
return CocoMockData.generate(root, year=config.year, num_samples=5)
633590

634591

635592
class SBDMockData:
@@ -702,9 +659,8 @@ def generate(cls, root):
702659

703660

704661
@register_mock
705-
def sbd(info, root, _):
706-
num_samples_map = SBDMockData.generate(root)
707-
return {config: num_samples_map[config.split] for config in info._configs}
662+
def sbd(info, root, config):
663+
return SBDMockData.generate(root)[config.split]
708664

709665

710666
@register_mock
@@ -821,12 +777,7 @@ def generate(cls, root, *, year, trainval):
821777
@register_mock
822778
def voc(info, root, config):
823779
trainval = config.split != "test"
824-
num_samples_map = VOCMockData.generate(root, year=config.year, trainval=trainval)
825-
return {
826-
config_: num_samples_map[config_.split]
827-
for config_ in info._configs
828-
if config_.year == config.year and ((config_.split == "test") ^ trainval)
829-
}
780+
return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split]
830781

831782

832783
class CelebAMockData:
@@ -918,13 +869,12 @@ def generate(cls, root):
918869

919870

920871
@register_mock
921-
def celeba(info, root, _):
922-
num_samples_map = CelebAMockData.generate(root)
923-
return {config: num_samples_map[config.split] for config in info._configs}
872+
def celeba(info, root, config):
873+
return CelebAMockData.generate(root)[config.split]
924874

925875

926876
@register_mock
927-
def dtd(info, root, _):
877+
def dtd(info, root, config):
928878
data_folder = root / "dtd"
929879

930880
num_images_per_class = 3
@@ -968,7 +918,7 @@ def dtd(info, root, _):
968918

969919
make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz")
970920

971-
return num_samples_map
921+
return num_samples_map[config]
972922

973923

974924
@register_mock
@@ -1108,7 +1058,7 @@ def clevr(info, root, config):
11081058

11091059
make_zip(root, f"{data_folder.name}.zip", data_folder)
11101060

1111-
return {config_: num_samples_map[config_.split] for config_ in info._configs}
1061+
return num_samples_map[config.split]
11121062

11131063

11141064
class OxfordIIITPetMockData:
@@ -1174,8 +1124,7 @@ def generate(self, root):
11741124

11751125
@register_mock
11761126
def oxford_iiit_pet(info, root, config):
1177-
num_samples_map = OxfordIIITPetMockData.generate(root)
1178-
return {config_: num_samples_map[config_.split] for config_ in info._configs}
1127+
return OxfordIIITPetMockData.generate(root)[config.split]
11791128

11801129

11811130
class _CUB200MockData:
@@ -1342,7 +1291,7 @@ def generate(cls, root):
13421291
@register_mock
13431292
def cub200(info, root, config):
13441293
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
1345-
return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year}
1294+
return num_samples_map[config.split]
13461295

13471296

13481297
@register_mock

0 commit comments

Comments
 (0)