Skip to content

Splash Attention is Broken on TPU Pods and does not follow keras.config.disable_flash_attention() #21116

Closed
@chaosmaster142857

Description

@chaosmaster142857

Describe the bug

  • Splash Attention breaks many small keras hub models on tpuv4-pods when using DataParallel. For instance I have tested a few siglip models and the clip model, and the same thing still happens.

To Reproduce
Colab Link

Expected behavior
The model to not crash with an error and to properly run the siglip model:
Keras:
20.6% that image 0 is 'This is a photo of 2 cats'
0.0% that image 1 is 'This is a photo of 2 dogs'
Loss: 6.132134437561035

Additional context

Would you like to help us fix it?

  • Frankly, I'm not sure how to fix the splash attention implementation in jax's pallas, so I think that having a way to disable it for tpus and use normal attention would be helpful.
  • I was able to get it running by using normal attention in jax:

return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation= "xla",
)

before


This just returns using normal flash attention.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions