20
20
from torch import nn
21
21
22
22
from ..image_processor import IPAdapterMaskProcessor
23
- from ..utils import deprecate , logging , is_torch_xla_available
23
+ from ..utils import deprecate , is_torch_xla_available , logging
24
24
from ..utils .import_utils import is_torch_npu_available , is_xformers_available
25
25
from ..utils .torch_utils import is_torch_version , maybe_allow_in_graph
26
26
27
+
27
28
if is_torch_xla_available ():
29
+ import torch_xla .distributed .spmd as xs
30
+ import torch_xla .runtime as xr
28
31
from torch_xla .experimental .custom_kernel import flash_attention
29
32
XLA_AVAILABLE = True
30
33
else :
@@ -2369,17 +2372,19 @@ def __call__(
2369
2372
2370
2373
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2371
2374
# TODO: add support for attn.scale when we move to Torch 2.1
2372
-
2375
+
2376
+ mesh = xs .get_global_mesh ()
2373
2377
if XLA_AVAILABLE and all (tensor .shape [2 ] >= 4096 for tensor in [query , key , value ]):
2374
2378
if attention_mask is not None :
2375
2379
attention_mask = attention_mask .view (batch_size , 1 , 1 , attention_mask .shape [- 1 ])
2376
2380
# Convert mask to float and replace 0s with -inf and 1s with 0
2377
2381
attention_mask = attention_mask .float ().masked_fill (attention_mask == 0 , float ('-inf' )).masked_fill (attention_mask == 1 , float (0.0 ))
2378
-
2382
+
2379
2383
# Apply attention mask to key
2380
2384
key = key + attention_mask
2381
2385
query /= math .sqrt (query .shape [3 ])
2382
- hidden_states = flash_attention (query , key , value , causal = False , partition_spec = ('data' , None , None , None ))
2386
+ partition_spec = ('data' , None , None , None ) if xr .use_spmd () else None
2387
+ hidden_states = flash_attention (query , key , value , causal = False , partition_spec = partition_spec , mesh = mesh )
2383
2388
else :
2384
2389
hidden_states = F .scaled_dot_product_attention (
2385
2390
query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
@@ -4317,4 +4322,4 @@ def __init__(self):
4317
4322
PAGIdentitySelfAttnProcessor2_0 ,
4318
4323
PAGCFGHunyuanAttnProcessor2_0 ,
4319
4324
PAGHunyuanAttnProcessor2_0 ,
4320
- ]
4325
+ ]
0 commit comments