diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 8a29d065..8811cb5a 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -111,14 +111,18 @@ AudioEncoder::AudioEncoder( TORCH_CHECK( avFormatContext != nullptr, "Couldn't allocate AVFormatContext. ", - "Check the desired extension? ", + "The destination file is ", + fileName, + ", check the desired extension? ", getFFMPEGErrorStringFromErrorCode(status)); avFormatContext_.reset(avFormatContext); status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); TORCH_CHECK( status >= 0, - "avio_open failed: ", + "avio_open failed. The destination file is ", + fileName, + ", make sure it's a valid path? ", getFFMPEGErrorStringFromErrorCode(status)); initializeEncoder(sampleRate, bitRate); @@ -139,7 +143,9 @@ AudioEncoder::AudioEncoder( TORCH_CHECK( avFormatContext != nullptr, "Couldn't allocate AVFormatContext. ", - "Check the desired extension? ", + "Check the desired format? Got format=", + formatName, + ". ", getFFMPEGErrorStringFromErrorCode(status)); avFormatContext_.reset(avFormatContext); diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 1240d2d6..3507df44 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -161,6 +161,7 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch. return torch.empty([], dtype=torch.long) +# TODO-ENCODING: rename wf to samples @register_fake("torchcodec_ns::encode_audio_to_file") def encode_audio_to_file_abstract( wf: torch.Tensor, sample_rate: int, filename: str, bit_rate: Optional[int] = None diff --git a/src/torchcodec/encoders/__init__.py b/src/torchcodec/encoders/__init__.py new file mode 100644 index 00000000..51f5942b --- /dev/null +++ b/src/torchcodec/encoders/__init__.py @@ -0,0 +1 @@ +from ._audio_encoder import AudioEncoder # noqa diff --git a/src/torchcodec/encoders/_audio_encoder.py b/src/torchcodec/encoders/_audio_encoder.py new file mode 100644 index 00000000..bee05d0a --- /dev/null +++ b/src/torchcodec/encoders/_audio_encoder.py @@ -0,0 +1,53 @@ +from pathlib import Path +from typing import Optional, Union + +import torch +from torch import Tensor + +from torchcodec import _core + + +class AudioEncoder: + def __init__(self, samples: Tensor, *, sample_rate: int): + # Some of these checks are also done in C++: it's OK, they're cheap, and + # doing them here allows to surface them when the AudioEncoder is + # instantiated, rather than later when the encoding methods are called. + if not isinstance(samples, Tensor): + raise ValueError( + f"Expected samples to be a Tensor, got {type(samples) = }." + ) + if samples.ndim != 2: + raise ValueError(f"Expected 2D samples, got {samples.shape = }.") + if samples.dtype != torch.float32: + raise ValueError(f"Expected float32 samples, got {samples.dtype = }.") + if sample_rate <= 0: + raise ValueError(f"{sample_rate = } must be > 0.") + + self._samples = samples + self._sample_rate = sample_rate + + def to_file( + self, + dest: Union[str, Path], + *, + bit_rate: Optional[int] = None, + ) -> None: + _core.encode_audio_to_file( + wf=self._samples, + sample_rate=self._sample_rate, + filename=dest, + bit_rate=bit_rate, + ) + + def to_tensor( + self, + format: str, + *, + bit_rate: Optional[int] = None, + ) -> Tensor: + return _core.encode_audio_to_tensor( + wf=self._samples, + sample_rate=self._sample_rate, + format=format, + bit_rate=bit_rate, + ) diff --git a/test/test_encoders.py b/test/test_encoders.py new file mode 100644 index 00000000..a5ae5493 --- /dev/null +++ b/test/test_encoders.py @@ -0,0 +1,41 @@ +import re + +import pytest +import torch + +from torchcodec.encoders import AudioEncoder + + +class TestAudioEncoder: + + def test_bad_input(self): + with pytest.raises(ValueError, match="Expected samples to be a Tensor"): + AudioEncoder(samples=123, sample_rate=32_000) + with pytest.raises(ValueError, match="Expected 2D samples"): + AudioEncoder(samples=torch.rand(10), sample_rate=32_000) + with pytest.raises(ValueError, match="Expected float32 samples"): + AudioEncoder( + samples=torch.rand(10, 10, dtype=torch.float64), sample_rate=32_000 + ) + with pytest.raises(ValueError, match="sample_rate = 0 must be > 0"): + AudioEncoder(samples=torch.rand(10, 10), sample_rate=0) + + encoder = AudioEncoder(samples=torch.rand(2, 100), sample_rate=32_000) + + bad_path = "/bad/path.mp3" + with pytest.raises( + RuntimeError, + match=f"avio_open failed. The destination file is {bad_path}, make sure it's a valid path", + ): + encoder.to_file(dest=bad_path) + + bad_extension = "output.bad_extension" + with pytest.raises(RuntimeError, match="check the desired extension"): + encoder.to_file(dest=bad_extension) + + bad_format = "bad_format" + with pytest.raises( + RuntimeError, + match=re.escape(f"Check the desired format? Got format={bad_format}"), + ): + encoder.to_tensor(format=bad_format) diff --git a/test/test_ops.py b/test/test_ops.py index 6e53d27b..5fb4d350 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1134,7 +1134,7 @@ def test_bad_input(self, tmp_path): encode_audio_to_file( wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3" ) - with pytest.raises(RuntimeError, match="Check the desired extension"): + with pytest.raises(RuntimeError, match="check the desired extension"): encode_audio_to_file( wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension" )