Skip to content

Commit 54e084f

Browse files
hiyougayoukaichao
authored andcommitted
[Bugfix] torchrun compatibility (#14899)
Signed-off-by: hiyouga <[email protected]> Signed-off-by: youkaichao <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent 9e8f089 commit 54e084f

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

vllm/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,9 @@ def get_layers_start_end_indices(
904904
else:
905905
total_num_hidden_layers = getattr(self.hf_text_config,
906906
"num_hidden_layers", 0)
907-
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
907+
# the layout order is: DP x PP x TP
908+
pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size
909+
) % parallel_config.pipeline_parallel_size
908910
pp_size = parallel_config.pipeline_parallel_size
909911
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
910912
return start, end

vllm/distributed/parallel_state.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,10 +897,23 @@ def initialize_model_parallel(
897897
get_world_group().device_group)
898898

899899
data_parallel_size = 1
900+
has_external_dp = False
900901
from vllm.config import get_current_vllm_config
901902
config = get_current_vllm_config()
902903
if config is not None:
903-
data_parallel_size = config.parallel_config.data_parallel_size
904+
if config.parallel_config.world_size != world_size:
905+
# detect external data parallelism.
906+
# dp in vllm means all dp instances need to run together.
907+
# if the world size does not match, it means this dp is external,
908+
# and the dp instances can run independently, e.g. in rlhf workflow
909+
# from https://github.com/volcengine/verl .
910+
# in that case, we treat the rest dimensions as if they are
911+
# data parallel, and create a dummy dp group that is not used.
912+
data_parallel_size = world_size // (pipeline_model_parallel_size *
913+
tensor_model_parallel_size)
914+
has_external_dp = True
915+
else:
916+
data_parallel_size = config.parallel_config.data_parallel_size
904917

905918
# the layout order is: DP x PP x TP
906919
# to get group_ranks for each dimension, transpose that dimension to the
@@ -940,6 +953,12 @@ def initialize_model_parallel(
940953
2).reshape(-1,
941954
data_parallel_size).unbind(0)
942955
group_ranks = [x.tolist() for x in group_ranks]
956+
if has_external_dp:
957+
# create a dummy dp group that is not used actually,
958+
# since this dp is external.
959+
# a dummy dp group means every rank is a group itself.
960+
# this way, no communication is needed, no memory is wasted.
961+
group_ranks = [[x] for x in range(world_size)]
943962
_DP = init_model_parallel_group(group_ranks,
944963
get_world_group().local_rank,
945964
backend,

0 commit comments

Comments
 (0)