You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Cherry-pick] Use scaled_dot_product_attention in WavLM attention (#3252, #3265) (#3264)
* Use scaled_dot_product_attention in WavLM attention (#3252)
Summary:
Fix#3219.
`torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`.
Pull Request resolved: #3252
Reviewed By: mthrok
Differential Revision: D44798634
Pulled By: nateanl
fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665
* Merge key_padding_mask into attn_mask_rel_pos in WavLM (#3265)
Summary:
When `key_padding_mask` is not `None`, it needs to be combined with `attn_mask_rel_pos` as one mask for `scaled_dot_product_attention` function.
Pull Request resolved: #3265
Reviewed By: hwangjeff
Differential Revision: D44901093
Pulled By: nateanl
fbshipit-source-id: 73ca7af48faf7f4eb36b35b603187a11e5582c70
0 commit comments