Skip to content

Commit ff4abff

Browse files
authored
Audio decoding support: range-based core API (#538)
1 parent 374d950 commit ff4abff

14 files changed

+499
-53
lines changed

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,22 @@ int64_t getDuration(const AVFrame* frame) {
6060
#endif
6161
}
6262

63-
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext) {
63+
int getNumChannels(const AVFrame* avFrame) {
6464
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
6565
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
66-
int numChannels = avCodecContext->ch_layout.nb_channels;
66+
return avFrame->ch_layout.nb_channels;
6767
#else
68-
int numChannels = avCodecContext->channels;
68+
return av_get_channel_layout_nb_channels(avFrame->channel_layout);
6969
#endif
70+
}
7071

71-
return static_cast<int64_t>(numChannels);
72+
int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
73+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
74+
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
75+
return avCodecContext->ch_layout.nb_channels;
76+
#else
77+
return avCodecContext->channels;
78+
#endif
7279
}
7380

7481
AVIOBytesContext::AVIOBytesContext(

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
139139
int64_t getDuration(const UniqueAVFrame& frame);
140140
int64_t getDuration(const AVFrame* frame);
141141

142-
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext);
142+
int getNumChannels(const AVFrame* avFrame);
143+
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
143144

144145
// Returns true if sws_scale can handle unaligned data.
145146
bool canSwsScaleHandleUnalignedData();

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 116 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <cstdint>
99
#include <cstdio>
1010
#include <iostream>
11+
#include <limits>
1112
#include <sstream>
1213
#include <stdexcept>
1314
#include <string_view>
@@ -552,7 +553,8 @@ void VideoDecoder::addAudioStream(int streamIndex) {
552553
containerMetadata_.allStreamMetadata[activeStreamIndex_];
553554
streamMetadata.sampleRate =
554555
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
555-
streamMetadata.numChannels = getNumChannels(streamInfo.codecContext);
556+
streamMetadata.numChannels =
557+
static_cast<int64_t>(getNumChannels(streamInfo.codecContext));
556558
}
557559

558560
// --------------------------------------------------------------------------
@@ -567,6 +569,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
567569

568570
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
569571
std::optional<torch::Tensor> preAllocatedOutputTensor) {
572+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
570573
AVFrameStream avFrameStream = decodeAVFrame(
571574
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
572575
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
@@ -685,6 +688,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
685688
}
686689

687690
VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
691+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
688692
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
689693
double frameStartTime =
690694
ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
@@ -757,7 +761,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
757761
double startSeconds,
758762
double stopSeconds) {
759763
validateActiveStream(AVMEDIA_TYPE_VIDEO);
760-
761764
const auto& streamMetadata =
762765
containerMetadata_.allStreamMetadata[activeStreamIndex_];
763766
TORCH_CHECK(
@@ -835,6 +838,68 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
835838
return frameBatchOutput;
836839
}
837840

841+
torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
842+
double startSeconds,
843+
std::optional<double> stopSecondsOptional) {
844+
validateActiveStream(AVMEDIA_TYPE_AUDIO);
845+
846+
double stopSeconds =
847+
stopSecondsOptional.value_or(std::numeric_limits<double>::max());
848+
849+
TORCH_CHECK(
850+
startSeconds <= stopSeconds,
851+
"Start seconds (" + std::to_string(startSeconds) +
852+
") must be less than or equal to stop seconds (" +
853+
std::to_string(stopSeconds) + ".");
854+
855+
if (startSeconds == stopSeconds) {
856+
// For consistency with video
857+
return torch::empty({0});
858+
}
859+
860+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
861+
862+
// TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
863+
// We should remove it and seek back to the stream's beginning when needed.
864+
// See test_multiple_calls
865+
TORCH_CHECK(
866+
streamInfo.lastDecodedAvFramePts +
867+
streamInfo.lastDecodedAvFrameDuration <=
868+
secondsToClosestPts(startSeconds, streamInfo.timeBase),
869+
"Audio decoder cannot seek backwards, or start from the last decoded frame.");
870+
871+
setCursorPtsInSeconds(startSeconds);
872+
873+
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
874+
// cat(). This would save a copy. We know the duration of the output and the
875+
// sample rate, so in theory we know the number of output samples.
876+
std::vector<torch::Tensor> tensors;
877+
878+
auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
879+
auto finished = false;
880+
while (!finished) {
881+
try {
882+
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
883+
return cursor_ < avFrame->pts + getDuration(avFrame);
884+
});
885+
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
886+
tensors.push_back(frameOutput.data);
887+
} catch (const EndOfFileException& e) {
888+
finished = true;
889+
}
890+
891+
// If stopSeconds is in [begin, end] of the last decoded frame, we should
892+
// stop decoding more frames. Note that if we were to use [begin, end),
893+
// which may seem more natural, then we would decode the frame starting at
894+
// stopSeconds, which isn't what we want!
895+
auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
896+
streamInfo.lastDecodedAvFrameDuration;
897+
finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts &&
898+
(stopPts <= lastDecodedAvFrameEnd);
899+
}
900+
return torch::cat(tensors, 1);
901+
}
902+
838903
// --------------------------------------------------------------------------
839904
// SEEKING APIs
840905
// --------------------------------------------------------------------------
@@ -871,6 +936,10 @@ I P P P I P P P I P P I P P I P
871936
(2) is more efficient than (1) if there is an I frame between x and y.
872937
*/
873938
bool VideoDecoder::canWeAvoidSeeking() const {
939+
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
940+
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
941+
return true;
942+
}
874943
int64_t lastDecodedAvFramePts =
875944
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;
876945
if (cursor_ < lastDecodedAvFramePts) {
@@ -897,7 +966,7 @@ bool VideoDecoder::canWeAvoidSeeking() const {
897966
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
898967
// the comment of canWeAvoidSeeking() for details.
899968
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
900-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
969+
validateActiveStream();
901970
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
902971

903972
decodeStats_.numSeeksAttempted++;
@@ -942,7 +1011,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
9421011

9431012
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
9441013
std::function<bool(AVFrame*)> filterFunction) {
945-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
1014+
validateActiveStream();
9461015

9471016
resetDecodeStats();
9481017

@@ -1071,13 +1140,14 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
10711140
AVFrame* avFrame = avFrameStream.avFrame.get();
10721141
frameOutput.streamIndex = streamIndex;
10731142
auto& streamInfo = streamInfos_[streamIndex];
1074-
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
10751143
frameOutput.ptsSeconds = ptsToSeconds(
10761144
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
10771145
frameOutput.durationSeconds = ptsToSeconds(
10781146
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
1079-
// TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
1080-
if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
1147+
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1148+
convertAudioAVFrameToFrameOutputOnCPU(
1149+
avFrameStream, frameOutput, preAllocatedOutputTensor);
1150+
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
10811151
convertAVFrameToFrameOutputOnCPU(
10821152
avFrameStream, frameOutput, preAllocatedOutputTensor);
10831153
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) {
@@ -1253,6 +1323,45 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
12531323
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
12541324
}
12551325

1326+
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1327+
VideoDecoder::AVFrameStream& avFrameStream,
1328+
FrameOutput& frameOutput,
1329+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1330+
TORCH_CHECK(
1331+
!preAllocatedOutputTensor.has_value(),
1332+
"pre-allocated audio tensor not supported yet.");
1333+
1334+
const AVFrame* avFrame = avFrameStream.avFrame.get();
1335+
1336+
auto numSamples = avFrame->nb_samples; // per channel
1337+
auto numChannels = getNumChannels(avFrame);
1338+
torch::Tensor outputData =
1339+
torch::empty({numChannels, numSamples}, torch::kFloat32);
1340+
1341+
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
1342+
// TODO-AUDIO Implement all formats.
1343+
switch (format) {
1344+
case AV_SAMPLE_FMT_FLTP: {
1345+
uint8_t* outputChannelData = static_cast<uint8_t*>(outputData.data_ptr());
1346+
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
1347+
for (auto channel = 0; channel < numChannels;
1348+
++channel, outputChannelData += numBytesPerChannel) {
1349+
memcpy(
1350+
outputChannelData,
1351+
avFrame->extended_data[channel],
1352+
numBytesPerChannel);
1353+
}
1354+
break;
1355+
}
1356+
default:
1357+
TORCH_CHECK(
1358+
false,
1359+
"Unsupported audio format (yet!): ",
1360+
av_get_sample_fmt_name(format));
1361+
}
1362+
frameOutput.data = outputData;
1363+
}
1364+
12561365
// --------------------------------------------------------------------------
12571366
// OUTPUT ALLOCATION AND SHAPE CONVERSION
12581367
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,11 @@ class VideoDecoder {
221221
double startSeconds,
222222
double stopSeconds);
223223

224+
// TODO-AUDIO: Should accept sampleRate
225+
torch::Tensor getFramesPlayedInRangeAudio(
226+
double startSeconds,
227+
std::optional<double> stopSecondsOptional = std::nullopt);
228+
224229
class EndOfFileException : public std::runtime_error {
225230
public:
226231
explicit EndOfFileException(const std::string& msg)
@@ -379,6 +384,11 @@ class VideoDecoder {
379384
FrameOutput& frameOutput,
380385
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
381386

387+
void convertAudioAVFrameToFrameOutputOnCPU(
388+
AVFrameStream& avFrameStream,
389+
FrameOutput& frameOutput,
390+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
391+
382392
torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame);
383393

384394
int convertAVFrameToTensorUsingSwsScale(

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ namespace facebook::torchcodec {
2525
// https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#readme
2626
TORCH_LIBRARY(torchcodec_ns, m) {
2727
m.impl_abstract_pystub(
28-
"torchcodec.decoders._core.video_decoder_ops",
29-
"//pytorch/torchcodec:torchcodec");
28+
"torchcodec.decoders._core.ops", "//pytorch/torchcodec:torchcodec");
3029
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3130
m.def(
3231
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
@@ -48,6 +47,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4847
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4948
m.def(
5049
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
50+
m.def(
51+
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> Tensor");
5152
m.def(
5253
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
5354
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
@@ -289,6 +290,14 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
289290
return makeOpsFrameBatchOutput(result);
290291
}
291292

293+
torch::Tensor get_frames_by_pts_in_range_audio(
294+
at::Tensor& decoder,
295+
double start_seconds,
296+
std::optional<double> stop_seconds) {
297+
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
298+
return videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds);
299+
}
300+
292301
std::string quoteValue(const std::string& value) {
293302
return "\"" + value + "\"";
294303
}
@@ -540,6 +549,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
540549
m.impl("get_frames_at_indices", &get_frames_at_indices);
541550
m.impl("get_frames_in_range", &get_frames_in_range);
542551
m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range);
552+
m.impl("get_frames_by_pts_in_range_audio", &get_frames_by_pts_in_range_audio);
543553
m.impl("get_frames_by_pts", &get_frames_by_pts);
544554
m.impl("_test_frame_pts_equality", &_test_frame_pts_equality);
545555
m.impl(

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
112112
double start_seconds,
113113
double stop_seconds);
114114

115+
torch::Tensor get_frames_by_pts_in_range_audio(
116+
at::Tensor& decoder,
117+
double start_seconds,
118+
std::optional<double> stop_seconds = std::nullopt);
119+
115120
// For testing only. We need to implement this operation as a core library
116121
// function because what we're testing is round-tripping pts values as
117122
// double-precision floating point numbers from C++ to Python and back to C++.

src/torchcodec/decoders/_core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
get_container_metadata_from_header,
1313
VideoStreamMetadata,
1414
)
15-
from .video_decoder_ops import (
15+
from .ops import (
1616
_add_video_stream,
1717
_get_key_frame_indices,
1818
_test_frame_pts_equality,
@@ -27,6 +27,7 @@
2727
get_frames_at_indices,
2828
get_frames_by_pts,
2929
get_frames_by_pts_in_range,
30+
get_frames_by_pts_in_range_audio,
3031
get_frames_in_range,
3132
get_json_metadata,
3233
get_next_frame,

src/torchcodec/decoders/_core/_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414

15-
from torchcodec.decoders._core.video_decoder_ops import (
15+
from torchcodec.decoders._core.ops import (
1616
_get_container_json_metadata,
1717
_get_stream_json_metadata,
1818
create_from_file,

src/torchcodec/decoders/_core/video_decoder_ops.py renamed to src/torchcodec/decoders/_core/ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def load_torchcodec_extension():
7878
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
7979
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
8080
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
81+
get_frames_by_pts_in_range_audio = (
82+
torch.ops.torchcodec_ns.get_frames_by_pts_in_range_audio.default
83+
)
8184
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
8285
_test_frame_pts_equality = torch.ops.torchcodec_ns._test_frame_pts_equality.default
8386
_get_container_json_metadata = (
@@ -262,6 +265,17 @@ def get_frames_by_pts_in_range_abstract(
262265
)
263266

264267

268+
@register_fake("torchcodec_ns::get_frames_by_pts_in_range_audio")
269+
def get_frames_by_pts_in_range_audio_abstract(
270+
decoder: torch.Tensor,
271+
*,
272+
start_seconds: float,
273+
stop_seconds: Optional[float] = None,
274+
) -> torch.Tensor:
275+
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
276+
return torch.empty(image_size)
277+
278+
265279
@register_fake("torchcodec_ns::_get_key_frame_indices")
266280
def get_key_frame_indices_abstract(decoder: torch.Tensor) -> torch.Tensor:
267281
return torch.empty([], dtype=torch.int)

0 commit comments

Comments
 (0)