Skip to content

Conversation

robertgshaw2-redhat
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat commented May 23, 2025

SUMMARY:

  • in case of full prefix cache hit locally on D worker, we are leaking memory on the P worker side since we are not currently calling send_notif since we skip calling update_state_after_alloc
  • also fixes the path where we do get a cache hit, which was passing the wrong thing

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label May 23, 2025
@robertgshaw2-redhat
Copy link
Collaborator Author

@njhill - can you let me know if this works okay with multi-connector?

Comment on lines +241 to +248
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (blocks.get_unhashed_block_ids()
if num_external_tokens > 0 else [])
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
request, local_block_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-redhat I'm still not sure that this part or the change to always call update_state_after_alloc is needed. I'd already added logic for this case in get_num_new_matched_tokens above:

# NOTE: if count is 0 here, we have less than block_size
# tokens to pull after subtracting the local prefix cache hit.
# The remote only sends fully computed blocks, so there is
# nothing to transfer but we still need to notify the
# prefill worker so that the remote blocks are freed.
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
self._reqs_need_recv[request.request_id] = (request, [])

I can see that the other two fixes below in build_connector_meta and _read_blocks are of course needed though.

If you think it's better to have this logic in this method then we can remove it from the other one. But again I feel it's logically clearer to not call update_state_after_alloc if 0 was returned from get_num_new_matched_tokens.

Copy link
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that get_num_new_matched_tokens should be a pure function. Adding a side effect to it is surprising given the name of the method and the fact that we will have different behavior depending on what happens if the request is or is not able to be scheduled. This issue is actually causing a bug right now.

  • If allocate_slots returns None, the request will remain in the waiting queue. this will cause us to add the requests to reqs_need_recv more than one and as a result we will call read_blocks twice which will do a double free on the P worker side. Similarly this will happen if the request is preempted (it will get re-added to waiting). This is because we are not properly updating the request to have do_remote_prefill=False when it is added to reqs_need_recv from the get_num_new_matched_tokens function.

This is all just evidence that putting a side effect into this function is not a good idea. The update_state_after_alloc is where we should handle everything related to reqs_need_recv so we have a single place where all the logic is handled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed those lines from get_num_new_matched_tokens

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-redhat that makes sense, I agree about the pure function thing. I did also notice the fact that this could result in a double free on the P worker side in the case that it can't be scheduled, which isn't ideal (though I think would probably be harmless).

But to me, thinking from the pov of a generic connector interface, it still feels a bit odd given the connector isn't offering any tokens. I guess we should very clearly document the semantics and expectations for the interface.

A related quirk is that in the async load case, I think currently update_state_after_alloc will be called twice for a request (a second time once the request moves out of WAITING_FOR_REMOTE_KVS).

if count > 0:
return count, True

# NOTE: if count is 0 here, we have less than block_size
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now handled in update_state_after_alloc

@njhill
Copy link
Member

njhill commented May 25, 2025

@robertgshaw2-redhat changes will be needed to multi-connector too, I've pushed them to a branch, feel free to pull into this PR: njhill@4150a41

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, with the multi-connector changes

njhill and others added 5 commits June 2, 2025 14:03
- Call get_num_new_matched_tokens for every connector
- Call update_state_after_alloc for every connector, but with no blocks/tokens for all but the "chosen" connector (the first one to return non-zero tokens from get_num_new_matched_tokens).

Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
@njhill
Copy link
Member

njhill commented Jun 3, 2025

I just have one more thing to fix up in the mulit-connector test now that the semantics have changed.

Copy link

mergify bot commented Jun 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @robertgshaw2-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 3, 2025
…ix-cache-hit

# Conflicts:
#	vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@mergify mergify bot removed the needs-rebase label Jun 4, 2025
@njhill njhill requested a review from mgoin June 4, 2025 16:20
@mgoin
Copy link
Member

mgoin commented Jun 4, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a memory leak in the Nixl connector related to full prefix cache hits and also fixes an issue with how cache hits were handled. The core of the fix involves ensuring update_state_after_alloc is consistently called across all relevant connectors, even when no external tokens are loaded, which allows for proper notification and resource cleanup on the P worker side. The changes in MultiConnector and NixlConnectorScheduler are key to this. Test cases have been updated appropriately to reflect these changes and the enhanced logging.

Overall, the changes look good and directly target the described issues. I have one point for clarification regarding a behavioral change in MultiConnector.get_num_new_matched_tokens.

Summary of Findings

  • Memory Leak Fix in Nixl Connector: The primary goal of this PR, fixing a memory leak on the Nixl P-worker during full prefix cache hits on the D-worker, appears to be successfully addressed. The core changes ensure that update_state_after_alloc is called for all relevant connector components, allowing for proper notifications and resource cleanup.
  • Behavioral Change in MultiConnector.get_num_new_matched_tokens: The get_num_new_matched_tokens method in MultiConnector now iterates through all sub-connectors, calling the method on each, even if a match was found earlier. Clarification on the necessity and impact of this change would be beneficial.
  • Test Coverage and Logging: The tests in test_multi_connector.py have been updated to reflect the new logic and include more detailed event logging, which is good for verifying the fix and aiding future debugging.

Merge Readiness

The pull request seems to address the reported memory leak effectively. The changes are logical and the tests have been updated accordingly. There is one point regarding a behavioral change in MultiConnector.get_num_new_matched_tokens that would benefit from clarification. Assuming this behavior is intended and understood, the PR appears to be in good shape for merging after addressing or clarifying that point. As an AI, I am not authorized to approve pull requests; this assessment is based on the code review.

@njhill njhill enabled auto-merge (squash) June 4, 2025 22:26
Copy link

mergify bot commented Jun 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @robertgshaw2-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 4, 2025
…ix-cache-hit

Signed-off-by: Nick Hill <[email protected]>

# Conflicts:
#	vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@mergify mergify bot removed the needs-rebase label Jun 4, 2025
@njhill njhill merged commit c56ed8b into vllm-project:main Jun 5, 2025
70 checks passed
leoli1208 pushed a commit to leoli1208/vllm that referenced this pull request Jul 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants