diff --git a/vllm/config.py b/vllm/config.py index 5fb9563fcf3a..b858ca567547 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -991,7 +991,7 @@ def __post_init__(self) -> None: raise ValueError(f"worker-use-ray can't be used with " f"distributed executor backend " f"'{self.distributed_executor_backend}'.") - ray_only_devices = ["tpu", "hpu"] + ray_only_devices = ["tpu"] if (current_platform.device_type in ray_only_devices and self.world_size > 1): if self.distributed_executor_backend is None: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 60dccd7a0812..48e65cd9acfb 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -645,6 +645,10 @@ def _get_executor_cls( from vllm.executor.cpu_executor import CPUExecutorAsync executor_class = CPUExecutorAsync elif engine_config.device_config.device_type == "hpu": + if distributed_executor_backend == "mp": + from vllm.executor.multiproc_hpu_executor import ( + MultiprocessingHPUExecutorAsync) + executor_class = MultiprocessingHPUExecutorAsync if distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6eca304b45f0..ee58866b1b68 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -477,7 +477,11 @@ def _get_executor_cls(cls, from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor elif engine_config.device_config.device_type == "hpu": - if distributed_executor_backend == "ray": + if distributed_executor_backend == "mp": + from vllm.executor.multiproc_hpu_executor import ( + MultiprocessingHPUExecutor) + executor_class = MultiprocessingHPUExecutor + elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_hpu_executor import RayHPUExecutor executor_class = RayHPUExecutor diff --git a/vllm/executor/multiproc_hpu_executor.py b/vllm/executor/multiproc_hpu_executor.py new file mode 100644 index 000000000000..fdedbbe2b3b4 --- /dev/null +++ b/vllm/executor/multiproc_hpu_executor.py @@ -0,0 +1,51 @@ +from typing import Callable, Optional, Tuple, Type + +import habana_frameworks.torch # noqa: F401 +import torch + +from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync) +from vllm.logger import init_logger +from vllm.utils import make_async +from vllm.worker.worker_base import WorkerBase + +logger = init_logger(__name__) + + +class MultiprocessingHPUExecutor(MultiprocessingGPUExecutor): + """Python multiprocessing-based multi-HPU executor""" + + def _get_worker_module_and_class( + self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: + worker_class_fn = None + if self.speculative_config is not None: + module_name = "vllm.spec_decode.spec_decode_worker" + class_name = "create_spec_worker" + else: + module_name = "vllm.worker.hpu_worker" + class_name = "HPUWorker" + return (module_name, class_name, worker_class_fn) + + def _check_executor_parameters(self): + world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + hpu_device_count = torch.hpu.device_count() + assert tensor_parallel_size <= hpu_device_count, ( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local hpu count ({hpu_device_count})") + + assert world_size <= hpu_device_count, ( + f"please ensure that world_size ({world_size}) " + f"is less than than max local hpu count ({hpu_device_count})") + + def __del__(self): + self.shutdown() + + +class MultiprocessingHPUExecutorAsync(MultiprocessingHPUExecutor, + MultiprocessingGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index fe475db6d3f5..d187643392d5 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -15,6 +15,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.triton_utils.importing import HAS_TRITON from vllm.utils import cuda_is_initialized @@ -291,6 +292,22 @@ def set_multiprocessing_worker_envs(parallel_config): "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.") os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + if (current_platform.is_hpu() + and parallel_config.distributed_executor_backend == 'mp' + and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'): + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) is not None: + logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork might " + "cause application hangs on exit. Using " + "VLLM_WORKER_MULTIPROC_METHOD=fork anyway, " + "as it was explicitly requested.") + else: + logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork might " + "cause application hangs on exit. Setting " + "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " + "To override that behavior, please set " + "VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + # Configure thread parallelism if OMP_NUM_THREADS isn't set # # Helps to avoid CPU contention. The default of spawning a thread per