|
27 | 27 | import torch.nn.functional as F
|
28 | 28 | import string
|
29 | 29 | import io
|
| 30 | +import zipfile |
30 | 31 |
|
31 | 32 |
|
32 | 33 | try:
|
@@ -1275,6 +1276,84 @@ def test_not_found_or_corrupted(self):
|
1275 | 1276 | self.skipTest("The data is generated at creation and thus cannot be non-existent or corrupted.")
|
1276 | 1277 |
|
1277 | 1278 |
|
| 1279 | +class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase): |
| 1280 | + DATASET_CLASS = datasets.PhotoTour |
| 1281 | + |
| 1282 | + # The PhotoTour dataset returns examples with different features with respect to the 'train' parameter. Thus, |
| 1283 | + # we overwrite 'FEATURE_TYPES' with a dummy value to satisfy the initial checks of the base class. Furthermore, we |
| 1284 | + # overwrite the 'test_feature_types()' method to select the correct feature types before the test is run. |
| 1285 | + FEATURE_TYPES = () |
| 1286 | + _TRAIN_FEATURE_TYPES = (torch.Tensor,) |
| 1287 | + _TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor) |
| 1288 | + |
| 1289 | + CONFIGS = datasets_utils.combinations_grid(train=(True, False)) |
| 1290 | + |
| 1291 | + _NAME = "liberty" |
| 1292 | + |
| 1293 | + def dataset_args(self, tmpdir, config): |
| 1294 | + return tmpdir, self._NAME |
| 1295 | + |
| 1296 | + def inject_fake_data(self, tmpdir, config): |
| 1297 | + tmpdir = pathlib.Path(tmpdir) |
| 1298 | + |
| 1299 | + # In contrast to the original data, the fake images injected here comprise only a single patch. Thus, |
| 1300 | + # num_images == num_patches. |
| 1301 | + num_patches = 5 |
| 1302 | + |
| 1303 | + image_files = self._create_images(tmpdir, self._NAME, num_patches) |
| 1304 | + point_ids, info_file = self._create_info_file(tmpdir / self._NAME, num_patches) |
| 1305 | + num_matches, matches_file = self._create_matches_file(tmpdir / self._NAME, num_patches, point_ids) |
| 1306 | + |
| 1307 | + self._create_archive(tmpdir, self._NAME, *image_files, info_file, matches_file) |
| 1308 | + |
| 1309 | + return num_patches if config["train"] else num_matches |
| 1310 | + |
| 1311 | + def _create_images(self, root, name, num_images): |
| 1312 | + # The images in the PhotoTour dataset comprises of multiple grayscale patches of 64 x 64 pixels. Thus, the |
| 1313 | + # smallest fake image is 64 x 64 pixels and comprises a single patch. |
| 1314 | + return datasets_utils.create_image_folder( |
| 1315 | + root, name, lambda idx: f"patches{idx:04d}.bmp", num_images, size=(1, 64, 64) |
| 1316 | + ) |
| 1317 | + |
| 1318 | + def _create_info_file(self, root, num_images): |
| 1319 | + point_ids = torch.randint(num_images, size=(num_images,)).tolist() |
| 1320 | + |
| 1321 | + file = root / "info.txt" |
| 1322 | + with open(file, "w") as fh: |
| 1323 | + fh.writelines([f"{point_id} 0\n" for point_id in point_ids]) |
| 1324 | + |
| 1325 | + return point_ids, file |
| 1326 | + |
| 1327 | + def _create_matches_file(self, root, num_patches, point_ids): |
| 1328 | + lines = [ |
| 1329 | + f"{patch_id1} {point_ids[patch_id1]} 0 {patch_id2} {point_ids[patch_id2]} 0\n" |
| 1330 | + for patch_id1, patch_id2 in itertools.combinations(range(num_patches), 2) |
| 1331 | + ] |
| 1332 | + |
| 1333 | + file = root / "m50_100000_100000_0.txt" |
| 1334 | + with open(file, "w") as fh: |
| 1335 | + fh.writelines(lines) |
| 1336 | + |
| 1337 | + return len(lines), file |
| 1338 | + |
| 1339 | + def _create_archive(self, root, name, *files): |
| 1340 | + archive = root / f"{name}.zip" |
| 1341 | + with zipfile.ZipFile(archive, "w") as zip: |
| 1342 | + for file in files: |
| 1343 | + zip.write(file, arcname=file.relative_to(root)) |
| 1344 | + |
| 1345 | + return archive |
| 1346 | + |
| 1347 | + @datasets_utils.test_all_configs |
| 1348 | + def test_feature_types(self, config): |
| 1349 | + feature_types = self.FEATURE_TYPES |
| 1350 | + self.FEATURE_TYPES = self._TRAIN_FEATURE_TYPES if config["train"] else self._TEST_FEATURE_TYPES |
| 1351 | + try: |
| 1352 | + super().test_feature_types.__wrapped__(self, config) |
| 1353 | + finally: |
| 1354 | + self.FEATURE_TYPES = feature_types |
| 1355 | + |
| 1356 | + |
1278 | 1357 | class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
|
1279 | 1358 | DATASET_CLASS = datasets.Flickr8k
|
1280 | 1359 |
|
|
0 commit comments