Skip to content

add metadata to video dataset classes. bug fix. more robustness #1376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions test/test_datasets_video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend

from common_utils import get_tmp_dir

Expand Down Expand Up @@ -62,23 +61,22 @@ def test_unfold(self):
@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_video_clips(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3) as video_list:
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
video_clips = VideoClips(video_list, 5, 5)
self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)

video_clips = VideoClips(video_list, 6, 6, _backend=_backend)
video_clips = VideoClips(video_list, 6, 6)
self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)

video_clips = VideoClips(video_list, 6, 1, _backend=_backend)
video_clips = VideoClips(video_list, 6, 1)
self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1))
for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
video_idx, clip_idx = video_clips.get_clip_location(i)
Expand All @@ -87,9 +85,8 @@ def test_video_clips(self):

@unittest.skip("Moved to reference scripts for now")
def test_video_sampler(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler)))
Expand All @@ -100,9 +97,8 @@ def test_video_sampler(self):

@unittest.skip("Moved to reference scripts for now")
def test_video_sampler_unequal(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 2 + 3 + 3)
indices = list(iter(sampler))
Expand All @@ -120,11 +116,10 @@ def test_video_sampler_unequal(self):
@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_video_clips_custom_fps(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
num_frames = 4
for fps in [1, 3, 4, 10]:
video_clips = VideoClips(video_list, num_frames, num_frames, fps, _backend=_backend)
video_clips = VideoClips(video_list, num_frames, num_frames, fps)
for i in range(video_clips.num_clips()):
video, audio, info, video_idx = video_clips.get_clip(i)
self.assertEqual(video.shape[0], num_frames)
Expand Down
7 changes: 6 additions & 1 deletion torchvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from torchvision import models
from torchvision import datasets
from torchvision import ops
Expand Down Expand Up @@ -57,7 +59,10 @@ def set_video_backend(backend):
raise ValueError(
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend
)
_video_backend = backend
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
warnings.warn("video_reader video backend is not available")
else:
_video_backend = backend


def get_video_backend():
Expand Down
15 changes: 13 additions & 2 deletions torchvision/datasets/hmdb51.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import glob
import os

from .video_utils import VideoClips
from .utils import list_dir
from .folder import make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset


Expand Down Expand Up @@ -51,7 +51,8 @@ class HMDB51(VisionDataset):

def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None):
_precomputed_metadata=None, num_workers=1, _video_width=0,
_video_height=0, _video_min_dimension=0, _audio_samples=0):
super(HMDB51, self).__init__(root)
if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
Expand All @@ -71,11 +72,21 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
step_between_clips,
frame_rate,
_precomputed_metadata,
num_workers=num_workers,
_video_width=_video_width,
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
)
self.video_clips_metadata = video_clips.metadata
self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices)
self.transform = transform

@property
def metadata(self):
return self.video_clips_metadata

def _select_fold(self, video_list, annotation_path, fold, train):
target_tag = 1 if train else 2
name = "*test_split{}.txt".format(fold)
Expand Down
15 changes: 13 additions & 2 deletions torchvision/datasets/kinetics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .video_utils import VideoClips
from .utils import list_dir
from .folder import make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset


Expand Down Expand Up @@ -37,7 +37,9 @@ class Kinetics400(VisionDataset):
"""

def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
extensions=('avi',), transform=None, _precomputed_metadata=None):
extensions=('avi',), transform=None, _precomputed_metadata=None,
num_workers=1, _video_width=0, _video_height=0,
_video_min_dimension=0, _audio_samples=0):
super(Kinetics400, self).__init__(root)
extensions = ('avi',)

Expand All @@ -52,9 +54,18 @@ def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
step_between_clips,
frame_rate,
_precomputed_metadata,
num_workers=num_workers,
_video_width=_video_width,
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
)
self.transform = transform

@property
def metadata(self):
return self.video_clips.metadata

def __len__(self):
return self.video_clips.num_clips()

Expand Down
15 changes: 13 additions & 2 deletions torchvision/datasets/ucf101.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import glob
import os

from .video_utils import VideoClips
from .utils import list_dir
from .folder import make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset


Expand Down Expand Up @@ -44,7 +44,8 @@ class UCF101(VisionDataset):

def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None):
_precomputed_metadata=None, num_workers=1, _video_width=0,
_video_height=0, _video_min_dimension=0, _audio_samples=0):
super(UCF101, self).__init__(root)
if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
Expand All @@ -64,11 +65,21 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
step_between_clips,
frame_rate,
_precomputed_metadata,
num_workers=num_workers,
_video_width=_video_width,
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
)
self.video_clips_metadata = video_clips.metadata
self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices)
self.transform = transform

@property
def metadata(self):
return self.video_clips_metadata

def _select_fold(self, video_list, annotation_path, fold, train):
name = "train" if train else "test"
name = "{}list{:02d}.txt".format(name, fold)
Expand Down
35 changes: 29 additions & 6 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,18 @@ class VideoClips(object):
0 means that the data will be loaded in the main process. (default: 0)
"""
def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1,
frame_rate=None, _precomputed_metadata=None, num_workers=0, _backend="pyav"):
frame_rate=None, _precomputed_metadata=None, num_workers=0,
_video_width=0, _video_height=0, _video_min_dimension=0,
_audio_samples=0):
from torchvision import get_video_backend

self.video_paths = video_paths
self.num_workers = num_workers
self._backend = _backend
self._backend = get_video_backend()
self._video_width = _video_width
self._video_height = _video_height
self._video_min_dimension = _video_min_dimension
self._audio_samples = _audio_samples

if _precomputed_metadata is None:
self._compute_frame_pts()
Expand Down Expand Up @@ -145,6 +153,7 @@ def metadata(self):
_metadata.update({"video_fps": self.video_fps})
else:
_metadata.update({"info": self.info})
return _metadata

def subset(self, indices):
video_paths = [self.video_paths[i] for i in indices]
Expand All @@ -162,7 +171,11 @@ def subset(self, indices):
else:
metadata.update({"info": info})
return type(self)(video_paths, self.num_frames, self.step, self.frame_rate,
_precomputed_metadata=metadata)
_precomputed_metadata=metadata, num_workers=self.num_workers,
_video_width=self._video_width,
_video_height=self._video_height,
_video_min_dimension=self._video_min_dimension,
_audio_samples=self._audio_samples)

@staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
Expand Down Expand Up @@ -206,9 +219,15 @@ def compute_clips(self, num_frames, step, frame_rate=None):
self.resampling_idxs.append(idxs)
else:
for video_pts, info in zip(self.video_pts, self.info):
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, info["video_fps"], frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
if "video_fps" in info:
clips, idxs = self.compute_clips_for_video(
video_pts, num_frames, step, info["video_fps"], frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
else:
# properly handle the cases where video decoding fails
self.clips.append(torch.zeros(0, num_frames, dtype=torch.int64))
self.resampling_idxs.append(torch.zeros(0, dtype=torch.int64))
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()

Expand Down Expand Up @@ -296,8 +315,12 @@ def get_clip(self, idx):
)
video, audio, info = _read_video_from_file(
video_path,
video_width=self._video_width,
video_height=self._video_height,
video_min_dimension=self._video_min_dimension,
video_pts_range=(video_start_pts, video_end_pts),
video_timebase=info["video_timebase"],
audio_samples=self._audio_samples,
audio_pts_range=(audio_start_pts, audio_end_pts),
audio_timebase=audio_timebase,
)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .video import write_video, read_video, read_video_timestamps
from ._video_opt import _read_video_from_file, _read_video_timestamps_from_file
from ._video_opt import _read_video_from_file, _read_video_timestamps_from_file, _HAS_VIDEO_OPT


__all__ = [
'write_video', 'read_video', 'read_video_timestamps',
'_read_video_from_file', '_read_video_timestamps_from_file',
'_read_video_from_file', '_read_video_timestamps_from_file', '_HAS_VIDEO_OPT',
]