@@ -262,7 +262,7 @@ def __init__(
262
262
263
263
self .embed_dim = embed_dim
264
264
self .num_heads = num_heads
265
- self .dropout = torch . nn . Dropout ( dropout )
265
+ self .dropout = dropout
266
266
self .head_dim = head_dim
267
267
268
268
self .scaling = self .head_dim ** - 0.5
@@ -304,25 +304,14 @@ def forward(
304
304
305
305
shape = (batch_size , length , self .num_heads , self .head_dim )
306
306
q = self .q_proj (x ).view (* shape ).transpose (2 , 1 ) # B, nH, L, Hd
307
- k = self .k_proj (x ).view (* shape ).permute ( 0 , 2 , 3 , 1 ) # B, nH, Hd, L
307
+ k = self .k_proj (x ).view (* shape ).transpose ( 2 , 1 ) # B, nH, L, Hd
308
308
v = self .v_proj (x ).view (* shape ).transpose (2 , 1 ) # B, nH, L, Hd
309
-
310
- # scale down q to avoid value overflow.
311
- weights = (self .scaling * q ) @ k # B, nH, L, L
312
- if attention_mask is not None :
313
- weights += attention_mask
314
- # subtracting a constant value from the tensor won't change the output of softmax.
315
- # apply the subtraction to avoid value overflow in torch.nn.functional.softmax.
316
- # for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778
317
- weights = weights - weights .max (dim = - 1 , keepdim = True )[0 ]
318
-
319
- weights = torch .nn .functional .softmax (weights , dim = - 1 )
320
- weights = self .dropout (weights )
321
-
322
- output = weights @ v # B, nH, L, Hd
323
- output = output .transpose (2 , 1 ).reshape (batch_size , length , embed_dim )
324
-
325
- output = self .out_proj (output )
309
+ dropout = self .dropout if self .training else 0.0
310
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
311
+ q , k , v , attn_mask = attention_mask , dropout_p = dropout , is_causal = False
312
+ )
313
+ attn_output = attn_output .transpose (1 , 2 ).reshape (batch_size , - 1 , self .num_heads * self .head_dim )
314
+ output = self .out_proj (attn_output )
326
315
return output , None # Necessary for compatibility with WavLMSelAttention
327
316
328
317
0 commit comments