Skip to content

Unify video metadata in VideoClips #1527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test/test_datasets_video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_unfold(self):
self.assertTrue(r.equal(expected))

@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
def test_video_clips(self):
with get_list_of_videos(num_videos=3) as video_list:
video_clips = VideoClips(video_list, 5, 5)
Expand All @@ -84,7 +84,7 @@ def test_video_clips(self):
self.assertEqual(clip_idx, c_idx)

@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
def test_video_clips_custom_fps(self):
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
num_frames = 4
Expand All @@ -94,6 +94,7 @@ def test_video_clips_custom_fps(self):
video, audio, info, video_idx = video_clips.get_clip(i)
self.assertEqual(video.shape[0], num_frames)
self.assertEqual(info["video_fps"], fps)
self.assertEqual(info, {"video_fps": fps})
# TODO add tests checking that the content is right

def test_compute_clips_for_video(self):
Expand Down
11 changes: 11 additions & 0 deletions test/test_datasets_video_utils_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import unittest
from torchvision import set_video_backend
import test_datasets_video_utils


set_video_backend('video_reader')


if __name__ == '__main__':
suite = unittest.TestLoader().loadTestsFromModule(test_datasets_video_utils)
unittest.TextTestRunner(verbosity=1).run(suite)
1 change: 1 addition & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def test_read_video_pts_unit_sec(self):

self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)
self.assertEqual(info, {"video_fps": 5})

def test_read_timestamps_pts_unit_sec(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
Expand Down
104 changes: 44 additions & 60 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torchvision.io import (
_read_video_timestamps_from_file,
_read_video_from_file,
_probe_video_from_file
)
from torchvision.io import read_video_timestamps, read_video

Expand Down Expand Up @@ -71,11 +72,11 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1
frame_rate=None, _precomputed_metadata=None, num_workers=0,
_video_width=0, _video_height=0, _video_min_dimension=0,
_audio_samples=0):
from torchvision import get_video_backend

self.video_paths = video_paths
self.num_workers = num_workers
self._backend = get_video_backend()

# these options are not valid for pyav backend
self._video_width = _video_width
self._video_height = _video_height
self._video_min_dimension = _video_min_dimension
Expand All @@ -89,87 +90,60 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1

def _compute_frame_pts(self):
self.video_pts = []
if self._backend == "pyav":
self.video_fps = []
else:
self.info = []
self.video_fps = []

# strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first
class DS(object):
def __init__(self, x, _backend):
def __init__(self, x):
self.x = x
self._backend = _backend

def __len__(self):
return len(self.x)

def __getitem__(self, idx):
if self._backend == "pyav":
return read_video_timestamps(self.x[idx])
else:
return _read_video_timestamps_from_file(self.x[idx])
return read_video_timestamps(self.x[idx])

import torch.utils.data
dl = torch.utils.data.DataLoader(
DS(self.video_paths, self._backend),
DS(self.video_paths),
batch_size=16,
num_workers=self.num_workers,
collate_fn=lambda x: x)

with tqdm(total=len(dl)) as pbar:
for batch in dl:
pbar.update(1)
if self._backend == "pyav":
clips, fps = list(zip(*batch))
clips = [torch.as_tensor(c) for c in clips]
self.video_pts.extend(clips)
self.video_fps.extend(fps)
else:
video_pts, _audio_pts, info = list(zip(*batch))
video_pts = [torch.as_tensor(c) for c in video_pts]
self.video_pts.extend(video_pts)
self.info.extend(info)
clips, fps = list(zip(*batch))
clips = [torch.as_tensor(c) for c in clips]
self.video_pts.extend(clips)
self.video_fps.extend(fps)

def _init_from_metadata(self, metadata):
self.video_paths = metadata["video_paths"]
assert len(self.video_paths) == len(metadata["video_pts"])
self.video_pts = metadata["video_pts"]

if self._backend == "pyav":
assert len(self.video_paths) == len(metadata["video_fps"])
self.video_fps = metadata["video_fps"]
else:
assert len(self.video_paths) == len(metadata["info"])
self.info = metadata["info"]
assert len(self.video_paths) == len(metadata["video_fps"])
self.video_fps = metadata["video_fps"]

@property
def metadata(self):
_metadata = {
"video_paths": self.video_paths,
"video_pts": self.video_pts,
"video_fps": self.video_fps
}
if self._backend == "pyav":
_metadata.update({"video_fps": self.video_fps})
else:
_metadata.update({"info": self.info})
return _metadata

def subset(self, indices):
video_paths = [self.video_paths[i] for i in indices]
video_pts = [self.video_pts[i] for i in indices]
if self._backend == "pyav":
video_fps = [self.video_fps[i] for i in indices]
else:
info = [self.info[i] for i in indices]
video_fps = [self.video_fps[i] for i in indices]
metadata = {
"video_paths": video_paths,
"video_pts": video_pts,
"video_fps": video_fps
}
if self._backend == "pyav":
metadata.update({"video_fps": video_fps})
else:
metadata.update({"info": info})
return type(self)(video_paths, self.num_frames, self.step, self.frame_rate,
_precomputed_metadata=metadata, num_workers=self.num_workers,
_video_width=self._video_width,
Expand Down Expand Up @@ -212,22 +186,10 @@ def compute_clips(self, num_frames, step, frame_rate=None):
self.frame_rate = frame_rate
self.clips = []
self.resampling_idxs = []
if self._backend == "pyav":
for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
else:
for video_pts, info in zip(self.video_pts, self.info):
if "video_fps" in info:
clips, idxs = self.compute_clips_for_video(
video_pts, num_frames, step, info["video_fps"], frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
else:
# properly handle the cases where video decoding fails
self.clips.append(torch.zeros(0, num_frames, dtype=torch.int64))
self.resampling_idxs.append(torch.zeros(0, dtype=torch.int64))
for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()

Expand Down Expand Up @@ -287,12 +249,28 @@ def get_clip(self, idx):
video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx]

if self._backend == "pyav":
from torchvision import get_video_backend
backend = get_video_backend()

if backend == "pyav":
# check for invalid options
if self._video_width != 0:
raise ValueError("pyav backend doesn't support _video_width != 0")
if self._video_height != 0:
raise ValueError("pyav backend doesn't support _video_height != 0")
if self._video_min_dimension != 0:
raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
if self._audio_samples != 0:
raise ValueError("pyav backend doesn't support _audio_samples != 0")

if backend == "pyav":
start_pts = clip_pts[0].item()
end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, start_pts, end_pts)
else:
info = self.info[video_idx]
info = _probe_video_from_file(video_path)
video_fps = info["video_fps"]
audio_fps = None

video_start_pts = clip_pts[0].item()
video_end_pts = clip_pts[-1].item()
Expand All @@ -313,6 +291,7 @@ def get_clip(self, idx):
info["audio_timebase"],
math.ceil,
)
audio_fps = info["audio_sample_rate"]
video, audio, info = _read_video_from_file(
video_path,
video_width=self._video_width,
Expand All @@ -324,6 +303,11 @@ def get_clip(self, idx):
audio_pts_range=(audio_start_pts, audio_end_pts),
audio_timebase=audio_timebase,
)

info = {"video_fps": video_fps}
if audio_fps is not None:
info["audio_fps"] = audio_fps

if self.frame_rate is not None:
resampling_idx = self.resampling_idxs[video_idx][clip_idx]
if isinstance(resampling_idx, torch.Tensor):
Expand Down
9 changes: 8 additions & 1 deletion torchvision/io/_video_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def get_pts(time_base):
audio_timebase = info['audio_timebase']
audio_pts_range = get_pts(audio_timebase)

return _read_video_from_file(
vframes, aframes, info = _read_video_from_file(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a few arguments of _read_video_from_file , such as video_width, video_height, video_min_dimension, are not used here, and can not be specified in API _read_video.

Video frame resizing is fast if done inside of video reader. I would like to keep those argument exposed.

However, pyav backends does not support such resizing. For now, you already raise ValueError when they are not zero. In the long-run, we can consider applying video resizing transform to keep the same behavior between pyav and video_reader backend.

Copy link
Member Author

@fmassa fmassa Oct 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking is that I want to first properly benchmark the benefits of having the rescaling happening in ffmpeg to have an idea of how much improvements it will bring us.

Once we have the numbers, we should analyze them and then decide on how to expose the resizing functionality in a clean way to the users. I think this should be done after December though

Note that power users can still use the private API for specifying the video_width / etc, so this functionality is available but not part of the public API

filename,
read_video_stream=True,
video_pts_range=video_pts_range,
Expand All @@ -392,6 +392,13 @@ def get_pts(time_base):
audio_pts_range=audio_pts_range,
audio_timebase=audio_timebase,
)
_info = {}
if has_video:
_info['video_fps'] = info['video_fps']
if has_audio:
_info['audio_fps'] = info['audio_sample_rate']

return vframes, aframes, _info


def _read_video_timestamps(filename, pts_unit='pts'):
Expand Down