@@ -580,14 +580,18 @@ void VideoDecoder::addVideoStream(
580
580
videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
581
581
}
582
582
583
- void VideoDecoder::addAudioStream (int streamIndex) {
583
+ void VideoDecoder::addAudioStream (
584
+ int streamIndex,
585
+ const AudioStreamOptions& audioStreamOptions) {
584
586
TORCH_CHECK (
585
587
seekMode_ == SeekMode::approximate,
586
588
" seek_mode must be 'approximate' for audio streams." );
587
589
588
590
addStream (streamIndex, AVMEDIA_TYPE_AUDIO);
589
591
590
592
auto & streamInfo = streamInfos_[activeStreamIndex_];
593
+ streamInfo.audioStreamOptions = audioStreamOptions;
594
+
591
595
auto & streamMetadata =
592
596
containerMetadata_.allStreamMetadata [activeStreamIndex_];
593
597
streamMetadata.sampleRate =
@@ -947,6 +951,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
947
951
(stopPts <= lastDecodedAvFrameEnd);
948
952
}
949
953
954
+ auto lastSamples = maybeFlushSwrBuffers ();
955
+ if (lastSamples.has_value ()) {
956
+ frames.push_back (*lastSamples);
957
+ }
958
+
950
959
return AudioFramesOutput{torch::cat (frames, 1 ), firstFramePtsSeconds};
951
960
}
952
961
@@ -1200,8 +1209,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
1200
1209
getDuration (avFrame),
1201
1210
formatContext_->streams [activeStreamIndex_]->time_base );
1202
1211
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1203
- convertAudioAVFrameToFrameOutputOnCPU (
1204
- avFrame, frameOutput, preAllocatedOutputTensor);
1212
+ convertAudioAVFrameToFrameOutputOnCPU (avFrame, frameOutput);
1205
1213
} else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
1206
1214
convertAVFrameToFrameOutputOnCPU (
1207
1215
avFrame, frameOutput, preAllocatedOutputTensor);
@@ -1379,24 +1387,30 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
1379
1387
1380
1388
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
1381
1389
UniqueAVFrame& srcAVFrame,
1382
- FrameOutput& frameOutput,
1383
- std::optional<torch::Tensor> preAllocatedOutputTensor) {
1384
- TORCH_CHECK (
1385
- !preAllocatedOutputTensor.has_value (),
1386
- " pre-allocated audio tensor not supported yet." );
1387
-
1390
+ FrameOutput& frameOutput) {
1388
1391
AVSampleFormat sourceSampleFormat =
1389
1392
static_cast <AVSampleFormat>(srcAVFrame->format );
1390
1393
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
1391
1394
1395
+ int sourceSampleRate = srcAVFrame->sample_rate ;
1396
+ int desiredSampleRate =
1397
+ streamInfos_[activeStreamIndex_].audioStreamOptions .sampleRate .value_or (
1398
+ sourceSampleRate);
1399
+
1400
+ bool mustConvert =
1401
+ (sourceSampleFormat != desiredSampleFormat ||
1402
+ sourceSampleRate != desiredSampleRate);
1403
+
1392
1404
UniqueAVFrame convertedAVFrame;
1393
- if (sourceSampleFormat != desiredSampleFormat) {
1394
- convertedAVFrame = convertAudioAVFrameSampleFormat (
1395
- srcAVFrame, sourceSampleFormat, desiredSampleFormat);
1405
+ if (mustConvert) {
1406
+ convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate (
1407
+ srcAVFrame,
1408
+ sourceSampleFormat,
1409
+ desiredSampleFormat,
1410
+ sourceSampleRate,
1411
+ desiredSampleRate);
1396
1412
}
1397
- const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1398
- ? convertedAVFrame
1399
- : srcAVFrame;
1413
+ const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
1400
1414
1401
1415
AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1402
1416
TORCH_CHECK (
@@ -1419,55 +1433,110 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1419
1433
memcpy (
1420
1434
outputChannelData, avFrame->extended_data [channel], numBytesPerChannel);
1421
1435
}
1436
+
1422
1437
frameOutput.data = outputData;
1423
1438
}
1424
1439
1425
- UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat (
1426
- const UniqueAVFrame& avFrame ,
1440
+ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate (
1441
+ const UniqueAVFrame& srcAVFrame ,
1427
1442
AVSampleFormat sourceSampleFormat,
1428
- AVSampleFormat desiredSampleFormat
1429
-
1430
- ) {
1443
+ AVSampleFormat desiredSampleFormat,
1444
+ int sourceSampleRate,
1445
+ int desiredSampleRate ) {
1431
1446
auto & streamInfo = streamInfos_[activeStreamIndex_];
1432
- const auto & streamMetadata =
1433
- containerMetadata_.allStreamMetadata [activeStreamIndex_];
1434
- int sampleRate = static_cast <int >(streamMetadata.sampleRate .value ());
1435
1447
1436
1448
if (!streamInfo.swrContext ) {
1437
1449
createSwrContext (
1438
- streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1450
+ streamInfo,
1451
+ sourceSampleFormat,
1452
+ desiredSampleFormat,
1453
+ sourceSampleRate,
1454
+ desiredSampleRate);
1439
1455
}
1440
1456
1441
1457
UniqueAVFrame convertedAVFrame (av_frame_alloc ());
1442
1458
TORCH_CHECK (
1443
1459
convertedAVFrame,
1444
1460
" Could not allocate frame for sample format conversion." );
1445
1461
1446
- setChannelLayout (convertedAVFrame, avFrame );
1462
+ setChannelLayout (convertedAVFrame, srcAVFrame );
1447
1463
convertedAVFrame->format = static_cast <int >(desiredSampleFormat);
1448
- convertedAVFrame->sample_rate = avFrame->sample_rate ;
1449
- convertedAVFrame->nb_samples = avFrame->nb_samples ;
1464
+ convertedAVFrame->sample_rate = desiredSampleRate;
1465
+ if (sourceSampleRate != desiredSampleRate) {
1466
+ // Note that this is an upper bound on the number of output samples.
1467
+ // `swr_convert()` will likely not fill convertedAVFrame with that many
1468
+ // samples if sample rate conversion is needed. It will buffer the last few
1469
+ // ones because those require future samples. That's also why we reset
1470
+ // nb_samples after the call to `swr_convert()`.
1471
+ // We could also use `swr_get_out_samples()` to determine the number of
1472
+ // output samples, but empirically `av_rescale_rnd()` seems to provide a
1473
+ // tighter bound.
1474
+ convertedAVFrame->nb_samples = av_rescale_rnd (
1475
+ swr_get_delay (streamInfo.swrContext .get (), sourceSampleRate) +
1476
+ srcAVFrame->nb_samples ,
1477
+ desiredSampleRate,
1478
+ sourceSampleRate,
1479
+ AV_ROUND_UP);
1480
+ } else {
1481
+ convertedAVFrame->nb_samples = srcAVFrame->nb_samples ;
1482
+ }
1450
1483
1451
1484
auto status = av_frame_get_buffer (convertedAVFrame.get (), 0 );
1452
1485
TORCH_CHECK (
1453
1486
status == AVSUCCESS,
1454
1487
" Could not allocate frame buffers for sample format conversion: " ,
1455
1488
getFFMPEGErrorStringFromErrorCode (status));
1456
1489
1457
- auto numSampleConverted = swr_convert (
1490
+ auto numConvertedSamples = swr_convert (
1458
1491
streamInfo.swrContext .get (),
1459
1492
convertedAVFrame->data ,
1460
1493
convertedAVFrame->nb_samples ,
1461
- static_cast <const uint8_t **>(const_cast <const uint8_t **>(avFrame->data )),
1462
- avFrame->nb_samples );
1494
+ static_cast <const uint8_t **>(
1495
+ const_cast <const uint8_t **>(srcAVFrame->data )),
1496
+ srcAVFrame->nb_samples );
1463
1497
TORCH_CHECK (
1464
- numSampleConverted > 0 ,
1498
+ numConvertedSamples > 0 ,
1465
1499
" Error in swr_convert: " ,
1466
- getFFMPEGErrorStringFromErrorCode (numSampleConverted));
1500
+ getFFMPEGErrorStringFromErrorCode (numConvertedSamples));
1501
+
1502
+ // See comment above about nb_samples
1503
+ convertedAVFrame->nb_samples = numConvertedSamples;
1467
1504
1468
1505
return convertedAVFrame;
1469
1506
}
1470
1507
1508
+ std::optional<torch::Tensor> VideoDecoder::maybeFlushSwrBuffers () {
1509
+ // When sample rate conversion is involved, swresample buffers some of the
1510
+ // samples in-between calls to swr_convert (see the libswresample docs).
1511
+ // That's because the last few samples in a given frame require future samples
1512
+ // from the next frame to be properly converted. This function flushes out the
1513
+ // samples that are stored in swresample's buffers.
1514
+ auto & streamInfo = streamInfos_[activeStreamIndex_];
1515
+ if (!streamInfo.swrContext ) {
1516
+ return std::nullopt;
1517
+ }
1518
+ auto numRemainingSamples = // this is an upper bound
1519
+ swr_get_out_samples (streamInfo.swrContext .get (), 0 );
1520
+
1521
+ if (numRemainingSamples == 0 ) {
1522
+ return std::nullopt;
1523
+ }
1524
+
1525
+ torch::Tensor lastSamples = torch::empty (
1526
+ {getNumChannels (streamInfo.codecContext ), numRemainingSamples},
1527
+ torch::kFloat32 );
1528
+ uint8_t * lastSamplesData = static_cast <uint8_t *>(lastSamples.data_ptr ());
1529
+
1530
+ auto actualNumRemainingSamples = swr_convert (
1531
+ streamInfo.swrContext .get (),
1532
+ &lastSamplesData,
1533
+ numRemainingSamples,
1534
+ nullptr ,
1535
+ 0 );
1536
+ return lastSamples.narrow (
1537
+ /* dim=*/ 1 , /* start=*/ 0 , /* length=*/ actualNumRemainingSamples);
1538
+ }
1539
+
1471
1540
// --------------------------------------------------------------------------
1472
1541
// OUTPUT ALLOCATION AND SHAPE CONVERSION
1473
1542
// --------------------------------------------------------------------------
@@ -1703,14 +1772,16 @@ void VideoDecoder::createSwsContext(
1703
1772
1704
1773
void VideoDecoder::createSwrContext (
1705
1774
StreamInfo& streamInfo,
1706
- int sampleRate,
1707
1775
AVSampleFormat sourceSampleFormat,
1708
- AVSampleFormat desiredSampleFormat) {
1776
+ AVSampleFormat desiredSampleFormat,
1777
+ int sourceSampleRate,
1778
+ int desiredSampleRate) {
1709
1779
auto swrContext = allocateSwrContext (
1710
1780
streamInfo.codecContext ,
1711
- sampleRate,
1712
1781
sourceSampleFormat,
1713
- desiredSampleFormat);
1782
+ desiredSampleFormat,
1783
+ sourceSampleRate,
1784
+ desiredSampleRate);
1714
1785
1715
1786
auto status = swr_init (swrContext);
1716
1787
TORCH_CHECK (
0 commit comments