Skip to content

Conversation

hyoon1
Copy link
Contributor

@hyoon1 hyoon1 commented Feb 25, 2025

Add additional custom paged attention kernels for AMD Navi 3x/4x GPU support based on PR: #12348
Due to the differences in architecture from MI, specific instructions and detailed logic have changed (mfma16 -> wmma16/wmma16_gfx12), so new kernels for each architecture has been added.

  • Supports cases where head_size == 128 and block_size == 16.
  • It does not support alibi_slopes and kv_cache_dtype == fp8.
  • It supports gqa_ratio up to 16, and shows performance gains over the existing kernel when gqa_ratio is 3 or higher. Therefore, it is enabled for gqa_ratio values between 3 and 16.
  • Fixed paged attention unit test to pass on Navi

Performance Gain
Script: python ./benchmarks/benchmark_throughput.py --model <path_to_model> --trust-remote-code --dataset <ShareGPT_V3_unfiltered_cleaned_split.json> --num_prompts 1000 --max-model-len 4096 --gpu-memory-utilization 0.95

Navi 3

Models Num Heads GQA Ratio Output Token/s (original) Output Token/s (custom) Gain
glm-4-9b-chat 32 16 991.43 1113.75 12.3%
chatglm3-6b 32 16 1442.07 1554.23 7.8%
Meta-Llama-3.1-8B-Instruct 32 4 1143.65 1221.75 6.8%
Llama-3.2-3B-Instruct 24 3 2058.97 2146.62 4.3%
Qwen1.5-7B-Chat 32 1 904.46 882.53 -2.4%

Navi 4

Models Num Heads GQA Ratio Output Token/s (original) Output Token/s (custom) Gain
glm-4-9b-chat 32 16 1195.56 1433.13 19.9%
chatglm3-6b 32 16 1750.34 1962.21 12.1%
Meta-Llama-3.1-8B-Instruct 32 4 1405.42 1516.69 7.9%
Llama-3.2-3B-Instruct 24 3 2419.31 2561.47 5.9%
Qwen1.5-7B-Chat 32 1 765.6 761.3 -0.6%

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Feb 25, 2025
@hyoon1 hyoon1 force-pushed the custom_pa_navi branch 2 times, most recently from 47d29aa to 7128062 Compare March 3, 2025 20:37
Copy link

mergify bot commented Mar 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hyoon1.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 3, 2025
@mergify mergify bot removed the needs-rebase label Mar 3, 2025
@hyoon1 hyoon1 changed the title [ROCm] Enable custom paged attention kernel for Navi3x [ROCm] Enable custom paged attention kernel for Navi3/4 Mar 6, 2025
Copy link

mergify bot commented Mar 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hyoon1.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 6, 2025
@mergify mergify bot removed the needs-rebase label Mar 6, 2025
@hongxiayang hongxiayang added the rocm Related to AMD ROCm label Mar 10, 2025
@LucasWilkinson
Copy link
Collaborator

Do we know how this stacks up against the new AMD triton kernels? (cc @SageMoore for direction to the kernels)

@SageMoore
Copy link
Contributor

Unfortunately, that kernel is only integrated into V1 right now. We should definitely integrate that kernel into V0 and see what the performance is like, though.

Here's the kernel in question if you are interested: https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/chunked_prefill_paged_decode.py#L186

@liangshen68
Copy link

@tlrmchlsmth and @WoosukKwon, could you please help to review and approve this PR, which provides good e2e perf gain for SOTA models running on AMD Radeon GPUs using vLLM? Thanks.

@tlrmchlsmth
Copy link
Member

Hi @hyoon1, thank you for your contribution.

I am hesitant to review and accept this PR, mainly because it only applies to vLLM V0. We are imminently going to switch to V1 by default starting with the 0.8.0 release, which will bring with it large performance improvements. V1 natively uses chunked-prefill, and as I understand it, this kernel doesn't fit easily into that case.

For V1 I think we have a good solution in the triton kernels added in #14152, but would also be interested in seeing how the kernels compare

@hyoon1
Copy link
Contributor Author

hyoon1 commented Mar 18, 2025

Hi @tlrmchlsmth thanks for letting me know about new V1 kernel.

As you mentioned, there seem to be significant performance improvements in V1 due to the new prefill/decode method. As a result, I compared the performance of V0 with custom_paged_attention (CPA) applied to V1 with the new kernel.

The comparison showed that, as of now, the approach with CPA applied to V0 performs better on AMD Navi GPUs in terms of output token throughput. Although it seems that V1 is not yet fully optimized, we do have customers who desire high performance with AMD Navi GPUs. Therefore, until V1 surpasses the optimized V0 in performance, we want to offer the optimized V0 option. Additionally, we have other optimization elements that can further improve performance, and we plan to submit more pull requests.

Here are benchmark results from Navi3x/Navi4x GPU system:

python ./benchmarks/benchmark_serving.py  --model /models/Llama-3.1-8B-Instruct --dataset-name sharegpt --dataset-path /sharegpt/ShareGPT_V3_unfiltered_cleaned_split.json

Navi4 GPU
V1

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  229.43
Total input tokens:                      215196
Total generated tokens:                  197286
Request throughput (req/s):              4.36
Output token throughput (tok/s):         859.89
Total Token throughput (tok/s):          1797.85
---------------Time to First Token----------------
Mean TTFT (ms):                          85488.14
Median TTFT (ms):                        91342.26
P99 TTFT (ms):                           192146.97
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          246.50
Median TPOT (ms):                        192.40
P99 TPOT (ms):                           786.78
---------------Inter-token Latency----------------
Mean ITL (ms):                           181.05
Median ITL (ms):                         101.11
P99 ITL (ms):                            784.66
==================================================

V0 + this PR (CPA)

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  200.67
Total input tokens:                      215196
Total generated tokens:                  197999
Request throughput (req/s):              4.98
Output token throughput (tok/s):         986.71
Total Token throughput (tok/s):          2059.12
---------------Time to First Token----------------
Mean TTFT (ms):                          69061.01
Median TTFT (ms):                        64760.14
P99 TTFT (ms):                           159312.67
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          296.71
Median TPOT (ms):                        231.04
P99 TPOT (ms):                           1870.42
---------------Inter-token Latency----------------
Mean ITL (ms):                           213.11
Median ITL (ms):                         134.48
P99 ITL (ms):                            776.32
==================================================

Navi3 GPU
V1

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  212.18
Total input tokens:                      215196
Total generated tokens:                  198024
Request throughput (req/s):              4.71
Output token throughput (tok/s):         933.27
Total Token throughput (tok/s):          1947.48
---------------Time to First Token----------------
Mean TTFT (ms):                          43194.99
Median TTFT (ms):                        40183.75
P99 TTFT (ms):                           97066.20
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          532.21
Median TPOT (ms):                        504.83
P99 TPOT (ms):                           795.70
---------------Inter-token Latency----------------
Mean ITL (ms):                           363.82
Median ITL (ms):                         244.72
P99 ITL (ms):                            818.46
==================================================

V0 + this PR (CPA)

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  204.61
Total input tokens:                      215196
Total generated tokens:                  197859
Request throughput (req/s):              4.89
Output token throughput (tok/s):         966.99
Total Token throughput (tok/s):          2018.71
---------------Time to First Token----------------
Mean TTFT (ms):                          66983.01
Median TTFT (ms):                        62685.91
P99 TTFT (ms):                           156777.94
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          288.71
Median TPOT (ms):                        231.92
P99 TPOT (ms):                           1687.35
---------------Inter-token Latency----------------
Mean ITL (ms):                           212.65
Median ITL (ms):                         133.75
P99 ITL (ms):                            811.25
==================================================

Copy link

mergify bot commented Apr 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hyoon1.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@hyoon1
Copy link
Contributor Author

hyoon1 commented Apr 22, 2025

Closing this PR in favor of new v1 support: #17004

@hyoon1 hyoon1 closed this Apr 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants