Skip to content

Conversation

libertyeagle
Copy link

@libertyeagle libertyeagle commented Oct 6, 2025

Purpose

PR #20775 introduces initial support of elastic expert parallelism. This PR adds further optimizations towards Milestone 2 in #20323. Key features include:

  • Breakdown the scale up/down logic into a state machine of multiple stages, with their execution controlled in vllm/distributed/elastic_ep/elastic_state.py and vllm/distributed/elastic_ep/elastic_execute.py.
  • Newly started workers receive all weights (non-MoE modules and expert weights) from peer GPUs.
  • We no longer need to drop traffic during scale up/down. During scale up, existing workers can continue to serve requests until new workers are ready (non-expert weights are already received and prepare to compile/warmup the model). Existing workers will progressively reconfigure to new EP size in DPEngineCoreProc. In run_busy_loop, elastic_scaling_state.progress() is called to progress reconfiguration by one step if ready. If reconfiguration cannot continue, existing workers continue to serve requests. Such progressive reconfiguration between forward steps also helps to quickly finish in-flight user requests, prevent requests from queuing up and improve SLO attainment.
  • If elastic EP is enabled (—enable-elastic-ep), all EP/DP communicators will be replaced by vllm/distributed/stateless_coordinator.py that is independent of torch.distributed’s global state. We can therefore create standby communicators while keeping the current ones, enabling the bootstrap of new workers to overlap with request serving on existing workers. We only need to do a single switch to use the new communicators after we are ready to switch to the new setup.
  • For scale-up, delay EPLB reshuffle until reconfiguration is finished. Newly joined workers can dispatch tokens to the original set of GPUs for expert computation, while experts can be progressively reshuffled to include the newly joined GPUs.
  • Support for enabling CUDA graphs, which is critical to performance especially in decode mode. In this PR, on existing workers, we will destroy compiled model and all captured CUDA graphs, followed by recompiling and recapturing all graphs. See switch_and_prepare() in vllm/distributed/elastic_ep/elastic_execute.py. We will introduce optimizations on CUDA graphs in follow-up PRs.

There are also some minor bug fixes including:

  • Fix ray resources discovery and engine zmq addr when scaling from intra-node to inter-node settings.
  • Fix the issue that throughput logging is not reported after scale up.

Test Plan

We test the performance before scale up and after scale on using Qwen/Qwen3-30B-A3B-Thinking-2507-FP8. The number of physical experts per GPU is set to 72. We note that the number of local physical experts remain the same during scale up and down, while the total number of redundant experts scales accordingly, which is the same assumption as in PR #20775. We use PPLX kernels (intra-node mode that does not require NVSHMEM) and enable CUDA graphs using default settings.

MODEL_NAME="Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"
vllm serve $MODEL_NAME --trust-remote-code \
    --disable-log-requests \
    --host $HOST \
    --port $PORT \
    --tensor-parallel-size 1 \
    --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
    --max-model-len $MAX_MODEL_LEN \
    --no-enable-prefix-caching \
    --enable-expert-parallel \
    --enable-elastic-ep \
    --enable-eplb \
    --eplb-config.num_redundant_experts $NUM_REDUNDANT_EXPERTS \
    --eplb-config.window_size $EPLB_WINDOW_SIZE \
    --eplb-config.step_interval $EPLB_STEP_INTERVAL \
    --data-parallel-backend ray \
    --data-parallel-size $DATA_PARALLEL_SIZE \
    --data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \
    --data-parallel-address $LEADER_ADDRESS \
    --data-parallel-rpc-port 9876 \
    --data-parallel-start-rank 0

To scale up we use:

python examples/online_serving/elastic_ep/scale.py --host $HOST --port $PORT --new-dp-size $NEW_DATA_PARALLEL_SIZE

Test Results

We use the following benchmark script.

vllm bench serve \
    --model $MODEL_NAME \
    --host $HOST \
    --port $PORT \
    --dataset-name random \
    --random-input-len 256 \
    --random-output-len 128 \
    --num-prompts 512

Serving on 2 GPUs (EP=2, TP=1) before scaling up:

============ Serving Benchmark Result ============
Successful requests:                     512       
Benchmark duration (s):                  15.85     
Total input tokens:                      130815    
Total generated tokens:                  65478     
Request throughput (req/s):              32.30     
Output token throughput (tok/s):         4131.03   
Peak output token throughput (tok/s):    17408.00  
Peak concurrent requests:                512.00    
Total Token throughput (tok/s):          12384.18  
---------------Time to First Token----------------
Mean TTFT (ms):                          6870.52   
Median TTFT (ms):                        7559.63   
P99 TTFT (ms):                           12107.77  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          69.94     
Median TPOT (ms):                        64.56     
P99 TPOT (ms):                           109.25    
---------------Inter-token Latency----------------
Mean ITL (ms):                           69.90     
Median ITL (ms):                         29.54     
P99 ITL (ms):                            1443.20   
==================================================

Serving on 4 GPUs (EP=4, TP=1) after scaling up:

============ Serving Benchmark Result ============
Successful requests:                     512       
Benchmark duration (s):                  9.89      
Total input tokens:                      130815    
Total generated tokens:                  65415     
Request throughput (req/s):              51.75     
Output token throughput (tok/s):         6612.23   
Peak output token throughput (tok/s):    18802.00  
Peak concurrent requests:                512.00    
Total Token throughput (tok/s):          19835.17  
---------------Time to First Token----------------
Mean TTFT (ms):                          4089.23   
Median TTFT (ms):                        4812.20   
P99 TTFT (ms):                           6322.47   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.82     
Median TPOT (ms):                        44.26     
P99 TPOT (ms):                           62.10     
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.91     
Median ITL (ms):                         27.23     
P99 ITL (ms):                            1481.01   
==================================================

Next Steps

  • PR 2/N: Support elastic EP kernels and weight communicators (e.g., P2P transfer engines like Mooncake and NIXL).
  • PR 3/N: CUDA graph capture cost optimization: enabling incremental CUDA graph updates while serving traffic, enabling CUDA graph memory pool optimizations to minimize new memory allocation during CUDA graph updates.
  • PR N/N: Further cost optimization (e.g., torch.compile cache management, incremental EPLB and incremental non-expert weight transfer); support more kernels (e.g., regular DeepEP), scheduler optimization to migrate dispatched requests to newly started workers for load balancing; …

CC List

@abmfy @ruisearch42 @simon-mo @tlrmchlsmth @njhill @kouroshHakha

Copy link

github-actions bot commented Oct 6, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the v1 label Oct 6, 2025
Copy link

mergify bot commented Oct 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @libertyeagle.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 6, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations for elastic expert parallelism, building upon initial support. The key changes include a new state machine for scaling up/down, peer-to-peer weight transfer for new workers, and progressive reconfiguration to avoid dropping traffic during scaling operations. The introduction of stateless communicators independent of torch.distributed's global state is a major architectural shift enabling these features. My review has identified a critical bug in the state machine logic and several high-severity issues related to fragile implementation details that could lead to future breakages. Overall, this is a substantial and well-structured contribution, but the identified issues should be addressed to ensure robustness and correctness.

Comment on lines +256 to +319
def get_next_stateless_world_group_port(self) -> list[int]:
return self._stateless_world_group_port_list.pop(0)

def get_next_stateless_dp_group_port(self) -> list[int]:
return self._stateless_dp_group_port_list.pop(0)

def get_next_stateless_ep_group_port(self) -> list[int]:
return self._stateless_ep_group_port_list.pop(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These methods use pop(0) to retrieve a port from a list without checking if the list is empty. If the port lists (_stateless_world_group_port_list, _stateless_dp_group_port_list, _stateless_ep_group_port_list) are exhausted for any reason, this will raise an IndexError and crash the process. While the logic in __post_init__ seems to pre-allocate the necessary ports, this design is fragile. A more robust implementation would be to check if the list is empty before popping and raise a more informative error message.

Comment on lines 98 to 119
# Check if this is a stateless process group
from torch.distributed.distributed_c10d import _world
is_stateless = _world.pg_map.get(cpu_group, None) is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check _world.pg_map.get(cpu_group, None) is None relies on an internal, undocumented implementation detail of torch.distributed to determine if a process group is stateless. This is a brittle approach that could break with future PyTorch updates. It would be more robust to use an explicit mechanism to identify stateless groups, such as a custom process group class that carries this information, or passing a flag during initialization.

Comment on lines +307 to +416
if op.op.__name__ == "isend":
self.send(op.tensor, op.group_peer, stream)
elif op.op.__name__ == "irecv":
self.recv(op.tensor, op.group_peer, stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Checking op.op.__name__ to determine the operation type is fragile. The name of a function can change, or it could be wrapped by a decorator, which would break this logic. It's more robust to check for function identity directly.

Suggested change
if op.op.__name__ == "isend":
self.send(op.tensor, op.group_peer, stream)
elif op.op.__name__ == "irecv":
self.recv(op.tensor, op.group_peer, stream)
if op.op is torch.distributed.isend:
self.send(op.tensor, op.group_peer, stream)
elif op.op is torch.distributed.irecv:
self.recv(op.tensor, op.group_peer, stream)

Comment on lines +143 to +148
if ep_group not in _world.pg_map:
ep_group = get_ep_group()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check if ep_group not in _world.pg_map: relies on an internal implementation detail of PyTorch's distributed library (_world.pg_map) to detect stateless process groups. This is not a public API and is subject to change without notice, which makes this code brittle. A more robust approach, such as using a custom process group class or an explicit flag, should be used to differentiate between stateful and stateless groups.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass review

self.available_gpu_memory_for_kv_cache = -1

if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
self._elastic_scale_up_post_init()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens as part of init, rather than after init, maybe rename?

waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
block = not self.process_input_queue_non_block
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: call it process_input_queue_block and there is no need to negate

Comment on lines +584 to +585
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
return self.elastic_scaling_executor.execute(execute_method, *args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: elastic_scaling_executor -> elastic_ep_executor to be consistent with elastic_ep_execute?

timeout=timeout,
)

if isinstance(group_name, str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if group_name: ?

Comment on lines +69 to +70
self.new_dp_group = (
self.engine_core.dp_group if worker_type == "new" else new_parallel_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is new_parallel_config assigned to self.new_dp_group?

return False

elif state == ScaleUpExistingWorkerState.CREATE_STANDBY_GROUPS:
# NOTE(yongji): wait for all exisiting workers to receive the request
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# NOTE(yongji): wait for all exisiting workers to receive the request
# NOTE(yongji): wait for all existing workers to receive the request

notification_type == "NEW_WORKERS_INIT_READY"
and self.state == ScaleUpExistingWorkerState.WAIT_NEW_WORKERS_INIT
):
self.waiting_for_notification = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this? Can we make it a property of self.state to simplify the logic here?

logger = init_logger(__name__)


class StatelessGroupCoordinator(GroupCoordinator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

TRANSFER_EXPERT_MAPPING = 2
WAIT_NEW_WORKERS_WEIGHTS_INIT = 3
TRANSFER_WEIGHTS = 4
SYNC_KV_CACHE_MEMORY = 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:SYNC_KV_CACHE_SIZE?

)
logger.info("[Elastic EP] Broadcasted expert mapping to new workers")

def _sync_kv_cache_memory(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_sync_kv_cache_size?

Signed-off-by: Yongji Wu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants