Closed
Description
Describe the bug
- Splash Attention breaks many small keras hub models on tpuv4-pods when using DataParallel. For instance I have tested a few siglip models and the clip model, and the same thing still happens.
To Reproduce
Colab Link
Expected behavior
The model to not crash with an error and to properly run the siglip model:
Keras:
20.6% that image 0 is 'This is a photo of 2 cats'
0.0% that image 1 is 'This is a photo of 2 dogs'
Loss: 6.132134437561035
Additional context
- This is reproducible with tpuv4-64 (and similar sizes) and tpuv2-8 (colab)
- I have tried the fix from: [Splash Attention] Remove Unnecessary head_dim_v Constraint and Update Scratch Array Shapes #27427 jax-ml/jax#27461, but it seems to give me a different error instead. It isn't merged yet, so changes may still happen.
Would you like to help us fix it?
- Frankly, I'm not sure how to fix the splash attention implementation in jax's pallas, so I think that having a way to disable it for tpus and use normal attention would be helpful.
- I was able to get it running by using normal attention in jax:
return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation= "xla",
)
before
keras/keras/src/backend/jax/nn.py
Line 1181 in 6d26efb
This just returns using normal flash attention.