Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,13 @@ def schedule(self) -> SchedulerOutput:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)

self._update_after_schedule(scheduler_output)
return scheduler_output

def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
) -> None:
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
Expand All @@ -589,11 +596,15 @@ def schedule(self) -> SchedulerOutput:
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items():
self.requests[req_id].num_computed_tokens += num_scheduled_token
request = self.requests[req_id]
request.num_computed_tokens += num_scheduled_token

# Clear the finished request IDs.
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
# it will also affect the scheduler output.
self.finished_req_ids = set()
return scheduler_output

def _make_cached_request_data(
self,
Expand Down Expand Up @@ -763,19 +774,10 @@ def update_from_output(
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)

cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty.
if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions.offset
num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input(
request, input_id)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
if request.has_encoder_inputs:
self._free_encoder_inputs(request)

stopped = False
new_logprobs = None
Expand Down Expand Up @@ -891,6 +893,25 @@ def update_from_output(

return engine_core_outputs

def _free_encoder_inputs(self, request: Request) -> None:
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty.
if not cached_encoder_input_ids:
return

# Here, we use list(set) to avoid modifying the set while iterating
# over it.
for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions.offset
num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input(
request, input_id)

def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting)
Expand Down