diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index ccca068d367..2488edc613d 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -59,7 +59,7 @@ def test_unfold(self): self.assertTrue(r.equal(expected)) @unittest.skipIf(not io.video._av_available(), "this test requires av") - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') + @unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows') def test_video_clips(self): with get_list_of_videos(num_videos=3) as video_list: video_clips = VideoClips(video_list, 5, 5) @@ -84,7 +84,7 @@ def test_video_clips(self): self.assertEqual(clip_idx, c_idx) @unittest.skipIf(not io.video._av_available(), "this test requires av") - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') + @unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows') def test_video_clips_custom_fps(self): with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list: num_frames = 4 @@ -94,6 +94,7 @@ def test_video_clips_custom_fps(self): video, audio, info, video_idx = video_clips.get_clip(i) self.assertEqual(video.shape[0], num_frames) self.assertEqual(info["video_fps"], fps) + self.assertEqual(info, {"video_fps": fps}) # TODO add tests checking that the content is right def test_compute_clips_for_video(self): diff --git a/test/test_datasets_video_utils_opt.py b/test/test_datasets_video_utils_opt.py new file mode 100644 index 00000000000..f94af400838 --- /dev/null +++ b/test/test_datasets_video_utils_opt.py @@ -0,0 +1,11 @@ +import unittest +from torchvision import set_video_backend +import test_datasets_video_utils + + +set_video_backend('video_reader') + + +if __name__ == '__main__': + suite = unittest.TestLoader().loadTestsFromModule(test_datasets_video_utils) + unittest.TextTestRunner(verbosity=1).run(suite) diff --git a/test/test_io.py b/test/test_io.py index 6063e250627..db292b73e0f 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -183,6 +183,7 @@ def test_read_video_pts_unit_sec(self): self.assertTrue(data.equal(lv)) self.assertEqual(info["video_fps"], 5) + self.assertEqual(info, {"video_fps": 5}) def test_read_timestamps_pts_unit_sec(self): with temp_video(10, 300, 300, 5) as (f_name, data): diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 6a63ec58138..4037bcc7b3b 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -5,6 +5,7 @@ from torchvision.io import ( _read_video_timestamps_from_file, _read_video_from_file, + _probe_video_from_file ) from torchvision.io import read_video_timestamps, read_video @@ -71,11 +72,11 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1 frame_rate=None, _precomputed_metadata=None, num_workers=0, _video_width=0, _video_height=0, _video_min_dimension=0, _audio_samples=0): - from torchvision import get_video_backend self.video_paths = video_paths self.num_workers = num_workers - self._backend = get_video_backend() + + # these options are not valid for pyav backend self._video_width = _video_width self._video_height = _video_height self._video_min_dimension = _video_min_dimension @@ -89,30 +90,23 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1 def _compute_frame_pts(self): self.video_pts = [] - if self._backend == "pyav": - self.video_fps = [] - else: - self.info = [] + self.video_fps = [] # strategy: use a DataLoader to parallelize read_video_timestamps # so need to create a dummy dataset first class DS(object): - def __init__(self, x, _backend): + def __init__(self, x): self.x = x - self._backend = _backend def __len__(self): return len(self.x) def __getitem__(self, idx): - if self._backend == "pyav": - return read_video_timestamps(self.x[idx]) - else: - return _read_video_timestamps_from_file(self.x[idx]) + return read_video_timestamps(self.x[idx]) import torch.utils.data dl = torch.utils.data.DataLoader( - DS(self.video_paths, self._backend), + DS(self.video_paths), batch_size=16, num_workers=self.num_workers, collate_fn=lambda x: x) @@ -120,56 +114,36 @@ def __getitem__(self, idx): with tqdm(total=len(dl)) as pbar: for batch in dl: pbar.update(1) - if self._backend == "pyav": - clips, fps = list(zip(*batch)) - clips = [torch.as_tensor(c) for c in clips] - self.video_pts.extend(clips) - self.video_fps.extend(fps) - else: - video_pts, _audio_pts, info = list(zip(*batch)) - video_pts = [torch.as_tensor(c) for c in video_pts] - self.video_pts.extend(video_pts) - self.info.extend(info) + clips, fps = list(zip(*batch)) + clips = [torch.as_tensor(c) for c in clips] + self.video_pts.extend(clips) + self.video_fps.extend(fps) def _init_from_metadata(self, metadata): self.video_paths = metadata["video_paths"] assert len(self.video_paths) == len(metadata["video_pts"]) self.video_pts = metadata["video_pts"] - - if self._backend == "pyav": - assert len(self.video_paths) == len(metadata["video_fps"]) - self.video_fps = metadata["video_fps"] - else: - assert len(self.video_paths) == len(metadata["info"]) - self.info = metadata["info"] + assert len(self.video_paths) == len(metadata["video_fps"]) + self.video_fps = metadata["video_fps"] @property def metadata(self): _metadata = { "video_paths": self.video_paths, "video_pts": self.video_pts, + "video_fps": self.video_fps } - if self._backend == "pyav": - _metadata.update({"video_fps": self.video_fps}) - else: - _metadata.update({"info": self.info}) return _metadata def subset(self, indices): video_paths = [self.video_paths[i] for i in indices] video_pts = [self.video_pts[i] for i in indices] - if self._backend == "pyav": - video_fps = [self.video_fps[i] for i in indices] - else: - info = [self.info[i] for i in indices] + video_fps = [self.video_fps[i] for i in indices] metadata = { "video_paths": video_paths, "video_pts": video_pts, + "video_fps": video_fps } - if self._backend == "pyav": - metadata.update({"video_fps": video_fps}) - else: - metadata.update({"info": info}) return type(self)(video_paths, self.num_frames, self.step, self.frame_rate, _precomputed_metadata=metadata, num_workers=self.num_workers, _video_width=self._video_width, @@ -212,22 +186,10 @@ def compute_clips(self, num_frames, step, frame_rate=None): self.frame_rate = frame_rate self.clips = [] self.resampling_idxs = [] - if self._backend == "pyav": - for video_pts, fps in zip(self.video_pts, self.video_fps): - clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) - self.clips.append(clips) - self.resampling_idxs.append(idxs) - else: - for video_pts, info in zip(self.video_pts, self.info): - if "video_fps" in info: - clips, idxs = self.compute_clips_for_video( - video_pts, num_frames, step, info["video_fps"], frame_rate) - self.clips.append(clips) - self.resampling_idxs.append(idxs) - else: - # properly handle the cases where video decoding fails - self.clips.append(torch.zeros(0, num_frames, dtype=torch.int64)) - self.resampling_idxs.append(torch.zeros(0, dtype=torch.int64)) + for video_pts, fps in zip(self.video_pts, self.video_fps): + clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) + self.clips.append(clips) + self.resampling_idxs.append(idxs) clip_lengths = torch.as_tensor([len(v) for v in self.clips]) self.cumulative_sizes = clip_lengths.cumsum(0).tolist() @@ -287,12 +249,28 @@ def get_clip(self, idx): video_path = self.video_paths[video_idx] clip_pts = self.clips[video_idx][clip_idx] - if self._backend == "pyav": + from torchvision import get_video_backend + backend = get_video_backend() + + if backend == "pyav": + # check for invalid options + if self._video_width != 0: + raise ValueError("pyav backend doesn't support _video_width != 0") + if self._video_height != 0: + raise ValueError("pyav backend doesn't support _video_height != 0") + if self._video_min_dimension != 0: + raise ValueError("pyav backend doesn't support _video_min_dimension != 0") + if self._audio_samples != 0: + raise ValueError("pyav backend doesn't support _audio_samples != 0") + + if backend == "pyav": start_pts = clip_pts[0].item() end_pts = clip_pts[-1].item() video, audio, info = read_video(video_path, start_pts, end_pts) else: - info = self.info[video_idx] + info = _probe_video_from_file(video_path) + video_fps = info["video_fps"] + audio_fps = None video_start_pts = clip_pts[0].item() video_end_pts = clip_pts[-1].item() @@ -313,6 +291,7 @@ def get_clip(self, idx): info["audio_timebase"], math.ceil, ) + audio_fps = info["audio_sample_rate"] video, audio, info = _read_video_from_file( video_path, video_width=self._video_width, @@ -324,6 +303,11 @@ def get_clip(self, idx): audio_pts_range=(audio_start_pts, audio_end_pts), audio_timebase=audio_timebase, ) + + info = {"video_fps": video_fps} + if audio_fps is not None: + info["audio_fps"] = audio_fps + if self.frame_rate is not None: resampling_idx = self.resampling_idxs[video_idx][clip_idx] if isinstance(resampling_idx, torch.Tensor): diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 7dbab3f7f9d..fa215680363 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -383,7 +383,7 @@ def get_pts(time_base): audio_timebase = info['audio_timebase'] audio_pts_range = get_pts(audio_timebase) - return _read_video_from_file( + vframes, aframes, info = _read_video_from_file( filename, read_video_stream=True, video_pts_range=video_pts_range, @@ -392,6 +392,13 @@ def get_pts(time_base): audio_pts_range=audio_pts_range, audio_timebase=audio_timebase, ) + _info = {} + if has_video: + _info['video_fps'] = info['video_fps'] + if has_audio: + _info['audio_fps'] = info['audio_sample_rate'] + + return vframes, aframes, _info def _read_video_timestamps(filename, pts_unit='pts'):