From 8406f11ce3c4226e004d8ed9d4c036abb6301712 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Feb 2025 12:54:42 -0800 Subject: [PATCH 1/8] [V1] Get input tokens from scheduler Signed-off-by: Woosuk Kwon --- vllm/v1/core/scheduler.py | 21 +++++++----- vllm/v1/core/scheduler_output.py | 9 ++--- vllm/v1/worker/gpu_model_runner.py | 55 +++++++++++++++--------------- 3 files changed, 44 insertions(+), 41 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 2d5a1192c227..f7af451633b6 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -312,23 +312,22 @@ def schedule(self) -> "SchedulerOutput": # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request(req, - req_to_new_block_ids[req.request_id], - req.num_computed_tokens) + req_to_new_block_ids[req.request_id]) for req in scheduled_new_reqs ] resumed_reqs_data = [ self._make_cached_request_data( req, + num_scheduled_tokens[req.request_id], req_to_new_block_ids[req.request_id], - req.num_computed_tokens, resumed_from_preemption=True, ) for req in scheduled_resumed_reqs ] running_reqs_data = [ self._make_cached_request_data( req, + num_scheduled_tokens[req.request_id], req_to_new_block_ids[req.request_id], - req.num_computed_tokens, resumed_from_preemption=False, ) for req in scheduled_running_reqs ] @@ -353,22 +352,26 @@ def schedule(self) -> "SchedulerOutput": def _make_cached_request_data( self, request: Request, + num_scheduled_tokens: int, new_block_ids: List[int], - num_computed_tokens: int, resumed_from_preemption: bool, ) -> "CachedRequestData": # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - if request.request_id in self._cached_reqs_data: - req_data = self._cached_reqs_data[request.request_id] + num_computed_tokens = request.num_computed_tokens + input_token_ids = request.all_token_ids[ + num_computed_tokens:num_computed_tokens + num_scheduled_tokens] + req_data = self._cached_reqs_data.get(request.request_id) + if req_data is not None: req_data.resumed_from_preemption = resumed_from_preemption + req_data.input_token_ids = input_token_ids req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens else: req_data = CachedRequestData.from_request(request, resumed_from_preemption, - new_block_ids, - num_computed_tokens) + input_token_ids, + new_block_ids) self._cached_reqs_data[request.request_id] = req_data return req_data diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index 990b3dd0ed78..409f4d7b64a2 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -30,7 +30,6 @@ def from_request( cls, request: "Request", block_ids: List[int], - num_computed_tokens: int, ) -> "NewRequestData": return cls( req_id=request.request_id, @@ -41,7 +40,7 @@ def from_request( mm_positions=request.mm_positions, sampling_params=request.sampling_params, block_ids=block_ids, - num_computed_tokens=num_computed_tokens, + num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, ) @@ -54,6 +53,7 @@ class CachedRequestData: # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool + input_token_ids: List[int] new_block_ids: List[int] num_computed_tokens: int @@ -62,14 +62,15 @@ def from_request( cls, request: "Request", resumed_from_preemption: bool, + input_token_ids: List[int], new_block_ids: List[int], - num_computed_tokens: int, ) -> "CachedRequestData": return cls( req_id=request.request_id, resumed_from_preemption=resumed_from_preemption, + input_token_ids=input_token_ids, new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, + num_computed_tokens=request.num_computed_tokens, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 821c9e138028..8955aeaa709d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,7 +2,7 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import numpy as np import torch @@ -322,7 +322,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_state = self.requests[req_id] # Update the cached states. - req_state.num_computed_tokens = req_data.num_computed_tokens + num_computed_tokens = req_data.num_computed_tokens + req_state.num_computed_tokens = num_computed_tokens + # Add the sampled token(s) from the previous step (if any). + num_new_tokens = (num_computed_tokens + + len(req_data.input_token_ids) - + req_state.num_tokens) + new_token_ids = (req_data.input_token_ids[-num_new_tokens:] + if num_new_tokens > 0 else []) + if new_token_ids: + req_state.output_token_ids.extend(new_token_ids) + # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. req_state.block_ids.extend(req_data.new_block_ids) @@ -341,12 +351,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) + num_computed_tokens) start_index = len(req_state.block_ids) - len( req_data.new_block_ids) self.input_batch.block_table.append_row(req_index, start_index, req_data.new_block_ids) - + if new_token_ids: + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + num_new_tokens + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = new_token_ids + + # Check if the batch has changed. If not, we can skip copying the + # sampling metadata from CPU to GPU. batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 # Add the new or resumed requests to the persistent batch. @@ -856,34 +874,21 @@ def execute_model( # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs - request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] - for i, req_id in enumerate( # type: ignore[assignment] - self.input_batch.req_ids[:num_reqs]): + req_ids: List[str] = [] + for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None + req_ids.append(req_id) req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len <= req_state.num_tokens - if seq_len == req_state.num_tokens: - # Append the sampled token to the output token ids. - self.input_batch.num_tokens[i] += 1 - # OPTIMIZATION: Priming the state updates for later updates. - req_state.output_token_ids.append(0) - request_seq_lens.append((i, req_state, seq_len)) - else: - # Ignore the sampled token from the partial request. + if seq_len < req_state.num_tokens: + # Ignore the sampled token. # Rewind the generator state as if the token was not sampled. generator = self.input_batch.generators.get(i) if generator is not None: # This relies on cuda-specific torch-internal impl details generator.set_offset(generator.get_offset() - 4) - # num_reqs entries should be non-None - assert all( - req_id is not None for req_id in - self.input_batch.req_ids[:num_reqs]), "req_ids contains None" - req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. sampled_token_ids = sampler_output.sampled_token_ids.tolist() @@ -897,12 +902,6 @@ def execute_model( scheduler_output, ) - # Update with the actual token ids - for i, req_state, seq_len in request_seq_lens: - token_id = sampled_token_ids[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids[-1] = token_id - model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, From 0399f095ec89f77f17fee43f804feae0c9d21a83 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Feb 2025 17:17:50 -0800 Subject: [PATCH 2/8] fix Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8955aeaa709d..e87ba1dd455f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -362,6 +362,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.token_ids_cpu[ req_index, start_token_index:end_token_index] = new_token_ids + self.input_batch.num_tokens[req_index] += num_new_tokens # Check if the batch has changed. If not, we can skip copying the # sampling metadata from CPU to GPU. From c54ff6c286e560279f42abea1e0cbe34f88ac483 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Feb 2025 00:59:36 -0800 Subject: [PATCH 3/8] fix Signed-off-by: Woosuk Kwon --- vllm/v1/core/scheduler.py | 31 +++-- vllm/v1/core/scheduler_output.py | 12 +- vllm/v1/worker/gpu_model_runner.py | 205 ++++++++++++++--------------- 3 files changed, 128 insertions(+), 120 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index e0b32fb24ac2..dda391eaa4f9 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -121,6 +121,8 @@ def schedule(self) -> "SchedulerOutput": encoder_budget = self.max_num_encoder_input_tokens # Spec decode-related. scheduled_spec_decode_tokens: Dict[str, List[int]] = {} + + # For logging. scheduled_timestamp = time.monotonic() # First, schedule the RUNNING requests. @@ -187,6 +189,15 @@ def schedule(self) -> "SchedulerOutput": token_budget -= num_new_tokens req_index += 1 + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids[:num_scheduled_spec_tokens]) + # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( @@ -196,11 +207,6 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget - # Speculative decode related. - if request.spec_token_ids: - scheduled_spec_decode_tokens[ - request.request_id] = request.spec_token_ids - # Record the LoRAs in scheduled_running_reqs requested_loras: Set[int] = set() if self.lora_config: @@ -331,6 +337,7 @@ def schedule(self) -> "SchedulerOutput": self._make_cached_request_data( req, num_scheduled_tokens[req.request_id], + scheduled_spec_decode_tokens.get(req.request_id, []), req_to_new_block_ids[req.request_id], resumed_from_preemption=True, ) for req in scheduled_resumed_reqs @@ -339,6 +346,7 @@ def schedule(self) -> "SchedulerOutput": self._make_cached_request_data( req, num_scheduled_tokens[req.request_id], + scheduled_spec_decode_tokens.get(req.request_id, []), req_to_new_block_ids[req.request_id], resumed_from_preemption=False, ) for req in scheduled_running_reqs @@ -348,8 +356,8 @@ def schedule(self) -> "SchedulerOutput": scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, - scheduled_encoder_inputs=scheduled_encoder_inputs, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. @@ -366,24 +374,27 @@ def _make_cached_request_data( self, request: Request, num_scheduled_tokens: int, + scheduled_spec_token_ids: List[int], new_block_ids: List[int], resumed_from_preemption: bool, ) -> "CachedRequestData": # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. num_computed_tokens = request.num_computed_tokens - input_token_ids = request.all_token_ids[ - num_computed_tokens:num_computed_tokens + num_scheduled_tokens] + num_spec_tokens = len(scheduled_spec_token_ids) + num_regular_tokens = num_scheduled_tokens - num_spec_tokens + new_token_ids = request.all_token_ids[ + num_computed_tokens:num_computed_tokens + num_regular_tokens] req_data = self._cached_reqs_data.get(request.request_id) if req_data is not None: req_data.resumed_from_preemption = resumed_from_preemption - req_data.input_token_ids = input_token_ids + req_data.new_token_ids = new_token_ids req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens else: req_data = CachedRequestData.from_request(request, resumed_from_preemption, - input_token_ids, + new_token_ids, new_block_ids) self._cached_reqs_data[request.request_id] = req_data return req_data diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index b97ebe96ba77..47413527c32f 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -53,7 +53,7 @@ class CachedRequestData: # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool - input_token_ids: List[int] + new_token_ids: List[int] new_block_ids: List[int] num_computed_tokens: int @@ -62,13 +62,13 @@ def from_request( cls, request: "Request", resumed_from_preemption: bool, - input_token_ids: List[int], + new_token_ids: List[int], new_block_ids: List[int], ) -> "CachedRequestData": return cls( req_id=request.request_id, resumed_from_preemption=resumed_from_preemption, - input_token_ids=input_token_ids, + new_token_ids=new_token_ids, new_block_ids=new_block_ids, num_computed_tokens=request.num_computed_tokens, ) @@ -92,9 +92,9 @@ class SchedulerOutput: # Total number of tokens scheduled for all requests. # Equal to sum(num_scheduled_tokens.values()) total_num_scheduled_tokens: int - # req_id -> spec_decode_tokens - # If a request does not have any spec decode tokens, it will - # not be included in the dictionary. + # req_id -> spec_token_ids + # If a request does not have any spec decode tokens, it will not be + # included in the dictionary. scheduled_spec_decode_tokens: Dict[str, List[int]] # req_id -> encoder input indices that need processing. # E.g., if a request has [0, 1], it could mean the vision encoder needs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3af45a9c25b6..961ee5bf4e0f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -32,7 +32,7 @@ KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata -# from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID +from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -181,7 +181,6 @@ def __init__( self.max_model_len, self.max_num_tokens), dtype=np.int32) - self.arange_cpu = torch.from_numpy(self.arange_np) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. @@ -327,13 +326,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: num_computed_tokens = req_data.num_computed_tokens req_state.num_computed_tokens = num_computed_tokens # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec decode tokens. num_new_tokens = (num_computed_tokens + - len(req_data.input_token_ids) - + len(req_data.new_token_ids) - req_state.num_tokens) - new_token_ids = (req_data.input_token_ids[-num_new_tokens:] + new_token_ids = (req_data.new_token_ids[-num_new_tokens:] if num_new_tokens > 0 else []) - if new_token_ids: - req_state.output_token_ids.extend(new_token_ids) + req_state.output_token_ids.extend(new_token_ids) # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. @@ -354,17 +353,26 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - start_index = len(req_state.block_ids) - len( - req_data.new_block_ids) + start_index = (len(req_state.block_ids) - + len(req_data.new_block_ids)) self.input_batch.block_table.append_row(req_index, start_index, req_data.new_block_ids) - if new_token_ids: - start_token_index = num_computed_tokens - end_token_index = num_computed_tokens + num_new_tokens + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(req_data.new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = req_data.new_token_ids + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, []) + if spec_token_ids: + start_index = end_token_index + end_token_index += len(spec_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = new_token_ids - self.input_batch.num_tokens[req_index] += num_new_tokens + req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec decode tokens. + self.input_batch.num_tokens[req_index] = end_token_index # Check if the batch has changed. If not, we can skip copying the # sampling metadata from CPU to GPU. @@ -390,7 +398,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: return batch_changed def _prepare_inputs( - self, scheduler_output: "SchedulerOutput" + self, + scheduler_output: "SchedulerOutput", ) -> Tuple[FlashAttentionMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -404,23 +413,13 @@ def _prepare_inputs( # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens_list: List[int] = [] - max_num_scheduled_tokens = 0 - all_spec_token_ids: List[int] = [] - num_spec_tokens_list: List[int] = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens_list.append(num_tokens) - max_num_scheduled_tokens = max(max_num_scheduled_tokens, - num_tokens) - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, []) - all_spec_token_ids.extend(spec_token_ids) - num_spec_tokens_list.append(len(spec_token_ids)) - - num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, - dtype=np.int32) - assert max_num_scheduled_tokens > 0 + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + max_num_scheduled_tokens = num_scheduled_tokens.max() # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] @@ -457,78 +456,6 @@ def _prepare_inputs( token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - use_spec_decode = len(all_spec_token_ids) > 0 - if use_spec_decode: - - # 1. Write spec_token_ids to input batch. - # Step 1. Get req indices that perform spec decode and repeat - # the req indices by the number of spec tokens. Note - # for requests that don't perform spec decode, the - # number of spec tokens is 0 and the req index is - # repeated 0 times. - # E.g., num_spec_tokens_list: [3, 0, 2, 0, 1] - # spec_req_indices: [0, 0, 0, 2, 2, 4] - spec_req_indices = np.repeat(self.arange_np[:num_reqs], - num_spec_tokens_list) - # spec_offsets: offsets within each spec token list. - # E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here - spec_offsets = np.concatenate( - [self.arange_np[1:val + 1] for val in num_spec_tokens_list]) - # spec_seq_offsets: offsets within each sequence. - # E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2] - # after repeating: [1, 1, 1, 3, 3, 2] - # spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1] - # = [2, 3, 4, 4, 5, 3] - spec_seq_offsets = np.repeat( - self.input_batch.num_computed_tokens_cpu[:num_reqs], - num_spec_tokens_list) + spec_offsets - # cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3] - cumsums_spec_offsets = ( - spec_seq_offsets + - spec_req_indices * self.input_batch.token_ids_cpu.shape[1]) - cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to( - torch.int64) - all_spec_token_ids = torch.tensor(all_spec_token_ids, - device="cpu", - dtype=self.input_ids_cpu.dtype) - - # Step 2. Write spec token ids to input_ids_cpu. - self.input_batch.token_ids_cpu_tensor.flatten().scatter_( - 0, cumsums_spec_offsets, all_spec_token_ids) - - # 2. Get spec decode logits indices. - # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] - # cu_num_tokens: [4, 104, 107, 207, 209] - # num_spec_tokens_list: [3, 0, 2, 0, 1] - # num_sampled_tokens: [4, 1, 3, 1, 2] - # spec_decode_logits_indices: - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32) - num_sampled_tokens = num_spec_tokens_np + 1 - # logits_start_loc: [0, 103, 104, 206, 207] - logits_start_loc = cu_num_tokens - num_sampled_tokens - # [0, 103, 104, 206, 207] -> - # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] - logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) - # The following three lines: - # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] - cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) - # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] - # -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - cumsums_sampled_offsets = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) - # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - # - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - total_num_sampled_tokens = num_sampled_tokens.sum() - sampled_arange = (self.arange_np[:total_num_sampled_tokens] - - cumsums_sampled_offsets) - - # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - spec_decode_logits_indices = logits_start_loc + sampled_arange - # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. @@ -622,9 +549,11 @@ def _prepare_inputs( suffix_kv_lens=suffix_kv_lens, ) + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 if use_spec_decode: - logits_indices = torch.from_numpy(spec_decode_logits_indices).to( - self.device, non_blocking=True) + logits_indices = self._calc_spec_decode_metadata( + scheduler_output, cu_num_tokens) else: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -778,6 +707,56 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr += completion_part_len + def _calc_spec_decode_metadata( + self, + scheduler_output: "SchedulerOutput", + cu_num_tokens: np.ndarray, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get the number of spec decode tokens for each request. + num_reqs = self.input_batch.num_reqs + num_spec_decode_tokens_list: List[int] = [] + for req_id in self.input_batch.req_ids[:num_reqs]: + assert req_id is not None + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, []) + num_spec_decode_tokens_list.append(len(spec_token_ids)) + num_spec_decode_tokens = np.array(num_spec_decode_tokens_list, + dtype=np.int32) + + # Get spec decode logits indices. + # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] + # cu_num_tokens: [4, 104, 107, 207, 209] + # num_spec_tokens_list: [3, 0, 2, 0, 1] + # num_sampled_tokens: [4, 1, 3, 1, 2] + # spec_decode_logits_indices: + # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + num_sampled_tokens = num_spec_decode_tokens + 1 + # logits_start_loc: [0, 103, 104, 206, 207] + logits_start_loc = cu_num_tokens - num_sampled_tokens + # [0, 103, 104, 206, 207] -> + # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) + # The following three lines: + # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) + # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] + # -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + cumsums_sampled_offsets = np.repeat( + cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + total_num_sampled_tokens = num_sampled_tokens.sum() + sampled_arange = (self.arange_np[:total_num_sampled_tokens] - + cumsums_sampled_offsets) + + # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> + # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + spec_decode_logits_indices = logits_start_loc + sampled_arange + return torch.from_numpy(spec_decode_logits_indices).to( + self.device, non_blocking=True) + def _prepare_sampling( self, batch_changed: bool, @@ -789,7 +768,9 @@ def _prepare_sampling( for req_id, req in self.requests.items()} sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, req_to_spec_token_ids, not batch_changed) + req_id_output_token_ids, + req_to_spec_token_ids, + skip_copy=not batch_changed) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): @@ -993,10 +974,26 @@ def execute_model( scheduler_output, ) + # Get the valid generated tokens. + sampled_token_ids = sampler_output.sampled_token_ids + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_mask = sampled_token_ids != INVALID_TOKEN_ID + gen_lens = valid_mask.sum(dim=1).tolist() + # TODO(woosuk): Optimize this. + valid_sampled_token_ids = [ + seq.tolist() + for seq in sampled_token_ids[valid_mask].split(gen_lens) + ] + model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=sampler_output.sampled_token_ids, + sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, ) From c833429ceae6748a05a9a4c06bd226e074f22843 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Feb 2025 10:43:15 -0800 Subject: [PATCH 4/8] comment Signed-off-by: Woosuk Kwon --- tests/v1/worker/test_gpu_model_runner.py | 1 + vllm/v1/worker/gpu_model_runner.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 576d906fa749..c655b0fded6e 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -154,6 +154,7 @@ def test_update_states_request_resumed(model_runner): cached_req_data = CachedRequestData( req_id=req_id, resumed_from_preemption=False, + new_token_ids=[], new_block_ids=[], num_computed_tokens=0, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2fd72a995998..19106d5fa081 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -959,6 +959,9 @@ def execute_model( # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs req_ids: List[str] = [] + # Because `input_batch.req_ids` is a list of length `max_num_reqs`, + # we need to stop at `num_reqs`. + # FIXME(woosuk): This is hacky. Refactor. for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_ids.append(req_id) From b42a16f4e390c39c7e2be6c86d0331b48cb1c96a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Feb 2025 12:39:18 -0800 Subject: [PATCH 5/8] [V1][Spec decode] Move drafter to model runner Signed-off-by: Woosuk Kwon --- vllm/v1/core/scheduler.py | 11 +++---- vllm/v1/engine/core.py | 30 ----------------- vllm/v1/outputs.py | 3 ++ vllm/v1/request.py | 12 ------- vllm/v1/spec_decode/ngram_proposer.py | 23 ++++++++----- vllm/v1/worker/gpu_input_batch.py | 7 ++++ vllm/v1/worker/gpu_model_runner.py | 47 +++++++++++++++++++++++++++ vllm/v1/worker/tpu_model_runner.py | 1 + 8 files changed, 77 insertions(+), 57 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index dda391eaa4f9..50ac1f9723b4 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -475,6 +475,7 @@ def update_from_output( model_runner_output: "ModelRunnerOutput", ) -> EngineCoreOutputs: sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens @@ -531,13 +532,9 @@ def update_from_output( self.encoder_cache_manager.free_encoder_input( request, input_id) - if request.num_computed_tokens >= request.num_tokens: - # Clear the spec tokens as the request has generated - # a new token. Here, We assume all spec tokens are verified - # if we perform speculative decoding for this request. - # Therefore, we can clear all spec tokens after - # the generation step. - request.clear_spec_tokens() + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + request.spec_token_ids = spec_token_ids[req_index] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c7ea7b1a94d8..6718a5f7b02d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -27,7 +27,6 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder -from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -86,15 +85,6 @@ def __init__( self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size) - # Setup speculative decode. - # TODO: find a better way to check if we are using ngram. - self.use_spec_decode = False - if self.scheduler.speculative_config: - assert self.scheduler.speculative_config.ngram_prompt_lookup_min \ - , "Only ngram spec decode is supported in V1." - self.proposer = NgramProposer() - self.use_spec_decode = True - def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: start = time.time() @@ -158,9 +148,6 @@ def step(self) -> EngineCoreOutputs: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) - if self.use_spec_decode: - self.propose_tokens() - scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( @@ -221,23 +208,6 @@ def shutdown(self): def profile(self, is_start: bool = True): self.model_executor.profile(is_start) - def propose_tokens(self): - assert self.scheduler.speculative_config is not None - for req in self.scheduler.running: - # Ignore requests that are doing chunked prefill. - if req.num_computed_tokens < req.num_tokens - 1: - continue - # Ignore requests that already have spec tokens. - if req.spec_token_ids: - continue - spec_tokens = self.proposer.propose( - req.all_token_ids, - self.scheduler.speculative_config.ngram_prompt_lookup_min, - self.scheduler.speculative_config.num_speculative_tokens, - ) - if spec_tokens: - req.append_spec_token_ids(spec_tokens) - def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index fb6c4051e9a6..0c8eca38ade7 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -67,6 +67,9 @@ class ModelRunnerOutput: # each request due to speculative/jump decoding. sampled_token_ids: List[List[int]] + # num_reqs x num_spec_tokens + spec_token_ids: Optional[List[List[int]]] + # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] # [num_reqs] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index a1bcc2d0393c..52d7faeeb066 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -104,18 +104,6 @@ def append_output_token_ids( self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) - def append_spec_token_ids( - self, - token_ids: Union[int, List[int]], - ) -> None: - if isinstance(token_ids, int): - self.spec_token_ids.append(token_ids) - else: - self.spec_token_ids.extend(token_ids) - - def clear_spec_tokens(self) -> None: - self.spec_token_ids.clear() - @property def num_tokens(self) -> int: return len(self._all_token_ids) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 8eee99506b1f..9b116e00af97 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import List, Optional -from vllm.v1.utils import ConstantList +import numpy as np class NgramProposer: @@ -9,8 +9,12 @@ class NgramProposer: def __init__(self): pass - def propose(self, context_token_ids: ConstantList[int], n: int, - k: int) -> Optional[List[int]]: + def propose( + self, + context_token_ids: np.ndarray, + n: int, + k: int, + ) -> Optional[np.ndarray]: """Proposes the next sequence of tokens based on n-gram pattern matching in the context. The function finds matches of the last n tokens in the previous context, and returns k tokens that followed @@ -25,8 +29,8 @@ def propose(self, context_token_ids: ConstantList[int], n: int, the maximum amount of tokens until the end. Returns: - List[int]: The sequence of tokens that followed - the matched n-gram in the context. + np.ndarray: The sequence of tokens that followed + the matched n-gram in the context. None: If no matching n-gram pattern is found. Example: @@ -66,9 +70,12 @@ def _kmp_lps_array(pattern: List[int]) -> List[int]: return lps @staticmethod - def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int, - k: int) -> Optional[List[int]]: - context_len = len(context_token_ids) + def _find_subarray_kmp( + context_token_ids: np.ndarray, + n: int, + k: int, + ) -> Optional[np.ndarray]: + context_len = context_token_ids.shape[0] assert n > 0 pattern = context_token_ids[-n:] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 805d8f618d2e..cb7411a44e2f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -78,6 +78,7 @@ def __init__( ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) @@ -217,7 +218,11 @@ def add_request( end_idx = start_idx + len(request.output_token_ids) self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids + # Number of token ids in token_ids_cpu. + # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens + # Number of tokens without spec decode tokens. + self.num_tokens_no_spec[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.block_table.add_row(req_index, request.block_ids) @@ -356,6 +361,8 @@ def condense(self, empty_req_indices: List[int]) -> None: self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] self.num_tokens[empty_index] = num_tokens + self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ + last_req_index] self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ last_req_index] self.num_computed_tokens_cpu[ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 19106d5fa081..dcc0a03a869e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -33,6 +33,7 @@ from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID +from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -115,6 +116,15 @@ def __init__( # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + # Set up speculative decoding. + self.use_spec_decode = False + if self.speculative_config: + # TODO: find a better way to check if we are using ngram. + assert self.speculative_config.ngram_prompt_lookup_min, \ + "Currently, only ngram spec decode is supported in V1." + self.drafter = NgramProposer() + self.use_spec_decode = True + # Request states. self.requests: Dict[str, CachedRequestState] = {} # Persistent batch. @@ -364,6 +374,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.token_ids_cpu[ req_index, start_token_index:end_token_index] = req_data.new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( req_id, []) @@ -1004,15 +1015,51 @@ def execute_model( for seq in sampled_token_ids[valid_mask].split(gen_lens) ] + if not self.use_spec_decode: + spec_token_ids = None + else: + spec_token_ids = self.generate_draft_token_ids( + valid_sampled_token_ids) + model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, ) return model_runner_output + def generate_draft_token_ids( + self, + sampled_token_ids: List[List[int]], + ) -> List[List[int]]: + # TODO(woosuk): Optimize. + num_reqs = len(sampled_token_ids) + draft_token_ids: List[List[int]] = [] + for i in range(num_reqs): + if len(sampled_token_ids[i]) == 0: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Add sampled_token_ids to token_ids_cpu. + start_idx = self.input_batch.num_tokens_no_spec[i] + end_idx = start_idx + len(sampled_token_ids[i]) + self.input_batch.token_ids_cpu[ + i, start_idx:end_idx] = sampled_token_ids[i] + drafter_output = self.drafter.propose( + self.input_batch.token_ids_cpu[i, :end_idx], + self.speculative_config.ngram_prompt_lookup_min, + self.speculative_config.num_speculative_tokens, + ) + if drafter_output is None or len(drafter_output) == 0: + draft_token_ids.append([]) + else: + draft_token_ids.append(drafter_output.tolist()) + return draft_token_ids + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8635ffce7027..4ee6853ba7ef 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -696,6 +696,7 @@ def execute_model( req_ids=all_req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[[token_id] for token_id in sampled_token_ids], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] ) From 74aad9edc1b3c42f1efc8f07e73596c0af39f737 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Feb 2025 13:36:56 -0800 Subject: [PATCH 6/8] fix test Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index e39a7f9f40bd..8b978fbc17d0 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -203,6 +203,7 @@ def test_schedule_partial_requests(): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[0] for _ in range(len(requests))], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, ) @@ -434,6 +435,7 @@ def test_schedule_concurrent_batches(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, ) @@ -450,6 +452,7 @@ def test_schedule_concurrent_batches(): req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[0]], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, ) From 30a72eb8952170ba4293f4f453d637e6512d2cdc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Feb 2025 13:38:32 -0800 Subject: [PATCH 7/8] fix test Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 8b978fbc17d0..34e623ac1c81 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -260,6 +260,7 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]], # First request hits EOS, second continues + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) From 6aac1c583405f6ec176b9158c71ed3fa96aa2abc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Feb 2025 13:40:06 -0800 Subject: [PATCH 8/8] fix test Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 34e623ac1c81..eb730973c946 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -309,6 +309,7 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) @@ -356,6 +357,7 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) @@ -396,6 +398,7 @@ def test_stop_via_update_from_output(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={})