Skip to content

Commit 0c263a9

Browse files
author
Caroline Chen
authored
Replace existing prototype RNNT Loss (#1479)
Replace the prototype RNNT implementation (using warp-transducer) with one without external library dependencies
1 parent b5d8027 commit 0c263a9

34 files changed

+2698
-561
lines changed

.gitmodules

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
[submodule "third_party/warp_transducer/submodule"]
2-
path = third_party/transducer/submodule
3-
url = https://github.com/HawkAaron/warp-transducer
4-
ignore = dirty
51
[submodule "kaldi"]
62
path = third_party/kaldi/submodule
73
url = https://github.com/kaldi-asr/kaldi

docs/source/index.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Features described in this documentation are classified by release status:
2121
*Prototype:* These features are typically not available as part of
2222
binary distributions like PyPI or Conda, except sometimes behind run-time
2323
flags, and are at an early stage for feedback and testing.
24-
24+
2525

2626
The :mod:`torchaudio` package consists of I/O, popular datasets and common audio transformations.
2727

@@ -39,9 +39,9 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
3939
compliance.kaldi
4040
kaldi_io
4141
utils
42-
transducer
43-
44-
42+
rnnt_loss
43+
44+
4545
.. toctree::
4646
:maxdepth: 1
4747
:caption: PyTorch Libraries

docs/source/transducer.rst renamed to docs/source/rnnt_loss.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
.. role:: hidden
22
:class: hidden-section
33

4-
torchaudio.prototype.transducer
4+
torchaudio.prototype.rnnt_loss
55
===============================
66

7-
.. currentmodule:: torchaudio.prototype.transducer
7+
.. currentmodule:: torchaudio.prototype.rnnt_loss
88

99
.. note::
1010

11-
The RNN transducer loss is a prototype feature, see `here <https://pytorch.org/audio>`_ to learn more about the nomenclature. It is only available within the nightlies, and also needs to be imported explicitly using: :code:`from torchaudio.prototype.transducer import rnnt_loss, RNNTLoss`.
11+
The RNN transducer loss is a prototype feature, see `here <https://pytorch.org/audio>`_ to learn more about the nomenclature. It is only available within the nightlies, and also needs to be imported explicitly using: :code:`from torchaudio.prototype.rnnt_loss import rnnt_loss, RNNTLoss`.
1212

1313
rnnt_loss
1414
---------

examples/libtorchaudio/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ SET(BUILD_LIBTORCHAUDIO ON CACHE BOOL "Build libtorchaudio")
66
SET(BUILD_SOX ON CACHE BOOL "Build libsox into libtorchaudio")
77

88
SET(BUILD_KALDI OFF CACHE BOOL "Build Kaldi into libtorchaudio")
9-
SET(BUILD_TRANSDUCER OFF CACHE BOOL "Build Python binding")
9+
SET(BUILD_TRANSDUCER OFF CACHE BOOL "Build transducer into libtorchaudio")
1010
SET(BUILD_TORCHAUDIO_PYTHON_EXTENSION OFF CACHE BOOL "Build Python binding")
1111

1212
find_package(Torch REQUIRED)

test/torchaudio_unittest/rnnt/__init__.py

Whitespace-only changes.
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import numpy as np
2+
import torch
3+
4+
5+
class _NumpyTransducer(torch.autograd.Function):
6+
@staticmethod
7+
def forward(
8+
ctx,
9+
log_probs,
10+
logit_lengths,
11+
target_lengths,
12+
targets,
13+
blank=-1,
14+
):
15+
device = log_probs.device
16+
log_probs = log_probs.cpu().data.numpy()
17+
logit_lengths = logit_lengths.cpu().data.numpy()
18+
target_lengths = target_lengths.cpu().data.numpy()
19+
targets = targets.cpu().data.numpy()
20+
21+
gradients, costs, _, _ = __class__.compute(
22+
log_probs=log_probs,
23+
logit_lengths=logit_lengths,
24+
target_lengths=target_lengths,
25+
targets=targets,
26+
blank=blank,
27+
)
28+
29+
costs = torch.FloatTensor(costs).to(device=device)
30+
gradients = torch.FloatTensor(gradients).to(device=device)
31+
ctx.grads = torch.autograd.Variable(gradients)
32+
33+
return costs
34+
35+
@staticmethod
36+
def backward(ctx, output_gradients):
37+
return ctx.grads, None, None, None, None, None, None, None, None
38+
39+
@staticmethod
40+
def compute_alpha_one_sequence(log_probs, targets, blank=-1):
41+
max_T, max_U, D = log_probs.shape
42+
alpha = np.zeros((max_T, max_U), dtype=np.float32)
43+
for t in range(1, max_T):
44+
alpha[t, 0] = alpha[t - 1, 0] + log_probs[t - 1, 0, blank]
45+
46+
for u in range(1, max_U):
47+
alpha[0, u] = alpha[0, u - 1] + log_probs[0, u - 1, targets[u - 1]]
48+
49+
for t in range(1, max_T):
50+
for u in range(1, max_U):
51+
skip = alpha[t - 1, u] + log_probs[t - 1, u, blank]
52+
emit = alpha[t, u - 1] + log_probs[t, u - 1, targets[u - 1]]
53+
alpha[t, u] = np.logaddexp(skip, emit)
54+
55+
cost = -(alpha[-1, -1] + log_probs[-1, -1, blank])
56+
return alpha, cost
57+
58+
@staticmethod
59+
def compute_beta_one_sequence(log_probs, targets, blank=-1):
60+
max_T, max_U, D = log_probs.shape
61+
beta = np.zeros((max_T, max_U), dtype=np.float32)
62+
beta[-1, -1] = log_probs[-1, -1, blank]
63+
64+
for t in reversed(range(max_T - 1)):
65+
beta[t, -1] = beta[t + 1, -1] + log_probs[t, -1, blank]
66+
67+
for u in reversed(range(max_U - 1)):
68+
beta[-1, u] = beta[-1, u + 1] + log_probs[-1, u, targets[u]]
69+
70+
for t in reversed(range(max_T - 1)):
71+
for u in reversed(range(max_U - 1)):
72+
skip = beta[t + 1, u] + log_probs[t, u, blank]
73+
emit = beta[t, u + 1] + log_probs[t, u, targets[u]]
74+
beta[t, u] = np.logaddexp(skip, emit)
75+
76+
cost = -beta[0, 0]
77+
return beta, cost
78+
79+
@staticmethod
80+
def compute_gradients_one_sequence(
81+
log_probs, alpha, beta, targets, blank=-1
82+
):
83+
max_T, max_U, D = log_probs.shape
84+
gradients = np.full(log_probs.shape, float("-inf"))
85+
cost = -beta[0, 0]
86+
87+
gradients[-1, -1, blank] = alpha[-1, -1]
88+
89+
gradients[:-1, :, blank] = alpha[:-1, :] + beta[1:, :]
90+
91+
for u, l in enumerate(targets):
92+
gradients[:, u, l] = alpha[:, u] + beta[:, u + 1]
93+
94+
gradients = -(np.exp(gradients + log_probs + cost))
95+
return gradients
96+
97+
@staticmethod
98+
def compute(
99+
log_probs,
100+
logit_lengths,
101+
target_lengths,
102+
targets,
103+
blank=-1,
104+
):
105+
gradients = np.zeros_like(log_probs)
106+
B_tgt, max_T, max_U, D = log_probs.shape
107+
B_src = logit_lengths.shape[0]
108+
109+
H = int(B_tgt / B_src)
110+
111+
alphas = np.zeros((B_tgt, max_T, max_U))
112+
betas = np.zeros((B_tgt, max_T, max_U))
113+
betas.fill(float("-inf"))
114+
alphas.fill(float("-inf"))
115+
costs = np.zeros(B_tgt)
116+
for b_tgt in range(B_tgt):
117+
b_src = int(b_tgt / H)
118+
T = int(logit_lengths[b_src])
119+
# NOTE: see https://arxiv.org/pdf/1211.3711.pdf Section 2.1
120+
U = int(target_lengths[b_tgt]) + 1
121+
122+
seq_log_probs = log_probs[b_tgt, :T, :U, :]
123+
seq_targets = targets[b_tgt, : int(target_lengths[b_tgt])]
124+
alpha, alpha_cost = __class__.compute_alpha_one_sequence(
125+
log_probs=seq_log_probs, targets=seq_targets, blank=blank
126+
)
127+
128+
beta, beta_cost = __class__.compute_beta_one_sequence(
129+
log_probs=seq_log_probs, targets=seq_targets, blank=blank
130+
)
131+
132+
seq_gradients = __class__.compute_gradients_one_sequence(
133+
log_probs=seq_log_probs,
134+
alpha=alpha,
135+
beta=beta,
136+
targets=seq_targets,
137+
blank=blank,
138+
)
139+
np.testing.assert_almost_equal(alpha_cost, beta_cost, decimal=2)
140+
gradients[b_tgt, :T, :U, :] = seq_gradients
141+
costs[b_tgt] = beta_cost
142+
alphas[b_tgt, :T, :U] = alpha
143+
betas[b_tgt, :T, :U] = beta
144+
145+
return gradients, costs, alphas, betas
146+
147+
148+
class NumpyTransducerLoss(torch.nn.Module):
149+
def __init__(self, blank=-1):
150+
super().__init__()
151+
self.blank = blank
152+
153+
def forward(
154+
self,
155+
logits,
156+
logit_lengths,
157+
target_lengths,
158+
targets,
159+
):
160+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
161+
return _NumpyTransducer.apply(
162+
log_probs,
163+
logit_lengths,
164+
target_lengths,
165+
targets,
166+
self.blank,
167+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch
2+
from torchaudio_unittest import common_utils
3+
from .utils import skipIfNoTransducer
4+
from .rnnt_loss_impl import RNNTLossTest
5+
6+
7+
@skipIfNoTransducer
8+
class TestRNNTLoss(RNNTLossTest, common_utils.PytorchTestCase):
9+
device = torch.device('cpu')
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import numpy as np
2+
from torchaudio.prototype.rnnt_loss import RNNTLoss
3+
4+
from .utils import (
5+
compute_with_numpy_transducer,
6+
compute_with_pytorch_transducer,
7+
get_B1_T10_U3_D4_data,
8+
get_data_basic,
9+
get_numpy_data_B1_T2_U3_D5,
10+
get_numpy_data_B2_T4_U3_D3,
11+
get_numpy_random_data,
12+
numpy_to_torch,
13+
)
14+
15+
16+
class RNNTLossTest:
17+
def _test_costs_and_gradients(
18+
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
19+
):
20+
logits_shape = data["logits"].shape
21+
for reuse_logits_for_grads in [False, True]:
22+
with self.subTest(reuse_logits_for_grads=reuse_logits_for_grads):
23+
costs, gradients = compute_with_pytorch_transducer(
24+
data=data, reuse_logits_for_grads=reuse_logits_for_grads
25+
)
26+
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
27+
self.assertEqual(logits_shape, gradients.shape)
28+
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
29+
for b in range(len(gradients)):
30+
T = data["logit_lengths"][b]
31+
U = data["target_lengths"][b]
32+
for t in range(gradients.shape[1]):
33+
for u in range(gradients.shape[2]):
34+
np.testing.assert_allclose(
35+
gradients[b, t, u],
36+
ref_gradients[b, t, u],
37+
atol=atol,
38+
rtol=rtol,
39+
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
40+
)
41+
42+
def test_basic_backward(self):
43+
rnnt_loss = RNNTLoss()
44+
logits, targets, logit_lengths, target_lengths = get_data_basic(self.device)
45+
loss = rnnt_loss(logits, targets, logit_lengths, target_lengths)
46+
loss.backward()
47+
48+
def test_costs_and_gradients_B1_T2_U3_D5_fp32(self):
49+
data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5(
50+
dtype=np.float32
51+
)
52+
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
53+
self._test_costs_and_gradients(
54+
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
55+
)
56+
57+
def test_costs_and_gradients_B1_T2_U3_D5_fp16(self):
58+
data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5(
59+
dtype=np.float16
60+
)
61+
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
62+
self._test_costs_and_gradients(
63+
data=data,
64+
ref_costs=ref_costs,
65+
ref_gradients=ref_gradients,
66+
atol=1e-3,
67+
rtol=1e-2,
68+
)
69+
70+
def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
71+
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(
72+
dtype=np.float32
73+
)
74+
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
75+
self._test_costs_and_gradients(
76+
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
77+
)
78+
79+
def test_costs_and_gradients_B2_T4_U3_D3_fp16(self):
80+
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(
81+
dtype=np.float16
82+
)
83+
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
84+
self._test_costs_and_gradients(
85+
data=data,
86+
ref_costs=ref_costs,
87+
ref_gradients=ref_gradients,
88+
atol=1e-3,
89+
rtol=1e-2,
90+
)
91+
92+
def test_costs_and_gradients_random_data_with_numpy_fp32(self):
93+
seed = 777
94+
for i in range(5):
95+
data = get_numpy_random_data(dtype=np.float32, seed=(seed + i))
96+
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
97+
ref_costs, ref_gradients = compute_with_numpy_transducer(data=data)
98+
self._test_costs_and_gradients(
99+
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
100+
)
101+
102+
def test_rnnt_nonfused_log_softmax(self):
103+
for random in [False, True]:
104+
data = get_B1_T10_U3_D4_data(
105+
random=random,
106+
)
107+
data = numpy_to_torch(
108+
data=data, device=self.device, requires_grad=True
109+
)
110+
data["fused_log_softmax"] = False
111+
ref_costs, ref_gradients = compute_with_numpy_transducer(
112+
data=data
113+
)
114+
self._test_costs_and_gradients(
115+
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
116+
)

0 commit comments

Comments
 (0)