Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_update_states_request_resumed(model_runner):
cached_req_data = CachedRequestData(
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=0,
)
Expand Down
43 changes: 28 additions & 15 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def schedule(self) -> "SchedulerOutput":
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}

# For logging.
scheduled_timestamp = time.monotonic()

# First, schedule the RUNNING requests.
Expand Down Expand Up @@ -187,6 +189,15 @@ def schedule(self) -> "SchedulerOutput":
token_budget -= num_new_tokens
req_index += 1

# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids[:num_scheduled_spec_tokens])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part fixes a bug in #12193


# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
Expand All @@ -196,11 +207,6 @@ def schedule(self) -> "SchedulerOutput":
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget

# Speculative decode related.
if request.spec_token_ids:
scheduled_spec_decode_tokens[
request.request_id] = request.spec_token_ids

# Record the LoRAs in scheduled_running_reqs
requested_loras: Set[int] = set()
if self.lora_config:
Expand Down Expand Up @@ -324,23 +330,24 @@ def schedule(self) -> "SchedulerOutput":
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
req_to_new_block_ids[req.request_id],
req.num_computed_tokens)
req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs
]
resumed_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
req.num_computed_tokens,
resumed_from_preemption=True,
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
req.num_computed_tokens,
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
Expand All @@ -349,8 +356,8 @@ def schedule(self) -> "SchedulerOutput":
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
Expand All @@ -366,22 +373,28 @@ def schedule(self) -> "SchedulerOutput":
def _make_cached_request_data(
self,
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: List[int],
num_computed_tokens: int,
resumed_from_preemption: bool,
) -> "CachedRequestData":
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
if request.request_id in self._cached_reqs_data:
req_data = self._cached_reqs_data[request.request_id]
num_computed_tokens = request.num_computed_tokens
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_data = self._cached_reqs_data.get(request.request_id)
if req_data is not None:
req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_token_ids = new_token_ids
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
req_data = CachedRequestData.from_request(request,
resumed_from_preemption,
new_block_ids,
num_computed_tokens)
new_token_ids,
new_block_ids)
self._cached_reqs_data[request.request_id] = req_data
return req_data

Expand Down
15 changes: 8 additions & 7 deletions vllm/v1/core/scheduler_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def from_request(
cls,
request: "Request",
block_ids: List[int],
num_computed_tokens: int,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
Expand All @@ -41,7 +40,7 @@ def from_request(
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
)

Expand All @@ -54,6 +53,7 @@ class CachedRequestData:
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: List[int]
new_block_ids: List[int]
num_computed_tokens: int

Expand All @@ -62,14 +62,15 @@ def from_request(
cls,
request: "Request",
resumed_from_preemption: bool,
new_token_ids: List[int],
new_block_ids: List[int],
num_computed_tokens: int,
) -> "CachedRequestData":
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
num_computed_tokens=request.num_computed_tokens,
)


Expand All @@ -91,9 +92,9 @@ class SchedulerOutput:
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_decode_tokens
# If a request does not have any spec decode tokens, it will
# not be included in the dictionary.
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
# included in the dictionary.
scheduled_spec_decode_tokens: Dict[str, List[int]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
Expand Down
Loading