Skip to content

Commit 40f5dca

Browse files
committed
fa
1 parent c805f6b commit 40f5dca

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2374,15 +2374,20 @@ def __call__(
23742374
if attention_mask is not None:
23752375
# Reshape hidden_states to (batch_size, num_heads, seq_len, head_dim)
23762376
hidden_states = hidden_states.view(batch_size, attn.heads, -1, head_dim)
2377-
2377+
23782378
# Ensure attention_mask is the correct shape (batch_size, 1, 1, seq_len)
2379-
attention_mask = attention_mask.view(batch_size, 1, 1, -1)
2379+
if attention_mask.shape != (batch_size, 1, 1, hidden_states.size(-2)):
2380+
attention_mask = attention_mask.view(batch_size, 1, 1, -1)
2381+
attention_mask = F.pad(attention_mask, (0, hidden_states.size(-2) - attention_mask.size(-1)), value=1)
23802382

23812383
# Expand attention_mask to match hidden_states shape
23822384
attention_mask = attention_mask.expand(-1, attn.heads, hidden_states.size(2), -1)
2383-
2385+
23842386
# Apply the mask
23852387
hidden_states = hidden_states * attention_mask.to(hidden_states.dtype)
2388+
2389+
# Reshape hidden_states back to its original shape
2390+
hidden_states = hidden_states.view(batch_size, -1, attn.heads * head_dim)
23862391
# hidden_states = F.scaled_dot_product_attention(
23872392
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
23882393
# )

0 commit comments

Comments
 (0)