9
9
import pathlib
10
10
import pickle
11
11
import random
12
+ import unittest .mock
12
13
import xml .etree .ElementTree as ET
13
14
from collections import defaultdict , Counter
14
15
15
16
import numpy as np
16
17
import PIL .Image
17
18
import pytest
18
19
import torch
19
- from datasets_utils import make_zip , make_tar , create_image_folder , create_image_file
20
+ from datasets_utils import make_zip , make_tar , create_image_folder , create_image_file , combinations_grid
20
21
from torch .nn .functional import one_hot
21
22
from torch .testing import make_tensor as _make_tensor
22
- from torchvision .prototype . datasets . _api import find
23
+ from torchvision .prototype import datasets
23
24
from torchvision .prototype .utils ._internal import sequence_to_str
24
25
25
26
make_tensor = functools .partial (_make_tensor , device = "cpu" )
30
31
31
32
32
33
class DatasetMock :
33
- def __init__ (self , name , mock_data_fn ):
34
- self .dataset = find (name )
35
- self .info = self .dataset .info
36
- self .name = self .info .name
37
-
34
+ def __init__ (self , name , * , mock_data_fn , configs ):
35
+ # FIXME: error handling for unknown names
36
+ self .name = name
38
37
self .mock_data_fn = mock_data_fn
39
- self .configs = self . info . _configs
38
+ self .configs = configs
40
39
41
40
def _parse_mock_info (self , mock_info ):
42
41
if mock_info is None :
@@ -65,10 +64,13 @@ def prepare(self, home, config):
65
64
root = home / self .name
66
65
root .mkdir (exist_ok = True )
67
66
68
- mock_info = self ._parse_mock_info (self .mock_data_fn (self . info , root , config ))
67
+ mock_info = self ._parse_mock_info (self .mock_data_fn (root , config ))
69
68
69
+ with unittest .mock .patch .object (datasets .utils .Dataset2 , "__init__" ):
70
+ required_file_names = {
71
+ resource .file_name for resource in datasets .load (self .name , root = root , ** config )._resources ()
72
+ }
70
73
available_file_names = {path .name for path in root .glob ("*" )}
71
- required_file_names = {resource .file_name for resource in self .dataset .resources (config )}
72
74
missing_file_names = required_file_names - available_file_names
73
75
if missing_file_names :
74
76
raise pytest .UsageError (
@@ -123,10 +125,16 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
123
125
DATASET_MOCKS = {}
124
126
125
127
126
- def register_mock (fn ):
127
- name = fn .__name__ .replace ("_" , "-" )
128
- DATASET_MOCKS [name ] = DatasetMock (name , fn )
129
- return fn
128
+ def register_mock (name = None , * , configs ):
129
+ def wrapper (mock_data_fn ):
130
+ nonlocal name
131
+ if name is None :
132
+ name = mock_data_fn .__name__
133
+ DATASET_MOCKS [name ] = DatasetMock (name , mock_data_fn = mock_data_fn , configs = configs )
134
+
135
+ return mock_data_fn
136
+
137
+ return wrapper
130
138
131
139
132
140
class MNISTMockData :
@@ -204,7 +212,7 @@ def generate(
204
212
return num_samples
205
213
206
214
207
- @register_mock
215
+ # @register_mock
208
216
def mnist (info , root , config ):
209
217
train = config .split == "train"
210
218
images_file = f"{ 'train' if train else 't10k' } -images-idx3-ubyte.gz"
@@ -217,10 +225,10 @@ def mnist(info, root, config):
217
225
)
218
226
219
227
220
- DATASET_MOCKS .update ({name : DatasetMock (name , mnist ) for name in ["fashionmnist" , "kmnist" ]})
228
+ # DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
221
229
222
230
223
- @register_mock
231
+ # @register_mock
224
232
def emnist (info , root , config ):
225
233
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
226
234
# labels in the data files. Thus, num_categories != len(categories) there.
@@ -247,7 +255,7 @@ def emnist(info, root, config):
247
255
return num_samples_map [config ]
248
256
249
257
250
- @register_mock
258
+ # @register_mock
251
259
def qmnist (info , root , config ):
252
260
num_categories = len (info .categories )
253
261
if config .split == "train" :
@@ -324,7 +332,7 @@ def generate(
324
332
make_tar (root , name , folder , compression = "gz" )
325
333
326
334
327
- @register_mock
335
+ # @register_mock
328
336
def cifar10 (info , root , config ):
329
337
train_files = [f"data_batch_{ idx } " for idx in range (1 , 6 )]
330
338
test_files = ["test_batch" ]
@@ -342,7 +350,7 @@ def cifar10(info, root, config):
342
350
return len (train_files if config .split == "train" else test_files )
343
351
344
352
345
- @register_mock
353
+ # @register_mock
346
354
def cifar100 (info , root , config ):
347
355
train_files = ["train" ]
348
356
test_files = ["test" ]
@@ -360,7 +368,7 @@ def cifar100(info, root, config):
360
368
return len (train_files if config .split == "train" else test_files )
361
369
362
370
363
- @register_mock
371
+ # @register_mock
364
372
def caltech101 (info , root , config ):
365
373
def create_ann_file (root , name ):
366
374
import scipy .io
@@ -410,7 +418,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples):
410
418
return num_images_per_category * len (info .categories )
411
419
412
420
413
- @register_mock
421
+ # @register_mock
414
422
def caltech256 (info , root , config ):
415
423
dir = root / "256_ObjectCategories"
416
424
num_images_per_category = 2
@@ -430,26 +438,26 @@ def caltech256(info, root, config):
430
438
return num_images_per_category * len (info .categories )
431
439
432
440
433
- @register_mock
434
- def imagenet (info , root , config ):
441
+ @register_mock ( configs = combinations_grid ( split = ( "train" , "val" , "test" )))
442
+ def imagenet (root , config ):
435
443
from scipy .io import savemat
436
444
437
- categories = info . categories
438
- wnids = [ info . extra . category_to_wnid [ category ] for category in categories ]
439
- if config . split == "train" :
440
- num_samples = len (wnids )
445
+ info = datasets . info ( "imagenet" )
446
+
447
+ if config [ " split" ] == "train" :
448
+ num_samples = len (info [ " wnids" ] )
441
449
archive_name = "ILSVRC2012_img_train.tar"
442
450
443
451
files = []
444
- for wnid in wnids :
452
+ for wnid in info [ " wnids" ] :
445
453
create_image_folder (
446
454
root = root ,
447
455
name = wnid ,
448
456
file_name_fn = lambda image_idx : f"{ wnid } _{ image_idx :04d} .JPEG" ,
449
457
num_examples = 1 ,
450
458
)
451
459
files .append (make_tar (root , f"{ wnid } .tar" ))
452
- elif config . split == "val" :
460
+ elif config [ " split" ] == "val" :
453
461
num_samples = 3
454
462
archive_name = "ILSVRC2012_img_val.tar"
455
463
files = [create_image_file (root , f"ILSVRC2012_val_{ idx + 1 :08d} .JPEG" ) for idx in range (num_samples )]
@@ -459,20 +467,20 @@ def imagenet(info, root, config):
459
467
data_root .mkdir (parents = True )
460
468
461
469
with open (data_root / "ILSVRC2012_validation_ground_truth.txt" , "w" ) as file :
462
- for label in torch .randint (0 , len (wnids ), (num_samples ,)).tolist ():
470
+ for label in torch .randint (0 , len (info [ " wnids" ] ), (num_samples ,)).tolist ():
463
471
file .write (f"{ label } \n " )
464
472
465
473
num_children = 0
466
474
synsets = [
467
475
(idx , wnid , category , "" , num_children , [], 0 , 0 )
468
- for idx , (category , wnid ) in enumerate (zip (categories , wnids ), 1 )
476
+ for idx , (category , wnid ) in enumerate (zip (info [ " categories" ], info [ " wnids" ] ), 1 )
469
477
]
470
478
num_children = 1
471
479
synsets .extend ((0 , "" , "" , "" , num_children , [], 0 , 0 ) for _ in range (5 ))
472
480
savemat (data_root / "meta.mat" , dict (synsets = synsets ))
473
481
474
482
make_tar (root , devkit_root .with_suffix (".tar.gz" ).name , compression = "gz" )
475
- else : # config. split == "test"
483
+ else : # config[" split"] == "test"
476
484
num_samples = 5
477
485
archive_name = "ILSVRC2012_img_test_v10102019.tar"
478
486
files = [create_image_file (root , f"ILSVRC2012_test_{ idx + 1 :08d} .JPEG" ) for idx in range (num_samples )]
@@ -587,7 +595,7 @@ def generate(
587
595
return num_samples
588
596
589
597
590
- @register_mock
598
+ # @register_mock
591
599
def coco (info , root , config ):
592
600
return CocoMockData .generate (root , year = config .year , num_samples = 5 )
593
601
@@ -661,12 +669,12 @@ def generate(cls, root):
661
669
return num_samples_map
662
670
663
671
664
- @register_mock
672
+ # @register_mock
665
673
def sbd (info , root , config ):
666
674
return SBDMockData .generate (root )[config .split ]
667
675
668
676
669
- @register_mock
677
+ # @register_mock
670
678
def semeion (info , root , config ):
671
679
num_samples = 3
672
680
num_categories = len (info .categories )
@@ -779,7 +787,7 @@ def generate(cls, root, *, year, trainval):
779
787
return num_samples_map
780
788
781
789
782
- @register_mock
790
+ # @register_mock
783
791
def voc (info , root , config ):
784
792
trainval = config .split != "test"
785
793
return VOCMockData .generate (root , year = config .year , trainval = trainval )[config .split ]
@@ -873,12 +881,12 @@ def generate(cls, root):
873
881
return num_samples_map
874
882
875
883
876
- @register_mock
884
+ # @register_mock
877
885
def celeba (info , root , config ):
878
886
return CelebAMockData .generate (root )[config .split ]
879
887
880
888
881
- @register_mock
889
+ # @register_mock
882
890
def dtd (info , root , config ):
883
891
data_folder = root / "dtd"
884
892
@@ -926,7 +934,7 @@ def dtd(info, root, config):
926
934
return num_samples_map [config ]
927
935
928
936
929
- @register_mock
937
+ # @register_mock
930
938
def fer2013 (info , root , config ):
931
939
num_samples = 5 if config .split == "train" else 3
932
940
@@ -951,7 +959,7 @@ def fer2013(info, root, config):
951
959
return num_samples
952
960
953
961
954
- @register_mock
962
+ # @register_mock
955
963
def gtsrb (info , root , config ):
956
964
num_examples_per_class = 5 if config .split == "train" else 3
957
965
classes = ("00000" , "00042" , "00012" )
@@ -1021,7 +1029,7 @@ def _make_ann_file(path, num_examples, class_idx):
1021
1029
return num_examples
1022
1030
1023
1031
1024
- @register_mock
1032
+ # @register_mock
1025
1033
def clevr (info , root , config ):
1026
1034
data_folder = root / "CLEVR_v1.0"
1027
1035
@@ -1127,7 +1135,7 @@ def generate(self, root):
1127
1135
return num_samples_map
1128
1136
1129
1137
1130
- @register_mock
1138
+ # @register_mock
1131
1139
def oxford_iiit_pet (info , root , config ):
1132
1140
return OxfordIIITPetMockData .generate (root )[config .split ]
1133
1141
@@ -1293,13 +1301,13 @@ def generate(cls, root):
1293
1301
return num_samples_map
1294
1302
1295
1303
1296
- @register_mock
1304
+ # @register_mock
1297
1305
def cub200 (info , root , config ):
1298
1306
num_samples_map = (CUB2002011MockData if config .year == "2011" else CUB2002010MockData ).generate (root )
1299
1307
return num_samples_map [config .split ]
1300
1308
1301
1309
1302
- @register_mock
1310
+ # @register_mock
1303
1311
def svhn (info , root , config ):
1304
1312
import scipy .io as sio
1305
1313
@@ -1319,7 +1327,7 @@ def svhn(info, root, config):
1319
1327
return num_samples
1320
1328
1321
1329
1322
- @register_mock
1330
+ # @register_mock
1323
1331
def pcam (info , root , config ):
1324
1332
import h5py
1325
1333
0 commit comments