Skip to content

Commit 8d2f6f8

Browse files
mthrokfacebook-github-bot
authored andcommitted
Support overwriting PTS in StreamWriter (#3135)
Summary: Pull Request resolved: #3135 Reviewed By: xiaohui-zhang Differential Revision: D43724273 Pulled By: mthrok fbshipit-source-id: 9b52823618948945a26e57d5b3deccbf5f9268c1
1 parent 3212a25 commit 8d2f6f8

File tree

7 files changed

+135
-22
lines changed

7 files changed

+135
-22
lines changed

test/torchaudio_unittest/io/stream_reader_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_chunktensor(self):
5959
w.add_audio_stream(8000, 2)
6060
with w.open():
6161
w.write_audio_chunk(0, c)
62+
w.write_audio_chunk(0, c, c.pts)
6263

6364

6465
################################################################################

test/torchaudio_unittest/io/stream_writer_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
import torchaudio
35

@@ -435,3 +437,53 @@ def test_audio_pts_increment(self):
435437
num_samples += chunk.size(0)
436438
print(chunk.pts, expected)
437439
assert abs(chunk.pts - expected) < 1e-10
440+
441+
@parameterized.expand(
442+
[
443+
(10, 100),
444+
(15, 150),
445+
(24, 240),
446+
(25, 200),
447+
(30, 300),
448+
(50, 500),
449+
(60, 600),
450+
# PTS value conversion involves float <-> int conversion, which can
451+
# introduce rounding error.
452+
# This test is a spot-check for popular 29.97 Hz
453+
(30000 / 1001, 10010),
454+
]
455+
)
456+
def test_video_pts_overwrite(self, frame_rate, num_frames):
457+
"""Can overwrite PTS"""
458+
459+
ext = "mp4"
460+
filename = f"test.{ext}"
461+
width, height = 8, 8
462+
463+
# Write data
464+
dst = self.get_dst(filename)
465+
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
466+
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
467+
468+
video = torch.zeros((1, 3, height, width), dtype=torch.uint8)
469+
reference_pts = []
470+
with writer.open():
471+
for i in range(num_frames):
472+
pts = i / frame_rate
473+
reference_pts.append(pts)
474+
writer.write_video_chunk(0, video, pts)
475+
476+
# check
477+
if self.test_fileobj:
478+
dst.flush()
479+
480+
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
481+
reader.add_video_stream(1)
482+
pts = [chunk.pts for (chunk,) in reader.stream()]
483+
assert len(pts) == len(reference_pts)
484+
485+
for val, ref in zip(pts, reference_pts):
486+
# torch provides isclose, but we don't know if converting floats to tensor
487+
# could introduce a descrepancy, so we compare floats and use math.isclose
488+
# for that.
489+
assert math.isclose(val, ref)

torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ AVFramePtr get_audio_frame(
218218
AVFramePtr frame{};
219219
frame->pts = 0;
220220
frame->format = src_fmt;
221-
// note: channels attribute is not required for encoding, but TensorConverter
222-
// refers to it
221+
// Note: `channels` attribute is not required for encoding, but
222+
// TensorConverter refers to it
223223
frame->channels = num_channels;
224224
frame->channel_layout = codec_ctx->channel_layout;
225225
frame->sample_rate = sample_rate;
@@ -461,6 +461,10 @@ AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
461461
av_err2string(ret),
462462
").");
463463
}
464+
// Note: `nb_samples` attribute is not used for video, but we set it
465+
// anyways so that we can make the logic of PTS increment agnostic to
466+
// audio and video.
467+
frame->nb_samples = 1;
464468
frame->pts = 0;
465469
return frame;
466470
}
@@ -511,24 +515,29 @@ EncodeProcess::EncodeProcess(
511515
src_frame(get_video_frame(format, codec_ctx)),
512516
converter(AVMEDIA_TYPE_VIDEO, src_frame) {}
513517

514-
void EncodeProcess::process(AVMediaType type, const torch::Tensor& tensor) {
518+
void EncodeProcess::process(
519+
AVMediaType type,
520+
const torch::Tensor& tensor,
521+
const c10::optional<double>& pts) {
515522
TORCH_CHECK(
516523
codec_ctx->codec_type == type,
517524
"Attempted to write ",
518525
av_get_media_type_string(type),
519526
" to ",
520527
av_get_media_type_string(codec_ctx->codec_type),
521528
" stream.");
522-
523-
AVRational codec_tb = codec_ctx->time_base;
529+
if (pts) {
530+
AVRational tb = codec_ctx->time_base;
531+
auto val = static_cast<int64_t>(std::round(pts.value() * tb.den / tb.num));
532+
if (src_frame->pts > val) {
533+
TORCH_WARN_ONCE(
534+
"The provided PTS value is smaller than the next expected value.");
535+
}
536+
src_frame->pts = val;
537+
}
524538
for (const auto& frame : converter.convert(tensor)) {
525539
process_frame(frame);
526-
if (type == AVMEDIA_TYPE_VIDEO) {
527-
frame->pts += 1;
528-
} else {
529-
AVRational sr_tb{1, codec_ctx->sample_rate};
530-
frame->pts += av_rescale_q(frame->nb_samples, sr_tb, codec_tb);
531-
}
540+
frame->pts += frame->nb_samples;
532541
}
533542
}
534543

torchaudio/csrc/ffmpeg/stream_writer/encode_process.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ class EncodeProcess {
3939
const c10::optional<std::string>& encoder_format,
4040
const c10::optional<std::string>& hw_accel);
4141

42-
void process(AVMediaType type, const torch::Tensor& tensor);
42+
void process(
43+
AVMediaType type,
44+
const torch::Tensor& tensor,
45+
const c10::optional<double>& pts);
4346

4447
void process_frame(AVFrame* src);
4548

torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,26 +200,32 @@ void StreamWriter::close() {
200200
is_open = false;
201201
}
202202

203-
void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
203+
void StreamWriter::write_audio_chunk(
204+
int i,
205+
const torch::Tensor& waveform,
206+
const c10::optional<double>& pts) {
204207
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
205208
TORCH_CHECK(
206209
0 <= i && i < static_cast<int>(processes.size()),
207210
"Invalid stream index. Index must be in range of [0, ",
208211
processes.size(),
209212
"). Found: ",
210213
i);
211-
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform);
214+
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform, pts);
212215
}
213216

214-
void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
217+
void StreamWriter::write_video_chunk(
218+
int i,
219+
const torch::Tensor& frames,
220+
const c10::optional<double>& pts) {
215221
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
216222
TORCH_CHECK(
217223
0 <= i && i < static_cast<int>(processes.size()),
218224
"Invalid stream index. Index must be in range of [0, ",
219225
processes.size(),
220226
"). Found: ",
221227
i);
222-
processes[i].process(AVMEDIA_TYPE_VIDEO, frames);
228+
processes[i].process(AVMEDIA_TYPE_VIDEO, frames, pts);
223229
}
224230

225231
void StreamWriter::flush() {

torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,42 @@ class StreamWriter {
162162
/// @param i Stream index.
163163
/// @param chunk Waveform tensor. Shape: ``(frame, channel)``.
164164
/// The ``dtype`` must match what was passed to ``add_audio_stream()`` method.
165-
void write_audio_chunk(int i, const torch::Tensor& chunk);
165+
/// @param pts
166+
/// @parblock
167+
/// Presentation timestamp. If provided, it overwrites the PTS of
168+
/// the first frame with the provided one. Otherwise, PTS are incremented per
169+
/// an inverse of sample rate. Only values exceed the PTS values processed
170+
/// internally.
171+
///
172+
/// __NOTE__: The provided value is converted to integer value expressed
173+
/// in basis of sample rate.
174+
/// Therefore, it is truncated to the nearest value of ``n / sample_rate``.
175+
/// @endparblock
176+
void write_audio_chunk(
177+
int i,
178+
const torch::Tensor& frames,
179+
const c10::optional<double>& pts = {});
166180
/// Write video data
167181
/// @param i Stream index.
168182
/// @param chunk Video/image tensor. Shape: ``(time, channel, height,
169183
/// width)``. The ``dtype`` must be ``torch.uint8``. The shape ``(height,
170184
/// width and the number of channels)`` must match what was configured when
171185
/// calling ``add_video_stream()``.
172-
void write_video_chunk(int i, const torch::Tensor& chunk);
186+
/// @param pts
187+
/// @parblock
188+
/// Presentation timestamp. If provided, it overwrites the PTS of
189+
/// the first frame with the provided one. Otherwise, PTS are incremented per
190+
/// an inverse of frame rate. Only values exceed the PTS values processed
191+
/// internally.
192+
///
193+
/// __NOTE__: The provided value is converted to integer value expressed
194+
/// in basis of frame rate.
195+
/// Therefore, it is truncated to the nearest value of ``n / frame_rate``.
196+
/// @endparblock
197+
void write_video_chunk(
198+
int i,
199+
const torch::Tensor& frames,
200+
const c10::optional<double>& pts = {});
173201
/// Flush the frames from encoders and write the frames to the destination.
174202
void flush();
175203
};

torchaudio/io/_stream_writer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,17 +275,24 @@ def close(self):
275275
self._s.close()
276276
self._is_open = False
277277

278-
def write_audio_chunk(self, i: int, chunk: torch.Tensor):
278+
def write_audio_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
279279
"""Write audio data
280280
281281
Args:
282282
i (int): Stream index.
283283
chunk (Tensor): Waveform tensor. Shape: `(frame, channel)`.
284284
The ``dtype`` must match what was passed to :py:meth:`add_audio_stream` method.
285+
pts (float, optional, or None): If provided, overwrite the presentation timestamp.
286+
287+
.. note::
288+
289+
The provided value is converted to integer value expressed in basis of
290+
sample rate. Therefore, it is truncated to the nearest value of
291+
``n / sample_rate``.
285292
"""
286-
self._s.write_audio_chunk(i, chunk)
293+
self._s.write_audio_chunk(i, chunk, pts)
287294

288-
def write_video_chunk(self, i: int, chunk: torch.Tensor):
295+
def write_video_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
289296
"""Write video/image data
290297
291298
Args:
@@ -295,8 +302,15 @@ def write_video_chunk(self, i: int, chunk: torch.Tensor):
295302
The ``dtype`` must be ``torch.uint8``.
296303
The shape (height, width and the number of channels) must match
297304
what was configured when calling :py:meth:`add_video_stream`
305+
pts (float, optional or None): If provided, overwrite the presentation timestamp.
306+
307+
.. note::
308+
309+
The provided value is converted to integer value expressed in basis of
310+
frame rate. Therefore, it is truncated to the nearest value of
311+
``n / frame_rate``.
298312
"""
299-
self._s.write_video_chunk(i, chunk)
313+
self._s.write_video_chunk(i, chunk, pts)
300314

301315
def flush(self):
302316
"""Flush the frames from encoders and write the frames to the destination."""

0 commit comments

Comments
 (0)