diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 20a40d74f311..263eec777a84 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -580,6 +580,13 @@ def schedule(self) -> SchedulerOutput: batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch) + self._update_after_schedule(scheduler_output) + return scheduler_output + + def _update_after_schedule( + self, + scheduler_output: SchedulerOutput, + ) -> None: # Advance the number of computed tokens for the request AFTER # the request is scheduled. # 1. The scheduler_output of the current step has to include the @@ -589,11 +596,15 @@ def schedule(self) -> SchedulerOutput: # scheduling step. # 3. If some tokens (e.g. spec tokens) are rejected later, the number of # computed tokens will be adjusted in update_from_output. + num_scheduled_tokens = scheduler_output.num_scheduled_tokens for req_id, num_scheduled_token in num_scheduled_tokens.items(): - self.requests[req_id].num_computed_tokens += num_scheduled_token + request = self.requests[req_id] + request.num_computed_tokens += num_scheduled_token + # Clear the finished request IDs. + # NOTE: We shouldn't do self.finished_req_ids.clear() here because + # it will also affect the scheduler output. self.finished_req_ids = set() - return scheduler_output def _make_cached_request_data( self, @@ -763,19 +774,10 @@ def update_from_output( num_draft_tokens=len(scheduled_spec_token_ids), num_accepted_tokens=len(generated_token_ids) - 1) - cached_encoder_input_ids = ( - self.encoder_cache_manager.get_cached_input_ids(request)) - # OPTIMIZATION: Avoid list(set) if the set is empty. - if cached_encoder_input_ids: - for input_id in list(cached_encoder_input_ids): - mm_positions = request.mm_positions[input_id] - start_pos = mm_positions.offset - num_tokens = mm_positions.length - if start_pos + num_tokens <= request.num_computed_tokens: - # The encoder output is already processed and stored - # in the decoder's KV cache. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + # NOTE(woosuk): This has to be executed after updating + # `request.num_computed_tokens`. + if request.has_encoder_inputs: + self._free_encoder_inputs(request) stopped = False new_logprobs = None @@ -891,6 +893,25 @@ def update_from_output( return engine_core_outputs + def _free_encoder_inputs(self, request: Request) -> None: + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if not cached_encoder_input_ids: + return + + # Here, we use list(set) to avoid modifying the set while iterating + # over it. + for input_id in list(cached_encoder_input_ids): + mm_positions = request.mm_positions[input_id] + start_pos = mm_positions.offset + num_tokens = mm_positions.length + if start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" return len(self.running), len(self.waiting)