Skip to content

Commit 8277823

Browse files
mthrokfacebook-github-bot
authored andcommitted
Support overwriting PTS in StreamWriter (#3135)
Summary: Pull Request resolved: #3135 Differential Revision: D43724273 Pulled By: mthrok fbshipit-source-id: f89f3d15a065fe5b3a5ef150e34089e8cbcbc948
1 parent 1c2d182 commit 8277823

File tree

6 files changed

+79
-12
lines changed

6 files changed

+79
-12
lines changed

test/torchaudio_unittest/io/stream_writer_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,35 @@ def test_audio_pts_increment(self):
419419
num_samples += chunk.size(0)
420420
print(chunk.pts, expected)
421421
assert abs(chunk.pts - expected) < 1e-10
422+
423+
def test_video_pts_overwrite(self):
424+
"""Can overwrite PTS"""
425+
426+
ext = "mp4"
427+
num_frames = 256
428+
filename = f"test.{ext}"
429+
frame_rate = 10
430+
width, height = 96, 128
431+
432+
# Write data
433+
dst = self.get_dst(filename)
434+
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
435+
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
436+
437+
video = torch.randint(256, (num_frames, 3, height, width), dtype=torch.uint8)
438+
reference_pts = [2 * i / frame_rate for i in range(num_frames)]
439+
with writer.open():
440+
for i, pts in enumerate(reference_pts):
441+
writer.write_video_chunk(0, video[i : i + 1], pts)
442+
443+
# check
444+
if self.test_fileobj:
445+
dst.flush()
446+
447+
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
448+
reader.add_video_stream(1)
449+
pts = [chunk.pts for (chunk,) in reader.stream()]
450+
assert len(pts) == len(reference_pts)
451+
452+
for val, ref in zip(pts, reference_pts):
453+
assert val == ref

torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,10 @@ EncodeProcess::EncodeProcess(
511511
src_frame(get_video_frame(format, codec_ctx)),
512512
converter(AVMEDIA_TYPE_VIDEO, src_frame) {}
513513

514-
void EncodeProcess::process(AVMediaType type, const torch::Tensor& tensor) {
514+
void EncodeProcess::process(
515+
AVMediaType type,
516+
const torch::Tensor& tensor,
517+
const c10::optional<double>& pts) {
515518
TORCH_CHECK(
516519
codec_ctx->codec_type == type,
517520
"Attempted to write ",
@@ -521,6 +524,10 @@ void EncodeProcess::process(AVMediaType type, const torch::Tensor& tensor) {
521524
" stream.");
522525

523526
AVRational codec_tb = codec_ctx->time_base;
527+
if (pts) {
528+
src_frame->pts =
529+
static_cast<int64_t>(pts.value() * codec_tb.den / codec_tb.num);
530+
}
524531
for (const auto& frame : converter.convert(tensor)) {
525532
process_frame(frame);
526533
if (type == AVMEDIA_TYPE_VIDEO) {

torchaudio/csrc/ffmpeg/stream_writer/encode_process.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ class EncodeProcess {
3939
const c10::optional<std::string>& encoder_format,
4040
const c10::optional<std::string>& hw_accel);
4141

42-
void process(AVMediaType type, const torch::Tensor& tensor);
42+
void process(
43+
AVMediaType type,
44+
const torch::Tensor& tensor,
45+
const c10::optional<double>& pts);
4346

4447
void process_frame(AVFrame* src);
4548

torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,24 +198,30 @@ void StreamWriter::close() {
198198
}
199199
}
200200

201-
void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
201+
void StreamWriter::write_audio_chunk(
202+
int i,
203+
const torch::Tensor& waveform,
204+
const c10::optional<double>& pts) {
202205
TORCH_CHECK(
203206
0 <= i && i < static_cast<int>(processes.size()),
204207
"Invalid stream index. Index must be in range of [0, ",
205208
processes.size(),
206209
"). Found: ",
207210
i);
208-
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform);
211+
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform, pts);
209212
}
210213

211-
void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
214+
void StreamWriter::write_video_chunk(
215+
int i,
216+
const torch::Tensor& frames,
217+
const c10::optional<double>& pts) {
212218
TORCH_CHECK(
213219
0 <= i && i < static_cast<int>(processes.size()),
214220
"Invalid stream index. Index must be in range of [0, ",
215221
processes.size(),
216222
"). Found: ",
217223
i);
218-
processes[i].process(AVMEDIA_TYPE_VIDEO, frames);
224+
processes[i].process(AVMEDIA_TYPE_VIDEO, frames, pts);
219225
}
220226

221227
void StreamWriter::flush() {

torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,20 @@ class StreamWriter {
161161
/// @param i Stream index.
162162
/// @param chunk Waveform tensor. Shape: ``(frame, channel)``.
163163
/// The ``dtype`` must match what was passed to ``add_audio_stream()`` method.
164-
void write_audio_chunk(int i, const torch::Tensor& chunk);
164+
void write_audio_chunk(
165+
int i,
166+
const torch::Tensor& frames,
167+
const c10::optional<double>& pts = {});
165168
/// Write video data
166169
/// @param i Stream index.
167170
/// @param chunk Video/image tensor. Shape: ``(time, channel, height,
168171
/// width)``. The ``dtype`` must be ``torch.uint8``. The shape ``(height,
169172
/// width and the number of channels)`` must match what was configured when
170173
/// calling ``add_video_stream()``.
171-
void write_video_chunk(int i, const torch::Tensor& chunk);
174+
void write_video_chunk(
175+
int i,
176+
const torch::Tensor& frames,
177+
const c10::optional<double>& pts = {});
172178
/// Flush the frames from encoders and write the frames to the destination.
173179
void flush();
174180
};

torchaudio/io/_stream_writer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,17 +275,23 @@ def close(self):
275275
self._s.close()
276276
self._is_open = False
277277

278-
def write_audio_chunk(self, i: int, chunk: torch.Tensor):
278+
def write_audio_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
279279
"""Write audio data
280280
281281
Args:
282282
i (int): Stream index.
283283
chunk (Tensor): Waveform tensor. Shape: `(frame, channel)`.
284284
The ``dtype`` must match what was passed to :py:meth:`add_audio_stream` method.
285+
pts (float, optional, or None): If provided, overwrite the presentation timestamp.
286+
.. note::
287+
288+
The value of pts is converted to integer value expressed in basis of
289+
sample rate. Therefore, it is truncated to the nearest value of
290+
``n / sample_rate``.
285291
"""
286-
self._s.write_audio_chunk(i, chunk)
292+
self._s.write_audio_chunk(i, chunk, pts)
287293

288-
def write_video_chunk(self, i: int, chunk: torch.Tensor):
294+
def write_video_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
289295
"""Write video/image data
290296
291297
Args:
@@ -295,8 +301,15 @@ def write_video_chunk(self, i: int, chunk: torch.Tensor):
295301
The ``dtype`` must be ``torch.uint8``.
296302
The shape (height, width and the number of channels) must match
297303
what was configured when calling :py:meth:`add_video_stream`
304+
pts (float, optional or None): If provided, overwrite the presentation timestamp.
305+
306+
.. note::
307+
308+
The value of pts is converted to integer value expressed in basis of
309+
frame rate. Therefore, it is truncated to the nearest value of
310+
``n / frame_rate``.
298311
"""
299-
self._s.write_video_chunk(i, chunk)
312+
self._s.write_video_chunk(i, chunk, pts)
300313

301314
def flush(self):
302315
"""Flush the frames from encoders and write the frames to the destination."""

0 commit comments

Comments
 (0)