diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index e39a7f9f40bd..eb730973c946 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -203,6 +203,7 @@ def test_schedule_partial_requests(): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[0] for _ in range(len(requests))], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, ) @@ -259,6 +260,7 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]], # First request hits EOS, second continues + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) @@ -307,6 +309,7 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) @@ -354,6 +357,7 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) @@ -394,6 +398,7 @@ def test_stop_via_update_from_output(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}) @@ -434,6 +439,7 @@ def test_schedule_concurrent_batches(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, ) @@ -450,6 +456,7 @@ def test_schedule_concurrent_batches(): req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[0]], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, ) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index e5c60afeb492..8f10834251c1 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -474,6 +474,7 @@ def update_from_output( model_runner_output: "ModelRunnerOutput", ) -> EngineCoreOutputs: sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens @@ -530,13 +531,9 @@ def update_from_output( self.encoder_cache_manager.free_encoder_input( request, input_id) - if request.num_computed_tokens >= request.num_tokens: - # Clear the spec tokens as the request has generated - # a new token. Here, We assume all spec tokens are verified - # if we perform speculative decoding for this request. - # Therefore, we can clear all spec tokens after - # the generation step. - request.clear_spec_tokens() + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + request.spec_token_ids = spec_token_ids[req_index] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c7ea7b1a94d8..6718a5f7b02d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -27,7 +27,6 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder -from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -86,15 +85,6 @@ def __init__( self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size) - # Setup speculative decode. - # TODO: find a better way to check if we are using ngram. - self.use_spec_decode = False - if self.scheduler.speculative_config: - assert self.scheduler.speculative_config.ngram_prompt_lookup_min \ - , "Only ngram spec decode is supported in V1." - self.proposer = NgramProposer() - self.use_spec_decode = True - def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: start = time.time() @@ -158,9 +148,6 @@ def step(self) -> EngineCoreOutputs: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) - if self.use_spec_decode: - self.propose_tokens() - scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( @@ -221,23 +208,6 @@ def shutdown(self): def profile(self, is_start: bool = True): self.model_executor.profile(is_start) - def propose_tokens(self): - assert self.scheduler.speculative_config is not None - for req in self.scheduler.running: - # Ignore requests that are doing chunked prefill. - if req.num_computed_tokens < req.num_tokens - 1: - continue - # Ignore requests that already have spec tokens. - if req.spec_token_ids: - continue - spec_tokens = self.proposer.propose( - req.all_token_ids, - self.scheduler.speculative_config.ngram_prompt_lookup_min, - self.scheduler.speculative_config.num_speculative_tokens, - ) - if spec_tokens: - req.append_spec_token_ids(spec_tokens) - def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index fb6c4051e9a6..0c8eca38ade7 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -67,6 +67,9 @@ class ModelRunnerOutput: # each request due to speculative/jump decoding. sampled_token_ids: List[List[int]] + # num_reqs x num_spec_tokens + spec_token_ids: Optional[List[List[int]]] + # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] # [num_reqs] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index a1bcc2d0393c..52d7faeeb066 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -104,18 +104,6 @@ def append_output_token_ids( self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) - def append_spec_token_ids( - self, - token_ids: Union[int, List[int]], - ) -> None: - if isinstance(token_ids, int): - self.spec_token_ids.append(token_ids) - else: - self.spec_token_ids.extend(token_ids) - - def clear_spec_tokens(self) -> None: - self.spec_token_ids.clear() - @property def num_tokens(self) -> int: return len(self._all_token_ids) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 8eee99506b1f..9b116e00af97 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import List, Optional -from vllm.v1.utils import ConstantList +import numpy as np class NgramProposer: @@ -9,8 +9,12 @@ class NgramProposer: def __init__(self): pass - def propose(self, context_token_ids: ConstantList[int], n: int, - k: int) -> Optional[List[int]]: + def propose( + self, + context_token_ids: np.ndarray, + n: int, + k: int, + ) -> Optional[np.ndarray]: """Proposes the next sequence of tokens based on n-gram pattern matching in the context. The function finds matches of the last n tokens in the previous context, and returns k tokens that followed @@ -25,8 +29,8 @@ def propose(self, context_token_ids: ConstantList[int], n: int, the maximum amount of tokens until the end. Returns: - List[int]: The sequence of tokens that followed - the matched n-gram in the context. + np.ndarray: The sequence of tokens that followed + the matched n-gram in the context. None: If no matching n-gram pattern is found. Example: @@ -66,9 +70,12 @@ def _kmp_lps_array(pattern: List[int]) -> List[int]: return lps @staticmethod - def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int, - k: int) -> Optional[List[int]]: - context_len = len(context_token_ids) + def _find_subarray_kmp( + context_token_ids: np.ndarray, + n: int, + k: int, + ) -> Optional[np.ndarray]: + context_len = context_token_ids.shape[0] assert n > 0 pattern = context_token_ids[-n:] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 805d8f618d2e..cb7411a44e2f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -78,6 +78,7 @@ def __init__( ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) @@ -217,7 +218,11 @@ def add_request( end_idx = start_idx + len(request.output_token_ids) self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids + # Number of token ids in token_ids_cpu. + # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens + # Number of tokens without spec decode tokens. + self.num_tokens_no_spec[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.block_table.add_row(req_index, request.block_ids) @@ -356,6 +361,8 @@ def condense(self, empty_req_indices: List[int]) -> None: self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] self.num_tokens[empty_index] = num_tokens + self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ + last_req_index] self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ last_req_index] self.num_computed_tokens_cpu[ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e1d1e43427b8..76dcb82e99ad 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -33,6 +33,7 @@ from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID +from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -117,6 +118,15 @@ def __init__( # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + # Set up speculative decoding. + self.use_spec_decode = False + if self.speculative_config: + # TODO: find a better way to check if we are using ngram. + assert self.speculative_config.ngram_prompt_lookup_min, \ + "Currently, only ngram spec decode is supported in V1." + self.drafter = NgramProposer() + self.use_spec_decode = True + # Request states. self.requests: Dict[str, CachedRequestState] = {} # Persistent batch. @@ -366,6 +376,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.token_ids_cpu[ req_index, start_token_index:end_token_index] = req_data.new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( req_id, []) @@ -1003,15 +1014,51 @@ def execute_model( for seq in sampled_token_ids[valid_mask].split(gen_lens) ] + if not self.use_spec_decode: + spec_token_ids = None + else: + spec_token_ids = self.generate_draft_token_ids( + valid_sampled_token_ids) + model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, ) return model_runner_output + def generate_draft_token_ids( + self, + sampled_token_ids: List[List[int]], + ) -> List[List[int]]: + # TODO(woosuk): Optimize. + num_reqs = len(sampled_token_ids) + draft_token_ids: List[List[int]] = [] + for i in range(num_reqs): + if len(sampled_token_ids[i]) == 0: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Add sampled_token_ids to token_ids_cpu. + start_idx = self.input_batch.num_tokens_no_spec[i] + end_idx = start_idx + len(sampled_token_ids[i]) + self.input_batch.token_ids_cpu[ + i, start_idx:end_idx] = sampled_token_ids[i] + drafter_output = self.drafter.propose( + self.input_batch.token_ids_cpu[i, :end_idx], + self.speculative_config.ngram_prompt_lookup_min, + self.speculative_config.num_speculative_tokens, + ) + if drafter_output is None or len(drafter_output) == 0: + draft_token_ids.append([]) + else: + draft_token_ids.append(drafter_output.tolist()) + return draft_token_ids + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8635ffce7027..4ee6853ba7ef 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -696,6 +696,7 @@ def execute_model( req_ids=all_req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[[token_id] for token_id in sampled_token_ids], + spec_token_ids=None, logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] )