Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def create_connector_v1(
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
assert issubclass(connector_cls, KVConnectorBase_V1)
logger.info("Creating v1 connector with name: %s", connector_name)
logger.info("Creating v1 connector with name: %s and engine_id: %s",
connector_name, kv_transfer_config.engine_id)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
Expand Down
12 changes: 7 additions & 5 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
get_tp_group, get_world_group)
from vllm.logger import init_logger
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -334,6 +334,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.engine_id = engine_id
self.rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.world_rank = get_world_group().rank_in_group
self.tp_group = get_tp_group()

# KV Caches and nixl tracking data.
Expand Down Expand Up @@ -382,7 +383,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):

@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, rank: int):
ready_event: threading.Event,
world_rank: int):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach like an ETCD server in the future.
Expand All @@ -403,7 +405,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
# NOTE(rob): we need each rank to have a unique port. This
# hack to keeps us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
path = make_zmq_path("tcp", host, port)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
Expand All @@ -422,7 +424,7 @@ def _nixl_handshake(self, host: str, port: int):
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
path = make_zmq_path("tcp", host, port + self.rank)
path = make_zmq_path("tcp", host, port + self.world_rank)
logger.debug("Querying metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:
# Send query for the request.
Expand Down Expand Up @@ -529,7 +531,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.rank),
args=(metadata, ready_event, self.world_rank),
daemon=True,
name="nixl_handshake_listener")
self._nixl_handshake_listener_t.start()
Expand Down
9 changes: 9 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,15 @@ def _init_data_parallel(self, vllm_config: VllmConfig):
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size

if vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
logger.debug("Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id)

from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var
world_size = vllm_config.parallel_config.world_size
Expand Down