Skip to content

Commit c805f6b

Browse files
committed
fa
1 parent f08dc92 commit c805f6b

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
2525
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
2626

27+
from torch_xla.experimental.custom_kernel import flash_attention
2728

2829
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2930

@@ -2364,9 +2365,27 @@ def __call__(
23642365

23652366
# the output of sdp = (batch, num_heads, seq_len, head_dim)
23662367
# TODO: add support for attn.scale when we move to Torch 2.1
2367-
hidden_states = F.scaled_dot_product_attention(
2368-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2369-
)
2368+
2369+
# print("piz_test6")
2370+
# logger.warning(f"piz_debug2: query shape: {query.shape}, key shape: {key.shape}, value shape: {value.shape}")
2371+
hidden_states = flash_attention(query, key, value, causal=False)
2372+
2373+
# Apply the attention mask if provided
2374+
if attention_mask is not None:
2375+
# Reshape hidden_states to (batch_size, num_heads, seq_len, head_dim)
2376+
hidden_states = hidden_states.view(batch_size, attn.heads, -1, head_dim)
2377+
2378+
# Ensure attention_mask is the correct shape (batch_size, 1, 1, seq_len)
2379+
attention_mask = attention_mask.view(batch_size, 1, 1, -1)
2380+
2381+
# Expand attention_mask to match hidden_states shape
2382+
attention_mask = attention_mask.expand(-1, attn.heads, hidden_states.size(2), -1)
2383+
2384+
# Apply the mask
2385+
hidden_states = hidden_states * attention_mask.to(hidden_states.dtype)
2386+
# hidden_states = F.scaled_dot_product_attention(
2387+
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2388+
# )
23702389

23712390
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
23722391
hidden_states = hidden_states.to(query.dtype)

0 commit comments

Comments
 (0)