Skip to content

Conversation

yewentao256
Copy link
Member

@yewentao256 yewentao256 commented Sep 19, 2025

Purpose

Context from @smarterclayton

Trying to test a very long context response - 128k input, 1 output token. I used DeepSeek-V3.1 and changed --max-model-len (although deepseek v3.1 is 128k automatically).
When I try to run a single request I'm getting an OOM:

(EngineCore_DP0 pid=948) ERROR 09-16 14:10:49 [v1/engine/core.py:720] torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 4.12 GiB. GPU 0 has a total capacity of 178.35 GiB of which 3.91 GiB is free. Including non-PyTorch memory, this process has 174.41 GiB memory in use. Of the allocated memory 154.86 GiB is allocated by PyTorch, with 3.79 GiB allocated in private pools (e.g., CUDA Graphs), and 1.95 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

I had 53GB for KVcache per GPU on this DP=16 2 node B200 config.

The main reason is we allocated too much mem for MLA chunk padding, this PR fixes the issue.

Note: As the comments said,

# For long-context models try not to over-allocate limiting
            # kv-cache space, limiting it to 64k tokens,
            # which would result in the workspace being:
            #   2*(576)*(64*1024) = 144mb
            # (assuming 576 MLA head dim, and fp16)
            # which would result in up-projected context being
            #   2*(192*128)*(64*1024) = 3gb
            # (assuming 192 QK head dim, 128 heads, and fp16)

We should assign 64 * 1024 instead of 128 * 1024 here as well, so this PR also fixes the consistency between comments and code.

**The OOM issue is reasonable if we have even more context length using limited GPU memory, considering add tp or reduce --gpu-memory-utilization 0.9 to a smaller number when OOM. **

Test

Now it is fixed.

============ Serving Benchmark Result ============
Successful requests:                     1         
Benchmark duration (s):                  186.05    
Total input tokens:                      129999    
Total generated tokens:                  1         
Request throughput (req/s):              0.01      
Output token throughput (tok/s):         0.01      
Peak output token throughput (tok/s):    1.00      
Peak concurrent requests:                1.00      
Total Token throughput (tok/s):          698.74    
---------------Time to First Token----------------
Mean TTFT (ms):                          167585.20 
Median TTFT (ms):                        167585.20 
P99 TTFT (ms):                           167585.20 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          0.00      
Median TPOT (ms):                        0.00      
P99 TPOT (ms):                           0.00      
---------------Inter-token Latency----------------
Mean ITL (ms):                           0.00      
Median ITL (ms):                         0.00      
P99 ITL (ms):                            0.00      
==================================================

Signed-off-by: yewentao256 <[email protected]>
@mergify mergify bot added the v1 label Sep 19, 2025
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 19, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical Out-Of-Memory (OOM) error for long-context inference with Multi-Layer Attention (MLA) by reducing the chunked_prefill_workspace_size. While the fix is correct in principle, I've identified a potential issue where the change could lead to an AssertionError with certain configurations, causing a crash. I've provided a suggestion to make the logic more robust and prevent this failure. Overall, a good fix for the OOM problem.

@yewentao256 yewentao256 changed the title [Bug] Fix Long Context OOM Issue [Bug] Partially Fix Long Context OOM Issue Sep 19, 2025
@yewentao256 yewentao256 changed the title [Bug] Partially Fix Long Context OOM Issue [Bug] Fix Long Context OOM Issue Sep 19, 2025
@mgoin mgoin added this to the v0.10.3 milestone Sep 20, 2025
@yewentao256
Copy link
Member Author

@LucasWilkinson CC

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; thanks!

@LucasWilkinson LucasWilkinson merged commit 4741239 into vllm-project:main Sep 23, 2025
60 of 61 checks passed
linfeng-yuan pushed a commit to linfeng-yuan/vllm that referenced this pull request Sep 23, 2025
@smarterclayton
Copy link
Contributor

smarterclayton commented Sep 23, 2025

On a DP=16 prefill B200 deepseek v3.1 config, where i should be able to handle 9 full length context requests per DP, I'm now hitting the assertion https://github.com/vllm-project/vllm/blame/273690a50ac2a5fa79fa7acc5077e49aa1af427e/vllm/v1/attention/backends/mla/common.py#L485:

(EngineCore_DP1 pid=1401) ERROR 09-23 09:15:38 [v1/engine/core.py:708]   File "/app/vllm/vllm/v1/worker/gpu_model_runner.py", line 3504, in create_attn_groups
(EngineCore_DP1 pid=1401) ERROR 09-23 09:15:38 [v1/engine/core.py:708]     attn_metadata_builders.append(attn_backend.get_builder_cls()(
(EngineCore_DP1 pid=1401) ERROR 09-23 09:15:38 [v1/engine/core.py:708]                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=1401) ERROR 09-23 09:15:38 [v1/engine/core.py:708]   File "/app/vllm/vllm/v1/attention/backends/mla/common.py", line 485, in __init__
(EngineCore_DP1 pid=1401) ERROR 09-23 09:15:38 [v1/engine/core.py:708]     assert self.chunked_prefill_workspace_size >= \
(EngineCore_DP1 pid=1401) ERROR 09-23 09:15:38 [v1/engine/core.py:708]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=1401) ERROR 09-23 09:15:38 [v1/engine/core.py:708] AssertionError
(EngineCore_DP1 pid=1401) DEBUG 09-23 09:15:38 [v1/worker/gpu_worker.py:285] Initial free memory: 177.65 GiB; Requested memory: 0.90 (util), 160.52 GiB
(EngineCore_DP1 pid=1401) DEBUG 09-23 09:15:38 [v1/worker/gpu_worker.py:292] Free memory after profiling: 101.37 GiB (total), 84.23 GiB (within requested)
(EngineCore_DP1 pid=1401) DEBUG 09-23 09:15:38 [v1/worker/gpu_worker.py:298] Memory profiling takes 237.50 seconds. Total non KV cache memory: 83.24GiB; torch peak memory increase: 10.20GiB; non-torch forward increase memory: 12.25GiB; weights memory: 60.79GiB.
(EngineCore_DP1 pid=1401) INFO 09-23 09:15:38 [v1/worker/gpu_worker.py:299] Available KV cache memory: 77.28 GiB
(EngineCore_DP1 pid=1401) INFO 09-23 09:15:38 [v1/core/kv_cache_utils.py:1087] GPU KV cache size: 1,180,672 tokens
(EngineCore_DP1 pid=1401) INFO 09-23 09:15:38 [v1/core/kv_cache_utils.py:1091] Maximum concurrency for 131,072 tokens per request: 9.01x

Reducing from 128k to 64k under the min() is less than num_seq (1024) * block_size (128 from cutlass MLA)

I expected after this change to be able to start this config with a long deepseekv3 context, but instead it exits immediately. I also can't start 65536 max tokens (not sure why)

@yewentao256 yewentao256 deleted the wye-fix-long-context-oom-issue branch September 23, 2025 14:43
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
yewentao256 added a commit that referenced this pull request Oct 3, 2025
gjc0824 pushed a commit to gjc0824/vllm that referenced this pull request Oct 10, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants