Skip to content

Commit ce8cef6

Browse files
authored
Merge branch 'release/2.0' into cherrypick-wav2vec2
2 parents 614cdfe + e99de15 commit ce8cef6

File tree

1 file changed

+8
-19
lines changed

1 file changed

+8
-19
lines changed

torchaudio/models/wav2vec2/components.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def __init__(
262262

263263
self.embed_dim = embed_dim
264264
self.num_heads = num_heads
265-
self.dropout = torch.nn.Dropout(dropout)
265+
self.dropout = dropout
266266
self.head_dim = head_dim
267267

268268
self.scaling = self.head_dim**-0.5
@@ -304,25 +304,14 @@ def forward(
304304

305305
shape = (batch_size, length, self.num_heads, self.head_dim)
306306
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
307-
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
307+
k = self.k_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
308308
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
309-
310-
# scale down q to avoid value overflow.
311-
weights = (self.scaling * q) @ k # B, nH, L, L
312-
if attention_mask is not None:
313-
weights += attention_mask
314-
# subtracting a constant value from the tensor won't change the output of softmax.
315-
# apply the subtraction to avoid value overflow in torch.nn.functional.softmax.
316-
# for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778
317-
weights = weights - weights.max(dim=-1, keepdim=True)[0]
318-
319-
weights = torch.nn.functional.softmax(weights, dim=-1)
320-
weights = self.dropout(weights)
321-
322-
output = weights @ v # B, nH, L, Hd
323-
output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)
324-
325-
output = self.out_proj(output)
309+
dropout = self.dropout if self.training else 0.0
310+
attn_output = torch.nn.functional.scaled_dot_product_attention(
311+
q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False
312+
)
313+
attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
314+
output = self.out_proj(attn_output)
326315
return output, None # Necessary for compatibility with WavLMSelAttention
327316

328317

0 commit comments

Comments
 (0)