Skip to content

Commit 0aae427

Browse files
authored
refactor prototype datasets to inherit from IterDataPipe (#5448)
* refactor prototype datasets to inherit from IterDataPipe * depend on new architecture * fix missing file detection * remove unrelated file * reinstante decorator for mock registering * options -> config * remove passing of info to mock data functions * refactor categories file generation
1 parent 1916bd7 commit 0aae427

File tree

8 files changed

+210
-174
lines changed

8 files changed

+210
-174
lines changed

test/builtin_dataset_mocks.py

Lines changed: 54 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@
99
import pathlib
1010
import pickle
1111
import random
12+
import unittest.mock
1213
import xml.etree.ElementTree as ET
1314
from collections import defaultdict, Counter
1415

1516
import numpy as np
1617
import PIL.Image
1718
import pytest
1819
import torch
19-
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
20+
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid
2021
from torch.nn.functional import one_hot
2122
from torch.testing import make_tensor as _make_tensor
22-
from torchvision.prototype.datasets._api import find
23+
from torchvision.prototype import datasets
2324
from torchvision.prototype.utils._internal import sequence_to_str
2425

2526
make_tensor = functools.partial(_make_tensor, device="cpu")
@@ -30,13 +31,11 @@
3031

3132

3233
class DatasetMock:
33-
def __init__(self, name, mock_data_fn):
34-
self.dataset = find(name)
35-
self.info = self.dataset.info
36-
self.name = self.info.name
37-
34+
def __init__(self, name, *, mock_data_fn, configs):
35+
# FIXME: error handling for unknown names
36+
self.name = name
3837
self.mock_data_fn = mock_data_fn
39-
self.configs = self.info._configs
38+
self.configs = configs
4039

4140
def _parse_mock_info(self, mock_info):
4241
if mock_info is None:
@@ -65,10 +64,13 @@ def prepare(self, home, config):
6564
root = home / self.name
6665
root.mkdir(exist_ok=True)
6766

68-
mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config))
67+
mock_info = self._parse_mock_info(self.mock_data_fn(root, config))
6968

69+
with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"):
70+
required_file_names = {
71+
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
72+
}
7073
available_file_names = {path.name for path in root.glob("*")}
71-
required_file_names = {resource.file_name for resource in self.dataset.resources(config)}
7274
missing_file_names = required_file_names - available_file_names
7375
if missing_file_names:
7476
raise pytest.UsageError(
@@ -123,10 +125,16 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
123125
DATASET_MOCKS = {}
124126

125127

126-
def register_mock(fn):
127-
name = fn.__name__.replace("_", "-")
128-
DATASET_MOCKS[name] = DatasetMock(name, fn)
129-
return fn
128+
def register_mock(name=None, *, configs):
129+
def wrapper(mock_data_fn):
130+
nonlocal name
131+
if name is None:
132+
name = mock_data_fn.__name__
133+
DATASET_MOCKS[name] = DatasetMock(name, mock_data_fn=mock_data_fn, configs=configs)
134+
135+
return mock_data_fn
136+
137+
return wrapper
130138

131139

132140
class MNISTMockData:
@@ -204,7 +212,7 @@ def generate(
204212
return num_samples
205213

206214

207-
@register_mock
215+
# @register_mock
208216
def mnist(info, root, config):
209217
train = config.split == "train"
210218
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
@@ -217,10 +225,10 @@ def mnist(info, root, config):
217225
)
218226

219227

220-
DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
228+
# DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
221229

222230

223-
@register_mock
231+
# @register_mock
224232
def emnist(info, root, config):
225233
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
226234
# labels in the data files. Thus, num_categories != len(categories) there.
@@ -247,7 +255,7 @@ def emnist(info, root, config):
247255
return num_samples_map[config]
248256

249257

250-
@register_mock
258+
# @register_mock
251259
def qmnist(info, root, config):
252260
num_categories = len(info.categories)
253261
if config.split == "train":
@@ -324,7 +332,7 @@ def generate(
324332
make_tar(root, name, folder, compression="gz")
325333

326334

327-
@register_mock
335+
# @register_mock
328336
def cifar10(info, root, config):
329337
train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
330338
test_files = ["test_batch"]
@@ -342,7 +350,7 @@ def cifar10(info, root, config):
342350
return len(train_files if config.split == "train" else test_files)
343351

344352

345-
@register_mock
353+
# @register_mock
346354
def cifar100(info, root, config):
347355
train_files = ["train"]
348356
test_files = ["test"]
@@ -360,7 +368,7 @@ def cifar100(info, root, config):
360368
return len(train_files if config.split == "train" else test_files)
361369

362370

363-
@register_mock
371+
# @register_mock
364372
def caltech101(info, root, config):
365373
def create_ann_file(root, name):
366374
import scipy.io
@@ -410,7 +418,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples):
410418
return num_images_per_category * len(info.categories)
411419

412420

413-
@register_mock
421+
# @register_mock
414422
def caltech256(info, root, config):
415423
dir = root / "256_ObjectCategories"
416424
num_images_per_category = 2
@@ -430,26 +438,26 @@ def caltech256(info, root, config):
430438
return num_images_per_category * len(info.categories)
431439

432440

433-
@register_mock
434-
def imagenet(info, root, config):
441+
@register_mock(configs=combinations_grid(split=("train", "val", "test")))
442+
def imagenet(root, config):
435443
from scipy.io import savemat
436444

437-
categories = info.categories
438-
wnids = [info.extra.category_to_wnid[category] for category in categories]
439-
if config.split == "train":
440-
num_samples = len(wnids)
445+
info = datasets.info("imagenet")
446+
447+
if config["split"] == "train":
448+
num_samples = len(info["wnids"])
441449
archive_name = "ILSVRC2012_img_train.tar"
442450

443451
files = []
444-
for wnid in wnids:
452+
for wnid in info["wnids"]:
445453
create_image_folder(
446454
root=root,
447455
name=wnid,
448456
file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG",
449457
num_examples=1,
450458
)
451459
files.append(make_tar(root, f"{wnid}.tar"))
452-
elif config.split == "val":
460+
elif config["split"] == "val":
453461
num_samples = 3
454462
archive_name = "ILSVRC2012_img_val.tar"
455463
files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
@@ -459,20 +467,20 @@ def imagenet(info, root, config):
459467
data_root.mkdir(parents=True)
460468

461469
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
462-
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
470+
for label in torch.randint(0, len(info["wnids"]), (num_samples,)).tolist():
463471
file.write(f"{label}\n")
464472

465473
num_children = 0
466474
synsets = [
467475
(idx, wnid, category, "", num_children, [], 0, 0)
468-
for idx, (category, wnid) in enumerate(zip(categories, wnids), 1)
476+
for idx, (category, wnid) in enumerate(zip(info["categories"], info["wnids"]), 1)
469477
]
470478
num_children = 1
471479
synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5))
472480
savemat(data_root / "meta.mat", dict(synsets=synsets))
473481

474482
make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz")
475-
else: # config.split == "test"
483+
else: # config["split"] == "test"
476484
num_samples = 5
477485
archive_name = "ILSVRC2012_img_test_v10102019.tar"
478486
files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
@@ -587,7 +595,7 @@ def generate(
587595
return num_samples
588596

589597

590-
@register_mock
598+
# @register_mock
591599
def coco(info, root, config):
592600
return CocoMockData.generate(root, year=config.year, num_samples=5)
593601

@@ -661,12 +669,12 @@ def generate(cls, root):
661669
return num_samples_map
662670

663671

664-
@register_mock
672+
# @register_mock
665673
def sbd(info, root, config):
666674
return SBDMockData.generate(root)[config.split]
667675

668676

669-
@register_mock
677+
# @register_mock
670678
def semeion(info, root, config):
671679
num_samples = 3
672680
num_categories = len(info.categories)
@@ -779,7 +787,7 @@ def generate(cls, root, *, year, trainval):
779787
return num_samples_map
780788

781789

782-
@register_mock
790+
# @register_mock
783791
def voc(info, root, config):
784792
trainval = config.split != "test"
785793
return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split]
@@ -873,12 +881,12 @@ def generate(cls, root):
873881
return num_samples_map
874882

875883

876-
@register_mock
884+
# @register_mock
877885
def celeba(info, root, config):
878886
return CelebAMockData.generate(root)[config.split]
879887

880888

881-
@register_mock
889+
# @register_mock
882890
def dtd(info, root, config):
883891
data_folder = root / "dtd"
884892

@@ -926,7 +934,7 @@ def dtd(info, root, config):
926934
return num_samples_map[config]
927935

928936

929-
@register_mock
937+
# @register_mock
930938
def fer2013(info, root, config):
931939
num_samples = 5 if config.split == "train" else 3
932940

@@ -951,7 +959,7 @@ def fer2013(info, root, config):
951959
return num_samples
952960

953961

954-
@register_mock
962+
# @register_mock
955963
def gtsrb(info, root, config):
956964
num_examples_per_class = 5 if config.split == "train" else 3
957965
classes = ("00000", "00042", "00012")
@@ -1021,7 +1029,7 @@ def _make_ann_file(path, num_examples, class_idx):
10211029
return num_examples
10221030

10231031

1024-
@register_mock
1032+
# @register_mock
10251033
def clevr(info, root, config):
10261034
data_folder = root / "CLEVR_v1.0"
10271035

@@ -1127,7 +1135,7 @@ def generate(self, root):
11271135
return num_samples_map
11281136

11291137

1130-
@register_mock
1138+
# @register_mock
11311139
def oxford_iiit_pet(info, root, config):
11321140
return OxfordIIITPetMockData.generate(root)[config.split]
11331141

@@ -1293,13 +1301,13 @@ def generate(cls, root):
12931301
return num_samples_map
12941302

12951303

1296-
@register_mock
1304+
# @register_mock
12971305
def cub200(info, root, config):
12981306
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
12991307
return num_samples_map[config.split]
13001308

13011309

1302-
@register_mock
1310+
# @register_mock
13031311
def svhn(info, root, config):
13041312
import scipy.io as sio
13051313

@@ -1319,7 +1327,7 @@ def svhn(info, root, config):
13191327
return num_samples
13201328

13211329

1322-
@register_mock
1330+
# @register_mock
13231331
def pcam(info, root, config):
13241332
import h5py
13251333

0 commit comments

Comments
 (0)