@@ -1354,5 +1354,92 @@ def test_feature_types(self, config):
1354
1354
self .FEATURE_TYPES = feature_types
1355
1355
1356
1356
1357
+ class Flickr8kTestCase (datasets_utils .ImageDatasetTestCase ):
1358
+ DATASET_CLASS = datasets .Flickr8k
1359
+
1360
+ FEATURE_TYPES = (PIL .Image .Image , list )
1361
+
1362
+ _IMAGES_FOLDER = "images"
1363
+ _ANNOTATIONS_FILE = "captions.html"
1364
+
1365
+ def dataset_args (self , tmpdir , config ):
1366
+ tmpdir = pathlib .Path (tmpdir )
1367
+ root = tmpdir / self ._IMAGES_FOLDER
1368
+ ann_file = tmpdir / self ._ANNOTATIONS_FILE
1369
+ return str (root ), str (ann_file )
1370
+
1371
+ def inject_fake_data (self , tmpdir , config ):
1372
+ num_images = 3
1373
+ num_captions_per_image = 3
1374
+
1375
+ tmpdir = pathlib .Path (tmpdir )
1376
+
1377
+ images = self ._create_images (tmpdir , self ._IMAGES_FOLDER , num_images )
1378
+ self ._create_annotations_file (tmpdir , self ._ANNOTATIONS_FILE , images , num_captions_per_image )
1379
+
1380
+ return dict (num_examples = num_images , captions = self ._create_captions (num_captions_per_image ))
1381
+
1382
+ def _create_images (self , root , name , num_images ):
1383
+ return datasets_utils .create_image_folder (root , name , self ._image_file_name , num_images )
1384
+
1385
+ def _image_file_name (self , idx ):
1386
+ id = datasets_utils .create_random_string (10 , string .digits )
1387
+ checksum = datasets_utils .create_random_string (10 , string .digits , string .ascii_lowercase [:6 ])
1388
+ size = datasets_utils .create_random_string (1 , "qwcko" )
1389
+ return f"{ id } _{ checksum } _{ size } .jpg"
1390
+
1391
+ def _create_annotations_file (self , root , name , images , num_captions_per_image ):
1392
+ with open (root / name , "w" ) as fh :
1393
+ fh .write ("<table>" )
1394
+ for image in (None , * images ):
1395
+ self ._add_image (fh , image , num_captions_per_image )
1396
+ fh .write ("</table>" )
1397
+
1398
+ def _add_image (self , fh , image , num_captions_per_image ):
1399
+ fh .write ("<tr>" )
1400
+ self ._add_image_header (fh , image )
1401
+ fh .write ("</tr><tr><td><ul>" )
1402
+ self ._add_image_captions (fh , num_captions_per_image )
1403
+ fh .write ("</ul></td></tr>" )
1404
+
1405
+ def _add_image_header (self , fh , image = None ):
1406
+ if image :
1407
+ url = f"http://www.flickr.com/photos/user/{ image .name .split ('_' )[0 ]} /"
1408
+ data = f'<a href="{ url } ">{ url } </a>'
1409
+ else :
1410
+ data = "Image Not Found"
1411
+ fh .write (f"<td>{ data } </td>" )
1412
+
1413
+ def _add_image_captions (self , fh , num_captions_per_image ):
1414
+ for caption in self ._create_captions (num_captions_per_image ):
1415
+ fh .write (f"<li>{ caption } " )
1416
+
1417
+ def _create_captions (self , num_captions_per_image ):
1418
+ return [str (idx ) for idx in range (num_captions_per_image )]
1419
+
1420
+ def test_captions (self ):
1421
+ with self .create_dataset () as (dataset , info ):
1422
+ _ , captions = dataset [0 ]
1423
+ self .assertSequenceEqual (captions , info ["captions" ])
1424
+
1425
+
1426
+ class Flickr30kTestCase (Flickr8kTestCase ):
1427
+ DATASET_CLASS = datasets .Flickr30k
1428
+
1429
+ FEATURE_TYPES = (PIL .Image .Image , list )
1430
+
1431
+ _ANNOTATIONS_FILE = "captions.token"
1432
+
1433
+ def _image_file_name (self , idx ):
1434
+ return f"{ idx } .jpg"
1435
+
1436
+ def _create_annotations_file (self , root , name , images , num_captions_per_image ):
1437
+ with open (root / name , "w" ) as fh :
1438
+ for image , (idx , caption ) in itertools .product (
1439
+ images , enumerate (self ._create_captions (num_captions_per_image ))
1440
+ ):
1441
+ fh .write (f"{ image .name } #{ idx } \t { caption } \n " )
1442
+
1443
+
1357
1444
if __name__ == "__main__" :
1358
1445
unittest .main ()
0 commit comments