Skip to content

Commit 08c8f0e

Browse files
authored
Merge mock data preparation and dataset logic in prototype tests (#6010)
* merge mock data preparation and loading * address comments * fix extra file creation * remove tmp folder * inline images meta creation in coco mock data
1 parent d9a6950 commit 08c8f0e

File tree

2 files changed

+94
-84
lines changed

2 files changed

+94
-84
lines changed

test/builtin_dataset_mocks.py

Lines changed: 78 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@
1010
import pathlib
1111
import pickle
1212
import random
13+
import shutil
1314
import unittest.mock
1415
import warnings
1516
import xml.etree.ElementTree as ET
1617
from collections import defaultdict, Counter
1718

1819
import numpy as np
19-
import PIL.Image
2020
import pytest
2121
import torch
2222
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid
2323
from torch.nn.functional import one_hot
2424
from torch.testing import make_tensor as _make_tensor
25-
from torchvision._utils import sequence_to_str
2625
from torchvision.prototype import datasets
2726

2827
make_tensor = functools.partial(_make_tensor, device="cpu")
@@ -62,27 +61,51 @@ def _parse_mock_info(self, mock_info):
6261

6362
return mock_info
6463

65-
def prepare(self, config):
64+
def load(self, config):
6665
# `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in
6766
# test/test_prototype_builtin_datasets.py
6867
root = pathlib.Path(datasets.home()) / self.name
69-
root.mkdir(exist_ok=True)
68+
# We cannot place the mock data upfront in `root`. Loading a dataset calls `OnlineResource.load`. In turn,
69+
# this will only download **and** preprocess if the file is not present. In other words, if we already place
70+
# the file in `root` before the resource is loaded, we are effectively skipping the preprocessing.
71+
# To avoid that we first place the mock data in a temporary directory and patch the download logic to move it to
72+
# `root` only when it is requested.
73+
tmp_mock_data_folder = root / "__mock__"
74+
tmp_mock_data_folder.mkdir(parents=True)
75+
76+
mock_info = self._parse_mock_info(self.mock_data_fn(tmp_mock_data_folder, config))
77+
78+
def patched_download(resource, root, **kwargs):
79+
src = tmp_mock_data_folder / resource.file_name
80+
if not src.exists():
81+
raise pytest.UsageError(
82+
f"Dataset '{self.name}' requires the file {resource.file_name} for {config}"
83+
f"but it was not created by the mock data function."
84+
)
7085

71-
mock_info = self._parse_mock_info(self.mock_data_fn(root, config))
86+
dst = root / resource.file_name
87+
shutil.move(str(src), str(root))
7288

73-
with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"):
74-
required_file_names = {
75-
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
76-
}
77-
available_file_names = {path.name for path in root.glob("*")}
78-
missing_file_names = required_file_names - available_file_names
79-
if missing_file_names:
89+
return dst
90+
91+
with unittest.mock.patch(
92+
"torchvision.prototype.datasets.utils._resource.OnlineResource.download", new=patched_download
93+
):
94+
dataset = datasets.load(self.name, **config)
95+
96+
extra_files = list(tmp_mock_data_folder.glob("**/*"))
97+
if extra_files:
8098
raise pytest.UsageError(
81-
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
82-
f"for {config}, but they were not created by the mock data function."
99+
(
100+
f"Dataset '{self.name}' created the following files for {config} in the mock data function, "
101+
f"but they were not loaded:\n\n"
102+
)
103+
+ "\n".join(str(file.relative_to(tmp_mock_data_folder)) for file in extra_files)
83104
)
84105

85-
return mock_info
106+
tmp_mock_data_folder.rmdir()
107+
108+
return dataset, mock_info
86109

87110

88111
def config_id(name, config):
@@ -513,22 +536,6 @@ def imagenet(root, config):
513536

514537

515538
class CocoMockData:
516-
@classmethod
517-
def _make_images_archive(cls, root, name, *, num_samples):
518-
image_paths = create_image_folder(
519-
root, name, file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_samples
520-
)
521-
522-
images_meta = []
523-
for path in image_paths:
524-
with PIL.Image.open(path) as image:
525-
width, height = image.size
526-
images_meta.append(dict(file_name=path.name, id=int(path.stem), width=width, height=height))
527-
528-
make_zip(root, f"{name}.zip")
529-
530-
return images_meta
531-
532539
@classmethod
533540
def _make_annotations_json(
534541
cls,
@@ -596,16 +603,38 @@ def generate(
596603
cls,
597604
root,
598605
*,
606+
split,
599607
year,
600608
num_samples,
601609
):
602610
annotations_dir = root / "annotations"
603611
annotations_dir.mkdir()
604612

605-
for split in ("train", "val"):
606-
config_name = f"{split}{year}"
613+
for split_ in ("train", "val"):
614+
config_name = f"{split_}{year}"
615+
616+
images_meta = [
617+
dict(
618+
file_name=f"{idx:012d}.jpg",
619+
id=idx,
620+
width=width,
621+
height=height,
622+
)
623+
for idx, (height, width) in enumerate(
624+
torch.randint(3, 11, size=(num_samples, 2), dtype=torch.int).tolist()
625+
)
626+
]
627+
628+
if split_ == split:
629+
create_image_folder(
630+
root,
631+
config_name,
632+
file_name_fn=lambda idx: images_meta[idx]["file_name"],
633+
num_examples=num_samples,
634+
size=lambda idx: (3, images_meta[idx]["height"], images_meta[idx]["width"]),
635+
)
636+
make_zip(root, f"{config_name}.zip")
607637

608-
images_meta = cls._make_images_archive(root, config_name, num_samples=num_samples)
609638
cls._make_annotations(
610639
annotations_dir,
611640
config_name,
@@ -625,7 +654,7 @@ def generate(
625654
)
626655
)
627656
def coco(root, config):
628-
return CocoMockData.generate(root, year=config["year"], num_samples=5)
657+
return CocoMockData.generate(root, split=config["split"], year=config["year"], num_samples=5)
629658

630659

631660
class SBDMockData:
@@ -799,8 +828,11 @@ def add_bndbox(obj):
799828
def generate(cls, root, *, year, trainval):
800829
archive_folder = root
801830
if year == "2011":
802-
archive_folder /= "TrainVal"
803-
data_folder = archive_folder / "VOCdevkit" / f"VOC{year}"
831+
archive_folder = root / "TrainVal"
832+
data_folder = archive_folder / "VOCdevkit"
833+
else:
834+
archive_folder = data_folder = root / "VOCdevkit"
835+
data_folder = data_folder / f"VOC{year}"
804836
data_folder.mkdir(parents=True, exist_ok=True)
805837

806838
ids, num_samples_map = cls._make_split_files(data_folder, year=year, trainval=trainval)
@@ -810,7 +842,7 @@ def generate(cls, root, *, year, trainval):
810842
(cls._make_detection_anns_folder, "Annotations", ".xml"),
811843
]:
812844
make_folder_fn(data_folder, name, file_name_fn=lambda idx: ids[idx] + suffix, num_examples=len(ids))
813-
make_tar(root, (cls._TRAIN_VAL_FILE_NAMES if trainval else cls._TEST_FILE_NAMES)[year], data_folder)
845+
make_tar(root, (cls._TRAIN_VAL_FILE_NAMES if trainval else cls._TEST_FILE_NAMES)[year], archive_folder)
814846

815847
return num_samples_map
816848

@@ -1091,8 +1123,10 @@ def _make_ann_file(path, num_examples, class_idx):
10911123
}
10921124
)
10931125

1126+
archive_folder = root / "GTSRB"
1127+
10941128
if config["split"] == "train":
1095-
train_folder = root / "GTSRB" / "Training"
1129+
train_folder = archive_folder / "Training"
10961130
train_folder.mkdir(parents=True)
10971131

10981132
for class_idx in classes:
@@ -1107,9 +1141,9 @@ def _make_ann_file(path, num_examples, class_idx):
11071141
num_examples=num_examples_per_class,
11081142
class_idx=int(class_idx),
11091143
)
1110-
make_zip(root, "GTSRB-Training_fixed.zip", train_folder)
1144+
make_zip(root, "GTSRB-Training_fixed.zip", archive_folder)
11111145
else:
1112-
test_folder = root / "GTSRB" / "Final_Test"
1146+
test_folder = archive_folder / "Final_Test"
11131147
test_folder.mkdir(parents=True)
11141148

11151149
create_image_folder(
@@ -1119,7 +1153,7 @@ def _make_ann_file(path, num_examples, class_idx):
11191153
num_examples=num_examples,
11201154
)
11211155

1122-
make_zip(root, "GTSRB_Final_Test_Images.zip", test_folder)
1156+
make_zip(root, "GTSRB_Final_Test_Images.zip", archive_folder)
11231157

11241158
_make_ann_file(
11251159
path=root / "GT-final_test.csv",
@@ -1484,11 +1518,10 @@ def stanford_cars(root, config):
14841518
num_samples = {"train": 5, "test": 7}[split]
14851519
num_categories = 3
14861520

1487-
devkit = root / "devkit"
1488-
devkit.mkdir(parents=True)
1489-
14901521
if split == "train":
14911522
images_folder_name = "cars_train"
1523+
devkit = root / "devkit"
1524+
devkit.mkdir()
14921525
annotations_mat_path = devkit / "cars_train_annos.mat"
14931526
else:
14941527
images_folder_name = "cars_test"

test/test_prototype_builtin_datasets.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,14 @@ def test_info(self, name):
5656

5757
@parametrize_dataset_mocks(DATASET_MOCKS)
5858
def test_smoke(self, dataset_mock, config):
59-
dataset_mock.prepare(config)
60-
61-
dataset = datasets.load(dataset_mock.name, **config)
59+
dataset, _ = dataset_mock.load(config)
6260

6361
if not isinstance(dataset, datasets.utils.Dataset):
6462
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
6563

6664
@parametrize_dataset_mocks(DATASET_MOCKS)
6765
def test_sample(self, dataset_mock, config):
68-
dataset_mock.prepare(config)
69-
70-
dataset = datasets.load(dataset_mock.name, **config)
66+
dataset, _ = dataset_mock.load(config)
7167

7268
try:
7369
sample = next(iter(dataset))
@@ -84,17 +80,13 @@ def test_sample(self, dataset_mock, config):
8480

8581
@parametrize_dataset_mocks(DATASET_MOCKS)
8682
def test_num_samples(self, dataset_mock, config):
87-
mock_info = dataset_mock.prepare(config)
88-
89-
dataset = datasets.load(dataset_mock.name, **config)
83+
dataset, mock_info = dataset_mock.load(config)
9084

9185
assert len(list(dataset)) == mock_info["num_samples"]
9286

9387
@parametrize_dataset_mocks(DATASET_MOCKS)
9488
def test_no_vanilla_tensors(self, dataset_mock, config):
95-
dataset_mock.prepare(config)
96-
97-
dataset = datasets.load(dataset_mock.name, **config)
89+
dataset, _ = dataset_mock.load(config)
9890

9991
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
10092
if vanilla_tensors:
@@ -105,24 +97,20 @@ def test_no_vanilla_tensors(self, dataset_mock, config):
10597

10698
@parametrize_dataset_mocks(DATASET_MOCKS)
10799
def test_transformable(self, dataset_mock, config):
108-
dataset_mock.prepare(config)
109-
110-
dataset = datasets.load(dataset_mock.name, **config)
100+
dataset, _ = dataset_mock.load(config)
111101

112102
next(iter(dataset.map(transforms.Identity())))
113103

114104
@pytest.mark.parametrize("only_datapipe", [False, True])
115105
@parametrize_dataset_mocks(DATASET_MOCKS)
116106
def test_traversable(self, dataset_mock, config, only_datapipe):
117-
dataset_mock.prepare(config)
118-
dataset = datasets.load(dataset_mock.name, **config)
107+
dataset, _ = dataset_mock.load(config)
119108

120109
traverse(dataset, only_datapipe=only_datapipe)
121110

122111
@parametrize_dataset_mocks(DATASET_MOCKS)
123112
def test_serializable(self, dataset_mock, config):
124-
dataset_mock.prepare(config)
125-
dataset = datasets.load(dataset_mock.name, **config)
113+
dataset, _ = dataset_mock.load(config)
126114

127115
pickle.dumps(dataset)
128116

@@ -135,8 +123,7 @@ def _collate_fn(self, batch):
135123
@pytest.mark.parametrize("num_workers", [0, 1])
136124
@parametrize_dataset_mocks(DATASET_MOCKS)
137125
def test_data_loader(self, dataset_mock, config, num_workers):
138-
dataset_mock.prepare(config)
139-
dataset = datasets.load(dataset_mock.name, **config)
126+
dataset, _ = dataset_mock.load(config)
140127

141128
dl = DataLoader(
142129
dataset,
@@ -153,17 +140,15 @@ def test_data_loader(self, dataset_mock, config, num_workers):
153140
@parametrize_dataset_mocks(DATASET_MOCKS)
154141
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
155142
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
156-
157-
dataset_mock.prepare(config)
158-
dataset = datasets.load(dataset_mock.name, **config)
143+
dataset, _ = dataset_mock.load(config)
159144

160145
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
161146
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
162147

163148
@parametrize_dataset_mocks(DATASET_MOCKS)
164149
def test_save_load(self, dataset_mock, config):
165-
dataset_mock.prepare(config)
166-
dataset = datasets.load(dataset_mock.name, **config)
150+
dataset, _ = dataset_mock.load(config)
151+
167152
sample = next(iter(dataset))
168153

169154
with io.BytesIO() as buffer:
@@ -173,8 +158,7 @@ def test_save_load(self, dataset_mock, config):
173158

174159
@parametrize_dataset_mocks(DATASET_MOCKS)
175160
def test_infinite_buffer_size(self, dataset_mock, config):
176-
dataset_mock.prepare(config)
177-
dataset = datasets.load(dataset_mock.name, **config)
161+
dataset, _ = dataset_mock.load(config)
178162

179163
for dp in extract_datapipes(dataset):
180164
if hasattr(dp, "buffer_size"):
@@ -184,18 +168,15 @@ def test_infinite_buffer_size(self, dataset_mock, config):
184168

185169
@parametrize_dataset_mocks(DATASET_MOCKS)
186170
def test_has_length(self, dataset_mock, config):
187-
dataset_mock.prepare(config)
188-
dataset = datasets.load(dataset_mock.name, **config)
171+
dataset, _ = dataset_mock.load(config)
189172

190173
assert len(dataset) > 0
191174

192175

193176
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
194177
class TestQMNIST:
195178
def test_extra_label(self, dataset_mock, config):
196-
dataset_mock.prepare(config)
197-
198-
dataset = datasets.load(dataset_mock.name, **config)
179+
dataset, _ = dataset_mock.load(config)
199180

200181
sample = next(iter(dataset))
201182
for key, type in (
@@ -218,9 +199,7 @@ def test_label_matches_path(self, dataset_mock, config):
218199
if config["split"] != "train":
219200
return
220201

221-
dataset_mock.prepare(config)
222-
223-
dataset = datasets.load(dataset_mock.name, **config)
202+
dataset, _ = dataset_mock.load(config)
224203

225204
for sample in dataset:
226205
label_from_path = int(Path(sample["path"]).parent.name)
@@ -230,9 +209,7 @@ def test_label_matches_path(self, dataset_mock, config):
230209
@parametrize_dataset_mocks(DATASET_MOCKS["usps"])
231210
class TestUSPS:
232211
def test_sample_content(self, dataset_mock, config):
233-
dataset_mock.prepare(config)
234-
235-
dataset = datasets.load(dataset_mock.name, **config)
212+
dataset, _ = dataset_mock.load(config)
236213

237214
for sample in dataset:
238215
assert "image" in sample

0 commit comments

Comments
 (0)