Skip to content

Commit 8336258

Browse files
nateanlfacebook-github-bot
authored andcommitted
Add SourceSeparationBundle to prototype (#2440)
Summary: - Add SourceSeparationBundle class for source separation pipeline - Add `CONVTASNET_BASE_LIBRI2MIX` that is trained on Libri2Mix dataset. - Add integration test with example mixture audio and expected scale-invariant signal-to-distortion ratio (Si-SDR) score. The test computes the Si-SDR score with permutation-invariant training (PIT) criterion for all permutations of sources and use the highest value as the final output. The test verifies if the score is equal to or larger than the expected value. Pull Request resolved: #2440 Reviewed By: mthrok Differential Revision: D37997646 Pulled By: nateanl fbshipit-source-id: c951bcbbe8b7ed9553cb8793d6dc1ef90d5a29fe
1 parent 5c6e602 commit 8336258

File tree

4 files changed

+133
-0
lines changed

4 files changed

+133
-0
lines changed

test/integration_tests/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import pytest
24
import torch
35
import torchaudio
@@ -40,6 +42,11 @@ def ctc_decoder():
4042
"fr": "20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac",
4143
"it": "20170516-0900-PLENARY-16-it_20170516-18_56_31_1.flac",
4244
}
45+
_MIXTURE_FILE = "mixture_3729-6852-0037_8463-287645-0000.wav"
46+
_CLEAN_FILES = [
47+
"s1_3729-6852-0037_8463-287645-0000.wav",
48+
"s2_3729-6852-0037_8463-287645-0000.wav",
49+
]
4350

4451

4552
@pytest.fixture
@@ -53,6 +60,21 @@ def sample_speech(tmp_path, lang):
5360
return path
5461

5562

63+
@pytest.fixture
64+
def mixture_source():
65+
path = torchaudio.utils.download_asset(os.path.join("test-assets", f"{_MIXTURE_FILE}"))
66+
return path
67+
68+
69+
@pytest.fixture
70+
def clean_sources():
71+
paths = []
72+
for file in _CLEAN_FILES:
73+
path = torchaudio.utils.download_asset(os.path.join("test-assets", f"{file}"))
74+
paths.append(path)
75+
return paths
76+
77+
5678
def pytest_addoption(parser):
5779
parser.addoption(
5880
"--use-tmp-hub-dir",
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
import torchaudio
6+
from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX
7+
8+
9+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
10+
from source_separation.utils.metrics import PIT, sdr
11+
12+
13+
def test_source_separation_models(mixture_source, clean_sources):
14+
"""Integration test for the source separation pipeline.
15+
Given the mixture waveform with dimensions `(batch, 1, time)`, the pre-trained pipeline generates
16+
the separated sources Tensor with dimensions `(batch, num_sources, time)`.
17+
The test computes the scale-invariant signal-to-distortion ratio (Si-SDR) score in decibel (dB) with
18+
permutation invariant training (PIT) criterion. PIT computes Si-SDR scores between the estimated sources and the
19+
target sources for all permuations, then returns the highest values as the final output. The final
20+
Si-SDR score should be equal to or larger than the expected score.
21+
"""
22+
BUNDLE = CONVTASNET_BASE_LIBRI2MIX
23+
EXPECTED_SCORE = 8.1373 # expected Si-SDR score.
24+
model = BUNDLE.get_model()
25+
mixture_waveform, sample_rate = torchaudio.load(mixture_source)
26+
assert sample_rate == BUNDLE.sample_rate, "The sample rate of audio must match that in the bundle."
27+
clean_waveforms = []
28+
for source in clean_sources:
29+
clean_waveform, sample_rate = torchaudio.load(source)
30+
assert sample_rate == BUNDLE.sample_rate, "The sample rate of audio must match that in the bundle."
31+
clean_waveforms.append(clean_waveform)
32+
mixture_waveform = mixture_waveform.reshape(1, 1, -1)
33+
estimated_sources = model(mixture_waveform)
34+
clean_waveforms = torch.cat(clean_waveforms).unsqueeze(0)
35+
_sdr_pit = PIT(utility_func=sdr)
36+
sdr_values = _sdr_pit(estimated_sources, clean_waveforms)
37+
assert sdr_values >= EXPECTED_SCORE
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
2+
from .source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX
23

34

45
__all__ = [
6+
"CONVTASNET_BASE_LIBRI2MIX",
57
"EMFORMER_RNNT_BASE_MUSTC",
68
"EMFORMER_RNNT_BASE_TEDLIUM3",
79
]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from dataclasses import dataclass
2+
from functools import partial
3+
from typing import Callable
4+
5+
import torch
6+
import torchaudio
7+
8+
from torchaudio.prototype.models import conv_tasnet_base
9+
10+
11+
@dataclass
12+
class SourceSeparationBundle:
13+
"""torchaudio.prototype.pipelines.SourceSeparationBundle()
14+
15+
Dataclass that bundles components for performing source separation.
16+
17+
Example
18+
>>> import torchaudio
19+
>>> from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX
20+
>>> import torch
21+
>>>
22+
>>> # Build the separation model.
23+
>>> model = CONVTASNET_BASE_LIBRI2MIX.get_model()
24+
>>> 100%|███████████████████████████████|19.1M/19.1M [00:04<00:00, 4.93MB/s]
25+
>>>
26+
>>> # Instantiate the test set of Libri2Mix dataset.
27+
>>> dataset = torchaudio.datasets.LibriMix("/home/datasets/", subset="test")
28+
>>>
29+
>>> # Apply source separation on mixture audio.
30+
>>> for i, data in enumerate(dataset):
31+
>>> sample_rate, mixture, clean_sources = data
32+
>>> # Make sure the shape of input suits the model requirement.
33+
>>> mixture = mixture.reshape(1, 1, -1)
34+
>>> estimated_sources = model(mixture)
35+
>>> score = si_snr_pit(estimated_sources, clean_sources) # for demonstration
36+
>>> print(f"Si-SNR score is : {score}.)
37+
>>> break
38+
>>> Si-SNR score is : 16.24.
39+
>>>
40+
"""
41+
42+
_model_path: str
43+
_model_factory_func: Callable[[], torch.nn.Module]
44+
_sample_rate: int
45+
46+
@property
47+
def sample_rate(self) -> int:
48+
"""Sample rate (in cycles per second) of input waveforms.
49+
:type: int
50+
"""
51+
return self._sample_rate
52+
53+
def get_model(self) -> torch.nn.Module:
54+
model = self._model_factory_func()
55+
path = torchaudio.utils.download_asset(self._model_path)
56+
state_dict = torch.load(path)
57+
model.load_state_dict(state_dict)
58+
model.eval()
59+
return model
60+
61+
62+
CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle(
63+
_model_path="models/conv_tasnet_base_libri2mix.pt",
64+
_model_factory_func=partial(conv_tasnet_base, num_sources=2),
65+
_sample_rate=8000,
66+
)
67+
CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained ConvTasNet pipeline for source separation.
68+
The underlying model is constructed by :py:func:`torchaudio.prototyoe.models.conv_tasnet_base`
69+
and utilizes weights trained on Libri2Mix using training script ``lightning_train.py``
70+
`here <https://github.com/pytorch/audio/tree/main/examples/source_separation/>`__ with default arguments.
71+
Please refer to :py:class:`SourceSeparationBundle` for usage instructions.
72+
"""

0 commit comments

Comments
 (0)