Skip to content

Splash Attention is Broken on TPU Pods and does not follow keras.config.disable_flash_attention() #21116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
chaosmaster142857 opened this issue Apr 2, 2025 · 4 comments

Comments

@chaosmaster142857
Copy link

Describe the bug

  • Splash Attention breaks many small keras hub models on tpuv4-pods when using DataParallel. For instance I have tested a few siglip models and the clip model, and the same thing still happens.

To Reproduce
Colab Link

Expected behavior
The model to not crash with an error and to properly run the siglip model:
Keras:
20.6% that image 0 is 'This is a photo of 2 cats'
0.0% that image 1 is 'This is a photo of 2 dogs'
Loss: 6.132134437561035

Additional context

Would you like to help us fix it?

  • Frankly, I'm not sure how to fix the splash attention implementation in jax's pallas, so I think that having a way to disable it for tpus and use normal attention would be helpful.
  • I was able to get it running by using normal attention in jax:

return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation= "xla",
)

before


This just returns using normal flash attention.

@chaosmaster142857
Copy link
Author

To fix this, you can also just check the value of use_flash_attention in keras's config.

@sonali-kumari1
Copy link
Contributor

Hi @chaosmaster142857 -

I have tested your code with v2-8 TPU(colab) and was unable to reproduce the crash. The siglip model executed properly with the following output:

Keras:
20.5% that image 0 is 'This is a photo of 2 cats'
0.0% that image 1 is 'This is a photo of 2 dogs'
Loss: 6.136449337005615

I am attaching a gist for your reference. Thanks!

@chaosmaster142857
Copy link
Author

Looks like this was fixed with #21254, recommeding closure

@sonali-kumari1
Copy link
Contributor

Hi @chaosmaster142857, Please feel free to close this issue. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants