From 25d1431584fccc32c979cdffd178f9a68ccc11fc Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 28 Jan 2025 12:47:52 -0800 Subject: [PATCH 1/4] Update attention_processor.py with inference spmd mesh for flahattention under Pei's guidence --- src/diffusers/models/attention_processor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 88e7072e6..ec2d2d600 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -26,6 +26,7 @@ if is_torch_xla_available(): from torch_xla.experimental.custom_kernel import flash_attention + import torch_xla.distributed.spmd as xs XLA_AVAILABLE = True else: XLA_AVAILABLE = False @@ -2369,7 +2370,9 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - + + mesh = xs.get_global_mesh() + print("mesh: ", mesh) if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]): if attention_mask is not None: attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) @@ -2379,7 +2382,7 @@ def __call__( # Apply attention mask to key key = key + attention_mask query /= math.sqrt(query.shape[3]) - hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None)) + hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None), mesh=mesh) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From 2ff42756d6a55e8283482f3b4350186efd759a86 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 28 Jan 2025 13:44:45 -0800 Subject: [PATCH 2/4] Update attention_processor.py --- src/diffusers/models/attention_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ec2d2d600..8899ac7f7 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -26,7 +26,7 @@ if is_torch_xla_available(): from torch_xla.experimental.custom_kernel import flash_attention - import torch_xla.distributed.spmd as xs + import torch_xla.runtime as xr XLA_AVAILABLE = True else: XLA_AVAILABLE = False @@ -2372,7 +2372,6 @@ def __call__( # TODO: add support for attn.scale when we move to Torch 2.1 mesh = xs.get_global_mesh() - print("mesh: ", mesh) if XLA_AVAILABLE and all(tensor.shape[2] >= 4096 for tensor in [query, key, value]): if attention_mask is not None: attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) @@ -2382,7 +2381,8 @@ def __call__( # Apply attention mask to key key = key + attention_mask query /= math.sqrt(query.shape[3]) - hidden_states = flash_attention(query, key, value, causal=False, partition_spec=('data', None, None, None), mesh=mesh) + partition_spec = ('data', None, None, None) if xr.use_spmd() else None + hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec, mesh=mesh) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From c141e0a97c7a870c65194b112a0fa3e0bc0c3506 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 28 Jan 2025 13:49:31 -0800 Subject: [PATCH 3/4] Update attention_processor.py --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8899ac7f7..5e133699f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -26,6 +26,7 @@ if is_torch_xla_available(): from torch_xla.experimental.custom_kernel import flash_attention + import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr XLA_AVAILABLE = True else: From 8acfe073b730884ec853012d8d1f19d634d2e346 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 28 Jan 2025 23:10:20 +0000 Subject: [PATCH 4/4] format change --- src/diffusers/models/attention_processor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5e133699f..c66b29030 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -20,14 +20,15 @@ from torch import nn from ..image_processor import IPAdapterMaskProcessor -from ..utils import deprecate, logging, is_torch_xla_available +from ..utils import deprecate, is_torch_xla_available, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph + if is_torch_xla_available(): - from torch_xla.experimental.custom_kernel import flash_attention import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr + from torch_xla.experimental.custom_kernel import flash_attention XLA_AVAILABLE = True else: XLA_AVAILABLE = False @@ -2378,7 +2379,7 @@ def __call__( attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) # Convert mask to float and replace 0s with -inf and 1s with 0 attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0)) - + # Apply attention mask to key key = key + attention_mask query /= math.sqrt(query.shape[3]) @@ -4321,4 +4322,4 @@ def __init__(self): PAGIdentitySelfAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0, -] +] \ No newline at end of file