Skip to content

Commit 83b50dd

Browse files
committed
Merge branch 'master' into flickr-tests
2 parents 7814f6b + f637c63 commit 83b50dd

File tree

3 files changed

+100
-14
lines changed

3 files changed

+100
-14
lines changed

test/datasets_utils.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -436,14 +436,17 @@ def test_feature_types(self, config):
436436
with self.create_dataset(config) as (dataset, _):
437437
example = dataset[0]
438438

439-
actual = len(example)
440-
expected = len(self.FEATURE_TYPES)
441-
self.assertEqual(
442-
actual,
443-
expected,
444-
f"The number of the returned features does not match the the number of elements in in FEATURE_TYPES: "
445-
f"{actual} != {expected}",
446-
)
439+
if len(self.FEATURE_TYPES) > 1:
440+
actual = len(example)
441+
expected = len(self.FEATURE_TYPES)
442+
self.assertEqual(
443+
actual,
444+
expected,
445+
f"The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
446+
f"{actual} != {expected}",
447+
)
448+
else:
449+
example = (example,)
447450

448451
for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
449452
with self.subTest(idx=idx):
@@ -586,7 +589,13 @@ def create_image_file(
586589

587590
image = create_image_or_video_tensor(size)
588591
file = pathlib.Path(root) / name
589-
PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file, **kwargs)
592+
593+
# torch (num_channels x height x width) -> PIL (width x height x num_channels)
594+
image = image.permute(2, 1, 0)
595+
# For grayscale images PIL doesn't use a channel dimension
596+
if image.shape[2] == 1:
597+
image = torch.squeeze(image, 2)
598+
PIL.Image.fromarray(image.numpy()).save(file, **kwargs)
590599
return file
591600

592601

test/test_datasets.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import torch.nn.functional as F
2828
import string
2929
import io
30+
import zipfile
3031

3132

3233
try:
@@ -1275,6 +1276,84 @@ def test_not_found_or_corrupted(self):
12751276
self.skipTest("The data is generated at creation and thus cannot be non-existent or corrupted.")
12761277

12771278

1279+
class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
1280+
DATASET_CLASS = datasets.PhotoTour
1281+
1282+
# The PhotoTour dataset returns examples with different features with respect to the 'train' parameter. Thus,
1283+
# we overwrite 'FEATURE_TYPES' with a dummy value to satisfy the initial checks of the base class. Furthermore, we
1284+
# overwrite the 'test_feature_types()' method to select the correct feature types before the test is run.
1285+
FEATURE_TYPES = ()
1286+
_TRAIN_FEATURE_TYPES = (torch.Tensor,)
1287+
_TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor)
1288+
1289+
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
1290+
1291+
_NAME = "liberty"
1292+
1293+
def dataset_args(self, tmpdir, config):
1294+
return tmpdir, self._NAME
1295+
1296+
def inject_fake_data(self, tmpdir, config):
1297+
tmpdir = pathlib.Path(tmpdir)
1298+
1299+
# In contrast to the original data, the fake images injected here comprise only a single patch. Thus,
1300+
# num_images == num_patches.
1301+
num_patches = 5
1302+
1303+
image_files = self._create_images(tmpdir, self._NAME, num_patches)
1304+
point_ids, info_file = self._create_info_file(tmpdir / self._NAME, num_patches)
1305+
num_matches, matches_file = self._create_matches_file(tmpdir / self._NAME, num_patches, point_ids)
1306+
1307+
self._create_archive(tmpdir, self._NAME, *image_files, info_file, matches_file)
1308+
1309+
return num_patches if config["train"] else num_matches
1310+
1311+
def _create_images(self, root, name, num_images):
1312+
# The images in the PhotoTour dataset comprises of multiple grayscale patches of 64 x 64 pixels. Thus, the
1313+
# smallest fake image is 64 x 64 pixels and comprises a single patch.
1314+
return datasets_utils.create_image_folder(
1315+
root, name, lambda idx: f"patches{idx:04d}.bmp", num_images, size=(1, 64, 64)
1316+
)
1317+
1318+
def _create_info_file(self, root, num_images):
1319+
point_ids = torch.randint(num_images, size=(num_images,)).tolist()
1320+
1321+
file = root / "info.txt"
1322+
with open(file, "w") as fh:
1323+
fh.writelines([f"{point_id} 0\n" for point_id in point_ids])
1324+
1325+
return point_ids, file
1326+
1327+
def _create_matches_file(self, root, num_patches, point_ids):
1328+
lines = [
1329+
f"{patch_id1} {point_ids[patch_id1]} 0 {patch_id2} {point_ids[patch_id2]} 0\n"
1330+
for patch_id1, patch_id2 in itertools.combinations(range(num_patches), 2)
1331+
]
1332+
1333+
file = root / "m50_100000_100000_0.txt"
1334+
with open(file, "w") as fh:
1335+
fh.writelines(lines)
1336+
1337+
return len(lines), file
1338+
1339+
def _create_archive(self, root, name, *files):
1340+
archive = root / f"{name}.zip"
1341+
with zipfile.ZipFile(archive, "w") as zip:
1342+
for file in files:
1343+
zip.write(file, arcname=file.relative_to(root))
1344+
1345+
return archive
1346+
1347+
@datasets_utils.test_all_configs
1348+
def test_feature_types(self, config):
1349+
feature_types = self.FEATURE_TYPES
1350+
self.FEATURE_TYPES = self._TRAIN_FEATURE_TYPES if config["train"] else self._TEST_FEATURE_TYPES
1351+
try:
1352+
super().test_feature_types.__wrapped__(self, config)
1353+
finally:
1354+
self.FEATURE_TYPES = feature_types
1355+
1356+
12781357
class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
12791358
DATASET_CLASS = datasets.Flickr8k
12801359

torchvision/datasets/phototour.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,7 @@ def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.T
121121
return data1, data2, m[2]
122122

123123
def __len__(self) -> int:
124-
if self.train:
125-
return self.lens[self.name]
126-
return len(self.matches)
124+
return len(self.data if self.train else self.matches)
127125

128126
def _check_datafile_exists(self) -> bool:
129127
return os.path.exists(self.data_file)
@@ -194,8 +192,8 @@ def find_files(_data_dir: str, _image_ext: str) -> List[str]:
194192

195193
for fpath in list_files:
196194
img = Image.open(fpath)
197-
for y in range(0, 1024, 64):
198-
for x in range(0, 1024, 64):
195+
for y in range(0, img.height, 64):
196+
for x in range(0, img.width, 64):
199197
patch = img.crop((x, y, x + 64, y + 64))
200198
patches.append(PIL2array(patch))
201199
return torch.ByteTensor(np.array(patches[:n]))

0 commit comments

Comments
 (0)