diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 88e7072e6..c66b29030 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -20,11 +20,14 @@ from torch import nn from ..image_processor import IPAdapterMaskProcessor -from ..utils import deprecate, logging, is_torch_xla_available +from ..utils import deprecate, is_torch_xla_available, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph + if is_torch_xla_available(): + import torch_xla.distributed.spmd as xs + import torch_xla.runtime as xr from torch_xla.experimental.custom_kernel import flash_attention XLA_AVAILABLE = True else: @@ -2369,17 +2372,19 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - + + mesh = xs.get_global_mesh() if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]): if attention_mask is not None: attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) # Convert mask to float and replace 0s with -inf and 1s with 0 attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0)) - + # Apply attention mask to key key = key + attention_mask query /= math.sqrt(query.shape[3]) - hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None)) + partition_spec = ('data', None, None, None) if xr.use_spmd() else None + hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec, mesh=mesh) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False @@ -4317,4 +4322,4 @@ def __init__(self): PAGIdentitySelfAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0, -] +] \ No newline at end of file