Skip to content

Commit 6da13db

Browse files
committed
pytorch#446 add vanilla deepspeech model
1 parent dd76e9d commit 6da13db

File tree

4 files changed

+115
-1
lines changed

4 files changed

+115
-1
lines changed

docs/source/models.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,10 @@ The models subpackage contains definitions of models for addressing common audio
3131
.. autoclass:: WaveRNN
3232

3333
.. automethod:: forward
34+
35+
:hidden:`DeepSpeech`
36+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
38+
.. autoclass:: DeepSpeech
39+
40+
.. automethod:: forward

test/torchaudio_unittest/models_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
from parameterized import parameterized
6-
from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN
6+
from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN, DeepSpeech
77
from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
88
from torchaudio_unittest import common_utils
99

@@ -174,3 +174,20 @@ def test_paper_configuration(self, num_sources, model_params):
174174
output = model(tensor)
175175

176176
assert output.shape == (batch_size, num_sources, num_frames)
177+
178+
179+
class TestDeepSpeech(common_utils.TorchaudioTestCase):
180+
181+
def test_deepspeech(self):
182+
batch_size = 2
183+
num_features = 1
184+
num_channels = 1
185+
num_classes = 40
186+
input_length = 320
187+
188+
model = DeepSpeech(in_features=1, num_classes=num_classes)
189+
190+
x = torch.rand(batch_size, num_channels, input_length, num_features)
191+
out = model(x)
192+
193+
assert out.size() == (input_length, batch_size, num_classes)

torchaudio/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .wav2letter import Wav2Letter
22
from .wavernn import WaveRNN
33
from .conv_tasnet import ConvTasNet
4+
from .deepspeech import DeepSpeech
45

56
__all__ = [
67
'Wav2Letter',
78
'WaveRNN',
89
'ConvTasNet',
10+
'DeepSpeech',
911
]

torchaudio/models/deepspeech.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
__all__ = ["DeepSpeech"]
5+
6+
7+
class FullyConnected(nn.Module):
8+
"""
9+
Args:
10+
in_features: Number of input features
11+
hidden_size: Internal hidden unit size.
12+
"""
13+
14+
def __init__(self,
15+
in_features: int,
16+
hidden_size: int,
17+
dropout: float,
18+
relu_max_clip: int = 20) -> None:
19+
super(FullyConnected, self).__init__()
20+
self.fc = nn.Linear(in_features, hidden_size, bias=True)
21+
self.nonlinearity = nn.Sequential(*[
22+
nn.ReLU(),
23+
nn.Hardtanh(0, relu_max_clip)
24+
])
25+
if dropout:
26+
self.nonlinearity = nn.Sequential(*[
27+
self.nonlinearity,
28+
nn.Dropout(dropout)
29+
])
30+
31+
def forward(self, x: torch.Tensor) -> torch.Tensor:
32+
x = self.fc(x)
33+
x = self.nonlinearity(x)
34+
return x
35+
36+
37+
class DeepSpeech(nn.Module):
38+
"""
39+
DeepSpeech model architecture from
40+
`"Deep Speech: Scaling up end-to-end speech recognition"`
41+
<https://arxiv.org/abs/1412.5567> paper.
42+
43+
Args:
44+
in_features: Number of input features
45+
hidden_size: Internal hidden unit size.
46+
num_classes: Number of output classes
47+
"""
48+
49+
def __init__(self,
50+
in_features: int,
51+
hidden_size: int = 2048,
52+
num_classes: int = 40,
53+
dropout: float = 0.0) -> None:
54+
super(DeepSpeech, self).__init__()
55+
self.hidden_size = hidden_size
56+
self.fc1 = FullyConnected(in_features, hidden_size, dropout)
57+
self.fc2 = FullyConnected(hidden_size, hidden_size, dropout)
58+
self.fc3 = FullyConnected(hidden_size, hidden_size, dropout)
59+
self.bi_rnn = nn.RNN(
60+
hidden_size, hidden_size, num_layers=1, nonlinearity='relu', bidirectional=True)
61+
self.nonlinearity = nn.ReLU()
62+
self.fc4 = FullyConnected(hidden_size, hidden_size, dropout)
63+
self.out = nn.Sequential(*[
64+
nn.Linear(hidden_size, num_classes),
65+
nn.LogSoftmax(dim=2)
66+
])
67+
68+
def forward(self, x: torch.Tensor) -> torch.Tensor:
69+
# N x C x T x F
70+
x = self.fc1(x)
71+
# N x C x T x H
72+
x = self.fc2(x)
73+
# N x C x T x H
74+
x = self.fc3(x)
75+
# N x C x T x H
76+
x = x.squeeze(1)
77+
# N x T x H
78+
x = x.transpose(0, 1)
79+
# T x N x H
80+
x, _ = self.bi_rnn(x)
81+
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs
82+
x = x[:, :, :self.hidden_size] + x[:, :, self.hidden_size:]
83+
# T x N x H
84+
x = self.fc4(x)
85+
# T x N x H
86+
x = self.out(x)
87+
# T x N x num_classes
88+
return x

0 commit comments

Comments
 (0)