20
20
from torch import nn
21
21
22
22
from ..image_processor import IPAdapterMaskProcessor
23
- from ..utils import deprecate , logging , is_torch_xla_available
23
+ from ..utils import deprecate , is_torch_xla_available , logging
24
24
from ..utils .import_utils import is_torch_npu_available , is_xformers_available
25
25
from ..utils .torch_utils import is_torch_version , maybe_allow_in_graph
26
26
27
+
27
28
if is_torch_xla_available ():
28
- from torch_xla .experimental .custom_kernel import flash_attention
29
29
import torch_xla .distributed .spmd as xs
30
30
import torch_xla .runtime as xr
31
+ from torch_xla .experimental .custom_kernel import flash_attention
31
32
XLA_AVAILABLE = True
32
33
else :
33
34
XLA_AVAILABLE = False
@@ -2378,7 +2379,7 @@ def __call__(
2378
2379
attention_mask = attention_mask .view (batch_size , 1 , 1 , attention_mask .shape [- 1 ])
2379
2380
# Convert mask to float and replace 0s with -inf and 1s with 0
2380
2381
attention_mask = attention_mask .float ().masked_fill (attention_mask == 0 , float ('-inf' )).masked_fill (attention_mask == 1 , float (0.0 ))
2381
-
2382
+
2382
2383
# Apply attention mask to key
2383
2384
key = key + attention_mask
2384
2385
query /= math .sqrt (query .shape [3 ])
@@ -4321,4 +4322,4 @@ def __init__(self):
4321
4322
PAGIdentitySelfAttnProcessor2_0 ,
4322
4323
PAGCFGHunyuanAttnProcessor2_0 ,
4323
4324
PAGHunyuanAttnProcessor2_0 ,
4324
- ]
4325
+ ]
0 commit comments