Skip to content

Commit 23295fb

Browse files
authored
PR: Add UCF101 dataset tests (#2548)
* Add fake data generator for UCF101 * Minor error correction * Reduce total number of categories * Fix naming * Increase length * Store in uint8 * Close fds * Add assertGreater * Add dimension tests * Use numel instead of size * Iterate over folds and splits
1 parent c2bbefc commit 23295fb

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

test/fakedata_generation.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import torch
88
from common_utils import get_tmp_dir
99
import pickle
10+
import random
11+
from itertools import cycle
12+
from torchvision.io.video import write_video
1013

1114

1215
@contextlib.contextmanager
@@ -265,3 +268,47 @@ def voc_root():
265268
f.write('test')
266269

267270
yield tmp_dir
271+
272+
273+
@contextlib.contextmanager
274+
def ucf101_root():
275+
with get_tmp_dir() as tmp_dir:
276+
ucf_dir = os.path.join(tmp_dir, 'UCF-101')
277+
video_dir = os.path.join(ucf_dir, 'video')
278+
annotations = os.path.join(ucf_dir, 'annotations')
279+
280+
os.makedirs(ucf_dir)
281+
os.makedirs(video_dir)
282+
os.makedirs(annotations)
283+
284+
fold_files = []
285+
for split in {'train', 'test'}:
286+
for fold in range(1, 4):
287+
fold_file = '{:s}list{:02d}.txt'.format(split, fold)
288+
fold_files.append(os.path.join(annotations, fold_file))
289+
290+
file_handles = [open(x, 'w') for x in fold_files]
291+
file_iter = cycle(file_handles)
292+
293+
for i in range(0, 2):
294+
current_class = 'class_{0}'.format(i + 1)
295+
class_dir = os.path.join(video_dir, current_class)
296+
os.makedirs(class_dir)
297+
for group in range(0, 3):
298+
for clip in range(0, 4):
299+
# Save sample file
300+
clip_name = 'v_{0}_g{1}_c{2}.avi'.format(
301+
current_class, group, clip)
302+
clip_path = os.path.join(class_dir, clip_name)
303+
length = random.randrange(10, 21)
304+
this_clip = torch.randint(
305+
0, 256, (length * 25, 320, 240, 3), dtype=torch.uint8)
306+
write_video(clip_path, this_clip, 25)
307+
# Add to annotations
308+
ann_file = next(file_iter)
309+
ann_file.write('{0}\n'.format(
310+
os.path.join(current_class, clip_name)))
311+
# Close all file descriptors
312+
for f in file_handles:
313+
f.close()
314+
yield (video_dir, annotations)

test/test_datasets.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torchvision
1010
from common_utils import get_tmp_dir
1111
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
12-
cityscapes_root, svhn_root, voc_root
12+
cityscapes_root, svhn_root, voc_root, ucf101_root
1313
import xml.etree.ElementTree as ET
1414

1515

@@ -19,6 +19,12 @@
1919
except ImportError:
2020
HAS_SCIPY = False
2121

22+
try:
23+
import av
24+
HAS_PYAV = True
25+
except ImportError:
26+
HAS_PYAV = False
27+
2228

2329
class Tester(unittest.TestCase):
2430
def generic_classification_dataset_test(self, dataset, num_images=1):
@@ -254,6 +260,26 @@ def test_voc_parse_xml(self, mock_download_extract):
254260
}]
255261
}})
256262

263+
@unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
264+
def test_ucf101(self):
265+
with ucf101_root() as (root, ann_root):
266+
for split in {True, False}:
267+
for fold in range(1, 4):
268+
for length in {10, 15, 20}:
269+
dataset = torchvision.datasets.UCF101(
270+
root, ann_root, length, fold=fold, train=split)
271+
self.assertGreater(len(dataset), 0)
272+
273+
video, audio, label = dataset[0]
274+
self.assertEqual(video.size(), (length, 320, 240, 3))
275+
self.assertEqual(audio.numel(), 0)
276+
self.assertEqual(label, 0)
277+
278+
video, audio, label = dataset[len(dataset) - 1]
279+
self.assertEqual(video.size(), (length, 320, 240, 3))
280+
self.assertEqual(audio.numel(), 0)
281+
self.assertEqual(label, 1)
282+
257283

258284
if __name__ == '__main__':
259285
unittest.main()

0 commit comments

Comments
 (0)