|
24 | 24 | from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
25 | 25 | from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
26 | 26 |
|
| 27 | +from torch_xla.experimental.custom_kernel import flash_attention |
27 | 28 |
|
28 | 29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29 | 30 |
|
@@ -2364,9 +2365,17 @@ def __call__(
|
2364 | 2365 |
|
2365 | 2366 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2366 | 2367 | # TODO: add support for attn.scale when we move to Torch 2.1
|
2367 |
| - hidden_states = F.scaled_dot_product_attention( |
2368 |
| - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
2369 |
| - ) |
| 2368 | + |
| 2369 | + # print("piz_test6") |
| 2370 | + # logger.warning(f"piz_debug2: query shape: {query.shape}, key shape: {key.shape}, value shape: {value.shape}") |
| 2371 | + if attention_mask is not None: |
| 2372 | + attention_mask = attention_mask.unsqueeze(1).expand(-1, attn.heads, -1, -1) |
| 2373 | + hidden_states = hidden_states * attention_mask |
| 2374 | + |
| 2375 | + hidden_states = flash_attention(query, key, value, causal=False) |
| 2376 | + # hidden_states = F.scaled_dot_product_attention( |
| 2377 | + # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
| 2378 | + # ) |
2370 | 2379 |
|
2371 | 2380 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2372 | 2381 | hidden_states = hidden_states.to(query.dtype)
|
|
0 commit comments