Skip to content

Commit a7b4bfd

Browse files
pmeierfmassa
andauthored
Add tests for UCF101 (#3411)
* enable default frames per clips for video test cases * add tests for UCF101 * remove old tests as well as fake data generation * better explain frames_per_clip overriding * lint Co-authored-by: Francisco Massa <[email protected]>
1 parent b7f3c81 commit a7b4bfd

File tree

3 files changed

+85
-70
lines changed

3 files changed

+85
-70
lines changed

test/datasets_utils.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,14 +496,44 @@ def new(fp, *args, **kwargs):
496496
class VideoDatasetTestCase(DatasetTestCase):
497497
"""Abstract base class for video dataset testcases.
498498
499-
- Overwrites the FEATURE_TYPES class attribute to expect two :class:`torch.Tensor` s for the video and audio as
499+
- Overwrites the 'FEATURE_TYPES' class attribute to expect two :class:`torch.Tensor` s for the video and audio as
500500
well as an integer label.
501-
- Overwrites the REQUIRED_PACKAGES class attribute to require PyAV (``av``).
501+
- Overwrites the 'REQUIRED_PACKAGES' class attribute to require PyAV (``av``).
502+
- Adds the 'DEFAULT_FRAMES_PER_CLIP' class attribute. If no 'frames_per_clip' is provided by 'inject_fake_data()'
503+
and it is the last parameter without a default value in the dataset constructor, the value of the
504+
'DEFAULT_FRAMES_PER_CLIP' class attribute is appended to the output.
502505
"""
503506

504507
FEATURE_TYPES = (torch.Tensor, torch.Tensor, int)
505508
REQUIRED_PACKAGES = ("av",)
506509

510+
DEFAULT_FRAMES_PER_CLIP = 1
511+
512+
def __init__(self, *args, **kwargs):
513+
super().__init__(*args, **kwargs)
514+
self.inject_fake_data = self._set_default_frames_per_clip(self.inject_fake_data)
515+
516+
def _set_default_frames_per_clip(self, inject_fake_data):
517+
argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
518+
args_without_default = argspec.args[1:-len(argspec.defaults)]
519+
frames_per_clip_last = args_without_default[-1] == "frames_per_clip"
520+
only_root_and_frames_per_clip = (len(args_without_default) == 2) and frames_per_clip_last
521+
522+
@functools.wraps(inject_fake_data)
523+
def wrapper(tmpdir, config):
524+
output = inject_fake_data(tmpdir, config)
525+
if isinstance(output, collections.abc.Sequence) and len(output) == 2:
526+
args, info = output
527+
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
528+
args = (*args, self.DEFAULT_FRAMES_PER_CLIP)
529+
return args, info
530+
elif isinstance(output, (int, dict)) and only_root_and_frames_per_clip:
531+
return (tmpdir, self.DEFAULT_FRAMES_PER_CLIP)
532+
else:
533+
return output
534+
535+
return wrapper
536+
507537

508538
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
509539
r"""Create a random uint8 tensor.

test/fakedata_generation.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -369,50 +369,6 @@ def _make_mat(file):
369369
yield root
370370

371371

372-
@contextlib.contextmanager
373-
def ucf101_root():
374-
with get_tmp_dir() as tmp_dir:
375-
ucf_dir = os.path.join(tmp_dir, 'UCF-101')
376-
video_dir = os.path.join(ucf_dir, 'video')
377-
annotations = os.path.join(ucf_dir, 'annotations')
378-
379-
os.makedirs(ucf_dir)
380-
os.makedirs(video_dir)
381-
os.makedirs(annotations)
382-
383-
fold_files = []
384-
for split in {'train', 'test'}:
385-
for fold in range(1, 4):
386-
fold_file = '{:s}list{:02d}.txt'.format(split, fold)
387-
fold_files.append(os.path.join(annotations, fold_file))
388-
389-
file_handles = [open(x, 'w') for x in fold_files]
390-
file_iter = cycle(file_handles)
391-
392-
for i in range(0, 2):
393-
current_class = 'class_{0}'.format(i + 1)
394-
class_dir = os.path.join(video_dir, current_class)
395-
os.makedirs(class_dir)
396-
for group in range(0, 3):
397-
for clip in range(0, 4):
398-
# Save sample file
399-
clip_name = 'v_{0}_g{1}_c{2}.avi'.format(
400-
current_class, group, clip)
401-
clip_path = os.path.join(class_dir, clip_name)
402-
length = random.randrange(10, 21)
403-
this_clip = torch.randint(
404-
0, 256, (length * 25, 320, 240, 3), dtype=torch.uint8)
405-
write_video(clip_path, this_clip, 25)
406-
# Add to annotations
407-
ann_file = next(file_iter)
408-
ann_file.write('{0}\n'.format(
409-
os.path.join(current_class, clip_name)))
410-
# Close all file descriptors
411-
for f in file_handles:
412-
f.close()
413-
yield (video_dir, annotations)
414-
415-
416372
@contextlib.contextmanager
417373
def places365_root(split="train-standard", small=False):
418374
VARIANTS = {

test/test_datasets.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.datasets import utils
1212
from common_utils import get_tmp_dir
1313
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
14-
cityscapes_root, svhn_root, ucf101_root, places365_root, widerface_root, stl10_root
14+
cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
1515
import xml.etree.ElementTree as ET
1616
from urllib.request import Request, urlopen
1717
import itertools
@@ -22,6 +22,7 @@
2222
import torch
2323
import shutil
2424
import json
25+
import random
2526

2627

2728
try:
@@ -261,29 +262,6 @@ def test_svhn(self, mock_check):
261262
dataset = torchvision.datasets.SVHN(root, split="extra")
262263
self.generic_classification_dataset_test(dataset, num_images=2)
263264

264-
@unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
265-
def test_ucf101(self):
266-
cached_meta_data = None
267-
with ucf101_root() as (root, ann_root):
268-
for split in {True, False}:
269-
for fold in range(1, 4):
270-
for length in {10, 15, 20}:
271-
dataset = torchvision.datasets.UCF101(root, ann_root, length, fold=fold, train=split,
272-
num_workers=2, _precomputed_metadata=cached_meta_data)
273-
if cached_meta_data is None:
274-
cached_meta_data = dataset.metadata
275-
self.assertGreater(len(dataset), 0)
276-
277-
video, audio, label = dataset[0]
278-
self.assertEqual(video.size(), (length, 320, 240, 3))
279-
self.assertEqual(audio.numel(), 0)
280-
self.assertEqual(label, 0)
281-
282-
video, audio, label = dataset[len(dataset) - 1]
283-
self.assertEqual(video.size(), (length, 320, 240, 3))
284-
self.assertEqual(audio.numel(), 0)
285-
self.assertEqual(label, 1)
286-
287265
def test_places365(self):
288266
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
289267
with places365_root(split=split, small=small) as places365:
@@ -905,5 +883,56 @@ def test_captions(self):
905883
self.assertEqual(tuple(captions), tuple(info["captions"]))
906884

907885

886+
class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
887+
DATASET_CLASS = datasets.UCF101
888+
889+
CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
890+
891+
def inject_fake_data(self, tmpdir, config):
892+
tmpdir = pathlib.Path(tmpdir)
893+
894+
video_folder = tmpdir / "videos"
895+
os.makedirs(video_folder)
896+
video_files = self._create_videos(video_folder)
897+
898+
annotations_folder = annotations_folder = tmpdir / "annotations"
899+
os.makedirs(annotations_folder)
900+
num_examples = self._create_annotation_files(annotations_folder, video_files, config["fold"], config["train"])
901+
902+
return (str(video_folder), str(annotations_folder)), num_examples
903+
904+
def _create_videos(self, root, num_examples_per_class=3):
905+
def file_name_fn(cls, idx, clips_per_group=2):
906+
return f"v_{cls}_g{(idx // clips_per_group) + 1:02d}_c{(idx % clips_per_group) + 1:02d}.avi"
907+
908+
video_files = [
909+
datasets_utils.create_video_folder(root, cls, lambda idx: file_name_fn(cls, idx), num_examples_per_class)
910+
for cls in ("ApplyEyeMakeup", "YoYo")
911+
]
912+
return [path.relative_to(root) for path in itertools.chain(*video_files)]
913+
914+
def _create_annotation_files(self, root, video_files, fold, train):
915+
current_videos = random.sample(video_files, random.randrange(1, len(video_files) - 1))
916+
current_annotation = self._annotation_file_name(fold, train)
917+
self._create_annotation_file(root, current_annotation, current_videos)
918+
919+
other_videos = set(video_files) - set(current_videos)
920+
other_annotations = [
921+
self._annotation_file_name(fold, train) for fold, train in itertools.product((1, 2, 3), (True, False))
922+
]
923+
other_annotations.remove(current_annotation)
924+
for name in other_annotations:
925+
self._create_annotation_file(root, name, other_videos)
926+
927+
return len(current_videos)
928+
929+
def _annotation_file_name(self, fold, train):
930+
return f"{'train' if train else 'test'}list{fold:02d}.txt"
931+
932+
def _create_annotation_file(self, root, name, video_files):
933+
with open(pathlib.Path(root) / name, "w") as fh:
934+
fh.writelines(f"{file}\n" for file in sorted(video_files))
935+
936+
908937
if __name__ == "__main__":
909938
unittest.main()

0 commit comments

Comments
 (0)