Skip to content

Commit c553272

Browse files
committed
fix
1 parent 8277823 commit c553272

File tree

3 files changed

+31
-17
lines changed

3 files changed

+31
-17
lines changed

test/torchaudio_unittest/io/stream_writer_test.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
import torchaudio
35

@@ -420,25 +422,38 @@ def test_audio_pts_increment(self):
420422
print(chunk.pts, expected)
421423
assert abs(chunk.pts - expected) < 1e-10
422424

423-
def test_video_pts_overwrite(self):
425+
@parameterized.expand([
426+
(10, 100),
427+
(15, 150),
428+
(24, 240),
429+
(25, 200),
430+
(30, 300),
431+
(50, 500),
432+
(60, 600),
433+
# PTS value conversion involves float <-> int conversion, which can
434+
# introduce rounding error.
435+
# This test is a spot-check for popular 29.97 Hz
436+
(30000/1001, 10010),
437+
])
438+
def test_video_pts_overwrite(self, frame_rate, num_frames):
424439
"""Can overwrite PTS"""
425440

426441
ext = "mp4"
427-
num_frames = 256
428442
filename = f"test.{ext}"
429-
frame_rate = 10
430-
width, height = 96, 128
443+
width, height = 8, 8
431444

432445
# Write data
433446
dst = self.get_dst(filename)
434447
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
435448
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
436449

437-
video = torch.randint(256, (num_frames, 3, height, width), dtype=torch.uint8)
438-
reference_pts = [2 * i / frame_rate for i in range(num_frames)]
450+
video = torch.zeros((1, 3, height, width), dtype=torch.uint8)
451+
reference_pts = []
439452
with writer.open():
440-
for i, pts in enumerate(reference_pts):
441-
writer.write_video_chunk(0, video[i : i + 1], pts)
453+
for i in range(num_frames):
454+
pts = i / frame_rate
455+
reference_pts.append(pts)
456+
writer.write_video_chunk(0, video, pts)
442457

443458
# check
444459
if self.test_fileobj:
@@ -450,4 +465,7 @@ def test_video_pts_overwrite(self):
450465
assert len(pts) == len(reference_pts)
451466

452467
for val, ref in zip(pts, reference_pts):
453-
assert val == ref
468+
# torch provides isclose, but we don't know if converting floats to tensor
469+
# could introduce a descrepancy, so we compare floats and use math.isclose
470+
# for that.
471+
assert math.isclose(val, ref)

torchaudio/csrc/ffmpeg/CMakeLists.txt

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,10 @@ set(
1616
stream_reader/sink.cpp
1717
stream_reader/stream_processor.cpp
1818
stream_reader/stream_reader.cpp
19+
stream_writer/encode_process.cpp
1920
stream_writer/encoder.cpp
20-
stream_writer/converter.cpp
21-
stream_writer/output_stream.cpp
22-
stream_writer/audio_converter.cpp
23-
stream_writer/audio_output_stream.cpp
24-
stream_writer/video_converter.cpp
25-
stream_writer/video_output_stream.cpp
2621
stream_writer/stream_writer.cpp
22+
stream_writer/tensor_converter.cpp
2723
compat.cpp
2824
utils.cpp
2925
)

torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,8 @@ void EncodeProcess::process(
525525

526526
AVRational codec_tb = codec_ctx->time_base;
527527
if (pts) {
528-
src_frame->pts =
529-
static_cast<int64_t>(pts.value() * codec_tb.den / codec_tb.num);
528+
double pts_val = pts.value() * codec_tb.den / codec_tb.num;
529+
src_frame->pts = static_cast<int64_t>(std::round(pts_val));
530530
}
531531
for (const auto& frame : converter.convert(tensor)) {
532532
process_frame(frame);

0 commit comments

Comments
 (0)