diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 06b3983ed68b..dce0b545c188 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -70,7 +70,8 @@ def create_connector_v1( connector_module = importlib.import_module(connector_module_path) connector_cls = getattr(connector_module, connector_name) assert issubclass(connector_cls, KVConnectorBase_V1) - logger.info("Creating v1 connector with name: %s", connector_name) + logger.info("Creating v1 connector with name: %s and engine_id: %s", + connector_name, kv_transfer_config.engine_id) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. # Scheduler connector: # - Co-locate with scheduler process diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6303d77ad305..f02434aeb5ca 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -19,7 +19,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group) + get_tp_group, get_world_group) from vllm.logger import init_logger from vllm.utils import make_zmq_path, make_zmq_socket, round_down from vllm.v1.core.sched.output import SchedulerOutput @@ -334,6 +334,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.engine_id = engine_id self.rank = get_tensor_model_parallel_rank() self.world_size = get_tensor_model_parallel_world_size() + self.world_rank = get_world_group().rank_in_group self.tp_group = get_tp_group() # KV Caches and nixl tracking data. @@ -382,7 +383,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, - ready_event: threading.Event, rank: int): + ready_event: threading.Event, + world_rank: int): """Background thread for getting new NIXL handshakes.""" # NOTE(rob): this is a simple implementation. We will move # to a better approach like an ETCD server in the future. @@ -403,7 +405,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, # NOTE(rob): we need each rank to have a unique port. This # hack to keeps us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank + port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank path = make_zmq_path("tcp", host, port) logger.debug("Starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: @@ -422,7 +424,7 @@ def _nixl_handshake(self, host: str, port: int): # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - path = make_zmq_path("tcp", host, port + self.rank) + path = make_zmq_path("tcp", host, port + self.world_rank) logger.debug("Querying metadata on path: %s", path) with zmq_ctx(zmq.REQ, path) as sock: # Send query for the request. @@ -529,7 +531,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, - args=(metadata, ready_event, self.rank), + args=(metadata, ready_event, self.world_rank), daemon=True, name="nixl_handshake_listener") self._nixl_handshake_listener_t.start() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 740ba60fe231..1a392bdf435c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -700,6 +700,15 @@ def _init_data_parallel(self, vllm_config: VllmConfig): assert dp_size > 1 assert 0 <= local_dp_rank <= dp_rank < dp_size + if vllm_config.kv_transfer_config is not None: + # modify the engine_id and append the local_dp_rank to it to ensure + # that the kv_transfer_config is unique for each DP rank. + vllm_config.kv_transfer_config.engine_id = ( + f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}" + ) + logger.debug("Setting kv_transfer_config.engine_id to %s", + vllm_config.kv_transfer_config.engine_id) + from vllm.platforms import current_platform device_control_env_var = current_platform.device_control_env_var world_size = vllm_config.parallel_config.world_size