Skip to content

Refactor test backend #719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ The following test modules are defined for corresponding `torchaudio` module/fun

## Adding test

The following is the current practice of torchaudio test suite.

1. Unless the tests are related to I/O, use synthetic data. [`common_utils`](./common_utils.py) has some data generator functions.
1. When you add a new test case, use `common_utils.TorchaudioTestCase` as base class unless you are writing tests that are common to CPU / CUDA.
- Set class memeber `dtype`, `device` and `backend` for the desired behavior.
- If you do not set `backend` value in your test suite, then I/O functions will be unassigned and attempt to load/save file will fail.
- For `backend` value, in addition to available backends, you can also provide the value "default" and backend will be picked automatically based on availability.
Comment on lines +52 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would merge these two bullet points

1. If you are writing tests that should pass on diffrent dtype/devices, write a common class inheriting `common_utils.TestBaseMixin`, then inherit `common_utils.PytorchTestCase` and define class attributes (`dtype` / `device` / `backend`) there. See [Torchscript consistency test implementation](./torchscript_consistency_impl.py) and test definitions for [CPU](./torchscript_consistency_cpu_test.py) and [CUDA](./torchscript_consistency_cuda_test.py) devices.
1. For numerically comparing Tensors, use `assertEqual` method from `common_utils.PytorchTestCase` class. This method has a better support for a wide variety of Tensor types.

When you add a new feature(functional/transform), consider the following

1. When you add a new feature, please make it Torchscript-able and batch-consistent unless it degrades the performance. Please add the tests to see if the new feature meet these requirements.
Expand Down
52 changes: 29 additions & 23 deletions test/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import tempfile
import unittest
from typing import Iterable, Union
from contextlib import contextmanager
from typing import Union
from shutil import copytree

import torch
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio

_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -55,24 +54,14 @@ def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32):
return torch.tensor(arr).float().view(size) / m


@contextmanager
def AudioBackendScope(new_backend):
previous_backend = torchaudio.get_audio_backend()
try:
torchaudio.set_audio_backend(new_backend)
yield
finally:
torchaudio.set_audio_backend(previous_backend)


def filter_backends_with_mp3(backends):
# Filter out backends that do not support mp3
test_filepath = get_asset_path('steam-train-whistle-daniel_simon.mp3')

def supports_mp3(backend):
torchaudio.set_audio_backend(backend)
try:
with AudioBackendScope(backend):
torchaudio.load(test_filepath)
torchaudio.load(test_filepath)
return True
except (RuntimeError, ImportError):
return False
Expand All @@ -83,21 +72,38 @@ def supports_mp3(backend):
BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS)


def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
if backend == 'default':
if 'sox' in BACKENDS:
be = 'sox'
elif 'soundfile' in BACKENDS:
be = 'soundfile'
else:
raise unittest.SkipTest('No default backend available')
else:
be = backend

torchaudio.set_audio_backend(be)


class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
device = None
backend = None

def setUp(self):
super().setUp()
set_audio_backend(self.backend)

skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')

class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass

def common_test_class_parameters(
dtypes: Iterable[str] = ("float32", "float64"),
devices: Iterable[str] = ("cpu", "cuda"),
):
for device in devices:
for dtype in dtypes:
yield {"device": torch.device(device), "dtype": getattr(torch, dtype)}

skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')


def get_whitenoise(
Expand Down
14 changes: 8 additions & 6 deletions test/functional_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
from .functional_impl import Lfilter


class TestLFilterFloat32(Lfilter, common_utils.TestCase):
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')


class TestLFilterFloat64(Lfilter, common_utils.TestCase):
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestComputeDeltas(unittest.TestCase):
class TestComputeDeltas(common_utils.TorchaudioTestCase):
"""Test suite for correctness of compute_deltas"""
def test_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
Expand Down Expand Up @@ -57,7 +57,7 @@ def _test_istft_is_inverse_of_stft(kwargs):
_compare_estimate(sound, estimate)


class TestIstft(unittest.TestCase):
class TestIstft(common_utils.TorchaudioTestCase):
"""Test suite for correctness of istft with various input"""
number_of_trials = 100

Expand Down Expand Up @@ -273,7 +273,9 @@ def test_linearity_of_istft4(self):
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)


class TestDetectPitchFrequency(unittest.TestCase):
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
backend = 'default'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is the only one that needs a backend to be specified?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are couple of them, the less class requires backend, the better. yet still the backend is reset for the each class.


def test_pitch(self):
test_filepath_100 = common_utils.get_asset_path("100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = common_utils.get_asset_path("440Hz_44100Hz_16bit_05sec.wav")
Expand All @@ -294,7 +296,7 @@ def test_pitch(self):
self.assertFalse(s)


class TestDB_to_amplitude(unittest.TestCase):
class TestDB_to_amplitude(common_utils.TorchaudioTestCase):
def test_DB_to_amplitude(self):
# Make some noise
x = torch.rand(1000)
Expand Down
4 changes: 2 additions & 2 deletions test/functional_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@


@common_utils.skipIfNoCuda
class TestLFilterFloat32(Lfilter, common_utils.TestCase):
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestLFilterFloat64(Lfilter, common_utils.TestCase):
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
4 changes: 2 additions & 2 deletions test/kaldi_compatibility_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from .kaldi_compatibility_impl import Kaldi


class TestKaldiFloat32(Kaldi, common_utils.TestCase):
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')


class TestKaldiFloat64(Kaldi, common_utils.TestCase):
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
4 changes: 2 additions & 2 deletions test/kaldi_compatibility_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@


@common_utils.skipIfNoCuda
class TestKaldiFloat32(Kaldi, common_utils.TestCase):
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestKaldiFloat64(Kaldi, common_utils.TestCase):
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
2 changes: 2 additions & 0 deletions test/kaldi_compatibility_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def _load_params(path):


class Kaldi(common_utils.TestBaseMixin):
backend = 'sox'

def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
Expand Down
10 changes: 6 additions & 4 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import torchaudio
from torchaudio._internal.module_utils import is_module_available

from . import common_utils

class BackendSwitch:

class BackendSwitchMixin:
"""Test set/get_audio_backend works"""
backend = None
backend_module = None
Expand All @@ -21,20 +23,20 @@ def test_switch(self):
assert torchaudio.info == self.backend_module.info


class TestBackendSwitch_NoBackend(BackendSwitch, unittest.TestCase):
class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = None
backend_module = torchaudio.backend.no_backend


@unittest.skipIf(
not is_module_available('torchaudio._torchaudio'),
'torchaudio C++ extension not available')
class TestBackendSwitch_SoX(BackendSwitch, unittest.TestCase):
class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox'
backend_module = torchaudio.backend.sox_backend


@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
class TestBackendSwitch_soundfile(BackendSwitch, unittest.TestCase):
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend
9 changes: 6 additions & 3 deletions test/test_batch_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import unittest

import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F

from . import common_utils


class TestFunctional(TestCase):
class TestFunctional(common_utils.TorchaudioTestCase):
backend = 'default'
"""Test functions defined in `functional` module"""
def assert_batch_consistency(
self, functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
Expand Down Expand Up @@ -98,12 +98,15 @@ def test_sliding_window_cmn(self):
self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=False)

def test_vad(self):
common_utils.set_audio_backend('default')
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
self.assert_batch_consistencies(F.vad, waveform, sample_rate=sample_rate)


class TestTransforms(TestCase):
class TestTransforms(common_utils.TorchaudioTestCase):
backend = 'default'

"""Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201))
Expand Down
13 changes: 7 additions & 6 deletions test/test_compliance_kaldi.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import math
import os
import math
import unittest

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import unittest

from . import common_utils
from .compliance import utils as compliance_utils
from .common_utils import AudioBackendScope, BACKENDS


def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
Expand Down Expand Up @@ -46,7 +46,10 @@ def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
window[f, s] = wave[s_in_wave]


class Test_Kaldi(unittest.TestCase):
@common_utils.skipIfNoSoxBackend
class Test_Kaldi(common_utils.TorchaudioTestCase):
backend = 'sox'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: only one the method in tests below needs sox, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_wav yields different value between sox backend and soundfile backend so this test does not work for soundfile backend.


test_filepath = common_utils.get_asset_path('kaldi_file.wav')
test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
kaldi_output_dir = common_utils.get_asset_path('kaldi')
Expand Down Expand Up @@ -162,8 +165,6 @@ def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_resample_waveform(self):
def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound, args[1], args[2])
Expand Down
8 changes: 4 additions & 4 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch.utils.data import Dataset, DataLoader

from . import common_utils
from .common_utils import AudioBackendScope, BACKENDS


class TORCHAUDIODS(Dataset):
Expand All @@ -28,9 +27,10 @@ def __len__(self):
return len(self.data)


class Test_DataLoader(unittest.TestCase):
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
class Test_DataLoader(common_utils.TorchaudioTestCase):
backend = 'sox'

@common_utils.skipIfNoSoxBackend
def test_1(self):
expected_size = (2, 1, 16000)
ds = TORCHAUDIODS()
Expand Down
Loading