Skip to content

[BugFix] pp cannot run successfully under NixlConnector #22976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
get_pipeline_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tp_group)
from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
Expand Down Expand Up @@ -238,7 +238,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
vllm_config.parallel_config.tensor_parallel_size *
vllm_config.parallel_config.pipeline_parallel_size)
self.use_host_buffer = \
vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
logger.info("Initializing NIXL Scheduler %s", engine_id)
Expand Down Expand Up @@ -439,11 +440,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# NIXL handshake port.
# NOTE(rob): Within a DP group, each DP rank gets its own
# base port (which is sent in the KVTransferParams).
# Each TP rank listens/queries on the base_port + tp_rank.
# Each TP rank listens/queries on the base_port +
# pp_rank * tp_size + tp_rank.
self.pp_rank = get_pipeline_model_parallel_rank()
self.side_channel_port: int = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
vllm_config.parallel_config.tensor_parallel_size *
vllm_config.parallel_config.pipeline_parallel_size +
self.pp_rank * vllm_config.parallel_config.tensor_parallel_size)

# Metadata.
self.engine_id: EngineId = engine_id
Expand Down Expand Up @@ -545,6 +550,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)

self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self.device_id = torch.cuda.current_device()
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
Expand Down Expand Up @@ -780,7 +786,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
region_len = self.num_blocks * self.block_len
# NOTE: use tp_rank for device_id since multi-node TP
# is rarely used.
caches_data.append((base_addr, region_len, self.tp_rank, ""))
caches_data.append((base_addr, region_len, self.device_id, ""))
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
self.num_regions = len(caches_data)
Expand Down Expand Up @@ -826,12 +832,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
addr = base_addr + block_offset
# (addr, len, device id)
# TODO: does device_id matter to DRAM?
blocks_data.append((addr, self.block_len, self.tp_rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank)
blocks_data.append((addr, self.block_len, self.device_id))
logger.debug(
"Created %s blocks for src engine %s , tp rank %s, device id %s ",
len(blocks_data), self.engine_id, self.tp_rank, self.device_id)

descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
self.nixl_memory_type)

# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
Expand Down
5 changes: 5 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,11 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group


def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
return get_pp_group().rank_in_group


def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment. """
assert _NODE_COUNT is not None, (
Expand Down