@@ -2374,15 +2374,20 @@ def __call__(
2374
2374
if attention_mask is not None :
2375
2375
# Reshape hidden_states to (batch_size, num_heads, seq_len, head_dim)
2376
2376
hidden_states = hidden_states .view (batch_size , attn .heads , - 1 , head_dim )
2377
-
2377
+
2378
2378
# Ensure attention_mask is the correct shape (batch_size, 1, 1, seq_len)
2379
- attention_mask = attention_mask .view (batch_size , 1 , 1 , - 1 )
2379
+ if attention_mask .shape != (batch_size , 1 , 1 , hidden_states .size (- 2 )):
2380
+ attention_mask = attention_mask .view (batch_size , 1 , 1 , - 1 )
2381
+ attention_mask = F .pad (attention_mask , (0 , hidden_states .size (- 2 ) - attention_mask .size (- 1 )), value = 1 )
2380
2382
2381
2383
# Expand attention_mask to match hidden_states shape
2382
2384
attention_mask = attention_mask .expand (- 1 , attn .heads , hidden_states .size (2 ), - 1 )
2383
-
2385
+
2384
2386
# Apply the mask
2385
2387
hidden_states = hidden_states * attention_mask .to (hidden_states .dtype )
2388
+
2389
+ # Reshape hidden_states back to its original shape
2390
+ hidden_states = hidden_states .view (batch_size , - 1 , attn .heads * head_dim )
2386
2391
# hidden_states = F.scaled_dot_product_attention(
2387
2392
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2388
2393
# )
0 commit comments