Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,20 +210,11 @@ def schedule(self) -> SchedulerOutput:
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]

num_new_tokens = (request.num_tokens_with_spec +
request.num_output_placeholders -
request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)

# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - 1 - request.num_computed_tokens)
num_new_tokens = self._calculate_num_new_tokens(
request,
request.num_computed_tokens,
token_budget,
is_running_request=True)

# Schedule encoder inputs.
encoder_inputs_to_schedule = None
Expand Down Expand Up @@ -420,11 +411,11 @@ def schedule(self) -> SchedulerOutput:
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = self._calculate_num_new_tokens(
request,
num_computed_tokens,
token_budget,
is_running_request=False)

# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
Expand All @@ -433,8 +424,6 @@ def schedule(self) -> SchedulerOutput:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue

num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0

# Schedule encoder inputs.
Expand Down Expand Up @@ -1021,6 +1010,38 @@ def update_from_output(

return engine_core_outputs

def _calculate_num_new_tokens(
self,
request: Request,
num_computed_tokens: int,
token_budget: int,
is_running_request: bool = False,
) -> int:
# Initial calculation differs between RUNNING and WAITING requests
if is_running_request:
# RUNNING requests: include speculative decoding and output
# placeholders
num_new_tokens = (request.num_tokens_with_spec +
request.num_output_placeholders -
num_computed_tokens)
else:
# WAITING requests: simple calculation
num_new_tokens = request.num_tokens - num_computed_tokens

if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)

num_new_tokens = min(num_new_tokens, token_budget)

# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(num_new_tokens,
self.max_model_len - 1 - num_computed_tokens)

return num_new_tokens

def _update_request_with_output(
self,
request: Request,
Expand Down