|
20 | 20 | from torch import nn
|
21 | 21 |
|
22 | 22 | from ..image_processor import IPAdapterMaskProcessor
|
23 |
| -from ..utils import deprecate, logging |
| 23 | +from ..utils import deprecate, logging, is_torch_xla_available |
24 | 24 | from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
25 | 25 | from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
26 | 26 |
|
| 27 | +if is_torch_xla_available(): |
| 28 | + from torch_xla.experimental.custom_kernel import flash_attention |
| 29 | + XLA_AVAILABLE = True |
| 30 | +else: |
| 31 | + XLA_AVAILABLE = False |
27 | 32 |
|
28 | 33 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29 | 34 |
|
@@ -2364,9 +2369,21 @@ def __call__(
|
2364 | 2369 |
|
2365 | 2370 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2366 | 2371 | # TODO: add support for attn.scale when we move to Torch 2.1
|
2367 |
| - hidden_states = F.scaled_dot_product_attention( |
2368 |
| - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
2369 |
| - ) |
| 2372 | + |
| 2373 | + if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]): |
| 2374 | + if attention_mask is not None: |
| 2375 | + attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) |
| 2376 | + # Convert mask to float and replace 0s with -inf and 1s with 0 |
| 2377 | + attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0)) |
| 2378 | + |
| 2379 | + # Apply attention mask to key |
| 2380 | + key = key + attention_mask |
| 2381 | + query /= math.sqrt(query.shape[3]) |
| 2382 | + hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None)) |
| 2383 | + else: |
| 2384 | + hidden_states = F.scaled_dot_product_attention( |
| 2385 | + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
| 2386 | + ) |
2370 | 2387 |
|
2371 | 2388 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2372 | 2389 | hidden_states = hidden_states.to(query.dtype)
|
|
0 commit comments