Skip to content

CUDA: attention sinks for mma FlashAttention #15157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adds support for attention sinks to the CUDA FlashAttention kernel using tensor cores. On Turing or newer attention sinks are now fully supported.

Performance changes
GPU Model Microbatch size Test t/s master t/s cuda-fa-mma-sink-2 Speedup
RTX 3090 gpt-oss 20B MXFP4 MoE 1 pp16384 140.73 157.42 1.12
RTX 3090 gpt-oss 20B MXFP4 MoE 2 pp16384 127.88 144.44 1.13
RTX 3090 gpt-oss 20B MXFP4 MoE 4 pp16384 189.83 234.50 1.24
RTX 3090 gpt-oss 20B MXFP4 MoE 8 pp16384 254.20 345.49 1.36
RTX 3090 gpt-oss 20B MXFP4 MoE 16 pp16384 322.00 486.15 1.51
RTX 3090 gpt-oss 20B MXFP4 MoE 32 pp16384 380.96 635.92 1.67
RTX 3090 gpt-oss 20B MXFP4 MoE 64 pp16384 504.90 1076.35 2.13
RTX 3090 gpt-oss 20B MXFP4 MoE 128 pp16384 619.46 1816.28 2.93
RTX 3090 gpt-oss 20B MXFP4 MoE 256 pp16384 676.09 2622.94 3.88
RTX 3090 gpt-oss 20B MXFP4 MoE 512 pp16384 690.35 3337.20 4.83
RTX 3090 gpt-oss 20B MXFP4 MoE 1024 pp16384 708.09 3888.88 5.49
RTX 3090 gpt-oss 20B MXFP4 MoE 2048 pp16384 681.68 4099.78 6.01
RTX 3090 gpt-oss 20B MXFP4 MoE 4096 pp16384 682.19 4106.01 6.02
RTX 3090 gpt-oss 20B MXFP4 MoE 8192 pp16384 682.25 4101.10 6.01
RTX 3090 gpt-oss 20B MXFP4 MoE 16384 pp16384 681.61 4086.37 6.00
RTX 4090 gpt-oss 20B MXFP4 MoE 1 pp16384 238.95 244.15 1.02
RTX 4090 gpt-oss 20B MXFP4 MoE 2 pp16384 199.57 208.26 1.04
RTX 4090 gpt-oss 20B MXFP4 MoE 4 pp16384 322.24 358.46 1.11
RTX 4090 gpt-oss 20B MXFP4 MoE 8 pp16384 476.46 566.93 1.19
RTX 4090 gpt-oss 20B MXFP4 MoE 16 pp16384 641.87 840.95 1.31
RTX 4090 gpt-oss 20B MXFP4 MoE 32 pp16384 818.55 1190.88 1.45
RTX 4090 gpt-oss 20B MXFP4 MoE 64 pp16384 1138.38 2024.07 1.78
RTX 4090 gpt-oss 20B MXFP4 MoE 128 pp16384 1477.22 3453.34 2.34
RTX 4090 gpt-oss 20B MXFP4 MoE 256 pp16384 1650.90 4900.11 2.97
RTX 4090 gpt-oss 20B MXFP4 MoE 512 pp16384 1704.01 6090.15 3.57
RTX 4090 gpt-oss 20B MXFP4 MoE 1024 pp16384 1730.74 6650.04 3.84
RTX 4090 gpt-oss 20B MXFP4 MoE 2048 pp16384 1651.59 6596.91 3.99
RTX 4090 gpt-oss 20B MXFP4 MoE 4096 pp16384 1651.71 6582.91 3.99
RTX 4090 gpt-oss 20B MXFP4 MoE 8192 pp16384 1648.47 6581.47 3.99
RTX 4090 gpt-oss 20B MXFP4 MoE 16384 pp16384 1648.81 6570.34 3.98
3x RTX 4090 gpt-oss 120B MXFP4 MoE 1 pp16384 217.68 221.46 1.02
3x RTX 4090 gpt-oss 120B MXFP4 MoE 2 pp16384 115.67 120.47 1.04
3x RTX 4090 gpt-oss 120B MXFP4 MoE 4 pp16384 178.51 196.02 1.10
3x RTX 4090 gpt-oss 120B MXFP4 MoE 8 pp16384 252.54 291.57 1.15
3x RTX 4090 gpt-oss 120B MXFP4 MoE 16 pp16384 324.61 397.50 1.22
3x RTX 4090 gpt-oss 120B MXFP4 MoE 32 pp16384 389.73 508.91 1.31
3x RTX 4090 gpt-oss 120B MXFP4 MoE 64 pp16384 541.84 795.13 1.47
3x RTX 4090 gpt-oss 120B MXFP4 MoE 128 pp16384 705.35 1217.79 1.73
3x RTX 4090 gpt-oss 120B MXFP4 MoE 256 pp16384 867.21 1811.41 2.09
3x RTX 4090 gpt-oss 120B MXFP4 MoE 512 pp16384 967.18 2426.25 2.51
3x RTX 4090 gpt-oss 120B MXFP4 MoE 1024 pp16384 1028.36 3076.55 2.99
3x RTX 4090 gpt-oss 120B MXFP4 MoE 2048 pp16384 1007.56 3343.48 3.32
3x RTX 4090 gpt-oss 120B MXFP4 MoE 4096 pp16384 1006.97 3339.11 3.32
3x RTX 4090 gpt-oss 120B MXFP4 MoE 8192 pp16384 1006.49 3315.28 3.29
3x RTX 4090 gpt-oss 120B MXFP4 MoE 16384 pp16384 1006.69 3335.57 3.31

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Aug 7, 2025
@slaren
Copy link
Member

slaren commented Aug 7, 2025

The reason you get the same performance for ubatch sizes >= 2048 is likely because the default batch size of llama-bench is 2048, and you haven't changed it. You need to increase the batch size as well if you want to try ubatch sizes bigger than 2048, otherwise llama-bench will only submit batches of 2048 tokens to llama.cpp.

@JohannesGaessler
Copy link
Collaborator Author

I thought we made it at some point that a physical batch size > 2048 would also raise the logical batch size but I guess I misremembered.

@slaren
Copy link
Member

slaren commented Aug 7, 2025

5090 (WSL)
Model Microbatch size Test t/s master t/s cuda-fa-mma-sink-3 Speedup
gpt-oss ?B MXFP4 MoE 1 pp2048 296.86 297.47 1.00
gpt-oss ?B MXFP4 MoE 2 pp2048 135.28 131.37 0.97
gpt-oss ?B MXFP4 MoE 4 pp2048 232.73 243.19 1.04
gpt-oss ?B MXFP4 MoE 8 pp2048 439.78 416.96 0.95
gpt-oss ?B MXFP4 MoE 16 pp2048 653.59 735.40 1.13
gpt-oss ?B MXFP4 MoE 32 pp2048 1052.25 1039.31 0.99
gpt-oss ?B MXFP4 MoE 64 pp2048 1566.95 2026.63 1.29
gpt-oss ?B MXFP4 MoE 128 pp2048 2878.56 3554.60 1.23
gpt-oss ?B MXFP4 MoE 256 pp2048 4202.15 5624.48 1.34
gpt-oss ?B MXFP4 MoE 512 pp2048 5128.63 8271.12 1.61
gpt-oss ?B MXFP4 MoE 1024 pp2048 6023.05 10134.61 1.68
gpt-oss ?B MXFP4 MoE 2048 pp2048 5490.28 10833.91 1.97

@slaren
Copy link
Member

slaren commented Aug 7, 2025

Not sure if it is caused by the changes in this branch, but I was trying the 120B model and it got stuck in a loop:
image

@ggerganov
Copy link
Member

ggerganov commented Aug 7, 2025

Not sure if it is caused by the changes in this branch, but I was trying the 120B model and it got stuck in a loop:

I was thinking today that the clamp with -9999 instead of -inf might not be correct and needs an extra look.

nvm: this was fixed when you fused the activation function

@slaren
Copy link
Member

slaren commented Aug 7, 2025

It might have been a fluke, I regenerated the same prompt a few times and it didn't happen again.

@JohannesGaessler
Copy link
Collaborator Author

When I set SOFTMAX_FTZ_THRESHOLD to -50 instead of -20 the perplexity on 1 chunk of wikitext goes down from 214.9317 to 191.5787. So maybe this is what's causing the issues since the vector kernels don't use it. This was kind of a hack in the first place to avoid NaNs from FP16 underflows, if it's now causing issues with GPT-OSS there may be a better way to handle this.

@JohannesGaessler
Copy link
Collaborator Author

When I calculate perplexity for the entire wikitext-2 test set I get virtually no difference. 244.1368 for a threshold of -20, 243.8748 for a threshold of -50.

@@ -3532,7 +3532,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
}
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the !fp16_mma_available check needed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for model support because the vector kernels I think cover all currently available models with attention sinks. But this enables running tests with attention sinks and head sizes != 64/128 so I thought it would be better to adjust.

@JohannesGaessler JohannesGaessler merged commit 1425f58 into ggml-org:master Aug 8, 2025
47 checks passed
@JohannesGaessler
Copy link
Collaborator Author

Just by looking at the code I am not seeing anything that would be wrong with it. For the next few days I should be available with low latency so I'll monitor the issues to see whether there is a systematic issue.

@abrimogard
Copy link

This may have caused some performance degradation. Anecdotally, I was running Qwen3-30B-A3B (Q5_XL) on a single RTX 3090. Token generation speed was ~120 TPS before this PR and now it's ~60 TPS after building latest.

@@ -282,7 +282,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);

// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
if (sinks) {
if (sinks && !fp16_mma_available(cc)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler I think this breaks Volta, since fp16_mma_available is true but the wmma kernel doesn't yet support attention sinks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this should be turing_mma_available.

Thireus added a commit to Thireus/ik_llama.cpp that referenced this pull request Aug 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants