Skip to content

Commit d4fee19

Browse files
committed
Fix file-like object support in FFmpeg dispatcher
In dispatcher mode, FFmpeg backend does not handle file-like object, and C++ implementation raises an issue. This commit fixes it by normalizing file-like object to string.
1 parent 5053aa7 commit d4fee19

File tree

4 files changed

+51
-3
lines changed

4 files changed

+51
-3
lines changed

test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import io
22
import itertools
33
import os
4+
import pathlib
45
import tarfile
56
from contextlib import contextmanager
67
from functools import partial
@@ -35,6 +36,25 @@
3536
class TestInfo(TempDirMixin, PytorchTestCase):
3637
_info = partial(get_info_func(), backend="ffmpeg")
3738

39+
def test_pathlike(self):
40+
"""FFmpeg dispatcher can query audio data from pathlike object"""
41+
sample_rate = 16000
42+
dtype = "float32"
43+
num_channels = 2
44+
duration = 1
45+
46+
path = self.get_temp_path("data.wav")
47+
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
48+
save_wav(path, data, sample_rate)
49+
50+
info = self._info(pathlib.Path(path))
51+
assert info.sample_rate == sample_rate
52+
assert info.num_frames == sample_rate * duration
53+
assert info.num_channels == num_channels
54+
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
55+
assert info.encoding == get_encoding("wav", dtype)
56+
57+
3858
@parameterized.expand(
3959
list(
4060
itertools.product(

test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import itertools
3+
import pathlib
34
import tarfile
45
from functools import partial
56

@@ -125,6 +126,21 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
125126
class TestLoad(LoadTestBase):
126127
"""Test the correctness of `self._load` for various formats"""
127128

129+
def test_pathlike(self):
130+
"""FFmpeg dispatcher can load waveform from pathlike object"""
131+
sample_rate = 16000
132+
dtype = "float32"
133+
num_channels = 2
134+
duration = 1
135+
136+
path = self.get_temp_path("data.wav")
137+
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
138+
save_wav(path, data, sample_rate)
139+
140+
waveform, sr = self._load(pathlib.Path(path))
141+
self.assertEqual(sr, sample_rate)
142+
self.assertEqual(waveform, data)
143+
128144
@parameterized.expand(
129145
list(
130146
itertools.product(

test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import os
3+
import pathlib
34
import subprocess
45
import sys
56
from functools import partial
@@ -146,6 +147,17 @@ def assert_save_consistency(
146147
@skipIfNoExec("ffmpeg")
147148
@skipIfNoFFmpeg
148149
class SaveTest(SaveTestBase):
150+
def test_pathlike(self):
151+
"""FFmpeg dispatcher can save audio data to pathlike object"""
152+
sample_rate = 16000
153+
dtype = "float32"
154+
num_channels = 2
155+
duration = 1
156+
157+
path = self.get_temp_path("data.wav")
158+
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
159+
self._save(pathlib.Path(path), data, sample_rate)
160+
149161
@nested_params(
150162
["path", "fileobj", "bytesio"],
151163
[

torchaudio/_backend/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_s
8282
if hasattr(uri, "read"):
8383
metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size)
8484
else:
85-
metadata = info_audio(uri, format)
85+
metadata = info_audio(os.path.normpath(uri), format)
8686
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
8787
metadata.encoding = _map_encoding(metadata.encoding)
8888
return metadata
@@ -108,7 +108,7 @@ def load(
108108
buffer_size,
109109
)
110110
else:
111-
return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)
111+
return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)
112112

113113
@staticmethod
114114
def save(
@@ -122,7 +122,7 @@ def save(
122122
buffer_size: int = 4096,
123123
) -> None:
124124
save_audio(
125-
uri,
125+
os.path.normpath(uri),
126126
src,
127127
sample_rate,
128128
channels_first,

0 commit comments

Comments
 (0)