Skip to content

Commit 611421e

Browse files
authored
Allow sample_rate parameter to audio decoder (#551)
1 parent ae19a78 commit 611421e

14 files changed

+1110
-63
lines changed

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ def __init__(
2525
source: Union[str, Path, bytes, Tensor],
2626
*,
2727
stream_index: Optional[int] = None,
28+
sample_rate: Optional[int] = None,
2829
):
2930
self._decoder = create_decoder(source=source, seek_mode="approximate")
3031

31-
core.add_audio_stream(self._decoder, stream_index=stream_index)
32+
core.add_audio_stream(
33+
self._decoder, stream_index=stream_index, sample_rate=sample_rate
34+
)
3235

3336
(
3437
self.metadata,
@@ -39,6 +42,9 @@ def __init__(
3942
decoder=self._decoder, stream_index=stream_index, media_type="audio"
4043
)
4144
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
45+
self._desired_sample_rate = (
46+
sample_rate if sample_rate is not None else self.metadata.sample_rate
47+
)
4248

4349
def get_samples_played_in_range(
4450
self, start_seconds: float, stop_seconds: Optional[float] = None
@@ -75,11 +81,7 @@ def get_samples_played_in_range(
7581
# So we do some basic math to figure out the position of the view that
7682
# we'll return.
7783

78-
# TODO: sample_rate is either the original one from metadata, or the
79-
# user-specified one (NIY)
80-
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
81-
sample_rate = self.metadata.sample_rate
82-
84+
sample_rate = self._desired_sample_rate
8385
# TODO: metadata's sample_rate should probably not be Optional
8486
assert sample_rate is not None # mypy.
8587

@@ -94,7 +96,7 @@ def get_samples_played_in_range(
9496
output_pts_seconds = first_pts
9597

9698
num_samples = frames.shape[1]
97-
last_pts = first_pts + num_samples / self.metadata.sample_rate
99+
last_pts = first_pts + num_samples / sample_rate
98100
if stop_seconds is not None and stop_seconds < last_pts:
99101
offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate)
100102
else:

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,21 @@ void setChannelLayout(
8686

8787
SwrContext* allocateSwrContext(
8888
UniqueAVCodecContext& avCodecContext,
89-
int sampleRate,
9089
AVSampleFormat sourceSampleFormat,
91-
AVSampleFormat desiredSampleFormat) {
90+
AVSampleFormat desiredSampleFormat,
91+
int sourceSampleRate,
92+
int desiredSampleRate) {
9293
SwrContext* swrContext = nullptr;
9394
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
9495
AVChannelLayout layout = avCodecContext->ch_layout;
9596
auto status = swr_alloc_set_opts2(
9697
&swrContext,
9798
&layout,
9899
desiredSampleFormat,
99-
sampleRate,
100+
desiredSampleRate,
100101
&layout,
101102
sourceSampleFormat,
102-
sampleRate,
103+
sourceSampleRate,
103104
0,
104105
nullptr);
105106

@@ -113,10 +114,10 @@ SwrContext* allocateSwrContext(
113114
nullptr,
114115
layout,
115116
desiredSampleFormat,
116-
sampleRate,
117+
desiredSampleRate,
117118
layout,
118119
sourceSampleFormat,
119-
sampleRate,
120+
sourceSampleRate,
120121
0,
121122
nullptr);
122123
#endif

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,10 @@ void setChannelLayout(
149149
const UniqueAVFrame& srcAVFrame);
150150
SwrContext* allocateSwrContext(
151151
UniqueAVCodecContext& avCodecContext,
152-
int sampleRate,
153152
AVSampleFormat sourceSampleFormat,
154-
AVSampleFormat desiredSampleFormat);
153+
AVSampleFormat desiredSampleFormat,
154+
int sourceSampleRate,
155+
int desiredSampleRate);
155156

156157
// Returns true if sws_scale can handle unaligned data.
157158
bool canSwsScaleHandleUnalignedData();

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 107 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -580,14 +580,18 @@ void VideoDecoder::addVideoStream(
580580
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
581581
}
582582

583-
void VideoDecoder::addAudioStream(int streamIndex) {
583+
void VideoDecoder::addAudioStream(
584+
int streamIndex,
585+
const AudioStreamOptions& audioStreamOptions) {
584586
TORCH_CHECK(
585587
seekMode_ == SeekMode::approximate,
586588
"seek_mode must be 'approximate' for audio streams.");
587589

588590
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
589591

590592
auto& streamInfo = streamInfos_[activeStreamIndex_];
593+
streamInfo.audioStreamOptions = audioStreamOptions;
594+
591595
auto& streamMetadata =
592596
containerMetadata_.allStreamMetadata[activeStreamIndex_];
593597
streamMetadata.sampleRate =
@@ -947,6 +951,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
947951
(stopPts <= lastDecodedAvFrameEnd);
948952
}
949953

954+
auto lastSamples = maybeFlushSwrBuffers();
955+
if (lastSamples.has_value()) {
956+
frames.push_back(*lastSamples);
957+
}
958+
950959
return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds};
951960
}
952961

@@ -1200,8 +1209,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
12001209
getDuration(avFrame),
12011210
formatContext_->streams[activeStreamIndex_]->time_base);
12021211
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1203-
convertAudioAVFrameToFrameOutputOnCPU(
1204-
avFrame, frameOutput, preAllocatedOutputTensor);
1212+
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
12051213
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
12061214
convertAVFrameToFrameOutputOnCPU(
12071215
avFrame, frameOutput, preAllocatedOutputTensor);
@@ -1379,24 +1387,30 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13791387

13801388
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13811389
UniqueAVFrame& srcAVFrame,
1382-
FrameOutput& frameOutput,
1383-
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1384-
TORCH_CHECK(
1385-
!preAllocatedOutputTensor.has_value(),
1386-
"pre-allocated audio tensor not supported yet.");
1387-
1390+
FrameOutput& frameOutput) {
13881391
AVSampleFormat sourceSampleFormat =
13891392
static_cast<AVSampleFormat>(srcAVFrame->format);
13901393
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
13911394

1395+
int sourceSampleRate = srcAVFrame->sample_rate;
1396+
int desiredSampleRate =
1397+
streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or(
1398+
sourceSampleRate);
1399+
1400+
bool mustConvert =
1401+
(sourceSampleFormat != desiredSampleFormat ||
1402+
sourceSampleRate != desiredSampleRate);
1403+
13921404
UniqueAVFrame convertedAVFrame;
1393-
if (sourceSampleFormat != desiredSampleFormat) {
1394-
convertedAVFrame = convertAudioAVFrameSampleFormat(
1395-
srcAVFrame, sourceSampleFormat, desiredSampleFormat);
1405+
if (mustConvert) {
1406+
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
1407+
srcAVFrame,
1408+
sourceSampleFormat,
1409+
desiredSampleFormat,
1410+
sourceSampleRate,
1411+
desiredSampleRate);
13961412
}
1397-
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1398-
? convertedAVFrame
1399-
: srcAVFrame;
1413+
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
14001414

14011415
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
14021416
TORCH_CHECK(
@@ -1419,55 +1433,110 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
14191433
memcpy(
14201434
outputChannelData, avFrame->extended_data[channel], numBytesPerChannel);
14211435
}
1436+
14221437
frameOutput.data = outputData;
14231438
}
14241439

1425-
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat(
1426-
const UniqueAVFrame& avFrame,
1440+
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate(
1441+
const UniqueAVFrame& srcAVFrame,
14271442
AVSampleFormat sourceSampleFormat,
1428-
AVSampleFormat desiredSampleFormat
1429-
1430-
) {
1443+
AVSampleFormat desiredSampleFormat,
1444+
int sourceSampleRate,
1445+
int desiredSampleRate) {
14311446
auto& streamInfo = streamInfos_[activeStreamIndex_];
1432-
const auto& streamMetadata =
1433-
containerMetadata_.allStreamMetadata[activeStreamIndex_];
1434-
int sampleRate = static_cast<int>(streamMetadata.sampleRate.value());
14351447

14361448
if (!streamInfo.swrContext) {
14371449
createSwrContext(
1438-
streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1450+
streamInfo,
1451+
sourceSampleFormat,
1452+
desiredSampleFormat,
1453+
sourceSampleRate,
1454+
desiredSampleRate);
14391455
}
14401456

14411457
UniqueAVFrame convertedAVFrame(av_frame_alloc());
14421458
TORCH_CHECK(
14431459
convertedAVFrame,
14441460
"Could not allocate frame for sample format conversion.");
14451461

1446-
setChannelLayout(convertedAVFrame, avFrame);
1462+
setChannelLayout(convertedAVFrame, srcAVFrame);
14471463
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
1448-
convertedAVFrame->sample_rate = avFrame->sample_rate;
1449-
convertedAVFrame->nb_samples = avFrame->nb_samples;
1464+
convertedAVFrame->sample_rate = desiredSampleRate;
1465+
if (sourceSampleRate != desiredSampleRate) {
1466+
// Note that this is an upper bound on the number of output samples.
1467+
// `swr_convert()` will likely not fill convertedAVFrame with that many
1468+
// samples if sample rate conversion is needed. It will buffer the last few
1469+
// ones because those require future samples. That's also why we reset
1470+
// nb_samples after the call to `swr_convert()`.
1471+
// We could also use `swr_get_out_samples()` to determine the number of
1472+
// output samples, but empirically `av_rescale_rnd()` seems to provide a
1473+
// tighter bound.
1474+
convertedAVFrame->nb_samples = av_rescale_rnd(
1475+
swr_get_delay(streamInfo.swrContext.get(), sourceSampleRate) +
1476+
srcAVFrame->nb_samples,
1477+
desiredSampleRate,
1478+
sourceSampleRate,
1479+
AV_ROUND_UP);
1480+
} else {
1481+
convertedAVFrame->nb_samples = srcAVFrame->nb_samples;
1482+
}
14501483

14511484
auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
14521485
TORCH_CHECK(
14531486
status == AVSUCCESS,
14541487
"Could not allocate frame buffers for sample format conversion: ",
14551488
getFFMPEGErrorStringFromErrorCode(status));
14561489

1457-
auto numSampleConverted = swr_convert(
1490+
auto numConvertedSamples = swr_convert(
14581491
streamInfo.swrContext.get(),
14591492
convertedAVFrame->data,
14601493
convertedAVFrame->nb_samples,
1461-
static_cast<const uint8_t**>(const_cast<const uint8_t**>(avFrame->data)),
1462-
avFrame->nb_samples);
1494+
static_cast<const uint8_t**>(
1495+
const_cast<const uint8_t**>(srcAVFrame->data)),
1496+
srcAVFrame->nb_samples);
14631497
TORCH_CHECK(
1464-
numSampleConverted > 0,
1498+
numConvertedSamples > 0,
14651499
"Error in swr_convert: ",
1466-
getFFMPEGErrorStringFromErrorCode(numSampleConverted));
1500+
getFFMPEGErrorStringFromErrorCode(numConvertedSamples));
1501+
1502+
// See comment above about nb_samples
1503+
convertedAVFrame->nb_samples = numConvertedSamples;
14671504

14681505
return convertedAVFrame;
14691506
}
14701507

1508+
std::optional<torch::Tensor> VideoDecoder::maybeFlushSwrBuffers() {
1509+
// When sample rate conversion is involved, swresample buffers some of the
1510+
// samples in-between calls to swr_convert (see the libswresample docs).
1511+
// That's because the last few samples in a given frame require future samples
1512+
// from the next frame to be properly converted. This function flushes out the
1513+
// samples that are stored in swresample's buffers.
1514+
auto& streamInfo = streamInfos_[activeStreamIndex_];
1515+
if (!streamInfo.swrContext) {
1516+
return std::nullopt;
1517+
}
1518+
auto numRemainingSamples = // this is an upper bound
1519+
swr_get_out_samples(streamInfo.swrContext.get(), 0);
1520+
1521+
if (numRemainingSamples == 0) {
1522+
return std::nullopt;
1523+
}
1524+
1525+
torch::Tensor lastSamples = torch::empty(
1526+
{getNumChannels(streamInfo.codecContext), numRemainingSamples},
1527+
torch::kFloat32);
1528+
uint8_t* lastSamplesData = static_cast<uint8_t*>(lastSamples.data_ptr());
1529+
1530+
auto actualNumRemainingSamples = swr_convert(
1531+
streamInfo.swrContext.get(),
1532+
&lastSamplesData,
1533+
numRemainingSamples,
1534+
nullptr,
1535+
0);
1536+
return lastSamples.narrow(
1537+
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
1538+
}
1539+
14711540
// --------------------------------------------------------------------------
14721541
// OUTPUT ALLOCATION AND SHAPE CONVERSION
14731542
// --------------------------------------------------------------------------
@@ -1703,14 +1772,16 @@ void VideoDecoder::createSwsContext(
17031772

17041773
void VideoDecoder::createSwrContext(
17051774
StreamInfo& streamInfo,
1706-
int sampleRate,
17071775
AVSampleFormat sourceSampleFormat,
1708-
AVSampleFormat desiredSampleFormat) {
1776+
AVSampleFormat desiredSampleFormat,
1777+
int sourceSampleRate,
1778+
int desiredSampleRate) {
17091779
auto swrContext = allocateSwrContext(
17101780
streamInfo.codecContext,
1711-
sampleRate,
17121781
sourceSampleFormat,
1713-
desiredSampleFormat);
1782+
desiredSampleFormat,
1783+
sourceSampleRate,
1784+
desiredSampleRate);
17141785

17151786
auto status = swr_init(swrContext);
17161787
TORCH_CHECK(

0 commit comments

Comments
 (0)