diff --git a/examples/pipeline_tacotron2/README.md b/examples/pipeline_tacotron2/README.md new file mode 100644 index 0000000000..8cb6c7c31b --- /dev/null +++ b/examples/pipeline_tacotron2/README.md @@ -0,0 +1 @@ +This is an example pipeline for text-to-speech using Tacotron2. diff --git a/examples/pipeline_tacotron2/loss.py b/examples/pipeline_tacotron2/loss.py new file mode 100644 index 0000000000..38f4b8bbcf --- /dev/null +++ b/examples/pipeline_tacotron2/loss.py @@ -0,0 +1,82 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** + +from typing import Tuple + +from torch import nn, Tensor + + +class Tacotron2Loss(nn.Module): + """Tacotron2 loss function modified from: + https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py + """ + + def __init__(self): + super().__init__() + + self.mse_loss = nn.MSELoss(reduction="mean") + self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean") + + def forward( + self, + model_outputs: Tuple[Tensor, Tensor, Tensor], + targets: Tuple[Tensor, Tensor], + ) -> Tuple[Tensor, Tensor, Tensor]: + r"""Pass the input through the Tacotron2 loss. + + The original implementation was introduced in + *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions* + [:footcite:`shen2018natural`]. + + Args: + model_outputs (tuple of three Tensors): The outputs of the + Tacotron2. These outputs should include three items: + (1) the predicted mel spectrogram before the postnet (``mel_specgram``) + with shape (batch, mel, time). + (2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``) + with shape (batch, mel, time), and + (3) the stop token prediction (``gate_out``) with shape (batch, ). + targets (tuple of two Tensors): The ground truth mel spectrogram (batch, mel, time) and + stop token with shape (batch, ). + + Returns: + mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram + with shape ``torch.Size([])``. + mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and + ground truth mel spectrogram with shape ``torch.Size([])``. + gate_loss (Tensor): The mean binary cross entropy loss of + the prediction on the stop token with shape ``torch.Size([])``. + """ + mel_target, gate_target = targets[0], targets[1] + gate_target = gate_target.view(-1, 1) + + mel_specgram, mel_specgram_postnet, gate_out = model_outputs + gate_out = gate_out.view(-1, 1) + mel_loss = self.mse_loss(mel_specgram, mel_target) + mel_postnet_loss = self.mse_loss(mel_specgram_postnet, mel_target) + gate_loss = self.bce_loss(gate_out, gate_target) + return mel_loss, mel_postnet_loss, gate_loss diff --git a/test/torchaudio_unittest/example/tacotron2/__init__.py b/test/torchaudio_unittest/example/tacotron2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_cpu_test.py b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_cpu_test.py new file mode 100644 index 0000000000..cb91655342 --- /dev/null +++ b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_cpu_test.py @@ -0,0 +1,23 @@ +import torch + +from .tacotron2_loss_impl import ( + Tacotron2LossShapeTests, + Tacotron2LossTorchscriptTests, + Tacotron2LossGradcheckTests, +) +from torchaudio_unittest.common_utils import PytorchTestCase + + +class TestTacotron2LossShapeFloat32CPU(PytorchTestCase, Tacotron2LossShapeTests): + dtype = torch.float32 + device = torch.device("cpu") + + +class TestTacotron2TorchsciptFloat32CPU(PytorchTestCase, Tacotron2LossTorchscriptTests): + dtype = torch.float32 + device = torch.device("cpu") + + +class TestTacotron2GradcheckFloat64CPU(PytorchTestCase, Tacotron2LossGradcheckTests): + dtype = torch.float64 # gradcheck needs a higher numerical accuracy + device = torch.device("cpu") diff --git a/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_gpu_test.py b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_gpu_test.py new file mode 100644 index 0000000000..9c1ae252c4 --- /dev/null +++ b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_gpu_test.py @@ -0,0 +1,26 @@ +import torch + +from .tacotron2_loss_impl import ( + Tacotron2LossShapeTests, + Tacotron2LossTorchscriptTests, + Tacotron2LossGradcheckTests, +) +from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase + + +@skipIfNoCuda +class TestTacotron2LossShapeFloat32CUDA(PytorchTestCase, Tacotron2LossShapeTests): + dtype = torch.float32 + device = torch.device("cuda") + + +@skipIfNoCuda +class TestTacotron2TorchsciptFloat32CUDA(PytorchTestCase, Tacotron2LossTorchscriptTests): + dtype = torch.float32 + device = torch.device("cuda") + + +@skipIfNoCuda +class TestTacotron2GradcheckFloat64CUDA(PytorchTestCase, Tacotron2LossGradcheckTests): + dtype = torch.float64 # gradcheck needs a higher numerical accuracy + device = torch.device("cuda") diff --git a/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_impl.py b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_impl.py new file mode 100644 index 0000000000..6bb6d8474e --- /dev/null +++ b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_impl.py @@ -0,0 +1,110 @@ +import torch +from torch.autograd import gradcheck, gradgradcheck + +from pipeline_tacotron2.loss import Tacotron2Loss +from torchaudio_unittest.common_utils import TempDirMixin + + +class Tacotron2LossInputMixin(TempDirMixin): + + def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300): + mel_specgram = torch.rand( + n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device + ) + mel_specgram_postnet = torch.rand( + n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device + ) + gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device) + truth_mel_specgram = torch.rand( + n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device + ) + truth_gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device) + + truth_mel_specgram.requires_grad = False + truth_gate_out.requires_grad = False + + return ( + mel_specgram, + mel_specgram_postnet, + gate_out, + truth_mel_specgram, + truth_gate_out, + ) + + +class Tacotron2LossShapeTests(Tacotron2LossInputMixin): + + def test_tacotron2_loss_shape(self): + """Validate the output shape of Tacotron2Loss.""" + n_batch = 16 + + ( + mel_specgram, + mel_specgram_postnet, + gate_out, + truth_mel_specgram, + truth_gate_out, + ) = self._get_inputs(n_batch=n_batch) + + mel_loss, mel_postnet_loss, gate_loss = Tacotron2Loss()( + (mel_specgram, mel_specgram_postnet, gate_out), + (truth_mel_specgram, truth_gate_out) + ) + + self.assertEqual(mel_loss.size(), torch.Size([])) + self.assertEqual(mel_postnet_loss.size(), torch.Size([])) + self.assertEqual(gate_loss.size(), torch.Size([])) + + +class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin): + + def _assert_torchscript_consistency(self, fn, tensors): + path = self.get_temp_path("func.zip") + torch.jit.script(fn).save(path) + ts_func = torch.jit.load(path) + + output = fn(tensors[:3], tensors[3:]) + ts_output = ts_func(tensors[:3], tensors[3:]) + + self.assertEqual(ts_output, output) + + def test_tacotron2_loss_torchscript_consistency(self): + """Validate the torchscript consistency of Tacotron2Loss.""" + + loss_fn = Tacotron2Loss() + self._assert_torchscript_consistency(loss_fn, self._get_inputs()) + + +class Tacotron2LossGradcheckTests(Tacotron2LossInputMixin): + + def test_tacotron2_loss_gradcheck(self): + """Performing gradient check on Tacotron2Loss.""" + ( + mel_specgram, + mel_specgram_postnet, + gate_out, + truth_mel_specgram, + truth_gate_out, + ) = self._get_inputs() + + mel_specgram.requires_grad_(True) + mel_specgram_postnet.requires_grad_(True) + gate_out.requires_grad_(True) + + def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out): + loss_fn = Tacotron2Loss() + return loss_fn( + (mel_specgram, mel_specgram_postnet, gate_out), + (truth_mel_specgram, truth_gate_out), + ) + + gradcheck( + _fn, + (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out), + fast_mode=True, + ) + gradgradcheck( + _fn, + (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out), + fast_mode=True, + )