Skip to content

Commit 614cdfe

Browse files
committed
Use scaled_dot_product_attention in WavLM attention (#3252)
Summary: Fix #3219. `torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`. Pull Request resolved: #3252 Reviewed By: mthrok Differential Revision: D44798634 Pulled By: nateanl fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665
1 parent d92216d commit 614cdfe

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

torchaudio/models/wav2vec2/wavlm_attention.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
self.head_dim = embed_dim // num_heads
7474
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
7575

76+
self.dropout = dropout
7677
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True)
7778

7879
self.gru_rel_pos = gru_rel_pos
@@ -165,7 +166,7 @@ def forward(
165166

166167
if self.rel_attn_embed is not None and position_bias is None:
167168
position_bias = self.compute_bias(seq_len, seq_len)
168-
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, seq_len, seq_len)
169+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1)
169170

170171
attn_mask_rel_pos: Optional[Tensor] = None
171172
if position_bias is not None:
@@ -178,11 +179,25 @@ def forward(
178179
self.gru_rel_pos_linear(query_layer).view(bsz, self.num_heads, seq_len, 2, 4).sum(-1, keepdim=False)
179180
).chunk(2, dim=-1)
180181
gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
181-
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
182-
183-
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len))
184-
185-
attn_output, _ = self.attention(
186-
query, query, query, key_padding_mask=key_padding_mask, attn_mask=attn_mask_rel_pos, need_weights=False
182+
attn_mask_rel_pos = gate_a_1.view(bsz, self.num_heads, -1, 1) * position_bias
183+
184+
attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len))
185+
186+
query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias)
187+
query, key, value = query_projected.chunk(3, -1)
188+
shape = (bsz, seq_len, self.num_heads, self.head_dim)
189+
query = query.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
190+
key = key.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
191+
value = value.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
192+
dropout = self.dropout if self.training else 0.0
193+
attn_output = torch.nn.functional.scaled_dot_product_attention(
194+
query,
195+
key,
196+
value,
197+
attn_mask=attn_mask_rel_pos,
198+
dropout_p=dropout,
199+
is_causal=False,
187200
)
201+
attn_output = attn_output.transpose(1, 2).reshape(bsz, -1, self.num_heads * self.head_dim)
202+
attn_output = self.attention.out_proj(attn_output)
188203
return attn_output, position_bias

0 commit comments

Comments
 (0)