diff --git a/test/fakedata_generation.py b/test/fakedata_generation.py index ab4333b74be..cc984d3e693 100644 --- a/test/fakedata_generation.py +++ b/test/fakedata_generation.py @@ -7,6 +7,9 @@ import torch from common_utils import get_tmp_dir import pickle +import random +from itertools import cycle +from torchvision.io.video import write_video @contextlib.contextmanager @@ -265,3 +268,47 @@ def voc_root(): f.write('test') yield tmp_dir + + +@contextlib.contextmanager +def ucf101_root(): + with get_tmp_dir() as tmp_dir: + ucf_dir = os.path.join(tmp_dir, 'UCF-101') + video_dir = os.path.join(ucf_dir, 'video') + annotations = os.path.join(ucf_dir, 'annotations') + + os.makedirs(ucf_dir) + os.makedirs(video_dir) + os.makedirs(annotations) + + fold_files = [] + for split in {'train', 'test'}: + for fold in range(1, 4): + fold_file = '{:s}list{:02d}.txt'.format(split, fold) + fold_files.append(os.path.join(annotations, fold_file)) + + file_handles = [open(x, 'w') for x in fold_files] + file_iter = cycle(file_handles) + + for i in range(0, 2): + current_class = 'class_{0}'.format(i + 1) + class_dir = os.path.join(video_dir, current_class) + os.makedirs(class_dir) + for group in range(0, 3): + for clip in range(0, 4): + # Save sample file + clip_name = 'v_{0}_g{1}_c{2}.avi'.format( + current_class, group, clip) + clip_path = os.path.join(class_dir, clip_name) + length = random.randrange(10, 21) + this_clip = torch.randint( + 0, 256, (length * 25, 320, 240, 3), dtype=torch.uint8) + write_video(clip_path, this_clip, 25) + # Add to annotations + ann_file = next(file_iter) + ann_file.write('{0}\n'.format( + os.path.join(current_class, clip_name))) + # Close all file descriptors + for f in file_handles: + f.close() + yield (video_dir, annotations) diff --git a/test/test_datasets.py b/test/test_datasets.py index 60445b1d98a..6a0dee4c4c8 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -9,7 +9,7 @@ import torchvision from common_utils import get_tmp_dir from fakedata_generation import mnist_root, cifar_root, imagenet_root, \ - cityscapes_root, svhn_root, voc_root + cityscapes_root, svhn_root, voc_root, ucf101_root import xml.etree.ElementTree as ET @@ -19,6 +19,12 @@ except ImportError: HAS_SCIPY = False +try: + import av + HAS_PYAV = True +except ImportError: + HAS_PYAV = False + class Tester(unittest.TestCase): def generic_classification_dataset_test(self, dataset, num_images=1): @@ -254,6 +260,26 @@ def test_voc_parse_xml(self, mock_download_extract): }] }}) + @unittest.skipIf(not HAS_PYAV, "PyAV unavailable") + def test_ucf101(self): + with ucf101_root() as (root, ann_root): + for split in {True, False}: + for fold in range(1, 4): + for length in {10, 15, 20}: + dataset = torchvision.datasets.UCF101( + root, ann_root, length, fold=fold, train=split) + self.assertGreater(len(dataset), 0) + + video, audio, label = dataset[0] + self.assertEqual(video.size(), (length, 320, 240, 3)) + self.assertEqual(audio.numel(), 0) + self.assertEqual(label, 0) + + video, audio, label = dataset[len(dataset) - 1] + self.assertEqual(video.size(), (length, 320, 240, 3)) + self.assertEqual(audio.numel(), 0) + self.assertEqual(label, 1) + if __name__ == '__main__': unittest.main()