diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 9fd7772c..f177c19b 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -101,8 +101,7 @@ AudioEncoder::AudioEncoder( const torch::Tensor wf, int sampleRate, std::string_view fileName, - std::optional bitRate, - std::optional numChannels) + const AudioStreamOptions& audioStreamOptions) : wf_(validateWf(wf)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; @@ -126,7 +125,7 @@ AudioEncoder::AudioEncoder( ", make sure it's a valid path? ", getFFMPEGErrorStringFromErrorCode(status)); - initializeEncoder(sampleRate, bitRate, numChannels); + initializeEncoder(sampleRate, audioStreamOptions); } AudioEncoder::AudioEncoder( @@ -134,8 +133,7 @@ AudioEncoder::AudioEncoder( int sampleRate, std::string_view formatName, std::unique_ptr avioContextHolder, - std::optional bitRate, - std::optional numChannels) + const AudioStreamOptions& audioStreamOptions) : wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; @@ -153,13 +151,12 @@ AudioEncoder::AudioEncoder( avFormatContext_->pb = avioContextHolder_->getAVIOContext(); - initializeEncoder(sampleRate, bitRate, numChannels); + initializeEncoder(sampleRate, audioStreamOptions); } void AudioEncoder::initializeEncoder( int sampleRate, - std::optional bitRate, - std::optional numChannels) { + const AudioStreamOptions& audioStreamOptions) { // We use the AVFormatContext's default codec for that // specific format/container. const AVCodec* avCodec = @@ -170,14 +167,17 @@ void AudioEncoder::initializeEncoder( TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); avCodecContext_.reset(avCodecContext); - if (bitRate.has_value()) { - TORCH_CHECK(*bitRate >= 0, "bit_rate=", *bitRate, " must be >= 0."); + auto desiredBitRate = audioStreamOptions.bitRate; + if (desiredBitRate.has_value()) { + TORCH_CHECK( + *desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0."); } // bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as // well when "-b:a" isn't specified. - avCodecContext_->bit_rate = bitRate.value_or(0); + avCodecContext_->bit_rate = desiredBitRate.value_or(0); - outNumChannels_ = static_cast(numChannels.value_or(wf_.sizes()[0])); + outNumChannels_ = + static_cast(audioStreamOptions.numChannels.value_or(wf_.sizes()[0])); validateNumChannels(*avCodec, outNumChannels_); // The avCodecContext layout defines the layout of the encoded output, it's // not related to the input sampes. diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 55a31e8a..08558b6b 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -2,6 +2,7 @@ #include #include "src/torchcodec/_core/AVIOBytesContext.h" #include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/StreamOptions.h" namespace facebook::torchcodec { class AudioEncoder { @@ -13,34 +14,30 @@ class AudioEncoder { // like passing 0, which results in choosing the minimum supported bit rate. // Passing 44_100 could result in output being 44000 if only 44000 is // supported. - // - // TODO-ENCODING: bundle the optional params like bitRate, numChannels, etc. - // into an AudioStreamOptions struct, or similar. AudioEncoder( const torch::Tensor wf, + // TODO-ENCODING: update this comment when we support an output sample + // rate. This will become the input sample rate. // The *output* sample rate. We can't really decide for the user what it // should be. Particularly, the sample rate of the input waveform should // match this, and that's up to the user. If sample rates don't match, // encoding will still work but audio will be distorted. int sampleRate, std::string_view fileName, - std::optional bitRate = std::nullopt, - std::optional numChannels = std::nullopt); + const AudioStreamOptions& audioStreamOptions); AudioEncoder( const torch::Tensor wf, int sampleRate, std::string_view formatName, std::unique_ptr avioContextHolder, - std::optional bitRate = std::nullopt, - std::optional numChannels = std::nullopt); + const AudioStreamOptions& audioStreamOptions); void encode(); torch::Tensor encodeToTensor(); private: void initializeEncoder( int sampleRate, - std::optional bitRate = std::nullopt, - std::optional numChannels = std::nullopt); + const AudioStreamOptions& audioStreamOptions); void encodeInnerLoop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& srcAVFrame); @@ -50,8 +47,8 @@ class AudioEncoder { UniqueAVCodecContext avCodecContext_; int streamIndex_; UniqueSwrContext swrContext_; - // TODO-ENCODING: outNumChannels should just be part of an options struct, - // see other TODO above. + AudioStreamOptions audioStreamOptions; + int outNumChannels_ = -1; const torch::Tensor wf_; diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index ef250da0..d600aa0a 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -43,8 +43,11 @@ struct VideoStreamOptions { struct AudioStreamOptions { AudioStreamOptions() {} - std::optional sampleRate; + // Encoding only + std::optional bitRate; + // Decoding and encoding: std::optional numChannels; + std::optional sampleRate; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index c6e43d09..b25a84e3 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -393,8 +393,13 @@ void encode_audio_to_file( std::string_view file_name, std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt) { + // TODO Fix implicit int conversion: + // https://github.com/pytorch/torchcodec/issues/679 + AudioStreamOptions audioStreamOptions; + audioStreamOptions.bitRate = bit_rate; + audioStreamOptions.numChannels = num_channels; AudioEncoder( - wf, validateSampleRate(sample_rate), file_name, bit_rate, num_channels) + wf, validateSampleRate(sample_rate), file_name, audioStreamOptions) .encode(); } @@ -405,13 +410,17 @@ at::Tensor encode_audio_to_tensor( std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt) { auto avioContextHolder = std::make_unique(); + // TODO Fix implicit int conversion: + // https://github.com/pytorch/torchcodec/issues/679 + AudioStreamOptions audioStreamOptions; + audioStreamOptions.bitRate = bit_rate; + audioStreamOptions.numChannels = num_channels; return AudioEncoder( wf, validateSampleRate(sample_rate), format, std::move(avioContextHolder), - bit_rate, - num_channels) + audioStreamOptions) .encodeToTensor(); }