-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Core] Freeze gc during cuda graph capture to speed up init #21146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Freeze gc during cuda graph capture to speed up init #21146
Conversation
Signed-off-by: Codex <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to speed up CUDA graph capture by disabling garbage collection within the capture loop. This is achieved by introducing a new context manager. My review focuses on improving the implementation of this context manager to use more idiomatic and direct APIs for controlling garbage collection, which enhances code clarity and maintainability.
Signed-off-by: mgoin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I agree that maybe we need to do gc.collect
only once for the entire graphs (rather than once per graph), I think a more proper fix is to modify
stack.enter_context(patch("gc.collect", lambda: None)) |
cc @youkaichao @zou3519
On the PyTorch side, the motivation of pytorch/pytorch#158193 is that pytorch shouldn't force people to gc.collect when doing the cuda graph recording. Users should get that flexibility to choose. For vLLM, doing GC at least for at the beginningn of all piecewise captures seems good. The idea is that we want to free up memory so that we have enough memory to do the CUDAGraph captures. Doing GC once per shape is a reasonable thing to do, though -- the capture of each shape is a forward pass through the model, and it may be possible for there to be a reference cycle somewhere that holds onto memory. But maybe we should consider these situations to be bugs and fix them, the startup time savings are significant. |
But
|
Maybe we could patch the function to only call gc.collect every N invocations, where N could be like 10 or higher? Or a timer like the PyTorch PR. I think the piecewise method certainly isn't working and we should generalize it to full graphs too |
I tested a version that only calls
Function: @contextmanager
def suppress_gc_collect(call_interval: int):
"""
Reduce `gc.collect` frequency to speed up CUDA graph capture.
Only calls the original gc.collect every N invocations.
"""
original_gc_collect = gc.collect
call_count = 0
def throttled_gc_collect():
nonlocal call_count
call_count += 1
if call_count % call_interval == 0:
return original_gc_collect()
return None
with patch("gc.collect", throttled_gc_collect):
yield |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: mgoin <[email protected]>
@WoosukKwon @zou3519 @njhill I implemented the collect-then-freeze approach, which seems to provide the same benefit, PTAL
|
@mgoin nice! Looks like it may still be slightly slower than disabling explicit We had actually already intended to do a final |
@contextmanager | ||
def freeze_gc(): | ||
# Optimize garbage collection during CUDA graph capture. | ||
# Clean up, then freeze all remaining objects from being included | ||
# in future collections. | ||
gc.collect() | ||
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC | ||
if should_freeze: | ||
gc.freeze() | ||
try: | ||
yield | ||
finally: | ||
if should_freeze: | ||
gc.unfreeze() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I think we should have this in utils.py
or something like that. The model runner is becoming bloated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I'll try to do a separate PR that will consolidate with the current implementation in the piecewise backend
…ject#21146) Signed-off-by: Codex <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: 董巍 <[email protected]>
…ject#21146) Signed-off-by: Codex <[email protected]> Signed-off-by: mgoin <[email protected]>
…ject#21146) Signed-off-by: Codex <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: x22x22 <[email protected]>
…ject#21146) Signed-off-by: Codex <[email protected]> Signed-off-by: mgoin <[email protected]>
…ject#21146) Signed-off-by: Codex <[email protected]> Signed-off-by: mgoin <[email protected]>
…ject#21146) Signed-off-by: Codex <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
…ject#21146) Signed-off-by: Codex <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: Paul Pak <[email protected]>
…ject#21146) Signed-off-by: Codex <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
…ject#21146) Signed-off-by: Codex <[email protected]> Signed-off-by: mgoin <[email protected]>
Summary
Speed up cudagraph capture loops by calling
gc.freeze
before capture. This speeds up cudagraph capture a huge amount, especially for small models. Qwen3-0.6B goes from 35s to 2s.For the "proper" approach we should possible use pytorch/pytorch#158193 in a future torch release.
Testing
Before
After
https://chatgpt.com/codex/tasks/task_e_687972e21944832987a7bb6219d4c65b