diff --git a/examples/text_to_image/train_text_to_image_xla.py b/examples/text_to_image/train_text_to_image_xla.py index 9c6e7c86a..37d9f381a 100644 --- a/examples/text_to_image/train_text_to_image_xla.py +++ b/examples/text_to_image/train_text_to_image_xla.py @@ -385,7 +385,7 @@ def main(args): server = xp.start_server(9012) num_devices = xr.global_runtime_device_count() - mesh = xs.get_1d_mesh('x') + mesh = xs.get_1d_mesh('data') xs.set_global_mesh(mesh) text_encoder = CLIPTextModel.from_pretrained( @@ -521,9 +521,9 @@ def collate_fn(examples): device, input_sharding={ "pixel_values": xs.ShardingSpec( - mesh, ("x", None, None, None), minibatch=True + mesh, ("data", None, None, None), minibatch=True ), - "input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True), + "input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True), }, loader_prefetch_size=args.loader_prefetch_size, device_prefetch_size=args.device_prefetch_size, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9f9bc5a46..88e7072e6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -20,10 +20,15 @@ from torch import nn from ..image_processor import IPAdapterMaskProcessor -from ..utils import deprecate, logging +from ..utils import deprecate, logging, is_torch_xla_available from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph +if is_torch_xla_available(): + from torch_xla.experimental.custom_kernel import flash_attention + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -2364,9 +2369,21 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + + if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]): + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) + # Convert mask to float and replace 0s with -inf and 1s with 0 + attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0)) + + # Apply attention mask to key + key = key + attention_mask + query /= math.sqrt(query.shape[3]) + hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None)) + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype)