Skip to content

PR: Add UCF101 dataset tests #2548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import torch
from common_utils import get_tmp_dir
import pickle
import random
from itertools import cycle
from torchvision.io.video import write_video


@contextlib.contextmanager
Expand Down Expand Up @@ -265,3 +268,47 @@ def voc_root():
f.write('test')

yield tmp_dir


@contextlib.contextmanager
def ucf101_root():
with get_tmp_dir() as tmp_dir:
ucf_dir = os.path.join(tmp_dir, 'UCF-101')
video_dir = os.path.join(ucf_dir, 'video')
annotations = os.path.join(ucf_dir, 'annotations')

os.makedirs(ucf_dir)
os.makedirs(video_dir)
os.makedirs(annotations)

fold_files = []
for split in {'train', 'test'}:
for fold in range(1, 4):
fold_file = '{:s}list{:02d}.txt'.format(split, fold)
fold_files.append(os.path.join(annotations, fold_file))

file_handles = [open(x, 'w') for x in fold_files]
file_iter = cycle(file_handles)

for i in range(0, 2):
current_class = 'class_{0}'.format(i + 1)
class_dir = os.path.join(video_dir, current_class)
os.makedirs(class_dir)
for group in range(0, 3):
for clip in range(0, 4):
# Save sample file
clip_name = 'v_{0}_g{1}_c{2}.avi'.format(
current_class, group, clip)
clip_path = os.path.join(class_dir, clip_name)
length = random.randrange(10, 21)
this_clip = torch.randint(
0, 256, (length * 25, 320, 240, 3), dtype=torch.uint8)
write_video(clip_path, this_clip, 25)
# Add to annotations
ann_file = next(file_iter)
ann_file.write('{0}\n'.format(
os.path.join(current_class, clip_name)))
# Close all file descriptors
for f in file_handles:
f.close()
yield (video_dir, annotations)
28 changes: 27 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torchvision
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root
cityscapes_root, svhn_root, voc_root, ucf101_root
import xml.etree.ElementTree as ET


Expand All @@ -19,6 +19,12 @@
except ImportError:
HAS_SCIPY = False

try:
import av
HAS_PYAV = True
except ImportError:
HAS_PYAV = False


class Tester(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1):
Expand Down Expand Up @@ -254,6 +260,26 @@ def test_voc_parse_xml(self, mock_download_extract):
}]
}})

@unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
def test_ucf101(self):
with ucf101_root() as (root, ann_root):
for split in {True, False}:
for fold in range(1, 4):
for length in {10, 15, 20}:
dataset = torchvision.datasets.UCF101(
root, ann_root, length, fold=fold, train=split)
self.assertGreater(len(dataset), 0)

video, audio, label = dataset[0]
self.assertEqual(video.size(), (length, 320, 240, 3))
self.assertEqual(audio.numel(), 0)
self.assertEqual(label, 0)

video, audio, label = dataset[len(dataset) - 1]
self.assertEqual(video.size(), (length, 320, 240, 3))
self.assertEqual(audio.numel(), 0)
self.assertEqual(label, 1)


if __name__ == '__main__':
unittest.main()