From 5201db75a71ff7834593e5f5f9ab8abcda99516a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 16 Mar 2025 19:07:44 +0000 Subject: [PATCH 1/3] fix torchrun compatibility Signed-off-by: hiyouga --- vllm/config.py | 3 ++- vllm/distributed/parallel_state.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 70cc0affe998..8a65c627e53a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -904,7 +904,8 @@ def get_layers_start_end_indices( else: total_num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", 0) - pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size + pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size + ) % parallel_config.pipeline_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) return start, end diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 86166dd5bb83..7cba572efeed 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -900,7 +900,13 @@ def initialize_model_parallel( from vllm.config import get_current_vllm_config config = get_current_vllm_config() if config is not None: - data_parallel_size = config.parallel_config.data_parallel_size + if (config.parallel_config.distributed_executor_backend == + "external_launcher"): + # do not use config.parallel_config to avoid hanging + data_parallel_size = world_size // (pipeline_model_parallel_size * + tensor_model_parallel_size) + else: + data_parallel_size = config.parallel_config.data_parallel_size # the layout order is: DP x PP x TP # to get group_ranks for each dimension, transpose that dimension to the From c183c36b30045824bc475f6b721c66484b4281e2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 18 Mar 2025 17:51:58 +0800 Subject: [PATCH 2/3] add comment Signed-off-by: youkaichao --- vllm/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config.py b/vllm/config.py index 8a65c627e53a..1cc940c3c921 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -904,6 +904,7 @@ def get_layers_start_end_indices( else: total_num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", 0) + # the layout order is: DP x PP x TP pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size ) % parallel_config.pipeline_parallel_size pp_size = parallel_config.pipeline_parallel_size From 9ac845180770691e308529d39d0e90c5cefbbeae Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 18 Mar 2025 18:02:46 +0800 Subject: [PATCH 3/3] improve Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 7cba572efeed..f897f1950e4c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -897,14 +897,21 @@ def initialize_model_parallel( get_world_group().device_group) data_parallel_size = 1 + has_external_dp = False from vllm.config import get_current_vllm_config config = get_current_vllm_config() if config is not None: - if (config.parallel_config.distributed_executor_backend == - "external_launcher"): - # do not use config.parallel_config to avoid hanging + if config.parallel_config.world_size != world_size: + # detect external data parallelism. + # dp in vllm means all dp instances need to run together. + # if the world size does not match, it means this dp is external, + # and the dp instances can run independently, e.g. in rlhf workflow + # from https://github.com/volcengine/verl . + # in that case, we treat the rest dimensions as if they are + # data parallel, and create a dummy dp group that is not used. data_parallel_size = world_size // (pipeline_model_parallel_size * tensor_model_parallel_size) + has_external_dp = True else: data_parallel_size = config.parallel_config.data_parallel_size @@ -946,6 +953,12 @@ def initialize_model_parallel( 2).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] + if has_external_dp: + # create a dummy dp group that is not used actually, + # since this dp is external. + # a dummy dp group means every rank is a group itself. + # this way, no communication is needed, no memory is wasted. + group_ranks = [[x] for x in range(world_size)] _DP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend,