Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/distributed/test_pipeline_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,27 @@ def _verify(partition_str, num_layers, pp_size, goldens):
# Wrong number of layers
with pytest.raises(ValueError):
_verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])


@pytest.mark.parametrize(
"num_hidden_layers,pp_size,pp_rank,indices",
[
# pp_size 2
(2, 2, 0, (0, 1)),
(2, 2, 1, (1, 2)),
(3, 2, 0, (0, 2)),
(3, 2, 1, (2, 3)),
# pp_size 3
(3, 3, 0, (0, 1)),
(3, 3, 1, (1, 2)),
(3, 3, 2, (2, 3)),
(4, 3, 0, (0, 1)),
(4, 3, 1, (1, 3)),
(4, 3, 2, (3, 4)),
(5, 3, 0, (0, 2)),
(5, 3, 1, (2, 4)),
(5, 3, 2, (4, 5)),
])
def test_uneven_auto_partition(num_hidden_layers: int, pp_size: int,
pp_rank: int, indices: tuple[int, int]):
assert indices == get_pp_indices(num_hidden_layers, pp_rank, pp_size)
30 changes: 22 additions & 8 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,17 @@ def split_tensor_along_last_dim(
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]:
"""Try to evenly distribute layers across partitions.

If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers.
the remaining layers are evenly distributed across all but the last
partition. The last partition is excluded because it often contains an
additional norm layer and we are attempting to balance compute.

If `pp_size > 2` and the number of remaining layers is
`0 < x <= pp_size - 2` then the remaining layers are evenly distributed
across the middle partitions. The first and last partitions are excluded
because they contain the input and output embeddings respectively and we
are attempting to reduce maximum memory consumption across partitions.
"""
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
if partition_list_str is not None:
Expand All @@ -84,15 +93,20 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
else:
layers_per_partition = num_hidden_layers // pp_size
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition

if pp_rank == pp_size - 1:
end_layer = num_hidden_layers
partitions = [layers_per_partition for _ in range(pp_size)]

if remaining_layers := num_hidden_layers % pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
logger.info("Hidden layers were unevenly partitioned: %s",
",".join(str(p) for p in partitions))
logger.info("This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION environment variable")

start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]

return (start_layer, end_layer)

Expand Down