Skip to content
7 changes: 7 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def test_schedule_partial_requests():
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
Expand Down Expand Up @@ -259,6 +260,7 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID],
[10,
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})

Expand Down Expand Up @@ -307,6 +309,7 @@ def test_stop_via_update_from_output():
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})

Expand Down Expand Up @@ -354,6 +357,7 @@ def test_stop_via_update_from_output():
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})

Expand Down Expand Up @@ -394,6 +398,7 @@ def test_stop_via_update_from_output():
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})

Expand Down Expand Up @@ -434,6 +439,7 @@ def test_schedule_concurrent_batches():
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
Expand All @@ -450,6 +456,7 @@ def test_schedule_concurrent_batches():
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
Expand Down
11 changes: 4 additions & 7 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def update_from_output(
model_runner_output: "ModelRunnerOutput",
) -> EngineCoreOutputs:
sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
Expand Down Expand Up @@ -530,13 +531,9 @@ def update_from_output(
self.encoder_cache_manager.free_encoder_input(
request, input_id)

if request.num_computed_tokens >= request.num_tokens:
# Clear the spec tokens as the request has generated
# a new token. Here, We assume all spec tokens are verified
# if we perform speculative decoding for this request.
# Therefore, we can clear all spec tokens after
# the generation step.
request.clear_spec_tokens()
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
request.spec_token_ids = spec_token_ids[req_index]

# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
Expand Down
30 changes: 0 additions & 30 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -86,15 +85,6 @@ def __init__(
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)

# Setup speculative decode.
# TODO: find a better way to check if we are using ngram.
self.use_spec_decode = False
if self.scheduler.speculative_config:
assert self.scheduler.speculative_config.ngram_prompt_lookup_min \
, "Only ngram spec decode is supported in V1."
self.proposer = NgramProposer()
self.use_spec_decode = True

def _initialize_kv_caches(self,
vllm_config: VllmConfig) -> Tuple[int, int]:
start = time.time()
Expand Down Expand Up @@ -158,9 +148,6 @@ def step(self) -> EngineCoreOutputs:
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())

if self.use_spec_decode:
self.propose_tokens()

scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
Expand Down Expand Up @@ -221,23 +208,6 @@ def shutdown(self):
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)

def propose_tokens(self):
assert self.scheduler.speculative_config is not None
for req in self.scheduler.running:
# Ignore requests that are doing chunked prefill.
if req.num_computed_tokens < req.num_tokens - 1:
continue
# Ignore requests that already have spec tokens.
if req.spec_token_ids:
continue
spec_tokens = self.proposer.propose(
req.all_token_ids,
self.scheduler.speculative_config.ngram_prompt_lookup_min,
self.scheduler.speculative_config.num_speculative_tokens,
)
if spec_tokens:
req.append_spec_token_ids(spec_tokens)

def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()

Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class ModelRunnerOutput:
# each request due to speculative/jump decoding.
sampled_token_ids: List[List[int]]

# num_reqs x num_spec_tokens
spec_token_ids: Optional[List[List[int]]]

# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs]
Expand Down
12 changes: 0 additions & 12 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,6 @@ def append_output_token_ids(
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)

def append_spec_token_ids(
self,
token_ids: Union[int, List[int]],
) -> None:
if isinstance(token_ids, int):
self.spec_token_ids.append(token_ids)
else:
self.spec_token_ids.extend(token_ids)

def clear_spec_tokens(self) -> None:
self.spec_token_ids.clear()

@property
def num_tokens(self) -> int:
return len(self._all_token_ids)
Expand Down
23 changes: 15 additions & 8 deletions vllm/v1/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional

from vllm.v1.utils import ConstantList
import numpy as np


class NgramProposer:

def __init__(self):
pass

def propose(self, context_token_ids: ConstantList[int], n: int,
k: int) -> Optional[List[int]]:
def propose(
self,
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
tokens in the previous context, and returns k tokens that followed
Expand All @@ -25,8 +29,8 @@ def propose(self, context_token_ids: ConstantList[int], n: int,
the maximum amount of tokens until the end.

Returns:
List[int]: The sequence of tokens that followed
the matched n-gram in the context.
np.ndarray: The sequence of tokens that followed
the matched n-gram in the context.
None: If no matching n-gram pattern is found.

Example:
Expand Down Expand Up @@ -66,9 +70,12 @@ def _kmp_lps_array(pattern: List[int]) -> List[int]:
return lps

@staticmethod
def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int,
k: int) -> Optional[List[int]]:
context_len = len(context_token_ids)
def _find_subarray_kmp(
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0]
assert n > 0

pattern = context_token_ids[-n:]
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)

Expand Down Expand Up @@ -217,7 +218,11 @@ def add_request(
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens

self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(req_index, request.block_ids)
Expand Down Expand Up @@ -356,6 +361,8 @@ def condense(self, empty_req_indices: List[int]) -> None:
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[
Expand Down
47 changes: 47 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
Expand Down Expand Up @@ -117,6 +118,15 @@ def __init__(
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}

# Set up speculative decoding.
self.use_spec_decode = False
if self.speculative_config:
# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1."
self.drafter = NgramProposer()
self.use_spec_decode = True

# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
Expand Down Expand Up @@ -366,6 +376,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = req_data.new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, [])
Expand Down Expand Up @@ -1003,15 +1014,51 @@ def execute_model(
for seq in sampled_token_ids[valid_mask].split(gen_lens)
]

if not self.use_spec_decode:
spec_token_ids = None
else:
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids)

model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)
return model_runner_output

def generate_draft_token_ids(
self,
sampled_token_ids: List[List[int]],
) -> List[List[int]]:
# TODO(woosuk): Optimize.
num_reqs = len(sampled_token_ids)
draft_token_ids: List[List[int]] = []
for i in range(num_reqs):
if len(sampled_token_ids[i]) == 0:
# Skip speculative decoding.
draft_token_ids.append([])
continue

# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + len(sampled_token_ids[i])
self.input_batch.token_ids_cpu[
i, start_idx:end_idx] = sampled_token_ids[i]
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx],
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids

def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ def execute_model(
req_ids=all_req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
)
Expand Down