From a2a7350bd388deb961ecc83fe5729162903f5e4b Mon Sep 17 00:00:00 2001 From: zyan3 Date: Mon, 23 Sep 2019 16:44:43 -0700 Subject: [PATCH 1/5] add _backend argument to __init__() of class VideoClips --- test/test_datasets_video_utils.py | 14 +-- test/test_io.py | 67 ++++++++++---- torchvision/__init__.py | 23 +++++ torchvision/datasets/video_utils.py | 134 +++++++++++++++++++++++----- torchvision/io/__init__.py | 1 + torchvision/io/_video_opt.py | 1 + 6 files changed, 199 insertions(+), 41 deletions(-) diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index a9cb7ab50ef..f90b20ded70 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -6,9 +6,11 @@ from torchvision import io from torchvision.datasets.video_utils import VideoClips, unfold +from torchvision import set_video_backend, get_video_backend from common_utils import get_tmp_dir +_backend = get_video_backend() @contextlib.contextmanager def get_list_of_videos(num_videos=5, sizes=None, fps=None): @@ -62,21 +64,21 @@ def test_unfold(self): @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_video_clips(self): with get_list_of_videos(num_videos=3) as video_list: - video_clips = VideoClips(video_list, 5, 5) + video_clips = VideoClips(video_list, 5, 5, _backend=_backend) self.assertEqual(video_clips.num_clips(), 1 + 2 + 3) for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]): video_idx, clip_idx = video_clips.get_clip_location(i) self.assertEqual(video_idx, v_idx) self.assertEqual(clip_idx, c_idx) - video_clips = VideoClips(video_list, 6, 6) + video_clips = VideoClips(video_list, 6, 6, _backend=_backend) self.assertEqual(video_clips.num_clips(), 0 + 1 + 2) for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]): video_idx, clip_idx = video_clips.get_clip_location(i) self.assertEqual(video_idx, v_idx) self.assertEqual(clip_idx, c_idx) - video_clips = VideoClips(video_list, 6, 1) + video_clips = VideoClips(video_list, 6, 1, _backend=_backend) self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1)) for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]: video_idx, clip_idx = video_clips.get_clip_location(i) @@ -86,7 +88,7 @@ def test_video_clips(self): @unittest.skip("Moved to reference scripts for now") def test_video_sampler(self): with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5) + video_clips = VideoClips(video_list, 5, 5, _backend=_backend) sampler = RandomClipSampler(video_clips, 3) # noqa: F821 self.assertEqual(len(sampler), 3 * 3) indices = torch.tensor(list(iter(sampler))) @@ -98,7 +100,7 @@ def test_video_sampler(self): @unittest.skip("Moved to reference scripts for now") def test_video_sampler_unequal(self): with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5) + video_clips = VideoClips(video_list, 5, 5, _backend=_backend) sampler = RandomClipSampler(video_clips, 3) # noqa: F821 self.assertEqual(len(sampler), 2 + 3 + 3) indices = list(iter(sampler)) @@ -119,7 +121,7 @@ def test_video_clips_custom_fps(self): with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list: num_frames = 4 for fps in [1, 3, 4, 10]: - video_clips = VideoClips(video_list, num_frames, num_frames, fps) + video_clips = VideoClips(video_list, num_frames, num_frames, fps, _backend=_backend) for i in range(video_clips.num_clips()): video, audio, info, video_idx = video_clips.get_clip(i) self.assertEqual(video.shape[0], num_frames) diff --git a/test/test_io.py b/test/test_io.py index 96c33a4be68..75e2300dfe0 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -4,6 +4,7 @@ import torch import torchvision.datasets.utils as utils import torchvision.io as io +from torchvision import get_video_backend import unittest import sys import warnings @@ -22,6 +23,20 @@ except ImportError: av = None +_video_backend = get_video_backend() + + +def _read_video(filename, start_pts=0, end_pts=None): + if _video_backend == "pyav": + return io.read_video(filename, start_pts, end_pts) + else: + if end_pts is None: + end_pts = -1 + return io._read_video_from_file( + filename, + video_pts_range=(start_pts, end_pts), + ) + def _create_video_frames(num_frames, height, width): y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) @@ -44,7 +59,12 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options = {'crf': '0'} if video_codec is None: - video_codec = 'libx264' + if _video_backend == "pyav": + video_codec = 'libx264' + else: + # when video_codec is not set, we assume it is libx264rgb which accepts + # RGB pixel formats as input instead of YUV + video_codec = 'libx264rgb' if options is None: options = {} @@ -63,15 +83,16 @@ class Tester(unittest.TestCase): def test_write_read_video(self): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - lv, _, info = io.read_video(f_name) - + lv, _, info = _read_video(f_name) self.assertTrue(data.equal(lv)) self.assertEqual(info["video_fps"], 5) def test_read_timestamps(self): with temp_video(10, 300, 300, 5) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) - + if _video_backend == "pyav": + pts, _ = io.read_video_timestamps(f_name) + else: + pts, _, _ = io._read_video_timestamps_from_file(f_name) # note: not all formats/codecs provide accurate information for computing the # timestamps. For the format that we use here, this information is available, # so we use it as a baseline @@ -85,26 +106,35 @@ def test_read_timestamps(self): def test_read_partial_video(self): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) + if _video_backend == "pyav": + pts, _ = io.read_video_timestamps(f_name) + else: + pts, _, _ = io._read_video_timestamps_from_file(f_name) for start in range(5): for l in range(1, 4): - lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1]) + lv, _, _ = _read_video(f_name, pts[start], pts[start + l - 1]) s_data = data[start:(start + l)] self.assertEqual(len(lv), l) self.assertTrue(s_data.equal(lv)) - lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) - self.assertEqual(len(lv), 4) - self.assertTrue(data[4:8].equal(lv)) + if _video_backend == "pyav": + # for "video_reader" backend, we don't decode the closest early frame + # when the given start pts is not matching any frame pts + lv, _, _ = _read_video(f_name, pts[4] + 1, pts[7]) + self.assertEqual(len(lv), 4) + self.assertTrue(data[4:8].equal(lv)) def test_read_partial_video_bframes(self): # do not use lossless encoding, to test the presence of B-frames options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'} with temp_video(100, 300, 300, 5, options=options) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) + if _video_backend == "pyav": + pts, _ = io.read_video_timestamps(f_name) + else: + pts, _, _ = io._read_video_timestamps_from_file(f_name) for start in range(0, 80, 20): for l in range(1, 4): - lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1]) + lv, _, _ = _read_video(f_name, pts[start], pts[start + l - 1]) s_data = data[start:(start + l)] self.assertEqual(len(lv), l) self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE) @@ -120,7 +150,12 @@ def test_read_packed_b_frames_divx_file(self): url = "https://download.pytorch.org/vision_tests/io/" + name try: utils.download_url(url, temp_dir) - pts, fps = io.read_video_timestamps(f_name) + if _video_backend == "pyav": + pts, fps = io.read_video_timestamps(f_name) + else: + pts, _, info = io._read_video_timestamps_from_file(f_name) + fps = info["video_fps"] + self.assertEqual(pts, sorted(pts)) self.assertEqual(fps, 30) except URLError: @@ -130,8 +165,10 @@ def test_read_packed_b_frames_divx_file(self): def test_read_timestamps_from_packet(self): with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) - + if _video_backend == "pyav": + pts, _ = io.read_video_timestamps(f_name) + else: + pts, _, _ = io._read_video_timestamps_from_file(f_name) # note: not all formats/codecs provide accurate information for computing the # timestamps. For the format that we use here, this information is available, # so we use it as a baseline diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 84dbe4fa1ee..16a412a4155 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -14,6 +14,7 @@ _image_backend = 'PIL' +_video_backend = "pyav" def set_image_backend(backend): """ @@ -38,6 +39,28 @@ def get_image_backend(): return _image_backend +def set_video_backend(backend): + """ + Specifies the package used to decode videos. + + Args: + backend (string): Name of the video backend. one of {'pyav', 'video_reader'}. + The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic + binding for the FFmpeg libraries. + The :mod:`video_reader` package includes a native c++ implementation on + top of FFMPEG libraries, and a python API of TorchScript custom operator. + It is generally decoding faster than pyav, but perhaps is less robust. + """ + global _video_backend + if backend not in ["pyav", "video_reader"]: + raise ValueError( + "Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend + ) + _video_backend = backend + +def get_video_backend(): + return _video_backend + def _is_tracing(): import torch return torch._C._get_tracing_state() diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 1d60b6c72db..d68f52b9d99 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -1,11 +1,29 @@ import bisect +from fractions import Fraction import math import torch +from torchvision.io import ( + _read_video_timestamps_from_file, + _read_video_from_file, +) from torchvision.io import read_video_timestamps, read_video from .utils import tqdm +def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor): + """convert pts between different time bases + Args: + pts: presentation timestamp, float + timebase_from: original timebase. Fraction + timebase_to: new timebase. Fraction + round_func: rounding function. + """ + new_pts = Fraction(pts, 1) * timebase_from / timebase_to + return round_func(new_pts) + + + def unfold(tensor, size, step, dilation=1): """ similar to tensor.unfold, but with the dilation @@ -49,9 +67,11 @@ class VideoClips(object): on the resampled video """ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1, - frame_rate=None, _precomputed_metadata=None, num_workers=1): + frame_rate=None, _precomputed_metadata=None, num_workers=1, + _backend="pyav"): self.video_paths = video_paths self.num_workers = num_workers + self.backend = _backend if _precomputed_metadata is None: self._compute_frame_pts() else: @@ -60,23 +80,30 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1 def _compute_frame_pts(self): self.video_pts = [] - self.video_fps = [] + if self.backend == "pyav": + self.video_fps = [] + else: + self.info = [] # strategy: use a DataLoader to parallelize read_video_timestamps # so need to create a dummy dataset first class DS(object): - def __init__(self, x): + def __init__(self, x, backend): self.x = x + self.backend = backend def __len__(self): return len(self.x) def __getitem__(self, idx): - return read_video_timestamps(self.x[idx]) + if self.backend == "pyav": + return read_video_timestamps(self.x[idx]) + else: + return _read_video_timestamps_from_file(self.x[idx]) import torch.utils.data dl = torch.utils.data.DataLoader( - DS(self.video_paths), + DS(self.video_paths, self.backend), batch_size=16, num_workers=self.num_workers, collate_fn=lambda x: x) @@ -84,25 +111,55 @@ def __getitem__(self, idx): with tqdm(total=len(dl)) as pbar: for batch in dl: pbar.update(1) - clips, fps = list(zip(*batch)) - clips = [torch.as_tensor(c) for c in clips] - self.video_pts.extend(clips) - self.video_fps.extend(fps) + if self.backend == "pyav": + clips, fps = list(zip(*batch)) + clips = [torch.as_tensor(c) for c in clips] + self.video_pts.extend(clips) + self.video_fps.extend(fps) + else: + video_pts, _audio_pts, info = list(zip(*batch)) + video_pts = [torch.as_tensor(c) for c in video_pts] + self.video_pts.extend(video_pts) + self.info.extend(info) def _init_from_metadata(self, metadata): + self.video_paths = metadata["video_paths"] assert len(self.video_paths) == len(metadata["video_pts"]) - assert len(self.video_paths) == len(metadata["video_fps"]) self.video_pts = metadata["video_pts"] - self.video_fps = metadata["video_fps"] + + if self.backend == "pyav": + assert len(self.video_paths) == len(metadata["video_fps"]) + self.video_fps = metadata["video_fps"] + else: + assert len(self.video_paths) == len(metadata["info"]) + self.info = metadata["info"] + + @property + def metadata(self): + _metadata = { + "video_paths": self.video_paths, + "video_pts": self.video_pts, + } + if self.backend == "pyav": + _metadata.update({"video_fps": self.video_fps}) + else: + _metadata.update({"info": self.info}) def subset(self, indices): video_paths = [self.video_paths[i] for i in indices] video_pts = [self.video_pts[i] for i in indices] - video_fps = [self.video_fps[i] for i in indices] + if self.backend == "pyav": + video_fps = [self.video_fps[i] for i in indices] + else: + info = [self.info[i] for i in indices] metadata = { + "video_paths": video_paths, "video_pts": video_pts, - "video_fps": video_fps } + if self.backend == "pyav": + metadata.update({"video_fps": video_fps}) + else: + metadata.update({"info": info}) return type(self)(video_paths, self.num_frames, self.step, self.frame_rate, _precomputed_metadata=metadata) @@ -141,10 +198,16 @@ def compute_clips(self, num_frames, step, frame_rate=None): self.frame_rate = frame_rate self.clips = [] self.resampling_idxs = [] - for video_pts, fps in zip(self.video_pts, self.video_fps): - clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) - self.clips.append(clips) - self.resampling_idxs.append(idxs) + if self.backend == "pyav": + for video_pts, fps in zip(self.video_pts, self.video_fps): + clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) + self.clips.append(clips) + self.resampling_idxs.append(idxs) + else: + for video_pts, info in zip(self.video_pts, self.info): + clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, info["video_fps"], frame_rate) + self.clips.append(clips) + self.resampling_idxs.append(idxs) clip_lengths = torch.as_tensor([len(v) for v in self.clips]) self.cumulative_sizes = clip_lengths.cumsum(0).tolist() @@ -203,9 +266,40 @@ def get_clip(self, idx): video_idx, clip_idx = self.get_clip_location(idx) video_path = self.video_paths[video_idx] clip_pts = self.clips[video_idx][clip_idx] - start_pts = clip_pts[0].item() - end_pts = clip_pts[-1].item() - video, audio, info = read_video(video_path, start_pts, end_pts) + + if self.backend == "pyav": + start_pts = clip_pts[0].item() + end_pts = clip_pts[-1].item() + video, audio, info = read_video(video_path, start_pts, end_pts) + else: + info = self.info[video_idx] + + video_start_pts = clip_pts[0].item() + video_end_pts = clip_pts[-1].item() + + audio_start_pts, audio_end_pts = 0, -1 + audio_timebase = Fraction(0, 1) + if "audio_timebase" in info: + audio_timebase = info["audio_timebase"] + audio_start_pts = pts_convert( + video_start_pts, + info["video_timebase"], + info["audio_timebase"], + math.floor, + ) + audio_end_pts = pts_convert( + video_start_pts, + info["video_timebase"], + info["audio_timebase"], + math.ceil, + ) + video, audio, info = _read_video_from_file( + video_path, + video_pts_range=(video_start_pts, video_end_pts), + video_timebase=info["video_timebase"], + audio_pts_range=(audio_start_pts, audio_end_pts), + audio_timebase=audio_timebase, + ) if self.frame_rate is not None: resampling_idx = self.resampling_idxs[video_idx][clip_idx] if isinstance(resampling_idx, torch.Tensor): diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 2e840b72895..978ac31555a 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -4,4 +4,5 @@ __all__ = [ 'write_video', 'read_video', 'read_video_timestamps', + '_read_video_from_file', '_read_video_timestamps_from_file', ] diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index f3edab1a957..26948fb2d25 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -10,6 +10,7 @@ try: lib_dir = os.path.join(os.path.dirname(__file__), '..') + lib_dir = "/data/users/zyan3/github/py3/vision/torchvision" _, path, description = imp.find_module("video_reader", [lib_dir]) torch.ops.load_library(path) _HAS_VIDEO_OPT = True From b73e8265552ece5d3abee6d096fff4906168da5f Mon Sep 17 00:00:00 2001 From: zyan3 Date: Mon, 23 Sep 2019 16:49:56 -0700 Subject: [PATCH 2/5] minor fix --- torchvision/io/_video_opt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 26948fb2d25..f3edab1a957 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -10,7 +10,6 @@ try: lib_dir = os.path.join(os.path.dirname(__file__), '..') - lib_dir = "/data/users/zyan3/github/py3/vision/torchvision" _, path, description = imp.find_module("video_reader", [lib_dir]) torch.ops.load_library(path) _HAS_VIDEO_OPT = True From c6c2c90571526395b04fac5c082637b04d6fad0b Mon Sep 17 00:00:00 2001 From: zyan3 Date: Mon, 23 Sep 2019 16:52:03 -0700 Subject: [PATCH 3/5] minor fix --- test/test_datasets_video_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index f90b20ded70..ef8e55003eb 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -6,7 +6,7 @@ from torchvision import io from torchvision.datasets.video_utils import VideoClips, unfold -from torchvision import set_video_backend, get_video_backend +from torchvision import get_video_backend from common_utils import get_tmp_dir From c4134f9a648d365b266938f629770519d14e7fc0 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 24 Sep 2019 11:00:49 -0300 Subject: [PATCH 4/5] Make backend private in VideoClips --- test/test_datasets_video_utils.py | 5 ++++- torchvision/datasets/video_utils.py | 26 +++++++++++++------------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index ef8e55003eb..b6790d49cce 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -10,7 +10,6 @@ from common_utils import get_tmp_dir -_backend = get_video_backend() @contextlib.contextmanager def get_list_of_videos(num_videos=5, sizes=None, fps=None): @@ -63,6 +62,7 @@ def test_unfold(self): @unittest.skipIf(not io.video._av_available(), "this test requires av") @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_video_clips(self): + _backend = get_video_backend() with get_list_of_videos(num_videos=3) as video_list: video_clips = VideoClips(video_list, 5, 5, _backend=_backend) self.assertEqual(video_clips.num_clips(), 1 + 2 + 3) @@ -87,6 +87,7 @@ def test_video_clips(self): @unittest.skip("Moved to reference scripts for now") def test_video_sampler(self): + _backend = get_video_backend() with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: video_clips = VideoClips(video_list, 5, 5, _backend=_backend) sampler = RandomClipSampler(video_clips, 3) # noqa: F821 @@ -99,6 +100,7 @@ def test_video_sampler(self): @unittest.skip("Moved to reference scripts for now") def test_video_sampler_unequal(self): + _backend = get_video_backend() with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: video_clips = VideoClips(video_list, 5, 5, _backend=_backend) sampler = RandomClipSampler(video_clips, 3) # noqa: F821 @@ -118,6 +120,7 @@ def test_video_sampler_unequal(self): @unittest.skipIf(not io.video._av_available(), "this test requires av") @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_video_clips_custom_fps(self): + _backend = get_video_backend() with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list: num_frames = 4 for fps in [1, 3, 4, 10]: diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index d68f52b9d99..5c7a4330636 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -71,7 +71,7 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1 _backend="pyav"): self.video_paths = video_paths self.num_workers = num_workers - self.backend = _backend + self._backend = _backend if _precomputed_metadata is None: self._compute_frame_pts() else: @@ -80,7 +80,7 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1 def _compute_frame_pts(self): self.video_pts = [] - if self.backend == "pyav": + if self._backend == "pyav": self.video_fps = [] else: self.info = [] @@ -88,22 +88,22 @@ def _compute_frame_pts(self): # strategy: use a DataLoader to parallelize read_video_timestamps # so need to create a dummy dataset first class DS(object): - def __init__(self, x, backend): + def __init__(self, x, _backend): self.x = x - self.backend = backend + self._backend = _backend def __len__(self): return len(self.x) def __getitem__(self, idx): - if self.backend == "pyav": + if self._backend == "pyav": return read_video_timestamps(self.x[idx]) else: return _read_video_timestamps_from_file(self.x[idx]) import torch.utils.data dl = torch.utils.data.DataLoader( - DS(self.video_paths, self.backend), + DS(self.video_paths, self._backend), batch_size=16, num_workers=self.num_workers, collate_fn=lambda x: x) @@ -111,7 +111,7 @@ def __getitem__(self, idx): with tqdm(total=len(dl)) as pbar: for batch in dl: pbar.update(1) - if self.backend == "pyav": + if self._backend == "pyav": clips, fps = list(zip(*batch)) clips = [torch.as_tensor(c) for c in clips] self.video_pts.extend(clips) @@ -127,7 +127,7 @@ def _init_from_metadata(self, metadata): assert len(self.video_paths) == len(metadata["video_pts"]) self.video_pts = metadata["video_pts"] - if self.backend == "pyav": + if self._backend == "pyav": assert len(self.video_paths) == len(metadata["video_fps"]) self.video_fps = metadata["video_fps"] else: @@ -140,7 +140,7 @@ def metadata(self): "video_paths": self.video_paths, "video_pts": self.video_pts, } - if self.backend == "pyav": + if self._backend == "pyav": _metadata.update({"video_fps": self.video_fps}) else: _metadata.update({"info": self.info}) @@ -148,7 +148,7 @@ def metadata(self): def subset(self, indices): video_paths = [self.video_paths[i] for i in indices] video_pts = [self.video_pts[i] for i in indices] - if self.backend == "pyav": + if self._backend == "pyav": video_fps = [self.video_fps[i] for i in indices] else: info = [self.info[i] for i in indices] @@ -156,7 +156,7 @@ def subset(self, indices): "video_paths": video_paths, "video_pts": video_pts, } - if self.backend == "pyav": + if self._backend == "pyav": metadata.update({"video_fps": video_fps}) else: metadata.update({"info": info}) @@ -198,7 +198,7 @@ def compute_clips(self, num_frames, step, frame_rate=None): self.frame_rate = frame_rate self.clips = [] self.resampling_idxs = [] - if self.backend == "pyav": + if self._backend == "pyav": for video_pts, fps in zip(self.video_pts, self.video_fps): clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) self.clips.append(clips) @@ -267,7 +267,7 @@ def get_clip(self, idx): video_path = self.video_paths[video_idx] clip_pts = self.clips[video_idx][clip_idx] - if self.backend == "pyav": + if self._backend == "pyav": start_pts = clip_pts[0].item() end_pts = clip_pts[-1].item() video, audio, info = read_video(video_path, start_pts, end_pts) From d40913d94abd5597891af5ff75c3be24c6b8091d Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 24 Sep 2019 11:23:11 -0300 Subject: [PATCH 5/5] Fix lint --- torchvision/__init__.py | 3 +++ torchvision/datasets/video_utils.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 16a412a4155..2fd780497d3 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -16,6 +16,7 @@ _video_backend = "pyav" + def set_image_backend(backend): """ Specifies the package used to load images. @@ -58,9 +59,11 @@ def set_video_backend(backend): ) _video_backend = backend + def get_video_backend(): return _video_backend + def _is_tracing(): import torch return torch._C._get_tracing_state() diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 5c7a4330636..b638c569a87 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -23,7 +23,6 @@ def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor): return round_func(new_pts) - def unfold(tensor, size, step, dilation=1): """ similar to tensor.unfold, but with the dilation