|
9 | 9 | import torchvision
|
10 | 10 | from torchvision.datasets import utils
|
11 | 11 | from common_utils import get_tmp_dir
|
12 |
| -from fakedata_generation import places365_root |
13 | 12 | import xml.etree.ElementTree as ET
|
14 | 13 | from urllib.request import Request, urlopen
|
15 | 14 | import itertools
|
|
41 | 40 | HAS_PYAV = False
|
42 | 41 |
|
43 | 42 |
|
44 |
| -class DatasetTestcase(unittest.TestCase): |
45 |
| - def generic_classification_dataset_test(self, dataset, num_images=1): |
46 |
| - self.assertEqual(len(dataset), num_images) |
47 |
| - img, target = dataset[0] |
48 |
| - self.assertTrue(isinstance(img, PIL.Image.Image)) |
49 |
| - self.assertTrue(isinstance(target, int)) |
50 |
| - |
51 |
| - def generic_segmentation_dataset_test(self, dataset, num_images=1): |
52 |
| - self.assertEqual(len(dataset), num_images) |
53 |
| - img, target = dataset[0] |
54 |
| - self.assertTrue(isinstance(img, PIL.Image.Image)) |
55 |
| - self.assertTrue(isinstance(target, PIL.Image.Image)) |
56 |
| - |
57 |
| - |
58 |
| -class Tester(DatasetTestcase): |
59 |
| - def test_places365(self): |
60 |
| - for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): |
61 |
| - with places365_root(split=split, small=small) as places365: |
62 |
| - root, data = places365 |
63 |
| - |
64 |
| - dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True) |
65 |
| - self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"])) |
66 |
| - |
67 |
| - def test_places365_transforms(self): |
68 |
| - expected_image = "image" |
69 |
| - expected_target = "target" |
70 |
| - |
71 |
| - def transform(image): |
72 |
| - return expected_image |
73 |
| - |
74 |
| - def target_transform(target): |
75 |
| - return expected_target |
76 |
| - |
77 |
| - with places365_root() as places365: |
78 |
| - root, data = places365 |
79 |
| - |
80 |
| - dataset = torchvision.datasets.Places365( |
81 |
| - root, transform=transform, target_transform=target_transform, download=True |
82 |
| - ) |
83 |
| - actual_image, actual_target = dataset[0] |
84 |
| - |
85 |
| - self.assertEqual(actual_image, expected_image) |
86 |
| - self.assertEqual(actual_target, expected_target) |
87 |
| - |
88 |
| - def test_places365_devkit_download(self): |
89 |
| - for split in ("train-standard", "train-challenge", "val"): |
90 |
| - with self.subTest(split=split): |
91 |
| - with places365_root(split=split) as places365: |
92 |
| - root, data = places365 |
93 |
| - |
94 |
| - dataset = torchvision.datasets.Places365(root, split=split, download=True) |
95 |
| - |
96 |
| - with self.subTest("classes"): |
97 |
| - self.assertSequenceEqual(dataset.classes, data["classes"]) |
98 |
| - |
99 |
| - with self.subTest("class_to_idx"): |
100 |
| - self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"]) |
101 |
| - |
102 |
| - with self.subTest("imgs"): |
103 |
| - self.assertSequenceEqual(dataset.imgs, data["imgs"]) |
104 |
| - |
105 |
| - def test_places365_devkit_no_download(self): |
106 |
| - for split in ("train-standard", "train-challenge", "val"): |
107 |
| - with self.subTest(split=split): |
108 |
| - with places365_root(split=split) as places365: |
109 |
| - root, data = places365 |
110 |
| - |
111 |
| - with self.assertRaises(RuntimeError): |
112 |
| - torchvision.datasets.Places365(root, split=split, download=False) |
113 |
| - |
114 |
| - def test_places365_images_download(self): |
115 |
| - for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): |
116 |
| - with self.subTest(split=split, small=small): |
117 |
| - with places365_root(split=split, small=small) as places365: |
118 |
| - root, data = places365 |
119 |
| - |
120 |
| - dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True) |
121 |
| - |
122 |
| - assert all(os.path.exists(item[0]) for item in dataset.imgs) |
123 |
| - |
124 |
| - def test_places365_images_download_preexisting(self): |
125 |
| - split = "train-standard" |
126 |
| - small = False |
127 |
| - images_dir = "data_large_standard" |
128 |
| - |
129 |
| - with places365_root(split=split, small=small) as places365: |
130 |
| - root, data = places365 |
131 |
| - os.mkdir(os.path.join(root, images_dir)) |
132 |
| - |
133 |
| - with self.assertRaises(RuntimeError): |
134 |
| - torchvision.datasets.Places365(root, split=split, small=small, download=True) |
135 |
| - |
136 |
| - def test_places365_repr_smoke(self): |
137 |
| - with places365_root() as places365: |
138 |
| - root, data = places365 |
139 |
| - |
140 |
| - dataset = torchvision.datasets.Places365(root, download=True) |
141 |
| - self.assertIsInstance(repr(dataset), str) |
142 |
| - |
143 |
| - |
144 | 43 | class STL10TestCase(datasets_utils.ImageDatasetTestCase):
|
145 | 44 | DATASET_CLASS = datasets.STL10
|
146 | 45 | ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
|
@@ -1763,5 +1662,96 @@ def inject_fake_data(self, tmpdir, config):
|
1763 | 1662 | return num_examples
|
1764 | 1663 |
|
1765 | 1664 |
|
| 1665 | +class Places365TestCase(datasets_utils.ImageDatasetTestCase): |
| 1666 | + DATASET_CLASS = datasets.Places365 |
| 1667 | + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( |
| 1668 | + split=("train-standard", "train-challenge", "val"), |
| 1669 | + small=(False, True), |
| 1670 | + ) |
| 1671 | + _CATEGORIES = "categories_places365.txt" |
| 1672 | + # {split: file} |
| 1673 | + _FILE_LISTS = { |
| 1674 | + "train-standard": "places365_train_standard.txt", |
| 1675 | + "train-challenge": "places365_train_challenge.txt", |
| 1676 | + "val": "places365_val.txt", |
| 1677 | + } |
| 1678 | + # {(split, small): folder_name} |
| 1679 | + _IMAGES = { |
| 1680 | + ("train-standard", False): "data_large_standard", |
| 1681 | + ("train-challenge", False): "data_large_challenge", |
| 1682 | + ("val", False): "val_large", |
| 1683 | + ("train-standard", True): "data_256_standard", |
| 1684 | + ("train-challenge", True): "data_256_challenge", |
| 1685 | + ("val", True): "val_256", |
| 1686 | + } |
| 1687 | + # (class, idx) |
| 1688 | + _CATEGORIES_CONTENT = ( |
| 1689 | + ("/a/airfield", 0), |
| 1690 | + ("/a/apartment_building/outdoor", 8), |
| 1691 | + ("/b/badlands", 30), |
| 1692 | + ) |
| 1693 | + # (file, idx) |
| 1694 | + _FILE_LIST_CONTENT = ( |
| 1695 | + ("Places365_val_00000001.png", 0), |
| 1696 | + *((f"{category}/Places365_train_00000001.png", idx) |
| 1697 | + for category, idx in _CATEGORIES_CONTENT), |
| 1698 | + ) |
| 1699 | + |
| 1700 | + @staticmethod |
| 1701 | + def _make_txt(root, name, seq): |
| 1702 | + file = os.path.join(root, name) |
| 1703 | + with open(file, "w") as fh: |
| 1704 | + for text, idx in seq: |
| 1705 | + fh.write(f"{text} {idx}\n") |
| 1706 | + |
| 1707 | + @staticmethod |
| 1708 | + def _make_categories_txt(root, name): |
| 1709 | + Places365TestCase._make_txt(root, name, Places365TestCase._CATEGORIES_CONTENT) |
| 1710 | + |
| 1711 | + @staticmethod |
| 1712 | + def _make_file_list_txt(root, name): |
| 1713 | + Places365TestCase._make_txt(root, name, Places365TestCase._FILE_LIST_CONTENT) |
| 1714 | + |
| 1715 | + @staticmethod |
| 1716 | + def _make_image(file_name, size): |
| 1717 | + os.makedirs(os.path.dirname(file_name), exist_ok=True) |
| 1718 | + PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file_name) |
| 1719 | + |
| 1720 | + @staticmethod |
| 1721 | + def _make_devkit_archive(root, split): |
| 1722 | + Places365TestCase._make_categories_txt(root, Places365TestCase._CATEGORIES) |
| 1723 | + Places365TestCase._make_file_list_txt(root, Places365TestCase._FILE_LISTS[split]) |
| 1724 | + |
| 1725 | + @staticmethod |
| 1726 | + def _make_images_archive(root, split, small): |
| 1727 | + folder_name = Places365TestCase._IMAGES[(split, small)] |
| 1728 | + image_size = (256, 256) if small else (512, random.randint(512, 1024)) |
| 1729 | + files, idcs = zip(*Places365TestCase._FILE_LIST_CONTENT) |
| 1730 | + images = [f.lstrip("/").replace("/", os.sep) for f in files] |
| 1731 | + for image in images: |
| 1732 | + Places365TestCase._make_image(os.path.join(root, folder_name, image), image_size) |
| 1733 | + |
| 1734 | + return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)] |
| 1735 | + |
| 1736 | + def inject_fake_data(self, tmpdir, config): |
| 1737 | + self._make_devkit_archive(tmpdir, config['split']) |
| 1738 | + return len(self._make_images_archive(tmpdir, config['split'], config['small'])) |
| 1739 | + |
| 1740 | + def test_classes(self): |
| 1741 | + classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT)) |
| 1742 | + with self.create_dataset() as (dataset, _): |
| 1743 | + self.assertEqual(dataset.classes, classes) |
| 1744 | + |
| 1745 | + def test_class_to_idx(self): |
| 1746 | + class_to_idx = dict(self._CATEGORIES_CONTENT) |
| 1747 | + with self.create_dataset() as (dataset, _): |
| 1748 | + self.assertEqual(dataset.class_to_idx, class_to_idx) |
| 1749 | + |
| 1750 | + def test_images_download_preexisting(self): |
| 1751 | + with self.assertRaises(RuntimeError): |
| 1752 | + with self.create_dataset({'download': True}): |
| 1753 | + pass |
| 1754 | + |
| 1755 | + |
1766 | 1756 | if __name__ == "__main__":
|
1767 | 1757 | unittest.main()
|
0 commit comments