Skip to content

Commit 94cc4bd

Browse files
nateanlfacebook-github-bot
authored andcommitted
Use scaled_dot_product_attention in Wav2vec2/HuBERT's SelfAttention (#3253)
Summary: Replace the attention computation with `torch.nn.functional.scaled_dot_product_attention` to improve running efficiency. Pull Request resolved: #3253 Reviewed By: mthrok Differential Revision: D44800353 Pulled By: nateanl fbshipit-source-id: 41550d868c809099aadbe812b0ebe2c38121efb8
1 parent 5a5b0fc commit 94cc4bd

File tree

1 file changed

+8
-19
lines changed

1 file changed

+8
-19
lines changed

torchaudio/models/wav2vec2/components.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def __init__(
262262

263263
self.embed_dim = embed_dim
264264
self.num_heads = num_heads
265-
self.dropout = torch.nn.Dropout(dropout)
265+
self.dropout = dropout
266266
self.head_dim = head_dim
267267

268268
self.scaling = self.head_dim**-0.5
@@ -304,25 +304,14 @@ def forward(
304304

305305
shape = (batch_size, length, self.num_heads, self.head_dim)
306306
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
307-
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
307+
k = self.k_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
308308
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
309-
310-
# scale down q to avoid value overflow.
311-
weights = (self.scaling * q) @ k # B, nH, L, L
312-
if attention_mask is not None:
313-
weights += attention_mask
314-
# subtracting a constant value from the tensor won't change the output of softmax.
315-
# apply the subtraction to avoid value overflow in torch.nn.functional.softmax.
316-
# for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778
317-
weights = weights - weights.max(dim=-1, keepdim=True)[0]
318-
319-
weights = torch.nn.functional.softmax(weights, dim=-1)
320-
weights = self.dropout(weights)
321-
322-
output = weights @ v # B, nH, L, Hd
323-
output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)
324-
325-
output = self.out_proj(output)
309+
dropout = self.dropout if self.training else 0.0
310+
attn_output = torch.nn.functional.scaled_dot_product_attention(
311+
q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False
312+
)
313+
attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
314+
output = self.out_proj(attn_output)
326315
return output, None # Necessary for compatibility with WavLMSelAttention
327316

328317

0 commit comments

Comments
 (0)