Skip to content

Commit ae63bd0

Browse files
authored
Revert "Revert "Ported places365 dataset's tests to the new test framework (#3705)" (#3718)" (#3731)
This reverts commit d419558.
1 parent 03f94a6 commit ae63bd0

File tree

2 files changed

+91
-201
lines changed

2 files changed

+91
-201
lines changed

test/fakedata_generation.py

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -208,103 +208,3 @@ def _make_annotations_archive(root):
208208
_make_annotations_archive(root_base)
209209

210210
yield root
211-
212-
213-
@contextlib.contextmanager
214-
def places365_root(split="train-standard", small=False):
215-
VARIANTS = {
216-
"train-standard": "standard",
217-
"train-challenge": "challenge",
218-
"val": "standard",
219-
}
220-
# {split: file}
221-
DEVKITS = {
222-
"train-standard": "filelist_places365-standard.tar",
223-
"train-challenge": "filelist_places365-challenge.tar",
224-
"val": "filelist_places365-standard.tar",
225-
}
226-
CATEGORIES = "categories_places365.txt"
227-
# {split: file}
228-
FILE_LISTS = {
229-
"train-standard": "places365_train_standard.txt",
230-
"train-challenge": "places365_train_challenge.txt",
231-
"val": "places365_train_standard.txt",
232-
}
233-
# {(split, small): (archive, folder_default, folder_renamed)}
234-
IMAGES = {
235-
("train-standard", False): ("train_large_places365standard.tar", "data_large", "data_large_standard"),
236-
("train-challenge", False): ("train_large_places365challenge.tar", "data_large", "data_large_challenge"),
237-
("val", False): ("val_large.tar", "val_large", "val_large"),
238-
("train-standard", True): ("train_256_places365standard.tar", "data_256", "data_256_standard"),
239-
("train-challenge", True): ("train_256_places365challenge.tar", "data_256", "data_256_challenge"),
240-
("val", True): ("val_256.tar", "val_256", "val_256"),
241-
}
242-
243-
# (class, idx)
244-
CATEGORIES_CONTENT = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30))
245-
# (file, idx)
246-
FILE_LIST_CONTENT = (
247-
("Places365_val_00000001.png", 0),
248-
*((f"{category}/Places365_train_00000001.png", idx) for category, idx in CATEGORIES_CONTENT),
249-
)
250-
251-
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
252-
return f"{partial}.{attr}"
253-
254-
def make_txt(root, name, seq):
255-
file = os.path.join(root, name)
256-
with open(file, "w") as fh:
257-
for string, idx in seq:
258-
fh.write(f"{string} {idx}\n")
259-
return name, compute_md5(file)
260-
261-
def make_categories_txt(root, name):
262-
return make_txt(root, name, CATEGORIES_CONTENT)
263-
264-
def make_file_list_txt(root, name):
265-
return make_txt(root, name, FILE_LIST_CONTENT)
266-
267-
def make_image(file, size):
268-
os.makedirs(os.path.dirname(file), exist_ok=True)
269-
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)
270-
271-
def make_devkit_archive(stack, root, split):
272-
archive = DEVKITS[split]
273-
files = []
274-
275-
meta = make_categories_txt(root, CATEGORIES)
276-
mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta)
277-
files.append(meta[0])
278-
279-
meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
280-
mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta)
281-
files.extend([item[0] for item in meta.values()])
282-
283-
meta = {VARIANTS[split]: make_tar(root, archive, *files)}
284-
mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta)
285-
286-
def make_images_archive(stack, root, split, small):
287-
archive, folder_default, folder_renamed = IMAGES[(split, small)]
288-
289-
image_size = (256, 256) if small else (512, random.randint(512, 1024))
290-
files, idcs = zip(*FILE_LIST_CONTENT)
291-
images = [file.lstrip("/").replace("/", os.sep) for file in files]
292-
for image in images:
293-
make_image(os.path.join(root, folder_default, image), image_size)
294-
295-
meta = {(split, small): make_tar(root, archive, folder_default)}
296-
mock_class_attribute(stack, mock_target("_IMAGES_META"), meta)
297-
298-
return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]
299-
300-
with contextlib.ExitStack() as stack, get_tmp_dir() as root:
301-
make_devkit_archive(stack, root, split)
302-
class_to_idx = dict(CATEGORIES_CONTENT)
303-
classes = list(class_to_idx.keys())
304-
305-
data = {"class_to_idx": class_to_idx, "classes": classes}
306-
data["imgs"] = make_images_archive(stack, root, split, small)
307-
308-
clean_dir(root, ".tar$")
309-
310-
yield root, data

test/test_datasets.py

Lines changed: 91 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torchvision
1010
from torchvision.datasets import utils
1111
from common_utils import get_tmp_dir
12-
from fakedata_generation import places365_root
1312
import xml.etree.ElementTree as ET
1413
from urllib.request import Request, urlopen
1514
import itertools
@@ -41,106 +40,6 @@
4140
HAS_PYAV = False
4241

4342

44-
class DatasetTestcase(unittest.TestCase):
45-
def generic_classification_dataset_test(self, dataset, num_images=1):
46-
self.assertEqual(len(dataset), num_images)
47-
img, target = dataset[0]
48-
self.assertTrue(isinstance(img, PIL.Image.Image))
49-
self.assertTrue(isinstance(target, int))
50-
51-
def generic_segmentation_dataset_test(self, dataset, num_images=1):
52-
self.assertEqual(len(dataset), num_images)
53-
img, target = dataset[0]
54-
self.assertTrue(isinstance(img, PIL.Image.Image))
55-
self.assertTrue(isinstance(target, PIL.Image.Image))
56-
57-
58-
class Tester(DatasetTestcase):
59-
def test_places365(self):
60-
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
61-
with places365_root(split=split, small=small) as places365:
62-
root, data = places365
63-
64-
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
65-
self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))
66-
67-
def test_places365_transforms(self):
68-
expected_image = "image"
69-
expected_target = "target"
70-
71-
def transform(image):
72-
return expected_image
73-
74-
def target_transform(target):
75-
return expected_target
76-
77-
with places365_root() as places365:
78-
root, data = places365
79-
80-
dataset = torchvision.datasets.Places365(
81-
root, transform=transform, target_transform=target_transform, download=True
82-
)
83-
actual_image, actual_target = dataset[0]
84-
85-
self.assertEqual(actual_image, expected_image)
86-
self.assertEqual(actual_target, expected_target)
87-
88-
def test_places365_devkit_download(self):
89-
for split in ("train-standard", "train-challenge", "val"):
90-
with self.subTest(split=split):
91-
with places365_root(split=split) as places365:
92-
root, data = places365
93-
94-
dataset = torchvision.datasets.Places365(root, split=split, download=True)
95-
96-
with self.subTest("classes"):
97-
self.assertSequenceEqual(dataset.classes, data["classes"])
98-
99-
with self.subTest("class_to_idx"):
100-
self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])
101-
102-
with self.subTest("imgs"):
103-
self.assertSequenceEqual(dataset.imgs, data["imgs"])
104-
105-
def test_places365_devkit_no_download(self):
106-
for split in ("train-standard", "train-challenge", "val"):
107-
with self.subTest(split=split):
108-
with places365_root(split=split) as places365:
109-
root, data = places365
110-
111-
with self.assertRaises(RuntimeError):
112-
torchvision.datasets.Places365(root, split=split, download=False)
113-
114-
def test_places365_images_download(self):
115-
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
116-
with self.subTest(split=split, small=small):
117-
with places365_root(split=split, small=small) as places365:
118-
root, data = places365
119-
120-
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
121-
122-
assert all(os.path.exists(item[0]) for item in dataset.imgs)
123-
124-
def test_places365_images_download_preexisting(self):
125-
split = "train-standard"
126-
small = False
127-
images_dir = "data_large_standard"
128-
129-
with places365_root(split=split, small=small) as places365:
130-
root, data = places365
131-
os.mkdir(os.path.join(root, images_dir))
132-
133-
with self.assertRaises(RuntimeError):
134-
torchvision.datasets.Places365(root, split=split, small=small, download=True)
135-
136-
def test_places365_repr_smoke(self):
137-
with places365_root() as places365:
138-
root, data = places365
139-
140-
dataset = torchvision.datasets.Places365(root, download=True)
141-
self.assertIsInstance(repr(dataset), str)
142-
143-
14443
class STL10TestCase(datasets_utils.ImageDatasetTestCase):
14544
DATASET_CLASS = datasets.STL10
14645
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
@@ -1763,5 +1662,96 @@ def inject_fake_data(self, tmpdir, config):
17631662
return num_examples
17641663

17651664

1665+
class Places365TestCase(datasets_utils.ImageDatasetTestCase):
1666+
DATASET_CLASS = datasets.Places365
1667+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
1668+
split=("train-standard", "train-challenge", "val"),
1669+
small=(False, True),
1670+
)
1671+
_CATEGORIES = "categories_places365.txt"
1672+
# {split: file}
1673+
_FILE_LISTS = {
1674+
"train-standard": "places365_train_standard.txt",
1675+
"train-challenge": "places365_train_challenge.txt",
1676+
"val": "places365_val.txt",
1677+
}
1678+
# {(split, small): folder_name}
1679+
_IMAGES = {
1680+
("train-standard", False): "data_large_standard",
1681+
("train-challenge", False): "data_large_challenge",
1682+
("val", False): "val_large",
1683+
("train-standard", True): "data_256_standard",
1684+
("train-challenge", True): "data_256_challenge",
1685+
("val", True): "val_256",
1686+
}
1687+
# (class, idx)
1688+
_CATEGORIES_CONTENT = (
1689+
("/a/airfield", 0),
1690+
("/a/apartment_building/outdoor", 8),
1691+
("/b/badlands", 30),
1692+
)
1693+
# (file, idx)
1694+
_FILE_LIST_CONTENT = (
1695+
("Places365_val_00000001.png", 0),
1696+
*((f"{category}/Places365_train_00000001.png", idx)
1697+
for category, idx in _CATEGORIES_CONTENT),
1698+
)
1699+
1700+
@staticmethod
1701+
def _make_txt(root, name, seq):
1702+
file = os.path.join(root, name)
1703+
with open(file, "w") as fh:
1704+
for text, idx in seq:
1705+
fh.write(f"{text} {idx}\n")
1706+
1707+
@staticmethod
1708+
def _make_categories_txt(root, name):
1709+
Places365TestCase._make_txt(root, name, Places365TestCase._CATEGORIES_CONTENT)
1710+
1711+
@staticmethod
1712+
def _make_file_list_txt(root, name):
1713+
Places365TestCase._make_txt(root, name, Places365TestCase._FILE_LIST_CONTENT)
1714+
1715+
@staticmethod
1716+
def _make_image(file_name, size):
1717+
os.makedirs(os.path.dirname(file_name), exist_ok=True)
1718+
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file_name)
1719+
1720+
@staticmethod
1721+
def _make_devkit_archive(root, split):
1722+
Places365TestCase._make_categories_txt(root, Places365TestCase._CATEGORIES)
1723+
Places365TestCase._make_file_list_txt(root, Places365TestCase._FILE_LISTS[split])
1724+
1725+
@staticmethod
1726+
def _make_images_archive(root, split, small):
1727+
folder_name = Places365TestCase._IMAGES[(split, small)]
1728+
image_size = (256, 256) if small else (512, random.randint(512, 1024))
1729+
files, idcs = zip(*Places365TestCase._FILE_LIST_CONTENT)
1730+
images = [f.lstrip("/").replace("/", os.sep) for f in files]
1731+
for image in images:
1732+
Places365TestCase._make_image(os.path.join(root, folder_name, image), image_size)
1733+
1734+
return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)]
1735+
1736+
def inject_fake_data(self, tmpdir, config):
1737+
self._make_devkit_archive(tmpdir, config['split'])
1738+
return len(self._make_images_archive(tmpdir, config['split'], config['small']))
1739+
1740+
def test_classes(self):
1741+
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
1742+
with self.create_dataset() as (dataset, _):
1743+
self.assertEqual(dataset.classes, classes)
1744+
1745+
def test_class_to_idx(self):
1746+
class_to_idx = dict(self._CATEGORIES_CONTENT)
1747+
with self.create_dataset() as (dataset, _):
1748+
self.assertEqual(dataset.class_to_idx, class_to_idx)
1749+
1750+
def test_images_download_preexisting(self):
1751+
with self.assertRaises(RuntimeError):
1752+
with self.create_dataset({'download': True}):
1753+
pass
1754+
1755+
17661756
if __name__ == "__main__":
17671757
unittest.main()

0 commit comments

Comments
 (0)