diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7fc4776b0261..9b7486fb660a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -210,20 +210,11 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) - num_new_tokens = min(num_new_tokens, token_budget) - - # Make sure the input position does not exceed the max model len. - # This is necessary when using spec decoding. - num_new_tokens = min( - num_new_tokens, - self.max_model_len - 1 - request.num_computed_tokens) + num_new_tokens = self._calculate_num_new_tokens( + request, + request.num_computed_tokens, + token_budget, + is_running_request=True) # Schedule encoder inputs. encoder_inputs_to_schedule = None @@ -420,11 +411,11 @@ def schedule(self) -> SchedulerOutput: # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. - num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = self._calculate_num_new_tokens( + request, + num_computed_tokens, + token_budget, + is_running_request=False) # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked @@ -433,8 +424,6 @@ def schedule(self) -> SchedulerOutput: self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue - - num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 # Schedule encoder inputs. @@ -1021,6 +1010,38 @@ def update_from_output( return engine_core_outputs + def _calculate_num_new_tokens( + self, + request: Request, + num_computed_tokens: int, + token_budget: int, + is_running_request: bool = False, + ) -> int: + # Initial calculation differs between RUNNING and WAITING requests + if is_running_request: + # RUNNING requests: include speculative decoding and output + # placeholders + num_new_tokens = (request.num_tokens_with_spec + + request.num_output_placeholders - + num_computed_tokens) + else: + # WAITING requests: simple calculation + num_new_tokens = request.num_tokens - num_computed_tokens + + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min(num_new_tokens, + self.max_model_len - 1 - num_computed_tokens) + + return num_new_tokens + def _update_request_with_output( self, request: Request,