Skip to content

Commit 8acfe07

Browse files
committed
format change
1 parent c141e0a commit 8acfe07

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
from torch import nn
2121

2222
from ..image_processor import IPAdapterMaskProcessor
23-
from ..utils import deprecate, logging, is_torch_xla_available
23+
from ..utils import deprecate, is_torch_xla_available, logging
2424
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
2525
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
2626

27+
2728
if is_torch_xla_available():
28-
from torch_xla.experimental.custom_kernel import flash_attention
2929
import torch_xla.distributed.spmd as xs
3030
import torch_xla.runtime as xr
31+
from torch_xla.experimental.custom_kernel import flash_attention
3132
XLA_AVAILABLE = True
3233
else:
3334
XLA_AVAILABLE = False
@@ -2378,7 +2379,7 @@ def __call__(
23782379
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
23792380
# Convert mask to float and replace 0s with -inf and 1s with 0
23802381
attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0))
2381-
2382+
23822383
# Apply attention mask to key
23832384
key = key + attention_mask
23842385
query /= math.sqrt(query.shape[3])
@@ -4321,4 +4322,4 @@ def __init__(self):
43214322
PAGIdentitySelfAttnProcessor2_0,
43224323
PAGCFGHunyuanAttnProcessor2_0,
43234324
PAGHunyuanAttnProcessor2_0,
4324-
]
4325+
]

0 commit comments

Comments
 (0)