Skip to content

Commit 0139808

Browse files
pmeierfmassa
andauthored
add tests for Flickr(8|30)k datasets (#3489)
* add tests for Flickr8k dataset * add tests for FLickr30k dataset * lint Co-authored-by: Francisco Massa <[email protected]>
1 parent 8fe439e commit 0139808

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

test/test_datasets.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,5 +1354,92 @@ def test_feature_types(self, config):
13541354
self.FEATURE_TYPES = feature_types
13551355

13561356

1357+
class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
1358+
DATASET_CLASS = datasets.Flickr8k
1359+
1360+
FEATURE_TYPES = (PIL.Image.Image, list)
1361+
1362+
_IMAGES_FOLDER = "images"
1363+
_ANNOTATIONS_FILE = "captions.html"
1364+
1365+
def dataset_args(self, tmpdir, config):
1366+
tmpdir = pathlib.Path(tmpdir)
1367+
root = tmpdir / self._IMAGES_FOLDER
1368+
ann_file = tmpdir / self._ANNOTATIONS_FILE
1369+
return str(root), str(ann_file)
1370+
1371+
def inject_fake_data(self, tmpdir, config):
1372+
num_images = 3
1373+
num_captions_per_image = 3
1374+
1375+
tmpdir = pathlib.Path(tmpdir)
1376+
1377+
images = self._create_images(tmpdir, self._IMAGES_FOLDER, num_images)
1378+
self._create_annotations_file(tmpdir, self._ANNOTATIONS_FILE, images, num_captions_per_image)
1379+
1380+
return dict(num_examples=num_images, captions=self._create_captions(num_captions_per_image))
1381+
1382+
def _create_images(self, root, name, num_images):
1383+
return datasets_utils.create_image_folder(root, name, self._image_file_name, num_images)
1384+
1385+
def _image_file_name(self, idx):
1386+
id = datasets_utils.create_random_string(10, string.digits)
1387+
checksum = datasets_utils.create_random_string(10, string.digits, string.ascii_lowercase[:6])
1388+
size = datasets_utils.create_random_string(1, "qwcko")
1389+
return f"{id}_{checksum}_{size}.jpg"
1390+
1391+
def _create_annotations_file(self, root, name, images, num_captions_per_image):
1392+
with open(root / name, "w") as fh:
1393+
fh.write("<table>")
1394+
for image in (None, *images):
1395+
self._add_image(fh, image, num_captions_per_image)
1396+
fh.write("</table>")
1397+
1398+
def _add_image(self, fh, image, num_captions_per_image):
1399+
fh.write("<tr>")
1400+
self._add_image_header(fh, image)
1401+
fh.write("</tr><tr><td><ul>")
1402+
self._add_image_captions(fh, num_captions_per_image)
1403+
fh.write("</ul></td></tr>")
1404+
1405+
def _add_image_header(self, fh, image=None):
1406+
if image:
1407+
url = f"http://www.flickr.com/photos/user/{image.name.split('_')[0]}/"
1408+
data = f'<a href="{url}">{url}</a>'
1409+
else:
1410+
data = "Image Not Found"
1411+
fh.write(f"<td>{data}</td>")
1412+
1413+
def _add_image_captions(self, fh, num_captions_per_image):
1414+
for caption in self._create_captions(num_captions_per_image):
1415+
fh.write(f"<li>{caption}")
1416+
1417+
def _create_captions(self, num_captions_per_image):
1418+
return [str(idx) for idx in range(num_captions_per_image)]
1419+
1420+
def test_captions(self):
1421+
with self.create_dataset() as (dataset, info):
1422+
_, captions = dataset[0]
1423+
self.assertSequenceEqual(captions, info["captions"])
1424+
1425+
1426+
class Flickr30kTestCase(Flickr8kTestCase):
1427+
DATASET_CLASS = datasets.Flickr30k
1428+
1429+
FEATURE_TYPES = (PIL.Image.Image, list)
1430+
1431+
_ANNOTATIONS_FILE = "captions.token"
1432+
1433+
def _image_file_name(self, idx):
1434+
return f"{idx}.jpg"
1435+
1436+
def _create_annotations_file(self, root, name, images, num_captions_per_image):
1437+
with open(root / name, "w") as fh:
1438+
for image, (idx, caption) in itertools.product(
1439+
images, enumerate(self._create_captions(num_captions_per_image))
1440+
):
1441+
fh.write(f"{image.name}#{idx}\t{caption}\n")
1442+
1443+
13571444
if __name__ == "__main__":
13581445
unittest.main()

0 commit comments

Comments
 (0)