8
8
#include < cstdint>
9
9
#include < cstdio>
10
10
#include < iostream>
11
+ #include < limits>
11
12
#include < sstream>
12
13
#include < stdexcept>
13
14
#include < string_view>
@@ -552,7 +553,8 @@ void VideoDecoder::addAudioStream(int streamIndex) {
552
553
containerMetadata_.allStreamMetadata [activeStreamIndex_];
553
554
streamMetadata.sampleRate =
554
555
static_cast <int64_t >(streamInfo.codecContext ->sample_rate );
555
- streamMetadata.numChannels = getNumChannels (streamInfo.codecContext );
556
+ streamMetadata.numChannels =
557
+ static_cast <int64_t >(getNumChannels (streamInfo.codecContext ));
556
558
}
557
559
558
560
// --------------------------------------------------------------------------
@@ -567,6 +569,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
567
569
568
570
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal (
569
571
std::optional<torch::Tensor> preAllocatedOutputTensor) {
572
+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
570
573
AVFrameStream avFrameStream = decodeAVFrame (
571
574
[this ](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
572
575
return convertAVFrameToFrameOutput (avFrameStream, preAllocatedOutputTensor);
@@ -685,6 +688,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
685
688
}
686
689
687
690
VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt (double seconds) {
691
+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
688
692
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
689
693
double frameStartTime =
690
694
ptsToSeconds (streamInfo.lastDecodedAvFramePts , streamInfo.timeBase );
@@ -757,7 +761,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
757
761
double startSeconds,
758
762
double stopSeconds) {
759
763
validateActiveStream (AVMEDIA_TYPE_VIDEO);
760
-
761
764
const auto & streamMetadata =
762
765
containerMetadata_.allStreamMetadata [activeStreamIndex_];
763
766
TORCH_CHECK (
@@ -835,6 +838,68 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
835
838
return frameBatchOutput;
836
839
}
837
840
841
+ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio (
842
+ double startSeconds,
843
+ std::optional<double > stopSecondsOptional) {
844
+ validateActiveStream (AVMEDIA_TYPE_AUDIO);
845
+
846
+ double stopSeconds =
847
+ stopSecondsOptional.value_or (std::numeric_limits<double >::max ());
848
+
849
+ TORCH_CHECK (
850
+ startSeconds <= stopSeconds,
851
+ " Start seconds (" + std::to_string (startSeconds) +
852
+ " ) must be less than or equal to stop seconds (" +
853
+ std::to_string (stopSeconds) + " ." );
854
+
855
+ if (startSeconds == stopSeconds) {
856
+ // For consistency with video
857
+ return torch::empty ({0 });
858
+ }
859
+
860
+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
861
+
862
+ // TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
863
+ // We should remove it and seek back to the stream's beginning when needed.
864
+ // See test_multiple_calls
865
+ TORCH_CHECK (
866
+ streamInfo.lastDecodedAvFramePts +
867
+ streamInfo.lastDecodedAvFrameDuration <=
868
+ secondsToClosestPts (startSeconds, streamInfo.timeBase ),
869
+ " Audio decoder cannot seek backwards, or start from the last decoded frame." );
870
+
871
+ setCursorPtsInSeconds (startSeconds);
872
+
873
+ // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
874
+ // cat(). This would save a copy. We know the duration of the output and the
875
+ // sample rate, so in theory we know the number of output samples.
876
+ std::vector<torch::Tensor> tensors;
877
+
878
+ auto stopPts = secondsToClosestPts (stopSeconds, streamInfo.timeBase );
879
+ auto finished = false ;
880
+ while (!finished) {
881
+ try {
882
+ AVFrameStream avFrameStream = decodeAVFrame ([this ](AVFrame* avFrame) {
883
+ return cursor_ < avFrame->pts + getDuration (avFrame);
884
+ });
885
+ auto frameOutput = convertAVFrameToFrameOutput (avFrameStream);
886
+ tensors.push_back (frameOutput.data );
887
+ } catch (const EndOfFileException& e) {
888
+ finished = true ;
889
+ }
890
+
891
+ // If stopSeconds is in [begin, end] of the last decoded frame, we should
892
+ // stop decoding more frames. Note that if we were to use [begin, end),
893
+ // which may seem more natural, then we would decode the frame starting at
894
+ // stopSeconds, which isn't what we want!
895
+ auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
896
+ streamInfo.lastDecodedAvFrameDuration ;
897
+ finished |= (streamInfo.lastDecodedAvFramePts ) <= stopPts &&
898
+ (stopPts <= lastDecodedAvFrameEnd);
899
+ }
900
+ return torch::cat (tensors, 1 );
901
+ }
902
+
838
903
// --------------------------------------------------------------------------
839
904
// SEEKING APIs
840
905
// --------------------------------------------------------------------------
@@ -871,6 +936,10 @@ I P P P I P P P I P P I P P I P
871
936
(2) is more efficient than (1) if there is an I frame between x and y.
872
937
*/
873
938
bool VideoDecoder::canWeAvoidSeeking () const {
939
+ const StreamInfo& streamInfo = streamInfos_.at (activeStreamIndex_);
940
+ if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
941
+ return true ;
942
+ }
874
943
int64_t lastDecodedAvFramePts =
875
944
streamInfos_.at (activeStreamIndex_).lastDecodedAvFramePts ;
876
945
if (cursor_ < lastDecodedAvFramePts) {
@@ -897,7 +966,7 @@ bool VideoDecoder::canWeAvoidSeeking() const {
897
966
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
898
967
// the comment of canWeAvoidSeeking() for details.
899
968
void VideoDecoder::maybeSeekToBeforeDesiredPts () {
900
- validateActiveStream (AVMEDIA_TYPE_VIDEO );
969
+ validateActiveStream ();
901
970
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
902
971
903
972
decodeStats_.numSeeksAttempted ++;
@@ -942,7 +1011,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
942
1011
943
1012
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
944
1013
std::function<bool (AVFrame*)> filterFunction) {
945
- validateActiveStream (AVMEDIA_TYPE_VIDEO );
1014
+ validateActiveStream ();
946
1015
947
1016
resetDecodeStats ();
948
1017
@@ -1071,13 +1140,14 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
1071
1140
AVFrame* avFrame = avFrameStream.avFrame .get ();
1072
1141
frameOutput.streamIndex = streamIndex;
1073
1142
auto & streamInfo = streamInfos_[streamIndex];
1074
- TORCH_CHECK (streamInfo.stream ->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO);
1075
1143
frameOutput.ptsSeconds = ptsToSeconds (
1076
1144
avFrame->pts , formatContext_->streams [streamIndex]->time_base );
1077
1145
frameOutput.durationSeconds = ptsToSeconds (
1078
1146
getDuration (avFrame), formatContext_->streams [streamIndex]->time_base );
1079
- // TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
1080
- if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
1147
+ if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1148
+ convertAudioAVFrameToFrameOutputOnCPU (
1149
+ avFrameStream, frameOutput, preAllocatedOutputTensor);
1150
+ } else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
1081
1151
convertAVFrameToFrameOutputOnCPU (
1082
1152
avFrameStream, frameOutput, preAllocatedOutputTensor);
1083
1153
} else if (streamInfo.videoStreamOptions .device .type () == torch::kCUDA ) {
@@ -1253,6 +1323,45 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
1253
1323
filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
1254
1324
}
1255
1325
1326
+ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
1327
+ VideoDecoder::AVFrameStream& avFrameStream,
1328
+ FrameOutput& frameOutput,
1329
+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
1330
+ TORCH_CHECK (
1331
+ !preAllocatedOutputTensor.has_value (),
1332
+ " pre-allocated audio tensor not supported yet." );
1333
+
1334
+ const AVFrame* avFrame = avFrameStream.avFrame .get ();
1335
+
1336
+ auto numSamples = avFrame->nb_samples ; // per channel
1337
+ auto numChannels = getNumChannels (avFrame);
1338
+ torch::Tensor outputData =
1339
+ torch::empty ({numChannels, numSamples}, torch::kFloat32 );
1340
+
1341
+ AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1342
+ // TODO-AUDIO Implement all formats.
1343
+ switch (format) {
1344
+ case AV_SAMPLE_FMT_FLTP: {
1345
+ uint8_t * outputChannelData = static_cast <uint8_t *>(outputData.data_ptr ());
1346
+ auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
1347
+ for (auto channel = 0 ; channel < numChannels;
1348
+ ++channel, outputChannelData += numBytesPerChannel) {
1349
+ memcpy (
1350
+ outputChannelData,
1351
+ avFrame->extended_data [channel],
1352
+ numBytesPerChannel);
1353
+ }
1354
+ break ;
1355
+ }
1356
+ default :
1357
+ TORCH_CHECK (
1358
+ false ,
1359
+ " Unsupported audio format (yet!): " ,
1360
+ av_get_sample_fmt_name (format));
1361
+ }
1362
+ frameOutput.data = outputData;
1363
+ }
1364
+
1256
1365
// --------------------------------------------------------------------------
1257
1366
// OUTPUT ALLOCATION AND SHAPE CONVERSION
1258
1367
// --------------------------------------------------------------------------
0 commit comments