diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 54c5af2fe366..6e50cc9bc933 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -413,6 +413,7 @@ async def add_request_async( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: ... @@ -426,6 +427,7 @@ async def add_request_async( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: ... @@ -442,6 +444,7 @@ async def add_request_async( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: @@ -453,6 +456,9 @@ async def add_request_async( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") if arrival_time is None: arrival_time = time.time() @@ -472,6 +478,7 @@ async def add_request_async( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, + priority=priority, ) async def check_health_async(self) -> None: @@ -822,6 +829,7 @@ def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ RequestOutput, EmbeddingRequestOutput], None]]: ... @@ -836,6 +844,7 @@ def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ RequestOutput, EmbeddingRequestOutput], None]]: ... @@ -853,6 +862,7 @@ async def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: @@ -870,6 +880,11 @@ async def add_request( "error that caused the background loop to stop " "(AsyncEngineDeadError).") + if (priority != 0 + and not self.engine.scheduler_config.policy == "priority"): + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, @@ -878,7 +893,9 @@ async def add_request( arrival_time=arrival_time or time.time(), lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) return stream.generator() @@ -889,7 +906,8 @@ async def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -906,6 +924,8 @@ async def generate( trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request to use for generation, if any. + priority: The priority of the request. + Only applicable with priority scheduling. Yields: The output `RequestOutput` objects from the LLMEngine @@ -961,6 +981,7 @@ async def generate( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, + priority=priority, ): yield LLMEngine.validate_output(output, RequestOutput) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 487255cb6b59..fd905e5a6255 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -786,7 +786,7 @@ def add_request( raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") - if priority > 0 and not self.scheduler_config.policy == "priority": + if priority != 0 and not self.scheduler_config.policy == "priority": raise ValueError(f"Got priority {priority} but " "Priority scheduling is not enabled.")