Skip to content

add _backend argument to __init__() of class VideoClips #1363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions test/test_datasets_video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend

from common_utils import get_tmp_dir

Expand Down Expand Up @@ -61,22 +62,23 @@ def test_unfold(self):
@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_video_clips(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3) as video_list:
video_clips = VideoClips(video_list, 5, 5)
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)

video_clips = VideoClips(video_list, 6, 6)
video_clips = VideoClips(video_list, 6, 6, _backend=_backend)
self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)

video_clips = VideoClips(video_list, 6, 1)
video_clips = VideoClips(video_list, 6, 1, _backend=_backend)
self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1))
for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
video_idx, clip_idx = video_clips.get_clip_location(i)
Expand All @@ -85,8 +87,9 @@ def test_video_clips(self):

@unittest.skip("Moved to reference scripts for now")
def test_video_sampler(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler)))
Expand All @@ -97,8 +100,9 @@ def test_video_sampler(self):

@unittest.skip("Moved to reference scripts for now")
def test_video_sampler_unequal(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 2 + 3 + 3)
indices = list(iter(sampler))
Expand All @@ -116,10 +120,11 @@ def test_video_sampler_unequal(self):
@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_video_clips_custom_fps(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
num_frames = 4
for fps in [1, 3, 4, 10]:
video_clips = VideoClips(video_list, num_frames, num_frames, fps)
video_clips = VideoClips(video_list, num_frames, num_frames, fps, _backend=_backend)
for i in range(video_clips.num_clips()):
video, audio, info, video_idx = video_clips.get_clip(i)
self.assertEqual(video.shape[0], num_frames)
Expand Down
67 changes: 52 additions & 15 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torchvision.datasets.utils as utils
import torchvision.io as io
from torchvision import get_video_backend
import unittest
import sys
import warnings
Expand All @@ -22,6 +23,20 @@
except ImportError:
av = None

_video_backend = get_video_backend()


def _read_video(filename, start_pts=0, end_pts=None):
if _video_backend == "pyav":
return io.read_video(filename, start_pts, end_pts)
else:
if end_pts is None:
end_pts = -1
return io._read_video_from_file(
filename,
video_pts_range=(start_pts, end_pts),
)


def _create_video_frames(num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
Expand All @@ -44,7 +59,12 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
options = {'crf': '0'}

if video_codec is None:
video_codec = 'libx264'
if _video_backend == "pyav":
video_codec = 'libx264'
else:
# when video_codec is not set, we assume it is libx264rgb which accepts
# RGB pixel formats as input instead of YUV
video_codec = 'libx264rgb'
if options is None:
options = {}

Expand All @@ -63,15 +83,16 @@ class Tester(unittest.TestCase):

def test_write_read_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name)

lv, _, info = _read_video(f_name)
self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)

def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)

if _video_backend == "pyav":
pts, _ = io.read_video_timestamps(f_name)
else:
pts, _, _ = io._read_video_timestamps_from_file(f_name)
# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
Expand All @@ -85,26 +106,35 @@ def test_read_timestamps(self):

def test_read_partial_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
if _video_backend == "pyav":
pts, _ = io.read_video_timestamps(f_name)
else:
pts, _, _ = io._read_video_timestamps_from_file(f_name)
for start in range(5):
for l in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
lv, _, _ = _read_video(f_name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue(s_data.equal(lv))

lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))
if _video_backend == "pyav":
# for "video_reader" backend, we don't decode the closest early frame
# when the given start pts is not matching any frame pts
lv, _, _ = _read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))

def test_read_partial_video_bframes(self):
# do not use lossless encoding, to test the presence of B-frames
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
if _video_backend == "pyav":
pts, _ = io.read_video_timestamps(f_name)
else:
pts, _, _ = io._read_video_timestamps_from_file(f_name)
for start in range(0, 80, 20):
for l in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
lv, _, _ = _read_video(f_name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)
Expand All @@ -120,7 +150,12 @@ def test_read_packed_b_frames_divx_file(self):
url = "https://download.pytorch.org/vision_tests/io/" + name
try:
utils.download_url(url, temp_dir)
pts, fps = io.read_video_timestamps(f_name)
if _video_backend == "pyav":
pts, fps = io.read_video_timestamps(f_name)
else:
pts, _, info = io._read_video_timestamps_from_file(f_name)
fps = info["video_fps"]

self.assertEqual(pts, sorted(pts))
self.assertEqual(fps, 30)
except URLError:
Expand All @@ -130,8 +165,10 @@ def test_read_packed_b_frames_divx_file(self):

def test_read_timestamps_from_packet(self):
with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)

if _video_backend == "pyav":
pts, _ = io.read_video_timestamps(f_name)
else:
pts, _, _ = io._read_video_timestamps_from_file(f_name)
# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
Expand Down
26 changes: 26 additions & 0 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

_image_backend = 'PIL'

_video_backend = "pyav"


def set_image_backend(backend):
"""
Expand All @@ -38,6 +40,30 @@ def get_image_backend():
return _image_backend


def set_video_backend(backend):
"""
Specifies the package used to decode videos.

Args:
backend (string): Name of the video backend. one of {'pyav', 'video_reader'}.
The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
binding for the FFmpeg libraries.
The :mod:`video_reader` package includes a native c++ implementation on
top of FFMPEG libraries, and a python API of TorchScript custom operator.
It is generally decoding faster than pyav, but perhaps is less robust.
"""
global _video_backend
if backend not in ["pyav", "video_reader"]:
raise ValueError(
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend
)
_video_backend = backend


def get_video_backend():
return _video_backend


def _is_tracing():
import torch
return torch._C._get_tracing_state()
Loading