1
1
import collections .abc
2
- import contextlib
3
2
import csv
4
3
import functools
5
4
import gzip
9
8
import pathlib
10
9
import pickle
11
10
import random
12
- import tempfile
13
- import unittest .mock
14
11
import xml .etree .ElementTree as ET
15
12
from collections import defaultdict , Counter
16
13
21
18
from datasets_utils import make_zip , make_tar , create_image_folder , create_image_file
22
19
from torch .nn .functional import one_hot
23
20
from torch .testing import make_tensor as _make_tensor
24
- from torchvision .prototype import datasets
25
21
from torchvision .prototype .datasets ._api import find
26
22
from torchvision .prototype .utils ._internal import sequence_to_str
27
23
28
24
make_tensor = functools .partial (_make_tensor , device = "cpu" )
29
25
make_scalar = functools .partial (make_tensor , ())
30
26
31
- TEST_HOME = pathlib .Path (tempfile .mkdtemp ())
32
-
33
27
34
28
__all__ = ["DATASET_MOCKS" , "parametrize_dataset_mocks" ]
35
29
@@ -40,76 +34,48 @@ def __init__(self, name, mock_data_fn):
40
34
self .info = self .dataset .info
41
35
self .name = self .info .name
42
36
43
- self .root = TEST_HOME / self .dataset .name
44
37
self .mock_data_fn = mock_data_fn
45
38
self .configs = self .info ._configs
46
- self ._cache = {}
47
39
48
- def _parse_mock_data (self , config , mock_infos ):
49
- if mock_infos is None :
40
+ def _parse_mock_info (self , mock_info ):
41
+ if mock_info is None :
50
42
raise pytest .UsageError (
51
43
f"The mock data function for dataset '{ self .name } ' returned nothing. It needs to at least return an "
52
44
f"integer indicating the number of samples for the current `config`."
53
45
)
54
-
55
- key_types = set (type (key ) for key in mock_infos ) if isinstance (mock_infos , dict ) else {}
56
- if datasets .utils .DatasetConfig not in key_types :
57
- mock_infos = {config : mock_infos }
58
- elif len (key_types ) > 1 :
46
+ elif isinstance (mock_info , int ):
47
+ mock_info = dict (num_samples = mock_info )
48
+ elif not isinstance (mock_info , dict ):
59
49
raise pytest .UsageError (
60
- f"Unable to handle the returned dictionary of the mock data function for dataset { self .name } . If "
61
- f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
50
+ f"The mock data function for dataset '{ self .name } ' returned a { type (mock_info )} . The returned object "
51
+ f"should be a dictionary containing at least the number of samples for the key `'num_samples'`. If no "
52
+ f"additional information is required for specific tests, the number of samples can also be returned as "
53
+ f"an integer."
54
+ )
55
+ elif "num_samples" not in mock_info :
56
+ raise pytest .UsageError (
57
+ f"The dictionary returned by the mock data function for dataset '{ self .name } ' has to contain a "
58
+ f"`'num_samples'` entry indicating the number of samples."
62
59
)
63
60
64
- for config_ , mock_info in mock_infos .items ():
65
- if config_ in self ._cache :
66
- raise pytest .UsageError (
67
- f"The mock info for config { config_ } of dataset { self .name } generated for config { config } "
68
- f"already exists in the cache."
69
- )
70
- if isinstance (mock_info , int ):
71
- mock_infos [config_ ] = dict (num_samples = mock_info )
72
- elif not isinstance (mock_info , dict ):
73
- raise pytest .UsageError (
74
- f"The mock data function for dataset '{ self .name } ' returned a { type (mock_infos )} for `config` "
75
- f"{ config_ } . The returned object should be a dictionary containing at least the number of "
76
- f"samples for the key `'num_samples'`. If no additional information is required for specific "
77
- f"tests, the number of samples can also be returned as an integer."
78
- )
79
- elif "num_samples" not in mock_info :
80
- raise pytest .UsageError (
81
- f"The dictionary returned by the mock data function for dataset '{ self .name } ' and config "
82
- f"{ config_ } has to contain a `'num_samples'` entry indicating the number of samples."
83
- )
84
-
85
- return mock_infos
86
-
87
- def _prepare_resources (self , config ):
88
- if config in self ._cache :
89
- return self ._cache [config ]
90
-
91
- self .root .mkdir (exist_ok = True )
92
- mock_infos = self ._parse_mock_data (config , self .mock_data_fn (self .info , self .root , config ))
61
+ return mock_info
93
62
94
- available_file_names = {path .name for path in self .root .glob ("*" )}
95
- for config_ , mock_info in mock_infos .items ():
96
- required_file_names = {resource .file_name for resource in self .dataset .resources (config_ )}
97
- missing_file_names = required_file_names - available_file_names
98
- if missing_file_names :
99
- raise pytest .UsageError (
100
- f"Dataset '{ self .name } ' requires the files { sequence_to_str (sorted (missing_file_names ))} "
101
- f"for { config_ } , but they were not created by the mock data function."
102
- )
63
+ def prepare (self , home , config ):
64
+ root = home / self .name
65
+ root .mkdir (exist_ok = True )
103
66
104
- self ._cache [ config_ ] = mock_info
67
+ mock_info = self ._parse_mock_info ( self . mock_data_fn ( self . info , root , config ))
105
68
106
- return self ._cache [config ]
69
+ available_file_names = {path .name for path in root .glob ("*" )}
70
+ required_file_names = {resource .file_name for resource in self .dataset .resources (config )}
71
+ missing_file_names = required_file_names - available_file_names
72
+ if missing_file_names :
73
+ raise pytest .UsageError (
74
+ f"Dataset '{ self .name } ' requires the files { sequence_to_str (sorted (missing_file_names ))} "
75
+ f"for { config } , but they were not created by the mock data function."
76
+ )
107
77
108
- @contextlib .contextmanager
109
- def prepare (self , config ):
110
- mock_info = self ._prepare_resources (config )
111
- with unittest .mock .patch ("torchvision.prototype.datasets._api.home" , return_value = str (TEST_HOME )):
112
- yield mock_info
78
+ return mock_info
113
79
114
80
115
81
def config_id (name , config ):
@@ -254,32 +220,30 @@ def mnist(info, root, config):
254
220
255
221
256
222
@register_mock
257
- def emnist (info , root , _ ):
223
+ def emnist (info , root , config ):
258
224
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
259
225
# labels in the data files. Thus, num_categories != len(categories) there.
260
226
num_categories = defaultdict (
261
227
lambda : len (info .categories ), {image_set : 47 for image_set in ("Balanced" , "By_Merge" )}
262
228
)
263
229
264
- mock_infos = {}
230
+ num_samples_map = {}
265
231
file_names = set ()
266
- for config in info ._configs :
267
- prefix = f"emnist-{ config .image_set .replace ('_' , '' ).lower ()} -{ config .split } "
232
+ for config_ in info ._configs :
233
+ prefix = f"emnist-{ config_ .image_set .replace ('_' , '' ).lower ()} -{ config_ .split } "
268
234
images_file = f"{ prefix } -images-idx3-ubyte.gz"
269
235
labels_file = f"{ prefix } -labels-idx1-ubyte.gz"
270
236
file_names .update ({images_file , labels_file })
271
- mock_infos [config ] = dict (
272
- num_samples = MNISTMockData .generate (
273
- root ,
274
- num_categories = num_categories [config .image_set ],
275
- images_file = images_file ,
276
- labels_file = labels_file ,
277
- )
237
+ num_samples_map [config_ ] = MNISTMockData .generate (
238
+ root ,
239
+ num_categories = num_categories [config_ .image_set ],
240
+ images_file = images_file ,
241
+ labels_file = labels_file ,
278
242
)
279
243
280
244
make_zip (root , "emnist-gzip.zip" , * file_names )
281
245
282
- return mock_infos
246
+ return num_samples_map [ config ]
283
247
284
248
285
249
@register_mock
@@ -290,25 +254,23 @@ def qmnist(info, root, config):
290
254
prefix = "qmnist-train"
291
255
suffix = ".gz"
292
256
compressor = gzip .open
293
- mock_infos = num_samples
294
257
elif config .split .startswith ("test" ):
295
258
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
296
259
# more than 10000 images for the dataset to not be empty.
297
260
num_samples_gen = 10001
261
+ num_samples = {
262
+ "test" : num_samples_gen ,
263
+ "test10k" : min (num_samples_gen , 10_000 ),
264
+ "test50k" : num_samples_gen - 10_000 ,
265
+ }[config .split ]
298
266
prefix = "qmnist-test"
299
267
suffix = ".gz"
300
268
compressor = gzip .open
301
- mock_infos = {
302
- info .make_config (split = "test" ): num_samples_gen ,
303
- info .make_config (split = "test10k" ): min (num_samples_gen , 10_000 ),
304
- info .make_config (split = "test50k" ): num_samples_gen - 10_000 ,
305
- }
306
269
else : # config.split == "nist"
307
270
num_samples = num_samples_gen = num_categories + 3
308
271
prefix = "xnist"
309
272
suffix = ".xz"
310
273
compressor = lzma .open
311
- mock_infos = num_samples
312
274
313
275
MNISTMockData .generate (
314
276
root ,
@@ -320,7 +282,7 @@ def qmnist(info, root, config):
320
282
label_dtype = torch .int32 ,
321
283
compressor = compressor ,
322
284
)
323
- return mock_infos
285
+ return num_samples
324
286
325
287
326
288
class CIFARMockData :
@@ -624,12 +586,7 @@ def generate(
624
586
625
587
@register_mock
626
588
def coco (info , root , config ):
627
- return dict (
628
- zip (
629
- [config_ for config_ in info ._configs if config_ .year == config .year ],
630
- itertools .repeat (CocoMockData .generate (root , year = config .year , num_samples = 5 )),
631
- )
632
- )
589
+ return CocoMockData .generate (root , year = config .year , num_samples = 5 )
633
590
634
591
635
592
class SBDMockData :
@@ -702,9 +659,8 @@ def generate(cls, root):
702
659
703
660
704
661
@register_mock
705
- def sbd (info , root , _ ):
706
- num_samples_map = SBDMockData .generate (root )
707
- return {config : num_samples_map [config .split ] for config in info ._configs }
662
+ def sbd (info , root , config ):
663
+ return SBDMockData .generate (root )[config .split ]
708
664
709
665
710
666
@register_mock
@@ -821,12 +777,7 @@ def generate(cls, root, *, year, trainval):
821
777
@register_mock
822
778
def voc (info , root , config ):
823
779
trainval = config .split != "test"
824
- num_samples_map = VOCMockData .generate (root , year = config .year , trainval = trainval )
825
- return {
826
- config_ : num_samples_map [config_ .split ]
827
- for config_ in info ._configs
828
- if config_ .year == config .year and ((config_ .split == "test" ) ^ trainval )
829
- }
780
+ return VOCMockData .generate (root , year = config .year , trainval = trainval )[config .split ]
830
781
831
782
832
783
class CelebAMockData :
@@ -918,13 +869,12 @@ def generate(cls, root):
918
869
919
870
920
871
@register_mock
921
- def celeba (info , root , _ ):
922
- num_samples_map = CelebAMockData .generate (root )
923
- return {config : num_samples_map [config .split ] for config in info ._configs }
872
+ def celeba (info , root , config ):
873
+ return CelebAMockData .generate (root )[config .split ]
924
874
925
875
926
876
@register_mock
927
- def dtd (info , root , _ ):
877
+ def dtd (info , root , config ):
928
878
data_folder = root / "dtd"
929
879
930
880
num_images_per_class = 3
@@ -968,7 +918,7 @@ def dtd(info, root, _):
968
918
969
919
make_tar (root , "dtd-r1.0.1.tar.gz" , data_folder , compression = "gz" )
970
920
971
- return num_samples_map
921
+ return num_samples_map [ config ]
972
922
973
923
974
924
@register_mock
@@ -1108,7 +1058,7 @@ def clevr(info, root, config):
1108
1058
1109
1059
make_zip (root , f"{ data_folder .name } .zip" , data_folder )
1110
1060
1111
- return { config_ : num_samples_map [config_ .split ] for config_ in info . _configs }
1061
+ return num_samples_map [config .split ]
1112
1062
1113
1063
1114
1064
class OxfordIIITPetMockData :
@@ -1174,8 +1124,7 @@ def generate(self, root):
1174
1124
1175
1125
@register_mock
1176
1126
def oxford_iiit_pet (info , root , config ):
1177
- num_samples_map = OxfordIIITPetMockData .generate (root )
1178
- return {config_ : num_samples_map [config_ .split ] for config_ in info ._configs }
1127
+ return OxfordIIITPetMockData .generate (root )[config .split ]
1179
1128
1180
1129
1181
1130
class _CUB200MockData :
@@ -1342,7 +1291,7 @@ def generate(cls, root):
1342
1291
@register_mock
1343
1292
def cub200 (info , root , config ):
1344
1293
num_samples_map = (CUB2002011MockData if config .year == "2011" else CUB2002010MockData ).generate (root )
1345
- return { config_ : num_samples_map [config_ .split ] for config_ in info . _configs if config_ . year == config . year }
1294
+ return num_samples_map [config .split ]
1346
1295
1347
1296
1348
1297
@register_mock
0 commit comments