Skip to content

Commit f1e5e91

Browse files
authored
Add AudioEncoder public Python API (#692)
1 parent d6b2d69 commit f1e5e91

File tree

6 files changed

+106
-4
lines changed

6 files changed

+106
-4
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,18 @@ AudioEncoder::AudioEncoder(
111111
TORCH_CHECK(
112112
avFormatContext != nullptr,
113113
"Couldn't allocate AVFormatContext. ",
114-
"Check the desired extension? ",
114+
"The destination file is ",
115+
fileName,
116+
", check the desired extension? ",
115117
getFFMPEGErrorStringFromErrorCode(status));
116118
avFormatContext_.reset(avFormatContext);
117119

118120
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
119121
TORCH_CHECK(
120122
status >= 0,
121-
"avio_open failed: ",
123+
"avio_open failed. The destination file is ",
124+
fileName,
125+
", make sure it's a valid path? ",
122126
getFFMPEGErrorStringFromErrorCode(status));
123127

124128
initializeEncoder(sampleRate, bitRate);
@@ -139,7 +143,9 @@ AudioEncoder::AudioEncoder(
139143
TORCH_CHECK(
140144
avFormatContext != nullptr,
141145
"Couldn't allocate AVFormatContext. ",
142-
"Check the desired extension? ",
146+
"Check the desired format? Got format=",
147+
formatName,
148+
". ",
143149
getFFMPEGErrorStringFromErrorCode(status));
144150
avFormatContext_.reset(avFormatContext);
145151

src/torchcodec/_core/ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
161161
return torch.empty([], dtype=torch.long)
162162

163163

164+
# TODO-ENCODING: rename wf to samples
164165
@register_fake("torchcodec_ns::encode_audio_to_file")
165166
def encode_audio_to_file_abstract(
166167
wf: torch.Tensor, sample_rate: int, filename: str, bit_rate: Optional[int] = None

src/torchcodec/encoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._audio_encoder import AudioEncoder # noqa
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from pathlib import Path
2+
from typing import Optional, Union
3+
4+
import torch
5+
from torch import Tensor
6+
7+
from torchcodec import _core
8+
9+
10+
class AudioEncoder:
11+
def __init__(self, samples: Tensor, *, sample_rate: int):
12+
# Some of these checks are also done in C++: it's OK, they're cheap, and
13+
# doing them here allows to surface them when the AudioEncoder is
14+
# instantiated, rather than later when the encoding methods are called.
15+
if not isinstance(samples, Tensor):
16+
raise ValueError(
17+
f"Expected samples to be a Tensor, got {type(samples) = }."
18+
)
19+
if samples.ndim != 2:
20+
raise ValueError(f"Expected 2D samples, got {samples.shape = }.")
21+
if samples.dtype != torch.float32:
22+
raise ValueError(f"Expected float32 samples, got {samples.dtype = }.")
23+
if sample_rate <= 0:
24+
raise ValueError(f"{sample_rate = } must be > 0.")
25+
26+
self._samples = samples
27+
self._sample_rate = sample_rate
28+
29+
def to_file(
30+
self,
31+
dest: Union[str, Path],
32+
*,
33+
bit_rate: Optional[int] = None,
34+
) -> None:
35+
_core.encode_audio_to_file(
36+
wf=self._samples,
37+
sample_rate=self._sample_rate,
38+
filename=dest,
39+
bit_rate=bit_rate,
40+
)
41+
42+
def to_tensor(
43+
self,
44+
format: str,
45+
*,
46+
bit_rate: Optional[int] = None,
47+
) -> Tensor:
48+
return _core.encode_audio_to_tensor(
49+
wf=self._samples,
50+
sample_rate=self._sample_rate,
51+
format=format,
52+
bit_rate=bit_rate,
53+
)

test/test_encoders.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import re
2+
3+
import pytest
4+
import torch
5+
6+
from torchcodec.encoders import AudioEncoder
7+
8+
9+
class TestAudioEncoder:
10+
11+
def test_bad_input(self):
12+
with pytest.raises(ValueError, match="Expected samples to be a Tensor"):
13+
AudioEncoder(samples=123, sample_rate=32_000)
14+
with pytest.raises(ValueError, match="Expected 2D samples"):
15+
AudioEncoder(samples=torch.rand(10), sample_rate=32_000)
16+
with pytest.raises(ValueError, match="Expected float32 samples"):
17+
AudioEncoder(
18+
samples=torch.rand(10, 10, dtype=torch.float64), sample_rate=32_000
19+
)
20+
with pytest.raises(ValueError, match="sample_rate = 0 must be > 0"):
21+
AudioEncoder(samples=torch.rand(10, 10), sample_rate=0)
22+
23+
encoder = AudioEncoder(samples=torch.rand(2, 100), sample_rate=32_000)
24+
25+
bad_path = "/bad/path.mp3"
26+
with pytest.raises(
27+
RuntimeError,
28+
match=f"avio_open failed. The destination file is {bad_path}, make sure it's a valid path",
29+
):
30+
encoder.to_file(dest=bad_path)
31+
32+
bad_extension = "output.bad_extension"
33+
with pytest.raises(RuntimeError, match="check the desired extension"):
34+
encoder.to_file(dest=bad_extension)
35+
36+
bad_format = "bad_format"
37+
with pytest.raises(
38+
RuntimeError,
39+
match=re.escape(f"Check the desired format? Got format={bad_format}"),
40+
):
41+
encoder.to_tensor(format=bad_format)

test/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,7 @@ def test_bad_input(self, tmp_path):
11341134
encode_audio_to_file(
11351135
wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
11361136
)
1137-
with pytest.raises(RuntimeError, match="Check the desired extension"):
1137+
with pytest.raises(RuntimeError, match="check the desired extension"):
11381138
encode_audio_to_file(
11391139
wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension"
11401140
)

0 commit comments

Comments
 (0)