Skip to content

Commit 637da5a

Browse files
authored
Merge pull request #12 from ManfeiBai/patch-1
SD2 inference pass mesh for flahattention with spmd
2 parents 9720f01 + 8acfe07 commit 637da5a

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@
2020
from torch import nn
2121

2222
from ..image_processor import IPAdapterMaskProcessor
23-
from ..utils import deprecate, logging, is_torch_xla_available
23+
from ..utils import deprecate, is_torch_xla_available, logging
2424
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
2525
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
2626

27+
2728
if is_torch_xla_available():
29+
import torch_xla.distributed.spmd as xs
30+
import torch_xla.runtime as xr
2831
from torch_xla.experimental.custom_kernel import flash_attention
2932
XLA_AVAILABLE = True
3033
else:
@@ -2369,17 +2372,19 @@ def __call__(
23692372

23702373
# the output of sdp = (batch, num_heads, seq_len, head_dim)
23712374
# TODO: add support for attn.scale when we move to Torch 2.1
2372-
2375+
2376+
mesh = xs.get_global_mesh()
23732377
if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
23742378
if attention_mask is not None:
23752379
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
23762380
# Convert mask to float and replace 0s with -inf and 1s with 0
23772381
attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0))
2378-
2382+
23792383
# Apply attention mask to key
23802384
key = key + attention_mask
23812385
query /= math.sqrt(query.shape[3])
2382-
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None))
2386+
partition_spec = ('data', None, None, None) if xr.use_spmd() else None
2387+
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec, mesh=mesh)
23832388
else:
23842389
hidden_states = F.scaled_dot_product_attention(
23852390
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
@@ -4317,4 +4322,4 @@ def __init__(self):
43174322
PAGIdentitySelfAttnProcessor2_0,
43184323
PAGCFGHunyuanAttnProcessor2_0,
43194324
PAGHunyuanAttnProcessor2_0,
4320-
]
4325+
]

0 commit comments

Comments
 (0)