From 509f488ffb90934c4ce0621b67b4a832f3987329 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 00:13:34 +0000 Subject: [PATCH 01/13] updated --- vllm/v1/engine/async_llm.py | 151 ++++++++++++++------------------- vllm/v1/engine/async_stream.py | 55 ------------ vllm/v1/engine/detokenizer.py | 4 + 3 files changed, 70 insertions(+), 140 deletions(-) delete mode 100644 vllm/v1/engine/async_stream.py diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index cfdbea8004c3..77d45c0b272c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -9,14 +9,13 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.outputs import RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.v1.engine.async_stream import AsyncStream from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor @@ -54,10 +53,8 @@ def __init__( lora_config=vllm_config.lora_config) self.tokenizer.ping() - # Request streams (map of request_id -> AsyncStream). - self.request_streams: Dict[str, AsyncStream] = {} - # List of cancelled request ids to be aborted. - self.client_aborted_requests: List[str] = [] + # Request streams (map of request_id -> queue). + self.rid_to_queues: Dict[str, asyncio.Queue] = {} # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( @@ -153,14 +150,14 @@ async def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + ) -> asyncio.Queue[RequestOutput]: """Add new request to the AsyncLLM.""" if self.detokenizer.is_request_active(request_id): raise ValueError(f"Request {request_id} already exists.") - # 1) Create a new AsyncStream for the request. - stream = self._add_request_to_streams(request_id) + # 1) Create a new output queue for the request. + q = self._add_request_to_queues(request_id) # 2) Convert input --> DetokenizerRequest / EngineCoreRequest. detokenizer_req, engine_core_req = self.processor.process_inputs( @@ -173,8 +170,7 @@ async def add_request( # 4) Add the EngineCoreRequest to EngineCore (separate process). await self.engine_core.add_request_async(engine_core_req) - # 5) Return the generator. - return stream.generator() + return q # TODO: we should support multiple prompts in one call, as you # can do with LLM.generate. So that for multi-prompt completion @@ -194,7 +190,7 @@ async def generate( """ Main function called by the API server to kick off a request * 1) Making an AsyncStream corresponding to the Request. - # 2) Processing the Input. + * 2) Processing the Input. * 3) Adding the Request to the Detokenizer. * 4) Adding the Request to the EngineCore (separate process). @@ -206,14 +202,15 @@ async def generate( returning the RequestOutput back to the caller. """ - # We start the output_handler on the first call to generate() so that - # we can call __init__ before the event loop starts, which enables us - # to handle startup failure gracefully in the OpenAI server. - if self.output_handler is None: - self.output_handler = asyncio.create_task( - self._run_output_handler()) - - async for output in await self.add_request( + try: + # We start the output_handler on the first call to generate() so + # we can call __init__ before the event loop, which enables us + # to handle startup failure gracefully in the OpenAI server. + if self.output_handler is None: + self.output_handler = asyncio.create_task( + self._run_output_handler()) + + q = await self.add_request( request_id, prompt, sampling_params, @@ -221,79 +218,58 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, - ): - yield output - - def _finish_stream(self, request_id: str): - stream = self.request_streams.pop(request_id, None) - if stream is not None: - stream.finish() + ) - def _add_request_to_streams( + # The output_handler task pushes items into the queue. + # This task pulls from the queue and yields to caller. + while True: + # Note: drain queue without await if possible (avoids + # task switching under load which helps performance). + out = q.get_nowait() if q.qsize() > 0 else await q.get() + + # Note: both Detokenizer and EngineCore handle their + # own request cleanup based on finished. + if out.finished: + del self.rid_to_queue[request_id] + yield out + break + + yield out + + # If the request is disconnected by the client, the + # generate() task will be canceled. So, we abort the + # request if we end up here. + except asyncio.CancelledError: + await self.abort(request_id) + raise + + def _add_request_to_queues( self, request_id: str, - ) -> AsyncStream: + ) -> asyncio.Queue[RequestOutput]: - if request_id in self.request_streams: + if request_id in self.rid_to_queues: raise ValueError(f"Request id {request_id} already running.") - # Avoid streams having circular ref to parent AsyncLLM object. - aborted_reqs = self.client_aborted_requests - stream = AsyncStream(request_id, aborted_reqs.append) - self.request_streams[request_id] = stream + self.rid_to_queues[request_id] = asyncio.Queue() if self.log_requests: logger.info("Added request %s.", request_id) - return stream - - async def _process_cancellations(self) -> None: - """ - Process requests cancelled from user disconnecting. - - When a client disconnects, AsyncStream._cancel() is called. - We passed a callback to AsyncStream(), which appends to - self.client_aborted_requests. - - As a result, if any requests are canceled from the user side - the request_id will show up in self.client_aborted_requests. - """ - - # Avoid streams having circular ref to parent AsyncLLM object. - if not self.client_aborted_requests: - return - reqs_to_abort = self.client_aborted_requests.copy() - self.client_aborted_requests.clear() - - # Remove from Detokenizer. - self.detokenizer.abort_requests(reqs_to_abort) - - # Remove from RequestStreams. - for request_id in reqs_to_abort: - if self.log_requests: - logger.info("User-cancelled request %s.", request_id) - self._finish_stream(request_id) - - # Remove from EngineCore. - await self.engine_core.abort_requests_async(reqs_to_abort) + return self.rid_to_queues[request_id] def _process_request_outputs(self, request_outputs: List[RequestOutput]): - """Process outputs by putting them into per-request AsyncStreams.""" + """Process outputs by putting them into per-request queues.""" for request_output in request_outputs: request_id = request_output.request_id - assert request_id in self.request_streams - # Each request in the API server pulls from the per-request stream. - stream = self.request_streams.get(request_id) - if stream is not None: - stream.put(request_output) - - # If finished, remove from the tracker. - if request_output.finished: - if self.log_requests: - logger.info("Finished request %s.", request_id) - self._finish_stream(request_id) + # Note: it is possible a request was aborted and removed from + # the state due to client cancellations, so if we encounter a + # request id not in the state, we skip. + if request_id in self.rid_to_queues: + q = self.rid_to_queues[request_id] + q.put_nowait(request_output) async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" @@ -306,24 +282,29 @@ async def _run_output_handler(self): # 2) Detokenize based on the output. request_outputs, reqs_to_abort = self.detokenizer.step(outputs) - # 3) Put the RequestOutputs into the per-request AsyncStreams. + # 3) Put the RequestOutputs into the per-request queues. self._process_request_outputs(request_outputs) # 4) Abort any requests that finished due to stop strings. await self.engine_core.abort_requests_async(reqs_to_abort) - # 5) Abort any requests due to client cancellations. - await self._process_cancellations() - except BaseException as e: logger.error(e) raise e - # TODO: can we eliminate these? - async def abort(self, request_id: str) -> None: - # Note: Who Calls this? I dont think this is actually used. - raise ValueError("Not Supported on V1 yet.") + """Abort RequestId in self, detokenizer, and engine core.""" + + request_ids = [request_id] + await self.engine_core.abort_requests_async(request_ids) + self.detokenizer.abort_requests(request_ids) + + # If a request is finished while we await above, + # then it is possible that the request is already + # removed from the queues, so we do nothing if the + # request_id is no longer in the tracked queues. + if request_id in self.rid_to_queues: + del self.rid_to_queues[request_id] def encode( self, diff --git a/vllm/v1/engine/async_stream.py b/vllm/v1/engine/async_stream.py deleted file mode 100644 index 35449238c325..000000000000 --- a/vllm/v1/engine/async_stream.py +++ /dev/null @@ -1,55 +0,0 @@ -import asyncio -from typing import Any, AsyncGenerator, Callable, Optional, Type, Union - -from vllm.outputs import PoolingRequestOutput, RequestOutput - - -class AsyncStream: - """A stream of RequestOutputs or PoolingRequestOutputs for a request - that can be iterated over asynchronously via an async generator.""" - - STOP_ITERATION = Exception() # Sentinel - - def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: - self.request_id = request_id - self._cancel = cancel - self._queue: asyncio.Queue = asyncio.Queue() - self._finished = False - - def put(self, item: Union[RequestOutput, PoolingRequestOutput, - Exception]) -> None: - if not self._finished: - self._queue.put_nowait(item) - - def finish( - self, - exception: Optional[Union[BaseException, Type[BaseException]]] = None, - ) -> None: - if not self._finished: - self._finished = True - self._queue.put_nowait(exception if self._is_raisable(exception) - else AsyncStream.STOP_ITERATION) - - async def generator( - self - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: - finished = False - try: - while True: - result = await self._queue.get() - if self._is_raisable(result): - finished = True - if result == AsyncStream.STOP_ITERATION: - return - raise result - yield result - finally: - self._finished = True - if not finished: - self._cancel(self.request_id) - - @staticmethod - def _is_raisable(value: Any): - return isinstance(value, BaseException) or \ - (isinstance(value, type) and \ - issubclass(value, BaseException)) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 02f34e2b54dd..39f189266c64 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -223,6 +223,10 @@ def abort_requests( """Remove the request_ids from the Detokenizer.""" for request_id in request_ids: + # Note: it is possible that a request is "finished" + # in process of an abort call by AsyncLLM. So we + # simply do nothing if a request id is not in the + # active request states. self.request_states.pop(request_id, None) def add_request( From 1789162d014f2133a459071d86dc5ba61d023708 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 00:26:30 +0000 Subject: [PATCH 02/13] updated --- vllm/v1/engine/async_llm.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 77d45c0b272c..d8b3d707846c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -54,7 +54,7 @@ def __init__( self.tokenizer.ping() # Request streams (map of request_id -> queue). - self.rid_to_queues: Dict[str, asyncio.Queue] = {} + self.rid_to_queue: Dict[str, asyncio.Queue] = {} # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( @@ -248,15 +248,15 @@ def _add_request_to_queues( request_id: str, ) -> asyncio.Queue[RequestOutput]: - if request_id in self.rid_to_queues: + if request_id in self.rid_to_queue: raise ValueError(f"Request id {request_id} already running.") - self.rid_to_queues[request_id] = asyncio.Queue() + self.rid_to_queue[request_id] = asyncio.Queue() if self.log_requests: logger.info("Added request %s.", request_id) - return self.rid_to_queues[request_id] + return self.rid_to_queue[request_id] def _process_request_outputs(self, request_outputs: List[RequestOutput]): """Process outputs by putting them into per-request queues.""" @@ -267,9 +267,8 @@ def _process_request_outputs(self, request_outputs: List[RequestOutput]): # Note: it is possible a request was aborted and removed from # the state due to client cancellations, so if we encounter a # request id not in the state, we skip. - if request_id in self.rid_to_queues: - q = self.rid_to_queues[request_id] - q.put_nowait(request_output) + if request_id in self.rid_to_queue: + self.rid_to_queue[request_id].put_nowait(request_output) async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" @@ -303,8 +302,8 @@ async def abort(self, request_id: str) -> None: # then it is possible that the request is already # removed from the queues, so we do nothing if the # request_id is no longer in the tracked queues. - if request_id in self.rid_to_queues: - del self.rid_to_queues[request_id] + if request_id in self.rid_to_queue: + del self.rid_to_queue[request_id] def encode( self, From 92398142ebef4a420db50eb9f07327d5a1956391 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 00:31:42 +0000 Subject: [PATCH 03/13] reduce logging time --- vllm/v1/engine/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 497d5db5b4c9..5bbf600921b5 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -32,7 +32,7 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -LOGGING_TIME_S = 5000 +LOGGING_TIME_S = 5 class EngineCore: From 83acac64eea11b4b651fb2a6ac1f2ce44f86ef08 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 00:34:03 +0000 Subject: [PATCH 04/13] cleanup --- vllm/v1/engine/async_llm.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index d8b3d707846c..d4af2b484987 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -153,11 +153,10 @@ async def add_request( ) -> asyncio.Queue[RequestOutput]: """Add new request to the AsyncLLM.""" - if self.detokenizer.is_request_active(request_id): - raise ValueError(f"Request {request_id} already exists.") - # 1) Create a new output queue for the request. - q = self._add_request_to_queues(request_id) + if request_id in self.rid_to_queue: + raise ValueError(f"Request id {request_id} already running.") + self.rid_to_queue[request_id] = asyncio.Queue() # 2) Convert input --> DetokenizerRequest / EngineCoreRequest. detokenizer_req, engine_core_req = self.processor.process_inputs( @@ -170,7 +169,10 @@ async def add_request( # 4) Add the EngineCoreRequest to EngineCore (separate process). await self.engine_core.add_request_async(engine_core_req) - return q + if self.log_requests: + logger.info("Added request %s.", request_id) + + return self.rid_to_queue[request_id] # TODO: we should support multiple prompts in one call, as you # can do with LLM.generate. So that for multi-prompt completion @@ -243,21 +245,6 @@ async def generate( await self.abort(request_id) raise - def _add_request_to_queues( - self, - request_id: str, - ) -> asyncio.Queue[RequestOutput]: - - if request_id in self.rid_to_queue: - raise ValueError(f"Request id {request_id} already running.") - - self.rid_to_queue[request_id] = asyncio.Queue() - - if self.log_requests: - logger.info("Added request %s.", request_id) - - return self.rid_to_queue[request_id] - def _process_request_outputs(self, request_outputs: List[RequestOutput]): """Process outputs by putting them into per-request queues.""" From aefeb8499754785ef818694404063b2cd9f6a90e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 00:38:58 +0000 Subject: [PATCH 05/13] update comment --- vllm/v1/engine/async_llm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index d4af2b484987..ba2b8377759d 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -285,10 +285,8 @@ async def abort(self, request_id: str) -> None: await self.engine_core.abort_requests_async(request_ids) self.detokenizer.abort_requests(request_ids) - # If a request is finished while we await above, - # then it is possible that the request is already - # removed from the queues, so we do nothing if the - # request_id is no longer in the tracked queues. + # If a request finishes while we await then the request_id + # will be removed from the tracked queues before we get here. if request_id in self.rid_to_queue: del self.rid_to_queue[request_id] From 8fb01d122dcf776ce8b3f23301b2056845004a18 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 00:39:42 +0000 Subject: [PATCH 06/13] clean --- vllm/v1/engine/detokenizer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 39f189266c64..02f34e2b54dd 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -223,10 +223,6 @@ def abort_requests( """Remove the request_ids from the Detokenizer.""" for request_id in request_ids: - # Note: it is possible that a request is "finished" - # in process of an abort call by AsyncLLM. So we - # simply do nothing if a request id is not in the - # active request states. self.request_states.pop(request_id, None) def add_request( From aeefcf26d601f5e3e5037d53010d467c5a84da54 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 00:46:03 +0000 Subject: [PATCH 07/13] stash --- vllm/v1/engine/async_llm.py | 1 + vllm/v1/engine/core.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ba2b8377759d..b719dbe92ea8 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -242,6 +242,7 @@ async def generate( # generate() task will be canceled. So, we abort the # request if we end up here. except asyncio.CancelledError: + print("calling abort") await self.abort(request_id) raise diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 5bbf600921b5..2bec2e2c2bba 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -32,7 +32,7 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -LOGGING_TIME_S = 5 +LOGGING_TIME_S = 0.5 class EngineCore: From 1749da1210a2bb60cdfaf07675d8aac75326d62e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 01:17:09 +0000 Subject: [PATCH 08/13] updated --- vllm/v1/engine/async_llm.py | 1 - vllm/v1/engine/core.py | 7 ++++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b719dbe92ea8..ba2b8377759d 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -242,7 +242,6 @@ async def generate( # generate() task will be canceled. So, we abort the # request if we end up here. except asyncio.CancelledError: - print("calling abort") await self.abort(request_id) raise diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 2bec2e2c2bba..d3f43c2cfecf 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -30,9 +30,10 @@ logger = init_logger(__name__) -POLLING_TIMEOUT_MS = 5000 -POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -LOGGING_TIME_S = 0.5 +POLLING_TIMEOUT_MS = 1000 +# POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 +POLLING_TIMEOUT_S = 1.0 +LOGGING_TIME_S = 1.0 class EngineCore: From e0ddb0525eb79fc65e4e4a9586bc0fec048b851b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 01:18:01 +0000 Subject: [PATCH 09/13] updated --- vllm/v1/engine/core.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d3f43c2cfecf..0aef61fc7f68 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -30,10 +30,9 @@ logger = init_logger(__name__) -POLLING_TIMEOUT_MS = 1000 -# POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -POLLING_TIMEOUT_S = 1.0 -LOGGING_TIME_S = 1.0 +POLLING_TIMEOUT_MS = 5000 +POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 +LOGGING_TIME_S = POLLING_TIMEOUT_S class EngineCore: From 60ef7aab05cb2b8736f1476141258e8eebbc0438 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 02:28:41 +0000 Subject: [PATCH 10/13] updated --- tests/models/registry.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index 819ef957a07f..f5a37420a290 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -61,6 +61,8 @@ class _HfExamplesInfo: "DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"), "DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501 trust_remote_code=True), + "DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501 + trust_remote_code=True), "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), From 92b05d7247bd4c0c511e1e03b5a756769f7c0856 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 03:38:11 +0000 Subject: [PATCH 11/13] remove DetokenizerRequest --- tests/v1/engine/test_detokenizer.py | 46 +++++++++++++++-------------- vllm/v1/engine/__init__.py | 16 +--------- vllm/v1/engine/async_llm.py | 14 +++++---- vllm/v1/engine/detokenizer.py | 21 ++++++------- vllm/v1/engine/llm_engine.py | 12 ++++---- vllm/v1/engine/processor.py | 23 +++------------ 6 files changed, 55 insertions(+), 77 deletions(-) diff --git a/tests/v1/engine/test_detokenizer.py b/tests/v1/engine/test_detokenizer.py index 07f343666cb5..e0f5b9ee171a 100644 --- a/tests/v1/engine/test_detokenizer.py +++ b/tests/v1/engine/test_detokenizer.py @@ -3,9 +3,9 @@ import pytest from transformers import AutoTokenizer -from vllm.sampling_params import RequestOutputKind -from vllm.v1.engine import EngineCoreOutput -from vllm.v1.engine.detokenizer import Detokenizer, DetokenizerRequest +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest +from vllm.v1.engine.detokenizer import Detokenizer TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) @@ -71,16 +71,17 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): # Make N requests. requests = [ - DetokenizerRequest( - request_id=f"request-{idx}", - prompt=prompt, - prompt_token_ids=prompt_tokens, - skip_special_tokens=False, - spaces_between_special_tokens=False, - output_kind=request_output_kind, - stop=[], - include_stop_str_in_output=False, - ) for idx, ( + EngineCoreRequest(request_id=f"request-{idx}", + prompt=prompt, + prompt_token_ids=prompt_tokens, + sampling_params=SamplingParams( + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=request_output_kind, + stop=[], + include_stop_str_in_output=False, + )) + for idx, ( prompt, prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) ] @@ -133,18 +134,19 @@ def test_stop_string(include_stop_str_in_output: bool): # Make N requests. requests = [ - DetokenizerRequest( + EngineCoreRequest( request_id=f"request-{idx}", prompt=prompt, prompt_token_ids=prompt_tokens, - skip_special_tokens=False, - spaces_between_special_tokens=False, - output_kind=RequestOutputKind.DELTA, - stop=STOP_STRINGS, - include_stop_str_in_output=include_stop_str_in_output, - ) for idx, ( - prompt, - prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + sampling_params=SamplingParams( + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=RequestOutputKind.DELTA, + stop=STOP_STRINGS, + include_stop_str_in_output=include_stop_str_in_output, + )) for idx, ( + prompt, + prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) ] # Add requests to the detokenizer. diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index cc0c7ea23469..f70464fc8829 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -6,21 +6,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind, SamplingParams - - -@dataclass -class DetokenizerRequest: - - request_id: str - prompt: Optional[str] - prompt_token_ids: List[int] - skip_special_tokens: bool - spaces_between_special_tokens: bool - output_kind: RequestOutputKind - - stop: List[str] - include_stop_str_in_output: bool +from vllm.sampling_params import SamplingParams @dataclass diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ba2b8377759d..5b75402bd62d 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -158,16 +158,18 @@ async def add_request( raise ValueError(f"Request id {request_id} already running.") self.rid_to_queue[request_id] = asyncio.Queue() - # 2) Convert input --> DetokenizerRequest / EngineCoreRequest. - detokenizer_req, engine_core_req = self.processor.process_inputs( - request_id, prompt, params, arrival_time, lora_request, - trace_headers, prompt_adapter_request, priority) + # 2) Convert Input --> Request. + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) # 3) Add the request to Detokenizer (this process). - self.detokenizer.add_request(detokenizer_req) + self.detokenizer.add_request(request) # 4) Add the EngineCoreRequest to EngineCore (separate process). - await self.engine_core.add_request_async(engine_core_req) + await self.engine_core.add_request_async(request) if self.log_requests: logger.info("Added request %s.", request_id) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 02f34e2b54dd..65be9e58e03c 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -8,7 +8,7 @@ from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest logger = init_logger(__name__) @@ -55,19 +55,19 @@ def output_token_ids(self) -> List[int]: def from_new_request( cls, tokenizer: AnyTokenizer, - request: DetokenizerRequest, + request: EngineCoreRequest, ) -> "IncrementalDetokenizer": tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( tokenizer=tokenizer, prompt_ids=request.prompt_token_ids, - skip_special_tokens=request.skip_special_tokens, + skip_special_tokens=request.sampling_params.skip_special_tokens, ) - stops = request.stop + stops = request.sampling_params.stop # Number of chars to hold back when stop strings are to be excluded # from streamed output. - if stops and not request.include_stop_str_in_output: + if stops and not request.sampling_params.include_stop_str_in_output: stop_buffer_length = max(len(s) for s in stops) - 1 else: stop_buffer_length = 0 @@ -79,13 +79,14 @@ def from_new_request( # NOTE(Nick): could we take ownership of it though? token_ids=request.prompt_token_ids.copy(), stop=stops, - include_stop_str_in_output=request.include_stop_str_in_output, + include_stop_str_in_output=request.sampling_params. + include_stop_str_in_output, prefix_offset=prefix_offset, read_offset=read_offset, - skip_special_tokens=request.skip_special_tokens, - spaces_between_special_tokens=request. + skip_special_tokens=request.sampling_params.skip_special_tokens, + spaces_between_special_tokens=request.sampling_params. spaces_between_special_tokens, - output_kind=request.output_kind, + output_kind=request.sampling_params.output_kind, request_id=request.request_id, prompt=request.prompt, prompt_token_ids=request.prompt_token_ids, @@ -227,7 +228,7 @@ def abort_requests( def add_request( self, - request: DetokenizerRequest, + request: EngineCoreRequest, ): """Add new request to the Detokenizer.""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index b58f62778ffe..adb00f036b7e 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -152,15 +152,17 @@ def add_request( ) -> None: # 1) Process raw inputs into the request. - detokenizer_req, engine_core_req = self.processor.process_inputs( - request_id, prompt, params, arrival_time, lora_request, - trace_headers, prompt_adapter_request, priority) + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) # 2) Add the request to Detokenizer. - self.detokenizer.add_request(detokenizer_req) + self.detokenizer.add_request(request) # 3) Add the request to EngineCore. - self.engine_core.add_request(engine_core_req) + self.engine_core.add_request(request) def step(self) -> List[RequestOutput]: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6ee8732bc902..5b5a5a61cea7 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,5 +1,5 @@ import time -from typing import Mapping, Optional, Tuple, Union +from typing import Mapping, Optional, Union from vllm.config import CacheConfig, LoRAConfig, ModelConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, @@ -13,7 +13,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest +from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient @@ -62,7 +62,7 @@ def process_inputs( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> Tuple[DetokenizerRequest, EngineCoreRequest]: + ) -> EngineCoreRequest: # TODO(woosuk): Support pooling models. # TODO(woosuk): Check max_logprobs @@ -123,20 +123,7 @@ def process_inputs( decoder_inputs.multi_modal_data, mm_hashes, decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs) - # Make Request for Detokenizer. - detokenizer_request = DetokenizerRequest( - request_id, - decoder_inputs.prompt, - decoder_inputs.prompt_token_ids, - sampling_params.skip_special_tokens, - sampling_params.spaces_between_special_tokens, - sampling_params.output_kind, - sampling_params.stop, - sampling_params.include_stop_str_in_output, - ) - - # Make Request for EngineCore. - engine_core_request = EngineCoreRequest( + return EngineCoreRequest( request_id, decoder_inputs.prompt, decoder_inputs.prompt_token_ids, @@ -149,8 +136,6 @@ def process_inputs( lora_request, ) - return detokenizer_request, engine_core_request - def _validate_model_inputs(self, inputs: ProcessorInputs): if is_encoder_decoder_inputs(inputs): # For encoder-decoder multimodal models, the max_prompt_len From b055925336bb7a371c618430de1a782417e48807 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 28 Dec 2024 15:03:24 +0000 Subject: [PATCH 12/13] fix tests Signed-off-by: rshaw@neuralmagic.com --- tests/v1/engine/test_detokenizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/v1/engine/test_detokenizer.py b/tests/v1/engine/test_detokenizer.py index e0f5b9ee171a..bba6e31c202a 100644 --- a/tests/v1/engine/test_detokenizer.py +++ b/tests/v1/engine/test_detokenizer.py @@ -74,13 +74,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): EngineCoreRequest(request_id=f"request-{idx}", prompt=prompt, prompt_token_ids=prompt_tokens, + arrival_time=0, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, output_kind=request_output_kind, stop=[], - include_stop_str_in_output=False, - )) + include_stop_str_in_output=False)) for idx, ( prompt, prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) @@ -138,6 +138,7 @@ def test_stop_string(include_stop_str_in_output: bool): request_id=f"request-{idx}", prompt=prompt, prompt_token_ids=prompt_tokens, + arrival_time=0, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, From c4ff199132f1b43f8a6cfa71190eb61c73fefb55 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 28 Dec 2024 15:05:11 +0000 Subject: [PATCH 13/13] signed Signed-off-by: rshaw@neuralmagic.com --- tests/v1/engine/test_detokenizer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/v1/engine/test_detokenizer.py b/tests/v1/engine/test_detokenizer.py index bba6e31c202a..aeae697ca32b 100644 --- a/tests/v1/engine/test_detokenizer.py +++ b/tests/v1/engine/test_detokenizer.py @@ -75,6 +75,11 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + eos_token_id=None, + lora_request=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -139,6 +144,11 @@ def test_stop_string(include_stop_str_in_output: bool): prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + eos_token_id=None, + lora_request=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False,