Skip to content

Conversation

wseaton
Copy link
Contributor

@wseaton wseaton commented May 29, 2025

Stacked PR, needs to go in after: #18559 Edit: Not really stacked, these changes can go in on their own, PD+DP just won't work without the scheduling bugfixes.

Changes:

  • Engine IDs for KVConnector are now generated in the factory if they are unset, in the DP case this means each thread gets it's own unique id
  • Use world_rank instead of rank for the Nixl side channel port to ensure a unique side channel and no port collision in DP

The sum of these changes is that DP>1 now works w/ P/D enabled when using the NixlConnector.

Both DP1 and DP2 on the prefill server was tested against DP2 decode.

Using via the following configuration:

TP_SIZE := "1"
DP_SIZE := "2"

MODEL := "deepseek-ai/DeepSeek-V2-Lite"

PREFILL_GPUS := "2,3"
DECODE_GPUS := "6,7"


prefill:
    VLLM_SERVER_DEV_MODE=1 \
    VLLM_NIXL_SIDE_CHANNEL_PORT=$(just port 5557) \
    UCX_LOG_LEVEL=warn \
    CUDA_VISIBLE_DEVICES={{PREFILL_GPUS}} \
    VLLM_LOGGING_LEVEL="DEBUG" \
    VLLM_WORKER_MULTIPROC_METHOD=spawn \
    VLLM_ENABLE_V1_MULTIPROCESSING=0 \
    vllm serve {{MODEL}} \
      --port $(just port 8100) \
      --tensor-parallel-size {{TP_SIZE}} \
     --data-parallel-size {{DP_SIZE}} \
      --gpu-memory-utilization {{MEMORY_UTIL}} \
      --trust-remote-code \
      --max-model-len 2048 \
      --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

decode:
    PYTHONFAULTHANDLER=1 \
    VLLM_SERVER_DEV_MODE=1 \
    VLLM_NIXL_SIDE_CHANNEL_PORT=$(just port 5558) \
    CUDA_VISIBLE_DEVICES={{DECODE_GPUS}} \
    UCX_LOG_LEVEL=warn \
    NCCL_DEBUG=INFO \
    NIXL_LOG_LEVEL=DEBUG \
    VLLM_LOGGING_LEVEL="DEBUG" \
    VLLM_WORKER_MULTIPROC_METHOD=spawn \
    VLLM_ENABLE_V1_MULTIPROCESSING=0 \
    vllm serve {{MODEL}} \
      --port $(just port 8200) \
      --tensor-parallel-size {{TP_SIZE}} \
      --data-parallel-size {{DP_SIZE}} \
      --gpu-memory-utilization {{MEMORY_UTIL}} \
      --trust-remote-code \
      --max-model-len 2048 \
      --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@wseaton wseaton changed the title [P/D] get the NixlConnector working w/ DP [P/D] NixlConnector DP fixes May 29, 2025
@wseaton
Copy link
Contributor Author

wseaton commented May 29, 2025

#18576 was merged since I forked, let me see what the impact of this is on the unique engine_id situation, should be a quick test.

@wseaton
Copy link
Contributor Author

wseaton commented May 29, 2025

#18576 does not go far enough, DP ranks do not get unique engine ids. cc @tlrmchlsmth

(EngineCore_1 pid=4087553) INFO 05-29 09:54:35 [factory.py:82] Creating v1 connector with name: NixlConnector and engine_id: b58bcd96-e345-4209-be06-7238f99fe819
(EngineCore_1 pid=4087553) INFO 05-29 09:54:35 [nixl_connector.py:175] Initializing NIXL Scheduler b58bcd96-e345-4209-be06-7238f99fe819
(EngineCore_0 pid=4087551) INFO 05-29 09:54:35 [factory.py:82] Creating v1 connector with name: NixlConnector and engine_id: b58bcd96-e345-4209-be06-7238f99fe819
(EngineCore_0 pid=4087551) INFO 05-29 09:54:35 [nixl_connector.py:175] Initializing NIXL Scheduler b58bcd96-e345-4209-be06-7238f99fe819

wseaton added 2 commits May 29, 2025 10:01
Signed-off-by: Will Eaton <[email protected]>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label May 29, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) May 29, 2025 15:33
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wseaton

@tlrmchlsmth tlrmchlsmth merged commit 64eaf5f into vllm-project:main May 29, 2025
77 checks passed
amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants