From 605130c50ebfb09170f227421ccfe4cbd56e8daf Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 15 Oct 2019 11:44:12 +0200 Subject: [PATCH 1/3] Handle corrupted video headers in io --- test/test_io.py | 17 ++++++++++ torchvision/io/video.py | 74 ++++++++++++++++++++++++++--------------- 2 files changed, 64 insertions(+), 27 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index 9bfc2aa403e..c2b49609fe3 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -236,6 +236,23 @@ def test_read_partial_video_pts_unit_sec(self): self.assertEqual(len(lv), 4) self.assertTrue(data[4:8].equal(lv)) + def test_read_video_corrupted_file(self): + with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + f.write(b'This is not an mpg4 file') + video, audio, info = io.read_video(f.name) + self.assertIsInstance(video, torch.Tensor) + self.assertIsInstance(audio, torch.Tensor) + self.assertEqual(video.numel(), 0) + self.assertEqual(audio.numel(), 0) + self.assertEqual(info, {}) + + def test_read_video_timestamps_corrupted_file(self): + with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + f.write(b'This is not an mpg4 file') + video_pts, video_fps = io.read_video_timestamps(f.name) + self.assertEqual(video_pts, []) + self.assertIs(video_fps, None) + # TODO add tests for audio diff --git a/torchvision/io/video.py b/torchvision/io/video.py index a957e96e9f5..f61ee9a5392 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -193,25 +193,36 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): raise ValueError("end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts)) - container = av.open(filename, metadata_errors='ignore') info = {} - video_frames = [] - if container.streams.video: - video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit, - container.streams.video[0], {'video': 0}) - info["video_fps"] = float(container.streams.video[0].average_rate) audio_frames = [] - if container.streams.audio: - audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit, - container.streams.audio[0], {'audio': 0}) - info["audio_fps"] = container.streams.audio[0].rate - container.close() + try: + container = av.open(filename, metadata_errors='ignore') + except av.AVError: + # TODO raise a warning? + pass + else: + if container.streams.video: + video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit, + container.streams.video[0], {'video': 0}) + info["video_fps"] = float(container.streams.video[0].average_rate) + + if container.streams.audio: + audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit, + container.streams.audio[0], {'audio': 0}) + info["audio_fps"] = container.streams.audio[0].rate + + container.close() vframes = [frame.to_rgb().to_ndarray() for frame in video_frames] aframes = [frame.to_ndarray() for frame in audio_frames] - vframes = torch.as_tensor(np.stack(vframes)) + + if vframes: + vframes = torch.as_tensor(np.stack(vframes)) + else: + vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) + if aframes: aframes = np.concatenate(aframes, 1) aframes = torch.as_tensor(aframes) @@ -255,21 +266,30 @@ def read_video_timestamps(filename, pts_unit='pts'): """ _check_av_available() - container = av.open(filename, metadata_errors='ignore') - video_frames = [] video_fps = None - if container.streams.video: - video_stream = container.streams.video[0] - video_time_base = video_stream.time_base - if _can_read_timestamps_from_packets(container): - # fast path - video_frames = [x for x in container.demux(video=0) if x.pts is not None] - else: - video_frames = _read_from_stream(container, 0, float("inf"), pts_unit, - video_stream, {'video': 0}) - video_fps = float(video_stream.average_rate) - container.close() + + try: + container = av.open(filename, metadata_errors='ignore') + except av.AVError: + # TODO add a warning + pass + else: + if container.streams.video: + video_stream = container.streams.video[0] + video_time_base = video_stream.time_base + if _can_read_timestamps_from_packets(container): + # fast path + video_frames = [x for x in container.demux(video=0) if x.pts is not None] + else: + video_frames = _read_from_stream(container, 0, float("inf"), pts_unit, + video_stream, {'video': 0}) + video_fps = float(video_stream.average_rate) + container.close() + + pts = [x.pts for x in video_frames] + if pts_unit == 'sec': - return [x.pts * video_time_base for x in video_frames], video_fps - return [x.pts for x in video_frames], video_fps + pts = [x * video_time_base for x in pts] + + return pts, video_fps From 3b1c370f3d1a9de2b8886beff11d8ea26eec88b6 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 15 Oct 2019 11:55:03 +0200 Subject: [PATCH 2/3] Catch exceptions while decoding partly-corrupted files --- torchvision/io/video.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index f61ee9a5392..efd70f9895c 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -124,13 +124,17 @@ def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, str # print("Corrupted file?", container.name) return [] buffer_count = 0 - for idx, frame in enumerate(container.decode(**stream_name)): - frames[frame.pts] = frame - if frame.pts >= end_offset: - if should_buffer and buffer_count < max_buffer_size: - buffer_count += 1 - continue - break + try: + for idx, frame in enumerate(container.decode(**stream_name)): + frames[frame.pts] = frame + if frame.pts >= end_offset: + if should_buffer and buffer_count < max_buffer_size: + buffer_count += 1 + continue + break + except av.AVError: + # TODO add a warning + pass # ensure that the results are sorted wrt the pts result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset] if start_offset > 0 and start_offset not in frames: From 1130fd1d125e0bd6e1a560749d83a8381542a4f1 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 15 Oct 2019 14:03:21 +0200 Subject: [PATCH 3/3] Add more tests --- test/test_io.py | 18 ++++++++++++++++++ torchvision/io/video.py | 5 ++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/test/test_io.py b/test/test_io.py index c2b49609fe3..bcd1ce4232b 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -253,6 +253,24 @@ def test_read_video_timestamps_corrupted_file(self): self.assertEqual(video_pts, []) self.assertIs(video_fps, None) + def test_read_video_partially_corrupted_file(self): + with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data): + with open(f_name, 'r+b') as f: + size = os.path.getsize(f_name) + bytes_to_overwrite = size // 10 + # seek to the middle of the file + f.seek(5 * bytes_to_overwrite) + # corrupt 10% of the file from the middle + f.write(b'\xff' * bytes_to_overwrite) + # this exercises the container.decode assertion check + video, audio, info = io.read_video(f.name, pts_unit='sec') + # check that size is not equal to 5, but 3 + self.assertEqual(len(video), 3) + # but the valid decoded content is still correct + self.assertTrue(video[:3].equal(data[:3])) + # and the last few frames are wrong + self.assertFalse(video.equal(data)) + # TODO add tests for audio diff --git a/torchvision/io/video.py b/torchvision/io/video.py index efd70f9895c..5337e26e396 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -210,7 +210,10 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): if container.streams.video: video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit, container.streams.video[0], {'video': 0}) - info["video_fps"] = float(container.streams.video[0].average_rate) + video_fps = container.streams.video[0].average_rate + # guard against potentially corrupted files + if video_fps is not None: + info["video_fps"] = float(video_fps) if container.streams.audio: audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,