Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/v1/engine/test_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
"""Set up VllmRunner instance."""
monkeypatch.setenv("VLLM_USE_V1", "1")
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
return vllm_runner(
MODEL,
dtype=DTYPE,
Expand Down
22 changes: 12 additions & 10 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,10 @@ def get_open_zmq_ipc_path() -> str:
return f"ipc://{base_rpc_path}/{uuid4()}"


def get_open_zmq_inproc_path() -> str:
return f"inproc://{uuid4()}"


def get_open_port() -> int:
"""
Get an open port for the vLLM process to listen on.
Expand Down Expand Up @@ -2108,12 +2112,12 @@ def get_exception_traceback():
def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str,
type: Any,
socket_type: Any,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""

mem = psutil.virtual_memory()
socket = ctx.socket(type)
socket = ctx.socket(socket_type)

# Calculate buffer size based on system memory
total_mem = mem.total / 1024**3
Expand All @@ -2127,29 +2131,27 @@ def make_zmq_socket(
else:
buf_size = -1 # Use system default buffer size

if type == zmq.constants.PULL:
if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.connect(path)
elif type == zmq.constants.PUSH:
elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")
raise ValueError(f"Unknown Socket Type: {socket_type}")

return socket


@contextlib.contextmanager
def zmq_socket_ctx(
path: str,
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""

ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined]
ctx = zmq.Context() # type: ignore[attr-defined]
try:
yield make_zmq_socket(ctx, path, type)
yield make_zmq_socket(ctx, path, socket_type)

except KeyboardInterrupt:
logger.debug("Got Keyboard Interrupt.")
Expand Down
84 changes: 56 additions & 28 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket)
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path,
kill_process_tree, make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
Expand Down Expand Up @@ -202,10 +202,11 @@ class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object."""

ctx: Union[zmq.Context, zmq.asyncio.Context] = None
ctx: Union[zmq.Context] = None
output_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None
input_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None
proc_handle: Optional[BackgroundProcHandle] = None
shutdown_path: Optional[str] = None

def __call__(self):
"""Clean up background resources."""
Expand All @@ -218,8 +219,13 @@ def __call__(self):
self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.ctx is not None:
self.ctx.destroy(linger=0)
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
shutdown_sender.connect(self.shutdown_path)
# Send shutdown signal.
shutdown_sender.send(b'')


class MPClient(EngineCoreClient):
Expand Down Expand Up @@ -261,28 +267,23 @@ def sigusr1_handler(signum, frame):
self.decoder = MsgpackDecoder(EngineCoreOutputs)

# ZMQ setup.
self.ctx = (
zmq.asyncio.Context() # type: ignore[attr-defined]
if asyncio_mode else zmq.Context()) # type: ignore[attr-defined]
sync_ctx = zmq.Context()
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx

# This will ensure resources created so far are closed
# when the client is garbage collected, even if an
# exception is raised mid-construction.
resources = BackgroundResources(ctx=self.ctx)
self._finalizer = weakref.finalize(self, resources)
self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources)

# Paths and sockets for IPC.
output_path = get_open_zmq_ipc_path()
# Paths for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
resources.output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL)
resources.input_socket = make_zmq_socket(self.ctx, input_path,
zmq.constants.PUSH)

# Start EngineCore in background process.
resources.proc_handle = BackgroundProcHandle(
self.resources.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
output_path=self.output_path,
process_name="EngineCore",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
Expand All @@ -291,8 +292,10 @@ def sigusr1_handler(signum, frame):
"log_stats": log_stats,
})

self.output_socket = resources.output_socket
self.input_socket = resources.input_socket
# Create input socket.
self.resources.input_socket = make_zmq_socket(self.ctx, input_path,
zmq.constants.PUSH)
self.input_socket = self.resources.input_socket
self.utility_results: dict[int, AnyFuture] = {}

def shutdown(self):
Expand Down Expand Up @@ -325,27 +328,48 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],

# Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc.
output_socket = self.output_socket
ctx = self.ctx
output_path = self.output_path
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue

shutdown_path = get_open_zmq_inproc_path()
self.resources.shutdown_path = shutdown_path

def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR)
shutdown_socket.bind(shutdown_path)
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
try:
poller = zmq.Poller()
poller.register(shutdown_socket)
poller.register(out_socket)
while True:
(frame, ) = output_socket.recv_multipart(copy=False)
socks = poller.poll()
if not socks:
continue
if len(socks) == 2 or socks[0][0] == shutdown_socket:
# shutdown signal, exit thread.
break

(frame, ) = out_socket.recv_multipart(copy=False)
outputs = decoder.decode(frame.buffer)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
else:
outputs_queue.put_nowait(outputs)
except zmq.error.ContextTerminated:
# Expected when the class is GC'd / during process termination.
pass
finally:
# Close sockets.
shutdown_socket.close(linger=0)
out_socket.close(linger=0)

# Process outputs from engine in separate thread.
Thread(target=process_outputs_socket, daemon=True).start()
self.output_queue_thread = Thread(target=process_outputs_socket,
name="EngineCoreOutputQueueThread",
daemon=True)
self.output_queue_thread.start()

def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get()
Expand Down Expand Up @@ -424,10 +448,13 @@ async def _start_output_queue_task(self):
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue()
output_socket = self.output_socket
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
output_path = self.output_path
output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL)
self.resources.output_socket = output_socket

async def process_outputs_socket():
while True:
Expand All @@ -439,7 +466,8 @@ async def process_outputs_socket():
else:
outputs_queue.put_nowait(outputs)

self.queue_task = asyncio.create_task(process_outputs_socket())
self.queue_task = asyncio.create_task(process_outputs_socket(),
name="EngineCoreOutputQueueTask")

async def get_output_async(self) -> EngineCoreOutputs:
if self.outputs_queue is None:
Expand Down