From e06d14f4ebd832e269a2b10348b1253b95b580c2 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 20 Feb 2025 15:27:47 -0800 Subject: [PATCH 01/12] work Signed-off-by: Cody Yu --- tests/v1/engine/test_engine_core.py | 93 ++++++++++++++++++++++++---- vllm/v1/core/scheduler.py | 95 ++++++++++++++++++----------- vllm/v1/engine/core.py | 8 +-- 3 files changed, 145 insertions(+), 51 deletions(-) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 5fdbcf5b9963..d0ff2a89c26e 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -214,8 +214,10 @@ def test_engine_core_concurrent_batches(monkeypatch): Test that the engine can handle multiple concurrent batches. """ - def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest: + def make_request_with_max_tokens(req_id: int, + max_tokens: int) -> EngineCoreRequest: request = make_request() + request.request_id = req_id request.sampling_params.max_tokens = max_tokens return request @@ -262,6 +264,8 @@ def max_concurrent_batches(self) -> int: # Avoid all requests being scheduled once. enable_prefix_caching=False, max_num_batched_tokens=10, + # Reduce startup time. + enforce_eager=True, ) vllm_config = engine_args.create_engine_config() engine_core = EngineCore(vllm_config=vllm_config, @@ -269,23 +273,86 @@ def max_concurrent_batches(self) -> int: executor_class=DummyExecutor) assert engine_core.batch_queue is not None - # Add two requests in a row. - req = make_request_with_max_tokens(5) - engine_core.add_request(req) - req = make_request_with_max_tokens(5) - engine_core.add_request(req) + # Add two requests in a row. Each request have 12 prompt tokens. + req0 = make_request_with_max_tokens(0, 5) + engine_core.add_request(req0) + req1 = make_request_with_max_tokens(1, 5) + engine_core.add_request(req1) - # First saturate the batch queue. + # Schedule Batch 1: (10, req0) assert engine_core.step_with_batch_queue() is None assert engine_core.batch_queue.qsize() == 1 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[0] == 10 + # num_computed_tokens should have been updated immediately. + assert engine_core.scheduler.requests[ + req0.request_id].num_computed_tokens == 10 + + # Schedule Batch 2: (2, req0), (8, req1) assert engine_core.step_with_batch_queue() is None assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[0] == 2 + assert scheduler_output.num_scheduled_tokens[1] == 8 + # num_computed_tokens should have been updated immediately. + assert engine_core.scheduler.requests[0].num_computed_tokens == 12 + assert engine_core.scheduler.requests[1].num_computed_tokens == 8 + assert engine_core.scheduler.get_num_unfinished_requests() == 2 - # Loop through both requests. - while engine_core.scheduler.get_num_unfinished_requests() == 2: - engine_core.step_with_batch_queue() + # Batch queue is full. Finish Batch 1. + engine_core.step_with_batch_queue() - # Reaching here when got the result of the first request. - while engine_core.scheduler.get_num_unfinished_requests() == 1: - engine_core.step_with_batch_queue() + # Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled + # because it is in the decoding stage now. + engine_core.step_with_batch_queue() + assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[1] == 4 + + # Batch queue is full. Finish Batch 2. Get first token of req0. + output = engine_core.step_with_batch_queue() + assert output is not None + assert len(output.outputs) == 1 + assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 + + # Schedule Batch 4: (1, req0). + engine_core.step_with_batch_queue() + assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[0] == 1 + + # Batch queue is full. Finish Batch 3. Get first token of req1. + engine_core.step_with_batch_queue() + assert output is not None + assert len(output.outputs) == 1 + assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 + + # Schedule Batch 5: (1, req1). + engine_core.step_with_batch_queue() + assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[1] == 1 + + # Loop until req0 is finished. + step = 0 + req_id = 0 + expected_num_tokens = [ + engine_core.scheduler.requests[0].num_tokens + 1, + engine_core.scheduler.requests[1].num_tokens + 1, + ] + while engine_core.scheduler.get_num_unfinished_requests() == 2: + output = engine_core.step_with_batch_queue() + if step % 2 == 0: + # Even steps consumes an output. + assert output is not None + assert len(output.outputs) == 1 + if req_id in engine_core.scheduler.requests: + assert engine_core.scheduler.requests[ + req_id].num_tokens == expected_num_tokens[req_id] + expected_num_tokens[req_id] += 1 + req_id = (req_id + 1) % 2 + else: + # Odd steps schedules a new batch. + assert output is None + step += 1 diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index a7e50f8f40ec..16a32d0d6495 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -3,9 +3,10 @@ from __future__ import annotations import time -from collections import deque from collections.abc import Iterable -from typing import Optional, Union +from collections import defaultdict, deque +from queue import Queue +from typing import Deque, Iterable, List, Optional, Union from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, SpeculativeConfig) @@ -65,11 +66,13 @@ def __init__( # req_id -> Request self.requests: dict[str, Request] = {} # Priority queues for requests. - self.waiting: deque[Request] = deque() - self.running: list[Request] = [] - # The requests that have been scheduled and are being executed - # by the executor. - self.scheduled_req_ids: set[str] = set() + self.waiting: Deque[Request] = deque() + self.running: List[Request] = [] + # req_id -> Number of times the request has been scheduled. + # We can only schedule a request more then once before the previous + # scheduling step is finished when PP is enabled and the request + # prompt is chunked. + self.scheduled_req_ids: dict[str, int] = defaultdict(int) # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished @@ -79,8 +82,9 @@ def __init__( # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - # Request id -> CachedRequestData - self._cached_reqs_data: dict[str, CachedRequestData] = {} + # Request id -> Queue of CachedRequestData + self._cached_reqs_data: dict[ + str, Queue[CachedRequestData]] = defaultdict(Queue) # Encoder-related. # Calculate encoder cache size if applicable @@ -143,8 +147,11 @@ def schedule(self) -> SchedulerOutput: req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - if request.request_id in self.scheduled_req_ids: - # This request has already been scheduled. + if (request.num_computed_tokens >= request.num_tokens + and self.scheduled_req_ids.get(request.request_id, 0) > 0): + # We avoid re-scheduling the decoding requests because + # the number of new decoded output tokens is unknown due + # to speculative decoding or jump decoding. req_index += 1 continue @@ -196,7 +203,7 @@ def schedule(self) -> SchedulerOutput: # Schedule the request. scheduled_running_reqs.append(request) - self.scheduled_req_ids.add(request.request_id) + self.scheduled_req_ids[request.request_id] += 1 if request.use_structured_output: # PERF: in case of chunked prefill, # request might not include any new tokens. @@ -319,7 +326,7 @@ def schedule(self) -> SchedulerOutput: request.request_id] = req_index req_index += 1 self.running.append(request) - self.scheduled_req_ids.add(request.request_id) + self.scheduled_req_ids[request.request_id] += 1 self.request_scheduled(request, scheduled_timestamp) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) @@ -419,6 +426,15 @@ def schedule(self) -> SchedulerOutput: grammar_bitmask=grammar_bitmask, ) + # Update the number of computed tokens for the request right after + # the request is scheduled. This allows the request doing chunk prefill + # to be scheduled again immediately in the next scheduling step. + # If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req in (scheduled_new_reqs + scheduled_resumed_reqs + + scheduled_running_reqs): + req.num_computed_tokens += num_scheduled_tokens[req.request_id] + self.finished_req_ids = set() return scheduler_output @@ -436,18 +452,21 @@ def _make_cached_request_data( num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens new_token_ids = request.all_token_ids[ num_computed_tokens:num_computed_tokens + num_regular_tokens] - req_data = self._cached_reqs_data.get(request.request_id) - if req_data is not None: + + req_data_queue = self._cached_reqs_data.get(request.request_id) + if req_data_queue: + req_data = req_data_queue.get() req_data.resumed_from_preemption = resumed_from_preemption req_data.new_token_ids = new_token_ids req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens else: + # No cached request data, or all cached request data has been + # used by the scheduled requests. req_data = CachedRequestData.from_request(request, resumed_from_preemption, new_token_ids, new_block_ids) - self._cached_reqs_data[request.request_id] = req_data return req_data def _try_schedule_encoder_inputs( @@ -530,6 +549,11 @@ def update_from_output( logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens + num_computed_tokens_at_schedule = { + req_data.req_id: req_data.num_computed_tokens + for req_data in (scheduler_output.scheduled_cached_reqs + + scheduler_output.scheduled_new_reqs) + } new_running: list[Request] = [] outputs: list[EngineCoreOutput] = [] @@ -547,28 +571,24 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] + num_computed_tokens_at_schedule[req_id] += num_tokens_scheduled if req_id not in scheduler_output.scheduled_spec_decode_tokens: # When the request's num_computed_tokens catches up # its num_tokens, the request generates output tokens. # Otherwise, we ignore the sampler output for the request. - request.num_computed_tokens += num_tokens_scheduled - assert request.num_computed_tokens <= request.num_tokens + assert num_computed_tokens_at_schedule[ + req_id] <= request.num_tokens else: - # num_computed_tokens_step represents the number of tokens - # processed in the current step, considering scheduled - # tokens and rejections. - # It is calculated as: - # num_computed_tokens_step = num_scheduled_tokens - - # num_tokens_rejected, - # where num_tokens_rejected is given by: + # request.num_computed_tokens already includes + # num_tokens_scheduled, so we only need to subtract + # num_tokens_rejected, which is given by: # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens[req_id]) - - num_computed_tokens_step = num_scheduled_tokens[req_id] - ( - len(scheduled_spec_token_ids) + 1 - - len(generated_token_ids)) - request.num_computed_tokens += num_computed_tokens_step + num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - + len(generated_token_ids)) + request.num_computed_tokens -= num_tokens_rejected + num_computed_tokens_at_schedule[req_id] -= num_tokens_rejected cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) @@ -577,7 +597,8 @@ def update_from_output( for input_id in list(cached_encoder_input_ids): start_pos = request.mm_positions[input_id]["offset"] num_tokens = request.mm_positions[input_id]["length"] - if start_pos + num_tokens <= request.num_computed_tokens: + if (start_pos + num_tokens + ) <= num_computed_tokens_at_schedule[req_id]: # The encoder output is already processed and stored # in the decoder's KV cache. self.encoder_cache_manager.free_encoder_input( @@ -594,7 +615,7 @@ def update_from_output( new_logprobs = None new_token_ids: list[int] = [] - if request.num_computed_tokens >= request.num_tokens: + if num_computed_tokens_at_schedule[req_id] >= request.num_tokens: for output_token_id in generated_token_ids: request.append_output_token_ids(output_token_id) new_token_ids.append(output_token_id) @@ -635,10 +656,16 @@ def update_from_output( stop_reason=request.stop_reason, events=request.take_events())) - self.scheduled_req_ids.remove(request.request_id) + self.scheduled_req_ids[request.request_id] -= 1 + if self.scheduled_req_ids[request.request_id] == 0: + del self.scheduled_req_ids[request.request_id] if not stopped: new_running.append(request) + # Return the cached request data to the queue so they can be reused. + for req_data in scheduler_output.scheduled_cached_reqs: + self._cached_reqs_data[req_data.req_id].put_nowait(req_data) + self.running = new_running return EngineCoreOutputs( outputs=outputs, @@ -694,7 +721,7 @@ def finish_requests( if request.status == RequestStatus.RUNNING: self.running.remove(request) if request.request_id in self.scheduled_req_ids: - self.scheduled_req_ids.remove(request.request_id) + del self.scheduled_req_ids[request.request_id] else: self.waiting.remove(request) request.status = finished_status diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 1ba557977707..bc2e80ae4f1c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -195,10 +195,10 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: engine_core_outputs = None scheduler_output = None - # If there are unscheduled requests and the job queue - # is not full, schedule a new batch. Note that this is not blocking. - if (self.scheduler.get_num_unscheduled_requests() > 0 - and not self.batch_queue.full()): + # Try to schedule a new batch if the batch queue is not full, but + # the scheduler may return an empty batch if all requests are scheduled. + # Note that this is not blocking. + if not self.batch_queue.full(): scheduler_output = self.scheduler.schedule() if scheduler_output.total_num_scheduled_tokens > 0: future = self.model_executor.execute_model(scheduler_output) From a1d8b5ed78753ac055a4a8b613b127b98071972f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 21 Feb 2025 09:37:41 -0800 Subject: [PATCH 02/12] doc Signed-off-by: Cody Yu --- vllm/v1/engine/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index bc2e80ae4f1c..35dc25a07f76 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -182,10 +182,10 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: Note that if nothing to output in this step, None is returned. The execution flow is as follows: - 1. Try to schedule a new batch if there are unscheduled requests - and the job queue is not full. If a new batch is scheduled, directly - return an empty engine core output. In other words, we won't check - and return model outputs before the batch queue is full. + 1. Try to schedule a new batch if the batch queue is not full. + If a new batch is scheduled, directly return an empty engine core + output. In other words, fulfilling the batch queue has a higher priority + then getting model outputs. 2. If there is no new scheduled batch, meaning that the batch queue is full or no other requests can be scheduled, we block until the first batch in the job queue is finished. From a2766a2d059844458041b688d98157db725aac41 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 27 Feb 2025 12:22:20 -0800 Subject: [PATCH 03/12] refactor Signed-off-by: Cody Yu --- vllm/v1/core/scheduler.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 16a32d0d6495..a9d0f486291d 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -68,10 +68,11 @@ def __init__( # Priority queues for requests. self.waiting: Deque[Request] = deque() self.running: List[Request] = [] + # req_id -> Number of times the request has been scheduled. - # We can only schedule a request more then once before the previous - # scheduling step is finished when PP is enabled and the request - # prompt is chunked. + # With PP, when the input prompt is divided into chunks, we can + # schedule a new chunk even before the previous chunk has completed + # the full pipeline stages. This helps reduce TTFT. self.scheduled_req_ids: dict[str, int] = defaultdict(int) # The request IDs that are finished in between the previous and the @@ -150,8 +151,7 @@ def schedule(self) -> SchedulerOutput: if (request.num_computed_tokens >= request.num_tokens and self.scheduled_req_ids.get(request.request_id, 0) > 0): # We avoid re-scheduling the decoding requests because - # the number of new decoded output tokens is unknown due - # to speculative decoding or jump decoding. + # there is no tokens for decoding requests to be scheduled. req_index += 1 continue @@ -426,14 +426,17 @@ def schedule(self) -> SchedulerOutput: grammar_bitmask=grammar_bitmask, ) - # Update the number of computed tokens for the request right after - # the request is scheduled. This allows the request doing chunk prefill - # to be scheduled again immediately in the next scheduling step. - # If some tokens (e.g. spec tokens) are rejected later, the number of - # computed tokens will be adjusted in update_from_output. - for req in (scheduled_new_reqs + scheduled_resumed_reqs + - scheduled_running_reqs): - req.num_computed_tokens += num_scheduled_tokens[req.request_id] + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the (prefill) request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + self.requests[req_id].num_computed_tokens += num_scheduled_token self.finished_req_ids = set() return scheduler_output @@ -549,6 +552,9 @@ def update_from_output( logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens + + # We cannot use num_computed_tokens from self.requests because + # their values have been advanced when the requests are scheduled. num_computed_tokens_at_schedule = { req_data.req_id: req_data.num_computed_tokens for req_data in (scheduler_output.scheduled_cached_reqs + @@ -598,7 +604,7 @@ def update_from_output( start_pos = request.mm_positions[input_id]["offset"] num_tokens = request.mm_positions[input_id]["length"] if (start_pos + num_tokens - ) <= num_computed_tokens_at_schedule[req_id]: + <= num_computed_tokens_at_schedule[req_id]): # The encoder output is already processed and stored # in the decoder's KV cache. self.encoder_cache_manager.free_encoder_input( From d9565e8e4d1f08e8b50787216a2b066c24676ef3 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 11 Mar 2025 15:23:53 -0700 Subject: [PATCH 04/12] lint Signed-off-by: Cody Yu --- vllm/v1/core/scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index a9d0f486291d..bb460f7a6fb3 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -3,10 +3,10 @@ from __future__ import annotations import time -from collections.abc import Iterable from collections import defaultdict, deque +from collections.abc import Iterable from queue import Queue -from typing import Deque, Iterable, List, Optional, Union +from typing import Optional, Union from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, SpeculativeConfig) @@ -66,8 +66,8 @@ def __init__( # req_id -> Request self.requests: dict[str, Request] = {} # Priority queues for requests. - self.waiting: Deque[Request] = deque() - self.running: List[Request] = [] + self.waiting: deque[Request] = deque() + self.running: list[Request] = [] # req_id -> Number of times the request has been scheduled. # With PP, when the input prompt is divided into chunks, we can From bedc1f8849e355163568b931455efe262dc273a8 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 11 Mar 2025 15:52:08 -0700 Subject: [PATCH 05/12] refactor Signed-off-by: Cody Yu --- vllm/v1/core/scheduler.py | 57 ++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index bb460f7a6fb3..d8c75db913a7 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -5,7 +5,6 @@ import time from collections import defaultdict, deque from collections.abc import Iterable -from queue import Queue from typing import Optional, Union from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, @@ -73,7 +72,8 @@ def __init__( # With PP, when the input prompt is divided into chunks, we can # schedule a new chunk even before the previous chunk has completed # the full pipeline stages. This helps reduce TTFT. - self.scheduled_req_ids: dict[str, int] = defaultdict(int) + self.scheduled_req_ids_to_orig_computed_tokens: dict[ + str, deque[int]] = defaultdict(deque) # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished @@ -85,7 +85,7 @@ def __init__( # them at each scheduling step. # Request id -> Queue of CachedRequestData self._cached_reqs_data: dict[ - str, Queue[CachedRequestData]] = defaultdict(Queue) + str, deque[CachedRequestData]] = defaultdict(deque) # Encoder-related. # Calculate encoder cache size if applicable @@ -149,7 +149,8 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] if (request.num_computed_tokens >= request.num_tokens - and self.scheduled_req_ids.get(request.request_id, 0) > 0): + and self.scheduled_req_ids_to_orig_computed_tokens.get( + request.request_id, deque())): # We avoid re-scheduling the decoding requests because # there is no tokens for decoding requests to be scheduled. req_index += 1 @@ -203,7 +204,8 @@ def schedule(self) -> SchedulerOutput: # Schedule the request. scheduled_running_reqs.append(request) - self.scheduled_req_ids[request.request_id] += 1 + self.scheduled_req_ids_to_orig_computed_tokens[ + request.request_id].append(request.num_computed_tokens) if request.use_structured_output: # PERF: in case of chunked prefill, # request might not include any new tokens. @@ -326,7 +328,8 @@ def schedule(self) -> SchedulerOutput: request.request_id] = req_index req_index += 1 self.running.append(request) - self.scheduled_req_ids[request.request_id] += 1 + self.scheduled_req_ids_to_orig_computed_tokens[ + request.request_id].append(request.num_computed_tokens) self.request_scheduled(request, scheduled_timestamp) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) @@ -458,7 +461,7 @@ def _make_cached_request_data( req_data_queue = self._cached_reqs_data.get(request.request_id) if req_data_queue: - req_data = req_data_queue.get() + req_data = req_data_queue.popleft() req_data.resumed_from_preemption = resumed_from_preemption req_data.new_token_ids = new_token_ids req_data.new_block_ids = new_block_ids @@ -553,14 +556,6 @@ def update_from_output( prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens - # We cannot use num_computed_tokens from self.requests because - # their values have been advanced when the requests are scheduled. - num_computed_tokens_at_schedule = { - req_data.req_id: req_data.num_computed_tokens - for req_data in (scheduler_output.scheduled_cached_reqs + - scheduler_output.scheduled_new_reqs) - } - new_running: list[Request] = [] outputs: list[EngineCoreOutput] = [] @@ -575,15 +570,20 @@ def update_from_output( new_running.append(request) continue + num_computed_tokens_with_output = ( + self.scheduled_req_ids_to_orig_computed_tokens[req_id].popleft( + )) + if not self.scheduled_req_ids_to_orig_computed_tokens[req_id]: + del self.scheduled_req_ids_to_orig_computed_tokens[req_id] + req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] - num_computed_tokens_at_schedule[req_id] += num_tokens_scheduled + num_computed_tokens_with_output += num_tokens_scheduled if req_id not in scheduler_output.scheduled_spec_decode_tokens: # When the request's num_computed_tokens catches up # its num_tokens, the request generates output tokens. # Otherwise, we ignore the sampler output for the request. - assert num_computed_tokens_at_schedule[ - req_id] <= request.num_tokens + assert num_computed_tokens_with_output <= request.num_tokens else: # request.num_computed_tokens already includes # num_tokens_scheduled, so we only need to subtract @@ -594,7 +594,7 @@ def update_from_output( num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - len(generated_token_ids)) request.num_computed_tokens -= num_tokens_rejected - num_computed_tokens_at_schedule[req_id] -= num_tokens_rejected + num_computed_tokens_with_output -= num_tokens_rejected cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) @@ -604,7 +604,7 @@ def update_from_output( start_pos = request.mm_positions[input_id]["offset"] num_tokens = request.mm_positions[input_id]["length"] if (start_pos + num_tokens - <= num_computed_tokens_at_schedule[req_id]): + <= num_computed_tokens_with_output): # The encoder output is already processed and stored # in the decoder's KV cache. self.encoder_cache_manager.free_encoder_input( @@ -621,7 +621,7 @@ def update_from_output( new_logprobs = None new_token_ids: list[int] = [] - if num_computed_tokens_at_schedule[req_id] >= request.num_tokens: + if num_computed_tokens_with_output >= request.num_tokens: for output_token_id in generated_token_ids: request.append_output_token_ids(output_token_id) new_token_ids.append(output_token_id) @@ -662,15 +662,12 @@ def update_from_output( stop_reason=request.stop_reason, events=request.take_events())) - self.scheduled_req_ids[request.request_id] -= 1 - if self.scheduled_req_ids[request.request_id] == 0: - del self.scheduled_req_ids[request.request_id] if not stopped: new_running.append(request) # Return the cached request data to the queue so they can be reused. for req_data in scheduler_output.scheduled_cached_reqs: - self._cached_reqs_data[req_data.req_id].put_nowait(req_data) + self._cached_reqs_data[req_data.req_id].append(req_data) self.running = new_running return EngineCoreOutputs( @@ -726,8 +723,10 @@ def finish_requests( if request.status == RequestStatus.RUNNING: self.running.remove(request) - if request.request_id in self.scheduled_req_ids: - del self.scheduled_req_ids[request.request_id] + if (request.request_id + in self.scheduled_req_ids_to_orig_computed_tokens): + del self.scheduled_req_ids_to_orig_computed_tokens[ + request.request_id] else: self.waiting.remove(request) request.status = finished_status @@ -756,10 +755,6 @@ def has_requests(self): not yet returned in SchedulerOutputs.""" return self.has_unfinished_requests() or self.has_finished_requests() - def get_num_unscheduled_requests(self) -> int: - """Number of requests that are not being processed by the executor.""" - return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) - def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() From f8c2cd2f67b8aa7b0fa7bf91c960955a7a2d957a Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 11 Mar 2025 16:13:58 -0700 Subject: [PATCH 06/12] comment Signed-off-by: Cody Yu --- vllm/v1/core/scheduler.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index d8c75db913a7..dd38bde09f60 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -68,10 +68,13 @@ def __init__( self.waiting: deque[Request] = deque() self.running: list[Request] = [] - # req_id -> Number of times the request has been scheduled. - # With PP, when the input prompt is divided into chunks, we can - # schedule a new chunk even before the previous chunk has completed - # the full pipeline stages. This helps reduce TTFT. + # req_id -> a queue of computed tokens when the request is scheduled. + # With PP, when an input prompt is split into chunks, we can schedule + # a new chunk even before the previous chunk has completed the full + # pipeline stages. This helps reduce TTFT. + # In this case, the deque will have multiple elements with the + # computed tokens before each chunk was scheduled. This is used by + # update_from_output() to determine the request status. self.scheduled_req_ids_to_orig_computed_tokens: dict[ str, deque[int]] = defaultdict(deque) @@ -83,7 +86,7 @@ def __init__( # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - # Request id -> Queue of CachedRequestData + # Request id -> deque of CachedRequestData self._cached_reqs_data: dict[ str, deque[CachedRequestData]] = defaultdict(deque) From 95785360f787ee36e304d0ae98c43ab582822675 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 14 Mar 2025 10:08:06 -0700 Subject: [PATCH 07/12] test Signed-off-by: Cody Yu --- tests/v1/core/test_scheduler.py | 13 +++++++--- vllm/v1/core/scheduler.py | 45 ++++++++++++++------------------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 9413373390fe..2c063fcead91 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from collections import deque from typing import Optional import pytest @@ -272,7 +273,8 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids.add(req.request_id) + scheduler.orig_num_computed_tokens[req.request_id] = deque( + [req.num_tokens]) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -324,7 +326,8 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids.add(req.request_id) + scheduler.orig_num_computed_tokens[req.request_id] = deque( + [req.num_tokens]) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -374,7 +377,8 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids.add(req.request_id) + scheduler.orig_num_computed_tokens[req.request_id] = deque( + [req.num_tokens]) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -424,7 +428,8 @@ def test_stop_via_update_from_output(): requests[0].num_computed_tokens = requests[0].num_tokens scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) - scheduler.scheduled_req_ids.add(requests[0].request_id) + scheduler.orig_num_computed_tokens[requests[0].request_id] = deque( + [requests[0].num_tokens]) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index fc163c22e471..1c328c5e0fa0 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -75,8 +75,8 @@ def __init__( # In this case, the deque will have multiple elements with the # computed tokens before each chunk was scheduled. This is used by # update_from_output() to determine the request status. - self.scheduled_req_ids_to_orig_computed_tokens: dict[ - str, deque[int]] = defaultdict(deque) + self.orig_num_computed_tokens: dict[str, + deque[int]] = defaultdict(deque) # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished @@ -151,18 +151,9 @@ def schedule(self) -> SchedulerOutput: req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - if (request.num_computed_tokens >= request.num_tokens - and self.scheduled_req_ids_to_orig_computed_tokens.get( - request.request_id, deque())): - # We avoid re-scheduling the decoding requests because - # there is no tokens for decoding requests to be scheduled. - req_index += 1 - continue - num_new_tokens = (request.num_tokens_with_spec - request.num_computed_tokens) num_new_tokens = min(num_new_tokens, token_budget) - assert num_new_tokens > 0 # Schedule encoder inputs. encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( @@ -171,8 +162,13 @@ def schedule(self) -> SchedulerOutput: num_new_tokens, encoder_budget)) if num_new_tokens == 0: - # The request cannot be scheduled because the encoder budget - # or the encoder cache is exhausted. + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when PP>1 and + # we have already scheduled all prompt tokens but they are + # not finished yet. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. # NOTE(woosuk): Here, by doing `continue` instead of `break`, # we do not strictly follow the FCFS scheduling policy and # allow the lower-priority requests to be scheduled. @@ -209,8 +205,8 @@ def schedule(self) -> SchedulerOutput: # Schedule the request. scheduled_running_reqs.append(request) - self.scheduled_req_ids_to_orig_computed_tokens[ - request.request_id].append(request.num_computed_tokens) + self.orig_num_computed_tokens[request.request_id].append( + request.num_computed_tokens) if request.use_structured_output: # PERF: in case of chunked prefill, # request might not include any new tokens. @@ -333,8 +329,8 @@ def schedule(self) -> SchedulerOutput: request.request_id] = req_index req_index += 1 self.running.append(request) - self.scheduled_req_ids_to_orig_computed_tokens[ - request.request_id].append(request.num_computed_tokens) + self.orig_num_computed_tokens[request.request_id].append( + request.num_computed_tokens) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) @@ -505,7 +501,7 @@ def _try_schedule_encoder_inputs( limitations, the method adjusts `num_new_tokens` to schedule only the decoder tokens up to just before the unschedulable encoder input. """ - if not request.has_encoder_inputs(): + if not request.has_encoder_inputs() or num_new_tokens == 0: return [], num_new_tokens, encoder_budget encoder_inputs_to_schedule: list[int] = [] @@ -578,10 +574,9 @@ def update_from_output( continue num_computed_tokens_with_output = ( - self.scheduled_req_ids_to_orig_computed_tokens[req_id].popleft( - )) - if not self.scheduled_req_ids_to_orig_computed_tokens[req_id]: - del self.scheduled_req_ids_to_orig_computed_tokens[req_id] + self.orig_num_computed_tokens[req_id].popleft()) + if not self.orig_num_computed_tokens[req_id]: + del self.orig_num_computed_tokens[req_id] req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] @@ -730,10 +725,8 @@ def finish_requests( if request.status == RequestStatus.RUNNING: self.running.remove(request) - if (request.request_id - in self.scheduled_req_ids_to_orig_computed_tokens): - del self.scheduled_req_ids_to_orig_computed_tokens[ - request.request_id] + if request.request_id in self.orig_num_computed_tokens: + del self.orig_num_computed_tokens[request.request_id] else: self.waiting.remove(request) request.status = finished_status From 72c792d8b18c79fa2ea2ba802e9470952336b6d7 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 14 Mar 2025 14:14:59 -0700 Subject: [PATCH 08/12] comment Signed-off-by: Cody Yu --- vllm/v1/engine/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 996dcdcf7b78..df4e673a9a4a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -203,7 +203,7 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: 1. Try to schedule a new batch if the batch queue is not full. If a new batch is scheduled, directly return an empty engine core output. In other words, fulfilling the batch queue has a higher priority - then getting model outputs. + than getting model outputs. 2. If there is no new scheduled batch, meaning that the batch queue is full or no other requests can be scheduled, we block until the first batch in the job queue is finished. @@ -228,6 +228,10 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: # If no more requests can be scheduled and the job queue is not empty, # block until the first batch in the job queue is finished. + # TODO(comaniac): Ideally we should peek the first batch in the + # job queue to check if it's finished before scheduling a new batch, + # but peeking the first element in a queue is not thread-safe, + # so we need more work. if not scheduled_batch and not self.batch_queue.empty(): future, scheduler_output = self.batch_queue.get_nowait() # Blocking until the first result is available. From 55f26b45cfef41326cf53bc113bdcd429330f5b0 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 14 Mar 2025 16:40:22 -0700 Subject: [PATCH 09/12] fix prefix caching Signed-off-by: Cody Yu --- vllm/v1/core/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 1c328c5e0fa0..0ef8d6a904ad 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -329,8 +329,6 @@ def schedule(self) -> SchedulerOutput: request.request_id] = req_index req_index += 1 self.running.append(request) - self.orig_num_computed_tokens[request.request_id].append( - request.num_computed_tokens) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) @@ -351,6 +349,8 @@ def schedule(self) -> SchedulerOutput: token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens + self.orig_num_computed_tokens[request.request_id].append( + request.num_computed_tokens) # Encoder-related. if encoder_inputs_to_schedule: From f14f928c91611592e8c505ace7d1a824520df932 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 27 Mar 2025 10:09:10 -0700 Subject: [PATCH 10/12] fix Signed-off-by: Cody Yu --- tests/v1/core/test_scheduler.py | 9 --------- vllm/v1/core/sched/interface.py | 5 ----- vllm/v1/core/sched/scheduler.py | 35 ++++++++++++--------------------- 3 files changed, 13 insertions(+), 36 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index afbd251635f7..4e057baa6369 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -from collections import deque from typing import Optional import pytest @@ -358,8 +357,6 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.orig_num_computed_tokens[req.request_id] = deque( - [req.num_tokens]) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -411,8 +408,6 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.orig_num_computed_tokens[req.request_id] = deque( - [req.num_tokens]) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -462,8 +457,6 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.orig_num_computed_tokens[req.request_id] = deque( - [req.num_tokens]) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -513,8 +506,6 @@ def test_stop_via_update_from_output(): requests[0].num_computed_tokens = requests[0].num_tokens scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) - scheduler.orig_num_computed_tokens[requests[0].request_id] = deque( - [requests[0].num_tokens]) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index bfed44f9d58c..1de236d42f02 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -117,11 +117,6 @@ def has_requests(self) -> bool: not yet returned in SchedulerOutputs.""" return self.has_unfinished_requests() or self.has_finished_requests() - @abstractmethod - def get_num_unscheduled_requests(self) -> int: - """Number of requests that are not being processed by the executor.""" - raise NotImplementedError - @abstractmethod def reset_prefix_cache(self) -> bool: """Reset the prefix cache for KV cache. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index fbc9782de79c..a11467c02e62 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -71,15 +71,11 @@ def __init__( self.waiting: deque[Request] = deque() self.running: list[Request] = [] - # req_id -> a queue of computed tokens when the request is scheduled. - # With PP, when an input prompt is split into chunks, we can schedule - # a new chunk even before the previous chunk has completed the full - # pipeline stages. This helps reduce TTFT. - # In this case, the deque will have multiple elements with the - # computed tokens before each chunk was scheduled. This is used by - # update_from_output() to determine the request status. - self.orig_num_computed_tokens: dict[str, - deque[int]] = defaultdict(deque) + # req_id -> Number of times the request has been scheduled. + # With PP, when the input prompt is divided into chunks, we can + # schedule a new chunk even before the previous chunk has completed + # the full pipeline stages. This helps reduce TTFT. + self.scheduled_req_ids: dict[str, int] = defaultdict(int) # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished @@ -212,8 +208,7 @@ def schedule(self) -> SchedulerOutput: # Schedule the request. scheduled_running_reqs.append(request) - self.orig_num_computed_tokens[request.request_id].append( - request.num_computed_tokens) + self.scheduled_req_ids[request.request_id] += 1 if request.use_structured_output: # PERF: in case of chunked prefill, # request might not include any new tokens. @@ -339,6 +334,7 @@ def schedule(self) -> SchedulerOutput: request.request_id] = req_index req_index += 1 self.running.append(request) + self.scheduled_req_ids[request.request_id] += 1 if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) @@ -359,8 +355,6 @@ def schedule(self) -> SchedulerOutput: token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - self.orig_num_computed_tokens[request.request_id].append( - request.num_computed_tokens) # Encoder-related. if encoder_inputs_to_schedule: @@ -583,11 +577,6 @@ def update_from_output( new_running.append(request) continue - num_computed_tokens_with_output = ( - self.orig_num_computed_tokens[req_id].popleft()) - if not self.orig_num_computed_tokens[req_id]: - del self.orig_num_computed_tokens[req_id] - req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] @@ -611,8 +600,7 @@ def update_from_output( for input_id in list(cached_encoder_input_ids): start_pos = request.mm_positions[input_id]["offset"] num_tokens = request.mm_positions[input_id]["length"] - if (start_pos + num_tokens - <= num_computed_tokens_with_output): + if start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. self.encoder_cache_manager.free_encoder_input( @@ -673,6 +661,9 @@ def update_from_output( # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors + self.scheduled_req_ids[request.request_id] -= 1 + if self.scheduled_req_ids[request.request_id] == 0: + del self.scheduled_req_ids[request.request_id] if not stopped: new_running.append(request) @@ -716,8 +707,8 @@ def finish_requests( if request.status == RequestStatus.RUNNING: self.running.remove(request) - if request.request_id in self.orig_num_computed_tokens: - del self.orig_num_computed_tokens[request.request_id] + if request.request_id in self.scheduled_req_ids: + del self.scheduled_req_ids[request.request_id] else: self.waiting.remove(request) request.status = finished_status From 0c5357492bf411bc0eedfad2635d1ea293fb96c2 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 27 Mar 2025 10:53:29 -0700 Subject: [PATCH 11/12] minor Signed-off-by: Cody Yu --- tests/v1/core/test_scheduler.py | 4 ++++ vllm/v1/core/sched/scheduler.py | 7 +------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 4e057baa6369..166fde685f73 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -357,6 +357,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + scheduler.scheduled_req_ids[req.request_id] = 1 scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -408,6 +409,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + scheduler.scheduled_req_ids[req.request_id] = 1 scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -457,6 +459,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + scheduler.scheduled_req_ids[req.request_id] = 1 scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -506,6 +509,7 @@ def test_stop_via_update_from_output(): requests[0].num_computed_tokens = requests[0].num_tokens scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) + scheduler.scheduled_req_ids[requests[0].request_id] = 1 scheduler_output = SchedulerOutput( scheduled_new_reqs=[], diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a11467c02e62..05abdbf76d2b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -505,7 +505,7 @@ def _try_schedule_encoder_inputs( limitations, the method adjusts `num_new_tokens` to schedule only the decoder tokens up to just before the unschedulable encoder input. """ - if not request.has_encoder_inputs() or num_new_tokens == 0: + if num_new_tokens == 0 or not request.has_encoder_inputs(): return [], num_new_tokens, encoder_budget encoder_inputs_to_schedule: list[int] = [] @@ -729,11 +729,6 @@ def get_num_unfinished_requests(self) -> int: def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 - def has_requests(self): - """Returns True if there are unfinished requests, or finished requests - not yet returned in SchedulerOutputs.""" - return self.has_unfinished_requests() or self.has_finished_requests() - def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() From 1154bca3b84a9a14b5351e6c831a283693629874 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 18 Apr 2025 11:26:09 -0700 Subject: [PATCH 12/12] comment Signed-off-by: Cody Yu Signed-off-by: Cody Yu --- tests/v1/core/test_scheduler.py | 4 ---- vllm/v1/core/sched/scheduler.py | 13 ------------- 2 files changed, 17 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 3853224b0702..9a982199e754 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -432,7 +432,6 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids[req.request_id] = 1 scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -484,7 +483,6 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids[req.request_id] = 1 scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -534,7 +532,6 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids[req.request_id] = 1 scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -584,7 +581,6 @@ def test_stop_via_update_from_output(): requests[0].num_computed_tokens = requests[0].num_tokens scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) - scheduler.scheduled_req_ids[requests[0].request_id] = 1 scheduler_output = SchedulerOutput( scheduled_new_reqs=[], diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 40874af63347..9df0e336c7f4 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -89,12 +89,6 @@ def __init__( self.waiting: deque[Request] = deque() self.running: list[Request] = [] - # req_id -> Number of times the request has been scheduled. - # With PP, when the input prompt is divided into chunks, we can - # schedule a new chunk even before the previous chunk has completed - # the full pipeline stages. This helps reduce TTFT. - self.scheduled_req_ids: dict[str, int] = defaultdict(int) - # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished # requests so that they can free the cached states for those requests. @@ -238,7 +232,6 @@ def schedule(self) -> SchedulerOutput: # Schedule the request. scheduled_running_reqs.append(request) - self.scheduled_req_ids[request.request_id] += 1 if request.use_structured_output: # PERF: in case of chunked prefill, # request might not include any new tokens. @@ -374,7 +367,6 @@ def schedule(self) -> SchedulerOutput: request.request_id] = req_index req_index += 1 self.running.append(request) - self.scheduled_req_ids[request.request_id] += 1 if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) @@ -726,9 +718,6 @@ def update_from_output( # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors - self.scheduled_req_ids[req_id] -= 1 - if self.scheduled_req_ids[req_id] == 0: - del self.scheduled_req_ids[req_id] if not stopped: new_running.append(request) @@ -778,8 +767,6 @@ def finish_requests( if request.status == RequestStatus.RUNNING: self.running.remove(request) - if request.request_id in self.scheduled_req_ids: - del self.scheduled_req_ids[request.request_id] else: self.waiting.remove(request) request.status = finished_status