diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 688a249d..8e4e6c5d 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -77,10 +77,10 @@ if(DEFINED ENV{BUILD_AGAINST_ALL_FFMPEG_FROM_S3}) ) - make_torchcodec_library(libtorchcodec4 ffmpeg4) make_torchcodec_library(libtorchcodec7 ffmpeg7) make_torchcodec_library(libtorchcodec6 ffmpeg6) make_torchcodec_library(libtorchcodec5 ffmpeg5) + make_torchcodec_library(libtorchcodec4 ffmpeg4) else() message( @@ -97,6 +97,7 @@ else() libavformat libavcodec libavutil + libswresample libswscale ) diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp index b7dbd8ef..cb0152f0 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp @@ -60,7 +60,7 @@ int64_t getDuration(const AVFrame* frame) { #endif } -int getNumChannels(const AVFrame* avFrame) { +int getNumChannels(const UniqueAVFrame& avFrame) { #if LIBAVFILTER_VERSION_MAJOR > 8 || \ (LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) return avFrame->ch_layout.nb_channels; @@ -78,6 +78,57 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext) { #endif } +void setChannelLayout( + UniqueAVFrame& dstAVFrame, + const UniqueAVFrame& srcAVFrame) { +#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 + dstAVFrame->ch_layout = srcAVFrame->ch_layout; +#else + dstAVFrame->channel_layout = srcAVFrame->channel_layout; +#endif +} + +SwrContext* allocateSwrContext( + UniqueAVCodecContext& avCodecContext, + int sampleRate, + AVSampleFormat sourceSampleFormat, + AVSampleFormat desiredSampleFormat) { + SwrContext* swrContext = nullptr; +#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 + AVChannelLayout layout = avCodecContext->ch_layout; + auto status = swr_alloc_set_opts2( + &swrContext, + &layout, + desiredSampleFormat, + sampleRate, + &layout, + sourceSampleFormat, + sampleRate, + 0, + nullptr); + + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't create SwrContext: ", + getFFMPEGErrorStringFromErrorCode(status)); +#else + int64_t layout = static_cast(avCodecContext->channel_layout); + swrContext = swr_alloc_set_opts( + nullptr, + layout, + desiredSampleFormat, + sampleRate, + layout, + sourceSampleFormat, + sampleRate, + 0, + nullptr); +#endif + + TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext"); + return swrContext; +} + AVIOBytesContext::AVIOBytesContext( const void* data, size_t dataSize, diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index 88a81d18..955ea82d 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -22,6 +22,7 @@ extern "C" { #include #include #include +#include #include } @@ -67,6 +68,8 @@ using UniqueAVIOContext = std:: unique_ptr>; using UniqueSwsContext = std::unique_ptr>; +using UniqueSwrContext = + std::unique_ptr>; // These 2 classes share the same underlying AVPacket object. They are meant to // be used in tandem, like so: @@ -139,9 +142,18 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode); int64_t getDuration(const UniqueAVFrame& frame); int64_t getDuration(const AVFrame* frame); -int getNumChannels(const AVFrame* avFrame); +int getNumChannels(const UniqueAVFrame& avFrame); int getNumChannels(const UniqueAVCodecContext& avCodecContext); +void setChannelLayout( + UniqueAVFrame& dstAVFrame, + const UniqueAVFrame& srcAVFrame); +SwrContext* allocateSwrContext( + UniqueAVCodecContext& avCodecContext, + int sampleRate, + AVSampleFormat sourceSampleFormat, + AVSampleFormat desiredSampleFormat); + // Returns true if sws_scale can handle unaligned data. bool canSwsScaleHandleUnalignedData(); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 96f47c95..9871db64 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -23,6 +23,7 @@ extern "C" { #include #include #include +#include #include } @@ -559,6 +560,12 @@ void VideoDecoder::addAudioStream(int streamIndex) { static_cast(streamInfo.codecContext->sample_rate); streamMetadata.numChannels = static_cast(getNumChannels(streamInfo.codecContext)); + + // FFmpeg docs say that the decoder will try to decode natively in this + // format, if it can. Docs don't say what the decoder does when it doesn't + // support that format, but it looks like it does nothing, so this probably + // doesn't hurt. + streamInfo.codecContext->request_sample_fmt = AV_SAMPLE_FMT_FLTP; } // -------------------------------------------------------------------------- @@ -1350,37 +1357,89 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( !preAllocatedOutputTensor.has_value(), "pre-allocated audio tensor not supported yet."); - const AVFrame* avFrame = avFrameStream.avFrame.get(); + AVSampleFormat sourceSampleFormat = + static_cast(avFrameStream.avFrame->format); + AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP; + + UniqueAVFrame convertedAVFrame; + if (sourceSampleFormat != desiredSampleFormat) { + convertedAVFrame = convertAudioAVFrameSampleFormat( + avFrameStream.avFrame, sourceSampleFormat, desiredSampleFormat); + } + const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat) + ? convertedAVFrame + : avFrameStream.avFrame; + + AVSampleFormat format = static_cast(avFrame->format); + TORCH_CHECK( + format == desiredSampleFormat, + "Something went wrong, the frame didn't get converted to the desired format. ", + "Desired format = ", + av_get_sample_fmt_name(desiredSampleFormat), + "source format = ", + av_get_sample_fmt_name(format)); auto numSamples = avFrame->nb_samples; // per channel auto numChannels = getNumChannels(avFrame); torch::Tensor outputData = torch::empty({numChannels, numSamples}, torch::kFloat32); - AVSampleFormat format = static_cast(avFrame->format); - // TODO-AUDIO Implement all formats. - switch (format) { - case AV_SAMPLE_FMT_FLTP: { - uint8_t* outputChannelData = static_cast(outputData.data_ptr()); - auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); - for (auto channel = 0; channel < numChannels; - ++channel, outputChannelData += numBytesPerChannel) { - memcpy( - outputChannelData, - avFrame->extended_data[channel], - numBytesPerChannel); - } - break; - } - default: - TORCH_CHECK( - false, - "Unsupported audio format (yet!): ", - av_get_sample_fmt_name(format)); + uint8_t* outputChannelData = static_cast(outputData.data_ptr()); + auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); + for (auto channel = 0; channel < numChannels; + ++channel, outputChannelData += numBytesPerChannel) { + memcpy( + outputChannelData, avFrame->extended_data[channel], numBytesPerChannel); } frameOutput.data = outputData; } +UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat( + const UniqueAVFrame& avFrame, + AVSampleFormat sourceSampleFormat, + AVSampleFormat desiredSampleFormat + +) { + auto& streamInfo = streamInfos_[activeStreamIndex_]; + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[activeStreamIndex_]; + int sampleRate = static_cast(streamMetadata.sampleRate.value()); + + if (!streamInfo.swrContext) { + createSwrContext( + streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat); + } + + UniqueAVFrame convertedAVFrame(av_frame_alloc()); + TORCH_CHECK( + convertedAVFrame, + "Could not allocate frame for sample format conversion."); + + setChannelLayout(convertedAVFrame, avFrame); + convertedAVFrame->format = static_cast(desiredSampleFormat); + convertedAVFrame->sample_rate = avFrame->sample_rate; + convertedAVFrame->nb_samples = avFrame->nb_samples; + + auto status = av_frame_get_buffer(convertedAVFrame.get(), 0); + TORCH_CHECK( + status == AVSUCCESS, + "Could not allocate frame buffers for sample format conversion: ", + getFFMPEGErrorStringFromErrorCode(status)); + + auto numSampleConverted = swr_convert( + streamInfo.swrContext.get(), + convertedAVFrame->data, + convertedAVFrame->nb_samples, + static_cast(const_cast(avFrame->data)), + avFrame->nb_samples); + TORCH_CHECK( + numSampleConverted > 0, + "Error in swr_convert: ", + getFFMPEGErrorStringFromErrorCode(numSampleConverted)); + + return convertedAVFrame; +} + // -------------------------------------------------------------------------- // OUTPUT ALLOCATION AND SHAPE CONVERSION // -------------------------------------------------------------------------- @@ -1614,6 +1673,25 @@ void VideoDecoder::createSwsContext( streamInfo.swsContext.reset(swsContext); } +void VideoDecoder::createSwrContext( + StreamInfo& streamInfo, + int sampleRate, + AVSampleFormat sourceSampleFormat, + AVSampleFormat desiredSampleFormat) { + auto swrContext = allocateSwrContext( + streamInfo.codecContext, + sampleRate, + sourceSampleFormat, + desiredSampleFormat); + + auto status = swr_init(swrContext); + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't initialize SwrContext: ", + getFFMPEGErrorStringFromErrorCode(status)); + streamInfo.swrContext.reset(swrContext); +} + // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 63e01899..f72f31ab 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -355,6 +355,7 @@ class VideoDecoder { FilterGraphContext filterGraphContext; ColorConversionLibrary colorConversionLibrary = FILTERGRAPH; UniqueSwsContext swsContext; + UniqueSwrContext swrContext; // Used to know whether a new FilterGraphContext or UniqueSwsContext should // be created before decoding a new frame. @@ -402,6 +403,11 @@ class VideoDecoder { const AVFrame* avFrame, torch::Tensor& outputTensor); + UniqueAVFrame convertAudioAVFrameSampleFormat( + const UniqueAVFrame& avFrame, + AVSampleFormat sourceSampleFormat, + AVSampleFormat desiredSampleFormat); + // -------------------------------------------------------------------------- // COLOR CONVERSION LIBRARIES HANDLERS CREATION // -------------------------------------------------------------------------- @@ -416,6 +422,12 @@ class VideoDecoder { const DecodedFrameContext& frameContext, const enum AVColorSpace colorspace); + void createSwrContext( + StreamInfo& streamInfo, + int sampleRate, + AVSampleFormat sourceSampleFormat, + AVSampleFormat desiredSampleFormat); + // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index 3fc56c11..e68c22b6 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -1070,3 +1070,21 @@ def test_frame_start_is_not_zero(self): reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index) torch.testing.assert_close(samples.data, reference_frames) + + def test_single_channel(self): + asset = SINE_MONO_S32 + decoder = AudioDecoder(asset.path) + + samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=2) + assert samples.data.shape[0] == asset.num_channels == 1 + + def test_format_conversion(self): + asset = SINE_MONO_S32 + decoder = AudioDecoder(asset.path) + assert decoder.metadata.sample_format == asset.sample_format == "s32" + + all_samples = decoder.get_samples_played_in_range(start_seconds=0) + assert all_samples.data.dtype == torch.float32 + + reference_frames = asset.get_frame_data_by_range(start=0, stop=asset.num_frames) + torch.testing.assert_close(all_samples.data, reference_frames) diff --git a/test/resources/sine_mono_s32.wav.stream0.all_frames.pt b/test/resources/sine_mono_s32.wav.stream0.all_frames.pt new file mode 100644 index 00000000..06cfe5b8 Binary files /dev/null and b/test/resources/sine_mono_s32.wav.stream0.all_frames.pt differ diff --git a/test/utils.py b/test/utils.py index c34ef51f..e376193a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -444,6 +444,9 @@ def sample_format(self) -> str: }, ) +# Note that the file itself is s32 sample format, but the reference frames are +# stored as fltp. We can add the s32 original reference frames once we support +# decoding to non-fltp format, but for now we don't need to. SINE_MONO_S32 = TestAudio( filename="sine_mono_s32.wav", default_stream_index=0,