Skip to content

Commit 693e0ae

Browse files
authored
Fixed missing audio with pyav backend (#4064)
1 parent bdc88f5 commit 693e0ae

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

test/test_video_reader.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import itertools
23
import math
34
import os
45
import unittest
@@ -1243,16 +1244,39 @@ def test_invalid_file(self):
12431244
with self.assertRaises(RuntimeError):
12441245
io.read_video('foo.mp4')
12451246

1246-
def test_audio_present(self):
1247-
"""Test if audio frames are returned with video_reader backend."""
1248-
set_video_backend('video_reader')
1247+
def test_audio_present_pts(self):
1248+
"""Test if audio frames are returned with pts unit."""
1249+
backends = ['video_reader', 'pyav']
1250+
start_offsets = [0, 1000]
1251+
end_offsets = [3000, None]
1252+
for test_video, _ in test_videos.items():
1253+
full_path = os.path.join(VIDEO_DIR, test_video)
1254+
container = av.open(full_path)
1255+
if container.streams.audio:
1256+
for backend, start_offset, end_offset in itertools.product(
1257+
backends, start_offsets, end_offsets):
1258+
set_video_backend(backend)
1259+
_, audio, _ = io.read_video(
1260+
full_path, start_offset, end_offset, pts_unit='pts')
1261+
self.assertGreaterEqual(audio.shape[0], 1)
1262+
self.assertGreaterEqual(audio.shape[1], 1)
1263+
1264+
def test_audio_present_sec(self):
1265+
"""Test if audio frames are returned with sec unit."""
1266+
backends = ['video_reader', 'pyav']
1267+
start_offsets = [0, 0.1]
1268+
end_offsets = [0.3, None]
12491269
for test_video, _ in test_videos.items():
12501270
full_path = os.path.join(VIDEO_DIR, test_video)
12511271
container = av.open(full_path)
12521272
if container.streams.audio:
1253-
_, audio, _ = io.read_video(full_path)
1254-
self.assertGreaterEqual(audio.shape[0], 1)
1255-
self.assertGreaterEqual(audio.shape[1], 1)
1273+
for backend, start_offset, end_offset in itertools.product(
1274+
backends, start_offsets, end_offsets):
1275+
set_video_backend(backend)
1276+
_, audio, _ = io.read_video(
1277+
full_path, start_offset, end_offset, pts_unit='sec')
1278+
self.assertGreaterEqual(audio.shape[0], 1)
1279+
self.assertGreaterEqual(audio.shape[1], 1)
12561280

12571281

12581282
if __name__ == "__main__":

torchvision/io/video.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,22 +283,25 @@ def read_video(
283283
info = {}
284284
video_frames = []
285285
audio_frames = []
286+
audio_timebase = _video_opt.default_timebase
286287

287288
try:
288289
with av.open(filename, metadata_errors="ignore") as container:
290+
if container.streams.audio:
291+
audio_timebase = container.streams.audio[0].time_base
289292
time_base = _video_opt.default_timebase
290293
if container.streams.video:
291294
time_base = container.streams.video[0].time_base
292295
elif container.streams.audio:
293296
time_base = container.streams.audio[0].time_base
294297
# video_timebase is the default time_base
295-
start_pts_sec, end_pts_sec, pts_unit = _video_opt._convert_to_sec(
298+
start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(
296299
start_pts, end_pts, pts_unit, time_base)
297300
if container.streams.video:
298301
video_frames = _read_from_stream(
299302
container,
300-
start_pts_sec,
301-
end_pts_sec,
303+
start_pts,
304+
end_pts,
302305
pts_unit,
303306
container.streams.video[0],
304307
{"video": 0},
@@ -311,8 +314,8 @@ def read_video(
311314
if container.streams.audio:
312315
audio_frames = _read_from_stream(
313316
container,
314-
start_pts_sec,
315-
end_pts_sec,
317+
start_pts,
318+
end_pts,
316319
pts_unit,
317320
container.streams.audio[0],
318321
{"audio": 0},
@@ -334,6 +337,10 @@ def read_video(
334337
if aframes_list:
335338
aframes = np.concatenate(aframes_list, 1)
336339
aframes = torch.as_tensor(aframes)
340+
if pts_unit == 'sec':
341+
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
342+
if end_pts != float("inf"):
343+
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
337344
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
338345
else:
339346
aframes = torch.empty((1, 0), dtype=torch.float32)

0 commit comments

Comments
 (0)