From 7d6ee8f0399a653ca08279298cf72d7d2d6945ee Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 13 Feb 2025 15:22:52 -0800 Subject: [PATCH 01/13] [V1] Optimize handling of sampling metadata and req_ids list - Move SamplingMetadata to a field in the persistent batch, updated only when the batch changes rather than constructed every step - Keep input_batch.req_ids sized to the number of requests in the batch, so that anywhere that iterates over it doesn't need to slice (copy) the list or keep track of the separate request count. It is still updated in-place Signed-off-by: Nick Hill --- tests/v1/worker/test_gpu_input_batch.py | 3 +- vllm/v1/utils.py | 11 ++ vllm/v1/worker/gpu_input_batch.py | 147 ++++++++++++++---------- vllm/v1/worker/gpu_model_runner.py | 50 +++----- 4 files changed, 109 insertions(+), 102 deletions(-) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 5e70cfb53777..11fda89f45cb 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -201,8 +201,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch.condense(req_indices_to_remove) # Generate the sampling metadata - sampling_metadata = input_batch.make_sampling_metadata( - req_id_output_token_ids, skip_copy=False) + sampling_metadata = input_batch._make_sampling_metadata() # Create expected output. expected_sampling_metadata = _construct_expected_sampling_metadata( diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 5494542c181d..546fe189568f 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -188,3 +188,14 @@ def bind_kv_cache( for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache] + + +def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, + length: int) -> None: + """ + Copy the first length elements of a tensor into another tensor in a + non-blocking manner + + Used to copy pinned CPU tensor data to pre-allocated GPU tensors. + """ + to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index d52b8827d35e..f9af7c35bdeb 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -3,7 +3,7 @@ # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast import numpy as np import torch @@ -12,6 +12,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: @@ -61,7 +62,7 @@ def __init__( self.pin_memory = pin_memory self.vocab_size = vocab_size - self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self._req_ids: List[Optional[str]] = [] self.req_id_to_index: Dict[str, int] = {} # TODO(woosuk): This buffer could be too large if max_model_len is big. @@ -159,10 +160,7 @@ def __init__( self.repetition_penalties_reqs: Set[str] = set() self.min_tokens: List[int] = [0] * max_num_reqs - self.stop_token_ids: List[Set[int]] = [ - set() for _ in range(max_num_reqs) - ] - self.prompt_token_ids: Optional[torch.Tensor] = None + self.stop_token_ids: List[Set[int]] = [set()] * max_num_reqs # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), @@ -183,6 +181,17 @@ def __init__( self.logit_bias: List[Optional[Dict[int, float]]] = [None] * max_num_reqs + self.req_output_token_ids: List[Optional[List[int]]] = [] + + # This is updated each time the batch constituents change. + self.sampling_metadata = self._make_sampling_metadata() + + @property + def req_ids(self) -> List[str]: + # None elements should only be present transiently + # while performing state updates to the batch. + return cast(List[str], self._req_ids) + def add_request( self, request: "CachedRequestState", @@ -193,7 +202,13 @@ def add_request( assert req_index < self.max_num_reqs req_id = request.req_id - self.req_ids[req_index] = req_id + if req_index == len(self._req_ids): + self._req_ids.append(req_id) + self.req_output_token_ids.append(request.output_token_ids) + else: + self._req_ids[req_index] = req_id + self.req_output_token_ids[req_index] = request.output_token_ids + self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. @@ -264,10 +279,13 @@ def add_request( self.request_lora_mapping[req_index] = 0 def remove_request(self, req_id: str) -> Optional[int]: + """This method must always be followed by a call to condense().""" + req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None - self.req_ids[req_index] = None + self._req_ids[req_index] = None + self.req_output_token_ids[req_index] = None self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) @@ -293,7 +311,8 @@ def remove_request(self, req_id: str) -> Optional[int]: return req_index def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs + self._req_ids.clear() + self.req_output_token_ids.clear() self.req_id_to_index.clear() self.greedy_reqs.clear() self.random_reqs.clear() @@ -311,13 +330,15 @@ def clear(self) -> None: self.logit_bias = [None] * self.max_num_reqs def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: + num_reqs = self.num_reqs + if num_reqs == 0: # The batched states are empty. + self.clear() return # NOTE(woosuk): This function assumes that the empty_req_indices # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 + last_req_index = num_reqs + len(empty_req_indices) - 1 while empty_req_indices: # Find the largest non-empty index. while last_req_index in empty_req_indices: @@ -329,10 +350,13 @@ def condense(self, empty_req_indices: List[int]) -> None: break # Swap the states. - req_id = self.req_ids[last_req_index] + req_id = self._req_ids[last_req_index] + output_token_ids = self.req_output_token_ids[last_req_index] assert req_id is not None - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None + self._req_ids[empty_index] = req_id + self._req_ids[last_req_index] = None + self.req_output_token_ids[empty_index] = output_token_ids + self.req_output_token_ids[last_req_index] = None self.req_id_to_index[req_id] = empty_index num_tokens = self.num_tokens[last_req_index] @@ -369,43 +393,39 @@ def condense(self, empty_req_indices: List[int]) -> None: # Decrement last_req_index since it is now empty. last_req_index -= 1 - def make_sampling_metadata( - self, - req_id_output_token_ids: Dict[str, List[int]], - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - if not self.no_penalties: - # Since syncing these tensors is expensive only copy them - # if necessary i.e. if there are requests which require - # penalties to be applied during sampling. - self.frequency_penalties[:self.num_reqs].copy_( - self.frequency_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True, - ) - self.presence_penalties[:self.num_reqs].copy_( - self.presence_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True, - ) - self.repetition_penalties[:self.num_reqs].copy_( - self.repetition_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True, - ) - # The prompt tokens are used only for applying penalties during - # the sampling process. Hence copy these tensors only when - # there are requests which need penalties to be applied. - self.prompt_token_ids = self._make_prompt_token_ids_tensor() - - output_token_ids: List[List[int]] = [] - - for req_id in self.req_ids[:self.num_reqs]: - assert req_id is not None + del self._req_ids[self.num_reqs:] + del self.req_output_token_ids[self.num_reqs:] + + # num_reqs entries should be non-None + assert all(req_id is not None + for req_id in self._req_ids), "req_ids contains None" + + def refresh_sampling_metadata(self): + self.sampling_metadata = self._make_sampling_metadata() + + def _make_sampling_metadata(self) -> SamplingMetadata: + + num_reqs = self.num_reqs + copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs) + copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) + copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) + + if not self.no_penalties: + # Since syncing these tensors is expensive only copy them + # if necessary i.e. if there are requests which require + # penalties to be applied during sampling. + copy_slice(self.frequency_penalties_cpu_tensor, + self.frequency_penalties, num_reqs) + copy_slice(self.presence_penalties_cpu_tensor, + self.presence_penalties, num_reqs) + copy_slice(self.repetition_penalties_cpu_tensor, + self.repetition_penalties, num_reqs) + + # The prompt tokens are used only for applying penalties during + # the sampling process. Hence copy these tensors only when + # there are requests which need penalties to be applied. + prompt_token_ids = self._make_prompt_token_ids_tensor() + # Currently we create a tensor for output_token_ids from scratch # at each step. However, for the penalties computation what we # need is stats about the token ids present in the output. This @@ -413,27 +433,28 @@ def make_sampling_metadata( # from scratch at each step. # TODO - Replace this with incremental update to output token # statistics. - output_token_ids.append(req_id_output_token_ids[req_id]) + else: + prompt_token_ids = None return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], + temperature=self.temperature[:num_reqs], all_greedy=self.all_greedy, all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], + top_p=self.top_p[:num_reqs], + top_k=self.top_k[:num_reqs], no_top_p=self.no_top_p, no_top_k=self.no_top_k, generators=self.generators, max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=self.prompt_token_ids, - frequency_penalties=self.frequency_penalties[:self.num_reqs], - presence_penalties=self.presence_penalties[:self.num_reqs], - repetition_penalties=self.repetition_penalties[:self.num_reqs], - output_token_ids=output_token_ids, - min_tokens=self.min_tokens[:self.num_reqs], - stop_token_ids=self.stop_token_ids[:self.num_reqs], + prompt_token_ids=prompt_token_ids, + frequency_penalties=self.frequency_penalties[:num_reqs], + presence_penalties=self.presence_penalties[:num_reqs], + repetition_penalties=self.repetition_penalties[:num_reqs], + output_token_ids=cast(List[List[int]], self.req_output_token_ids), + min_tokens=self.min_tokens[:num_reqs], + stop_token_ids=self.stop_token_ids[:num_reqs], no_penalties=self.no_penalties, - logit_bias=self.logit_bias[:self.num_reqs], + logit_bias=self.logit_bias[:num_reqs], ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e90b76dcdd9a..1af7d7e13e04 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,7 +2,7 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import numpy as np import torch @@ -31,7 +31,6 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput -from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -209,16 +208,15 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. The updated states are used by the `_prepare_inputs` function to create the input GPU tensors for the model. - Returns: - True if there is a new/resumed/paused/finished request in the batch. - If False, we can skip copying SamplingMetadata to the GPU. + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: @@ -366,7 +364,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: if removed_req_indices: self.input_batch.condense(removed_req_indices) - return batch_changed + if batch_changed: + self.input_batch.refresh_sampling_metadata() def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -382,7 +381,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens_list: List[int] = [] max_num_scheduled_tokens = 0 - for req_id in self.input_batch.req_ids[:num_reqs]: + for req_id in self.input_batch.req_ids: assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens_list.append(num_tokens) @@ -617,8 +616,7 @@ def _compute_cascade_attn_prefix_len( def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 - num_reqs = self.input_batch.num_reqs - for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + for index, req_id in enumerate(self.input_batch.req_ids): assert req_id is not None req = self.requests[req_id] @@ -670,19 +668,6 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr += completion_part_len - def _prepare_sampling( - self, - batch_changed: bool, - ) -> SamplingMetadata: - # Create the sampling metadata. - req_id_output_token_ids: Dict[str, List[int]] = \ - {req_id: req.output_token_ids \ - for req_id, req in self.requests.items()} - - sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, skip_copy=not batch_changed) - return sampling_metadata - def _execute_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: @@ -736,8 +721,7 @@ def _gather_encoder_outputs( scheduler_output: "SchedulerOutput", ) -> List[torch.Tensor]: encoder_outputs: List[torch.Tensor] = [] - num_reqs = self.input_batch.num_reqs - for req_id in self.input_batch.req_ids[:num_reqs]: + for req_id in self.input_batch.req_ids: assert req_id is not None num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] @@ -780,7 +764,7 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> ModelRunnerOutput: - batch_changed = self._update_states(scheduler_output) + self._update_states(scheduler_output) if self.is_multimodal_model: # Run the multimodal encoder if any. @@ -847,7 +831,7 @@ def execute_model( logits = self.model.compute_logits(sample_hidden_states, None) # Sample the next token and get logprobs if needed. - sampling_metadata = self._prepare_sampling(batch_changed) + sampling_metadata = self.input_batch.sampling_metadata sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, @@ -855,10 +839,8 @@ def execute_model( # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. - num_reqs = self.input_batch.num_reqs request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] - for i, req_id in enumerate( # type: ignore[assignment] - self.input_batch.req_ids[:num_reqs]): + for i, req_id in enumerate(self.input_batch.req_ids): assert req_id is not None req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + @@ -878,12 +860,6 @@ def execute_model( # This relies on cuda-specific torch-internal impl details generator.set_offset(generator.get_offset() - 4) - # num_reqs entries should be non-None - assert all( - req_id is not None for req_id in - self.input_batch.req_ids[:num_reqs]), "req_ids contains None" - req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. sampled_token_ids = sampler_output.sampled_token_ids.tolist() @@ -904,7 +880,7 @@ def execute_model( req_state.output_token_ids[-1] = token_id model_runner_output = ModelRunnerOutput( - req_ids=req_ids, + req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=sampled_token_ids, logprobs=logprobs_lists, From 37d1f98cf04439ac393c4c3c803d5e8b155160dc Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 14 Feb 2025 14:58:53 -0800 Subject: [PATCH 02/13] don't mutate "constant" sampling metadata tensors Signed-off-by: Nick Hill --- vllm/model_executor/layers/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index dfe71028c1bc..a9ef973917e1 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -45,7 +45,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, vocab_size, num_seqs) output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat( + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( 1, vocab_size) logits[logits > 0] /= torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)[logits > 0] @@ -53,6 +53,6 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, repetition_penalties, 1.0)[logits <= 0] # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze_(dim=1) * output_mask + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits From 602d3b6d772e92e094a8da4536cf397a109f5ef1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 14 Feb 2025 21:49:12 -0800 Subject: [PATCH 03/13] simplify sampling metadata Signed-off-by: Nick Hill --- tests/v1/sample/test_sampler.py | 26 ++++++-------- tests/v1/worker/test_gpu_input_batch.py | 23 ++++++------ vllm/v1/sample/metadata.py | 11 +++--- vllm/v1/sample/ops/penalties.py | 6 ++-- vllm/v1/sample/ops/topk_topp_sampler.py | 48 +++++++++++-------------- vllm/v1/sample/sampler.py | 12 +++---- vllm/v1/worker/gpu_input_batch.py | 33 ++++++++++------- 7 files changed, 76 insertions(+), 83 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index cfef475d8dee..7b6157ce185e 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -77,12 +77,9 @@ def _create_default_sampling_metadata( temperature=torch.full((batch_size, ), 0.0), all_greedy=True, all_random=False, - top_p=torch.empty(batch_size, ), - top_k=torch.empty(batch_size, ), - no_top_p=True, - no_top_k=True, - min_p=torch.empty(batch_size, ), - no_min_p=True, + top_p=None, + top_k=None, + min_p=None, generators={}, max_num_logprobs=0, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, @@ -92,7 +89,7 @@ def _create_default_sampling_metadata( presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), no_penalties=True, - min_tokens=[], + min_tokens={}, stop_token_ids=[], logit_bias=[None] * batch_size, ) @@ -102,9 +99,9 @@ def _create_default_sampling_metadata( def _generate_min_token_penalties_and_stop_tokens( num_output_tokens: int, batch_size: int, vocab_size: int, batch_indices_for_min_token_penalty: List[int] -) -> Tuple[List[int], List[Set[int]]]: +) -> Tuple[Dict[int, int], List[Set[int]]]: """ - Generates and returns a list of minimum token penalties (`min_tokens`) + Generates and returns a dict of minimum token penalties (`min_tokens`) and a corresponding list of stop token IDs (`stop_token_ids`) for each batch. @@ -114,21 +111,20 @@ def _generate_min_token_penalties_and_stop_tokens( `min_tokens` value is assigned, and the stop token IDs set is empty. """ stop_token_ids: List[Set[int]] = [] - min_tokens: List[int] = [] + min_tokens: Dict[int, int] = {} for index in range(batch_size): if index in batch_indices_for_min_token_penalty: - min_tokens.append( - np.random.randint(num_output_tokens + 1, - 2 * num_output_tokens)) + min_tokens[index] = np.random.randint(num_output_tokens + 1, + 2 * num_output_tokens) stop_token_ids.append( set( np.random.randint(0, vocab_size - 1) for _ in range(np.random.randint(0, vocab_size)))) else: - min_tokens.append(np.random.randint(0, num_output_tokens)) + min_tokens[index] = np.random.randint(0, num_output_tokens) stop_token_ids.append(set()) - return (min_tokens, stop_token_ids) + return min_tokens, stop_token_ids def _create_weighted_output_token_list( diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 11fda89f45cb..5ed9a0f480a4 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple import numpy as np import pytest @@ -91,10 +91,10 @@ def _construct_expected_sampling_metadata( device=device), all_greedy=False, all_random=True, - top_p=torch.tensor(top_p, dtype=torch.float, device=device), - top_k=torch.tensor(top_k, dtype=torch.int, device=device), - no_top_p=all(x == 1.0 for x in top_p), - no_top_k=all(x == 0 for x in top_k), + top_p=None if all(x == 1.0 for x in top_p) else torch.tensor( + top_p, dtype=torch.float, device=device), + top_k=None if all(x == 0 for x in top_k) else torch.tensor( + top_k, dtype=torch.int, device=device), generators={}, max_num_logprobs=0, prompt_token_ids=make_tensor_with_pad( @@ -210,13 +210,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch.req_id_to_index, device=torch.device(device)) + def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: + return (t1 is None + and t2 is None) or (t1 is not None and t2 is not None + and torch.allclose(t1, t2)) + # Assert the actual and expected output. assert torch.allclose(expected_sampling_metadata.temperature, sampling_metadata.temperature) - assert torch.allclose(expected_sampling_metadata.top_p, - sampling_metadata.top_p) - assert torch.allclose(expected_sampling_metadata.top_k, - sampling_metadata.top_k) + assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p) + assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k) assert torch.allclose( expected_sampling_metadata.frequency_penalties, sampling_metadata.frequency_penalties, @@ -238,6 +241,4 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): sampling_metadata.stop_token_ids assert expected_sampling_metadata.no_penalties == \ sampling_metadata.no_penalties - assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p - assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index cfcc54b7e343..68dc81f1df2d 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -13,12 +13,9 @@ class SamplingMetadata: all_greedy: bool all_random: bool - top_p: torch.Tensor - top_k: torch.Tensor - no_top_p: bool - no_top_k: bool - min_p: torch.Tensor - no_min_p: bool + top_p: Optional[torch.Tensor] + top_k: Optional[torch.Tensor] + min_p: Optional[torch.Tensor] generators: Dict[int, torch.Generator] @@ -32,7 +29,7 @@ class SamplingMetadata: repetition_penalties: torch.Tensor output_token_ids: List[List[int]] - min_tokens: List[int] + min_tokens: Dict[int, int] stop_token_ids: List[Set[int]] logit_bias: List[Optional[Dict[int, float]]] diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index ba368b44ab9c..611ac66f1be3 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Set, Tuple +from typing import List, Set, Tuple, Dict import torch @@ -11,13 +11,13 @@ def apply_min_token_penalties(logits: torch.Tensor, output_token_ids: List[List[int]], stop_token_ids: List[Set[int]], - min_tokens: List[int]) -> None: + min_tokens: Dict[int, int]) -> None: """ Applies minimum token penalty by setting the logits of the stop tokens to -inf. """ min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] - for index, min_token in enumerate(min_tokens): + for index, min_token in min_tokens.items(): if len(output_token_ids[index]) < min_token: for stop_token_id in stop_token_ids[index]: min_tokens_logits_to_penalize.append((index, stop_token_id)) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 27431001e3e7..78c88ad8b830 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict +from typing import Dict, Optional import torch import torch.nn as nn @@ -55,13 +55,11 @@ def forward_native( self, logits: torch.Tensor, generators: Dict[int, torch.Generator], - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], ) -> torch.Tensor: """PyTorch-native implementation of top-k and top-p sampling.""" - logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) + logits = apply_top_k_top_p(logits, k, p) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) @@ -69,37 +67,33 @@ def forward_cuda( self, logits: torch.Tensor, generators: Dict[int, torch.Generator], - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], ) -> torch.Tensor: """More optimized implementation for top-k and top-p sampling.""" probs = logits.softmax(dim=-1, dtype=torch.float32) - if no_top_k and no_top_p: + if k is None and p is None: # We prefer `random_sample` over `flashinfer_sample` when sorting is # not needed. This is because `random_sample` does not require # CPU-GPU synchronization while `flashinfer_sample` does. return random_sample(probs, generators) - return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) + return flashinfer_sample(probs, k, p, generators) def apply_top_k_top_p( logits: torch.Tensor, - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], ) -> torch.Tensor: """Apply top-k and top-p masks to the logits. This function sorts the logits tensor, which can be slow for large batches. """ - if no_top_k and no_top_p: + if k is None and p is None: return logits logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - if not no_top_k: + if k is not None: # Apply top-k. top_k_mask = logits_sort.size(1) - k.to(torch.long) # Get all the top_k values. @@ -107,7 +101,7 @@ def apply_top_k_top_p( top_k_mask = logits_sort < top_k_mask logits_sort.masked_fill_(top_k_mask, -float("inf")) - if not no_top_p: + if p is not None: # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) probs_sum = probs_sort.cumsum(dim=-1) @@ -147,10 +141,8 @@ def random_sample( def flashinfer_sample( probs: torch.Tensor, - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], generators: Dict[int, torch.Generator], ) -> torch.Tensor: """Sample from the probabilities using FlashInfer. @@ -167,7 +159,7 @@ def flashinfer_sample( does not. Call this function at the end of the forward pass to minimize the synchronization overhead. """ - assert not (no_top_k and no_top_p) + assert not (k is None and p is None) max_top_k_round = 32 batch_size = probs.shape[0] uniform_samples = torch.empty((max_top_k_round, batch_size), @@ -178,11 +170,11 @@ def flashinfer_sample( for i, generator in generators.items(): uniform_samples[:, i].uniform_(generator=generator) - if no_top_k: + if k is None: # Top-p only. next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( probs, uniform_samples, p, deterministic=True) - elif no_top_p: + elif p is None: # Top-k only. next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( probs, uniform_samples, k, deterministic=True) @@ -194,9 +186,9 @@ def flashinfer_sample( # NOTE: CPU-GPU synchronization happens here. if not success.all(): - if not no_top_k: + if k is not None: probs = flashinfer.sampling.top_k_renorm_prob(probs, k) - if not no_top_p: + if p is not None: probs = flashinfer.sampling.top_p_renorm_prob(probs, p) next_token_ids = flashinfer.sampling.sampling_from_probs( probs, uniform_samples[0], deterministic=True) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index ac32c90d6769..cfd26a89a1af 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -88,13 +88,11 @@ def sample( random_sampled = self.topk_topp_sampler( logits, sampling_metadata.generators, - sampling_metadata.no_top_k, sampling_metadata.top_k, - sampling_metadata.no_top_p, sampling_metadata.top_p, ) - if not sampling_metadata.no_min_p: + if sampling_metadata.min_p is not None: logits = self.apply_min_p(logits, sampling_metadata.min_p) if sampling_metadata.all_random: @@ -160,9 +158,11 @@ def apply_penalties( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - apply_min_token_penalties(logits, sampling_metadata.output_token_ids, - sampling_metadata.stop_token_ids, - sampling_metadata.min_tokens) + if sampling_metadata.min_tokens: + apply_min_token_penalties(logits, + sampling_metadata.output_token_ids, + sampling_metadata.stop_token_ids, + sampling_metadata.min_tokens) if not sampling_metadata.no_penalties: assert sampling_metadata.prompt_token_ids is not None logits = apply_all_penalties( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 83e8ccc79df0..8d01d9499817 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -171,7 +171,8 @@ def __init__( self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() - self.min_tokens: List[int] = [0] * max_num_reqs + # req_index -> min_tokens + self.min_tokens: Dict[int, int] = {} self.stop_token_ids: List[Set[int]] = [set()] * max_num_reqs # lora related @@ -265,7 +266,8 @@ def add_request( req_index] = sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) - self.min_tokens[req_index] = sampling_params.min_tokens + if sampling_params.min_tokens: + self.min_tokens[req_index] = sampling_params.min_tokens self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids # NOTE(woosuk): self.generators should not include the requests that @@ -307,6 +309,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) self.min_p_reqs.discard(req_id) + self.min_tokens.pop(req_index) self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) @@ -338,6 +341,7 @@ def clear(self) -> None: self.frequency_penalties_reqs.clear() self.presence_penalties_reqs.clear() self.repetition_penalties_reqs.clear() + self.min_tokens.clear() self.generators.clear() self.num_logprobs.clear() self.num_prompt_logprobs.clear() @@ -396,13 +400,16 @@ def condense(self, empty_req_indices: List[int]) -> None: self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] - self.min_tokens[empty_index] = self.min_tokens[last_req_index] self.stop_token_ids[empty_index] = self.stop_token_ids[ last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator + min_token = self.min_tokens.pop(last_req_index, 0) + if min_token: + self.min_tokens[empty_index] = min_token + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ last_req_index] @@ -424,9 +431,12 @@ def refresh_sampling_metadata(self): def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs) - copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) - copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) - copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs) + if not self.no_top_p: + copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) + if not self.no_top_k: + copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) + if not self.no_min_p: + copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs) if not self.no_penalties: # Since syncing these tensors is expensive only copy them @@ -458,12 +468,9 @@ def _make_sampling_metadata(self) -> SamplingMetadata: temperature=self.temperature[:num_reqs], all_greedy=self.all_greedy, all_random=self.all_random, - top_p=self.top_p[:num_reqs], - top_k=self.top_k[:num_reqs], - min_p=self.min_p[:num_reqs], - no_min_p=self.no_min_p, - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, + top_p=None if self.no_top_p else self.top_p[:num_reqs], + top_k=None if self.no_top_k else self.top_k[:num_reqs], + min_p=None if self.no_min_p else self.min_p[:num_reqs], generators=self.generators, max_num_logprobs=self.max_num_logprobs, prompt_token_ids=prompt_token_ids, @@ -471,7 +478,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata: presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], output_token_ids=cast(List[List[int]], self.req_output_token_ids), - min_tokens=self.min_tokens[:num_reqs], + min_tokens=self.min_tokens, stop_token_ids=self.stop_token_ids[:num_reqs], no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], From 57cd6113b8c1ef10f8ab414fe9ec94028ade680d Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 14 Feb 2025 23:30:59 -0800 Subject: [PATCH 04/13] group stop_token_ids with min_tokens Signed-off-by: Nick Hill --- tests/v1/sample/test_sampler.py | 18 ++++++++---------- tests/v1/worker/test_gpu_input_batch.py | 8 +------- vllm/v1/sample/metadata.py | 7 ++++--- vllm/v1/sample/ops/penalties.py | 11 +++++------ vllm/v1/sample/sampler.py | 1 - vllm/v1/worker/gpu_input_batch.py | 18 +++++++----------- vllm/v1/worker/tpu_model_runner.py | 2 -- 7 files changed, 25 insertions(+), 40 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index ac011c71ce8e..514b517b1d0f 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -99,7 +99,7 @@ def _create_default_sampling_metadata( def _generate_min_token_penalties_and_stop_tokens( num_output_tokens: int, batch_size: int, vocab_size: int, batch_indices_for_min_token_penalty: List[int] -) -> Tuple[Dict[int, int], List[Set[int]]]: +) -> Dict[int, Tuple[int, Set[int]]]: """ Generates and returns a dict of minimum token penalties (`min_tokens`) and a corresponding list of stop token IDs (`stop_token_ids`) for each @@ -110,21 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens( and a random set of stop token IDs is created. Otherwise, a lower `min_tokens` value is assigned, and the stop token IDs set is empty. """ - stop_token_ids: List[Set[int]] = [] - min_tokens: Dict[int, int] = {} + min_tokens: Dict[int, Tuple[int, Set[int]]] = {} for index in range(batch_size): if index in batch_indices_for_min_token_penalty: - min_tokens[index] = np.random.randint(num_output_tokens + 1, - 2 * num_output_tokens) - stop_token_ids.append( + min_tokens[index] = ( + np.random.randint(num_output_tokens + 1, + 2 * num_output_tokens), set( np.random.randint(0, vocab_size - 1) for _ in range(np.random.randint(0, vocab_size)))) - else: - min_tokens[index] = np.random.randint(0, num_output_tokens) - stop_token_ids.append(set()) - return min_tokens, stop_token_ids + min_tokens[index] = (np.random.randint(0, + num_output_tokens), set()) + return min_tokens def _create_weighted_output_token_list( diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 87c45acc0696..184f5560efd6 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata( top_p = [0.0 for _ in range(num_reqs)] min_p = [0.0 for _ in range(num_reqs)] temperature = [0.0 for _ in range(num_reqs)] - stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)] - min_tokens = [0 for _ in range(num_reqs)] + min_tokens = {} logit_bias = [None] * num_reqs for req in reqs: if req.req_id not in req_ids_retained: @@ -83,8 +82,6 @@ def _construct_expected_sampling_metadata( top_p[index_in_input_batch] = req.sampling_params.top_p min_p[index_in_input_batch] = req.sampling_params.min_p temperature[index_in_input_batch] = req.sampling_params.temperature - stop_token_ids[ - index_in_input_batch] = req.sampling_params.all_stop_token_ids min_tokens[index_in_input_batch] = req.sampling_params.min_tokens logit_bias[index_in_input_batch] = req.sampling_params.logit_bias return SamplingMetadata( @@ -117,7 +114,6 @@ def _construct_expected_sampling_metadata( device=device), output_token_ids=output_token_ids, min_tokens=min_tokens, - stop_token_ids=stop_token_ids, no_penalties=(all(x == 0 for x in presence_penalties) and all(x == 0 for x in frequency_penalties) and all(x == 1 for x in repetition_penalties)), @@ -240,8 +236,6 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: assert (expected_sampling_metadata.output_token_ids == sampling_metadata.output_token_ids) assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens - assert expected_sampling_metadata.stop_token_ids == \ - sampling_metadata.stop_token_ids assert expected_sampling_metadata.no_penalties == \ sampling_metadata.no_penalties assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 68dc81f1df2d..a474ff6584c3 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple import torch @@ -29,7 +29,8 @@ class SamplingMetadata: repetition_penalties: torch.Tensor output_token_ids: List[List[int]] - min_tokens: Dict[int, int] - stop_token_ids: List[Set[int]] + + # req_index -> (min_tokens, stop_token_ids) + min_tokens: Dict[int, Tuple[int, Set[int]]] logit_bias: List[Optional[Dict[int, float]]] diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index df526b7594d8..8d9f6529fa0b 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -8,18 +8,17 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad -def apply_min_token_penalties(logits: torch.Tensor, - output_token_ids: List[List[int]], - stop_token_ids: List[Set[int]], - min_tokens: Dict[int, int]) -> None: +def apply_min_token_penalties( + logits: torch.Tensor, output_token_ids: List[List[int]], + min_tokens: Dict[int, Tuple[int, Set[int]]]) -> None: """ Applies minimum token penalty by setting the logits of the stop tokens to -inf. """ min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] - for index, min_token in min_tokens.items(): + for index, (min_token, stop_token_ids) in min_tokens.items(): if len(output_token_ids[index]) < min_token: - for stop_token_id in stop_token_ids[index]: + for stop_token_id in stop_token_ids: min_tokens_logits_to_penalize.append((index, stop_token_id)) if min_tokens_logits_to_penalize: logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 0a48e0d6e0c0..4489400c638f 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -167,7 +167,6 @@ def apply_penalties( if sampling_metadata.min_tokens: apply_min_token_penalties(logits, sampling_metadata.output_token_ids, - sampling_metadata.stop_token_ids, sampling_metadata.min_tokens) if not sampling_metadata.no_penalties: assert sampling_metadata.prompt_token_ids is not None diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 8d01d9499817..8c6ba7d0b474 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -171,9 +171,8 @@ def __init__( self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() - # req_index -> min_tokens - self.min_tokens: Dict[int, int] = {} - self.stop_token_ids: List[Set[int]] = [set()] * max_num_reqs + # req_index -> (min_tokens, stop_token_ids) + self.min_tokens: Dict[int, Tuple[int, Set[int]]] = {} # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), @@ -267,8 +266,8 @@ def add_request( if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) if sampling_params.min_tokens: - self.min_tokens[req_index] = sampling_params.min_tokens - self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids + self.min_tokens[req_index] = (sampling_params.min_tokens, + sampling_params.all_stop_token_ids) # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. @@ -309,7 +308,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) self.min_p_reqs.discard(req_id) - self.min_tokens.pop(req_index) + self.min_tokens.pop(req_index, None) self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) @@ -400,14 +399,12 @@ def condense(self, empty_req_indices: List[int]) -> None: self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] - self.stop_token_ids[empty_index] = self.stop_token_ids[ - last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator - min_token = self.min_tokens.pop(last_req_index, 0) - if min_token: + min_token = self.min_tokens.pop(last_req_index, None) + if min_token is not None: self.min_tokens[empty_index] = min_token self.request_lora_mapping[empty_index] = self.request_lora_mapping[ @@ -479,7 +476,6 @@ def _make_sampling_metadata(self) -> SamplingMetadata: repetition_penalties=self.repetition_penalties[:num_reqs], output_token_ids=cast(List[List[int]], self.req_output_token_ids), min_tokens=self.min_tokens, - stop_token_ids=self.stop_token_ids[:num_reqs], no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b64581bf5f42..3ab69a7c8dd0 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1047,8 +1047,6 @@ def swap_positions(b: InputBatch, id_1, id_2): b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ id_1] - b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[ - id_2], b.stop_token_ids[id_1] gen_1 = b.generators.pop(id_1, None) gen_2 = b.generators.pop(id_2, None) From c7e2bfd33c226b8b54d218a78cd7ca0fc5f5f786 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 15 Feb 2025 16:28:50 -0800 Subject: [PATCH 05/13] test updates Signed-off-by: Nick Hill --- tests/v1/sample/test_sampler.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 514b517b1d0f..35bb28ccc40a 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -90,7 +90,6 @@ def _create_default_sampling_metadata( repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), no_penalties=True, min_tokens={}, - stop_token_ids=[], logit_bias=[None] * batch_size, ) return fake_sampling_metadata @@ -101,8 +100,8 @@ def _generate_min_token_penalties_and_stop_tokens( batch_indices_for_min_token_penalty: List[int] ) -> Dict[int, Tuple[int, Set[int]]]: """ - Generates and returns a dict of minimum token penalties (`min_tokens`) - and a corresponding list of stop token IDs (`stop_token_ids`) for each + Generates and returns a dict of minimum token penalties and + corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each batch. If a batch index is included in `batch_indices_for_min_token_penalty`, @@ -157,7 +156,7 @@ def _create_weighted_output_token_list( output_token_ids_for_batch.extend( [token_id for _ in range(index + 1)]) output_token_ids.append(output_token_ids_for_batch) - return (output_token_ids, sorted_token_ids_in_output) + return output_token_ids, sorted_token_ids_in_output @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -174,17 +173,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int): NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) batch_indices_for_min_token_penalty = np.random.randint( 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() - min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens( + min_tokens = _generate_min_token_penalties_and_stop_tokens( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, batch_indices_for_min_token_penalty) sampling_metadata.min_tokens = min_tokens - sampling_metadata.stop_token_ids = stop_token_ids sampler = Sampler() logits = sampler.apply_penalties(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): for token_id in range(VOCAB_SIZE): - if token_id in stop_token_ids[batch_idx]: + _, stop_token_ids = min_tokens.get(batch_idx, (0, set())) + if token_id in stop_token_ids: assert logits[batch_idx][token_id] == -float("inf") else: assert logits[batch_idx][token_id] != -float("inf") From d246ce52cc30f6b50c69b2a61c64dba30cc75b87 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 17 Feb 2025 17:31:48 -0800 Subject: [PATCH 06/13] Some more small list/tuple optimizations; fix linting Signed-off-by: Nick Hill --- vllm/v1/core/scheduler.py | 12 ++++++------ vllm/v1/core/scheduler_output.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 13 ++++++++----- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 599c9f1f3419..2d7678ff1ca1 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -118,7 +118,7 @@ def schedule(self) -> "SchedulerOutput": num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. - scheduled_encoder_inputs: Dict[str, List[int]] = {} + scheduled_encoder_inputs: Dict[str, Sequence[int]] = {} encoder_budget = self.max_num_encoder_input_tokens # Spec decode-related. scheduled_spec_decode_tokens: Dict[str, Sequence[int]] = {} @@ -199,8 +199,8 @@ def schedule(self) -> "SchedulerOutput": if isinstance(request.spec_token_ids, list): del request.spec_token_ids[num_scheduled_spec_tokens:] else: - request.spec_token_ids = request.spec_token_ids[ - :num_scheduled_spec_tokens] + request.spec_token_ids = ( + request.spec_token_ids[:num_scheduled_spec_tokens]) scheduled_spec_decode_tokens[request.request_id] = ( request.spec_token_ids) @@ -410,7 +410,7 @@ def _try_schedule_encoder_inputs( num_computed_tokens: int, num_new_tokens: int, encoder_budget: int, - ) -> Tuple[List[int], int, int]: + ) -> Tuple[Sequence[int], int, int]: """ Determine which encoder inputs need to be scheduled in the current step, and update `num_new_tokens` and encoder token budget accordingly. @@ -428,7 +428,7 @@ def _try_schedule_encoder_inputs( decoder tokens up to just before the unschedulable encoder input. """ if not request.has_encoder_inputs(): - return [], num_new_tokens, encoder_budget + return (), num_new_tokens, encoder_budget encoder_inputs_to_schedule: List[int] = [] mm_positions = request.mm_positions @@ -573,7 +573,7 @@ def update_from_output( outputs.append( EngineCoreOutput( request_id=req_id, - new_token_ids=new_token_ids or [], + new_token_ids=new_token_ids, finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index b136e1879a34..ec8dcd6e1c3c 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -99,7 +99,7 @@ class SchedulerOutput: # req_id -> encoder input indices that need processing. # E.g., if a request has [0, 1], it could mean the vision encoder needs # to process that the request's 0-th and 1-th images in the current step. - scheduled_encoder_inputs: Dict[str, List[int]] + scheduled_encoder_inputs: Dict[str, Sequence[int]] # Number of common prefix blocks for all requests. # This can be used for cascade attention. num_common_prefix_blocks: int diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a669698c5787..9690e9201f45 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -342,9 +342,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_new_tokens = (num_computed_tokens + len(req_data.new_token_ids) - req_state.num_tokens) - new_token_ids = (req_data.new_token_ids[-num_new_tokens:] - if num_new_tokens > 0 else []) - req_state.output_token_ids.extend(new_token_ids) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(req_data.new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + req_data.new_token_ids[-num_new_tokens:]) # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. @@ -378,7 +381,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, []) + req_id, ()) if spec_token_ids: start_index = end_token_index end_token_index += len(spec_token_ids) @@ -1176,7 +1179,7 @@ def profile_run(self) -> None: # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. dummy_kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) + torch.tensor((), dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] From 5e216c74777bde62123893300094f533e105264f Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 17 Feb 2025 17:44:59 -0800 Subject: [PATCH 07/13] Small adjustment Signed-off-by: Nick Hill --- vllm/v1/worker/gpu_input_batch.py | 7 +++++-- vllm/v1/worker/gpu_model_runner.py | 10 +++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 7e6f5d23e15f..18b6304c65da 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -480,13 +480,16 @@ def _make_sampling_metadata(self) -> SamplingMetadata: logit_bias=self.logit_bias[:num_reqs], ) - def set_spec_token_ids_in_sampling_metadata( - self, req_id_to_spec_token_ids: Dict[str, Sequence[int]]): + def get_sampling_metadata( + self, + req_id_to_spec_token_ids: Dict[str, + Sequence[int]]) -> SamplingMetadata: self.sampling_metadata.spec_token_ids.clear() if req_id_to_spec_token_ids: for req_id in self.req_ids: spec_token_ids = req_id_to_spec_token_ids.get(req_id, ()) self.sampling_metadata.spec_token_ids.append(spec_token_ids) + return self.sampling_metadata def _make_prompt_token_ids_tensor(self) -> torch.Tensor: max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9690e9201f45..01880c167efd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -724,7 +724,7 @@ def _calc_spec_decode_metadata( self, scheduler_output: "SchedulerOutput", cu_num_tokens: np.ndarray, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: # Get the number of spec decode tokens for each request. num_reqs = self.input_batch.num_reqs num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32) @@ -942,13 +942,9 @@ def execute_model( sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) - # Update the sampling metadata with any spec decode tokens from the - # scheduler. - self.input_batch.set_spec_token_ids_in_sampling_metadata( - scheduler_output.scheduled_spec_decode_tokens) - # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.sampling_metadata + sampling_metadata = self.input_batch.get_sampling_metadata( + scheduler_output.scheduled_spec_decode_tokens) sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, From b2a43baac11e8c0f97ad04637a4bdb7ac3dae7b3 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 17 Feb 2025 18:58:21 -0800 Subject: [PATCH 08/13] Fix rejection sampler test Signed-off-by: Nick Hill --- tests/v1/sample/test_rejection_sampler.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 20a571da2cdf..3e810e525e1c 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -26,16 +26,13 @@ def create_logits_tensor(token_ids: List[int], def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: batch_size = len(spec_tokens) return SamplingMetadata( - temperature=0.0, + temperature=torch.tensor([]), all_greedy=True, all_random=False, spec_token_ids=spec_tokens, top_p=None, top_k=None, - no_top_p=False, - no_top_k=False, min_p=torch.empty(batch_size, ), - no_min_p=True, generators={}, max_num_logprobs=0, no_penalties=False, @@ -44,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: presence_penalties=torch.tensor([]), repetition_penalties=torch.tensor([]), output_token_ids=[], - min_tokens=[], - stop_token_ids=[], + min_tokens={}, logit_bias=[None] * batch_size, ) From 2fbc6e1b801aa58a06a1051dfde7deae6711ac1f Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 17 Feb 2025 19:19:37 -0800 Subject: [PATCH 09/13] Revert change related to list vs tuple Signed-off-by: Nick Hill --- vllm/v1/core/scheduler.py | 7 +------ vllm/v1/worker/gpu_input_batch.py | 2 ++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 2d7678ff1ca1..03a0e5e3ba9e 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -196,13 +196,8 @@ def schedule(self) -> "SchedulerOutput": request.num_computed_tokens - request.num_tokens) if num_scheduled_spec_tokens > 0: - if isinstance(request.spec_token_ids, list): - del request.spec_token_ids[num_scheduled_spec_tokens:] - else: - request.spec_token_ids = ( - request.spec_token_ids[:num_scheduled_spec_tokens]) scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) + request.spec_token_ids[:num_scheduled_spec_tokens]) # Encoder-related. if encoder_inputs_to_schedule: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 18b6304c65da..7d97a4f261d7 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -422,6 +422,7 @@ def condense(self, empty_req_indices: List[int]) -> None: # Decrement last_req_index since it is now empty. last_req_index -= 1 + # Trim lists to the batch size. del self._req_ids[self.num_reqs:] del self.req_output_token_ids[self.num_reqs:] @@ -486,6 +487,7 @@ def get_sampling_metadata( Sequence[int]]) -> SamplingMetadata: self.sampling_metadata.spec_token_ids.clear() if req_id_to_spec_token_ids: + # Set the new spec token ids in the cached sampling metadata. for req_id in self.req_ids: spec_token_ids = req_id_to_spec_token_ids.get(req_id, ()) self.sampling_metadata.spec_token_ids.append(spec_token_ids) From 1b68e03bf62a135d17f982d8d16262cd3f396a37 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 17 Feb 2025 19:31:35 -0800 Subject: [PATCH 10/13] Revert List->Sequence changes Signed-off-by: Nick Hill --- vllm/v1/core/scheduler.py | 14 +++++++------- vllm/v1/core/scheduler_output.py | 6 +++--- vllm/v1/outputs.py | 4 ++-- vllm/v1/request.py | 4 ++-- vllm/v1/sample/metadata.py | 4 ++-- vllm/v1/utils.py | 2 +- vllm/v1/worker/gpu_input_batch.py | 11 +++++------ vllm/v1/worker/gpu_model_runner.py | 10 +++++----- 8 files changed, 27 insertions(+), 28 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 03a0e5e3ba9e..d5627164e967 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -2,8 +2,7 @@ import time from collections import deque -from typing import (Deque, Dict, Iterable, List, Optional, Sequence, Set, - Tuple, Union) +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, SpeculativeConfig) @@ -118,10 +117,10 @@ def schedule(self) -> "SchedulerOutput": num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. - scheduled_encoder_inputs: Dict[str, Sequence[int]] = {} + scheduled_encoder_inputs: Dict[str, List[int]] = {} encoder_budget = self.max_num_encoder_input_tokens # Spec decode-related. - scheduled_spec_decode_tokens: Dict[str, Sequence[int]] = {} + scheduled_spec_decode_tokens: Dict[str, List[int]] = {} # For logging. scheduled_timestamp = time.monotonic() @@ -196,8 +195,9 @@ def schedule(self) -> "SchedulerOutput": request.num_computed_tokens - request.num_tokens) if num_scheduled_spec_tokens > 0: + del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids[:num_scheduled_spec_tokens]) + request.spec_token_ids) # Encoder-related. if encoder_inputs_to_schedule: @@ -405,7 +405,7 @@ def _try_schedule_encoder_inputs( num_computed_tokens: int, num_new_tokens: int, encoder_budget: int, - ) -> Tuple[Sequence[int], int, int]: + ) -> Tuple[List[int], int, int]: """ Determine which encoder inputs need to be scheduled in the current step, and update `num_new_tokens` and encoder token budget accordingly. @@ -423,7 +423,7 @@ def _try_schedule_encoder_inputs( decoder tokens up to just before the unschedulable encoder input. """ if not request.has_encoder_inputs(): - return (), num_new_tokens, encoder_budget + return [], num_new_tokens, encoder_budget encoder_inputs_to_schedule: List[int] = [] mm_positions = request.mm_positions diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index ec8dcd6e1c3c..47413527c32f 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple if TYPE_CHECKING: from vllm.lora.request import LoRARequest @@ -95,11 +95,11 @@ class SchedulerOutput: # req_id -> spec_token_ids # If a request does not have any spec decode tokens, it will not be # included in the dictionary. - scheduled_spec_decode_tokens: Dict[str, Sequence[int]] + scheduled_spec_decode_tokens: Dict[str, List[int]] # req_id -> encoder input indices that need processing. # E.g., if a request has [0, 1], it could mean the vision encoder needs # to process that the request's 0-th and 1-th images in the current step. - scheduled_encoder_inputs: Dict[str, Sequence[int]] + scheduled_encoder_inputs: Dict[str, List[int]] # Number of common prefix blocks for all requests. # This can be used for cascade attention. num_common_prefix_blocks: int diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 5fa8c204c3b5..0c8eca38ade7 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, NamedTuple, Optional, Sequence +from typing import Dict, List, NamedTuple, Optional import torch @@ -68,7 +68,7 @@ class ModelRunnerOutput: sampled_token_ids: List[List[int]] # num_reqs x num_spec_tokens - spec_token_ids: Optional[List[Sequence[int]]] + spec_token_ids: Optional[List[List[int]]] # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 00787e6d2b12..52d7faeeb066 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import enum -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, List, Optional, Union from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams @@ -46,7 +46,7 @@ def __init__( self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: List[int] = [] self._all_token_ids: List[int] = self.prompt_token_ids.copy() - self.spec_token_ids: Sequence[int] = [] + self.spec_token_ids: List[int] = [] self.num_computed_tokens = 0 # Multi-modal related diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 4cb80b5942aa..dac6cb64aa71 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple import torch @@ -14,7 +14,7 @@ class SamplingMetadata: all_random: bool # The list will empty if no requests have spec tokens. - spec_token_ids: List[Sequence[int]] + spec_token_ids: List[List[int]] top_p: Optional[torch.Tensor] top_k: Optional[torch.Tensor] diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 546fe189568f..5be465014242 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -194,7 +194,7 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int) -> None: """ Copy the first length elements of a tensor into another tensor in a - non-blocking manner + non-blocking manner. Used to copy pinned CPU tensor data to pre-allocated GPU tensors. """ diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 7d97a4f261d7..0900cd632332 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -2,8 +2,7 @@ # Datastructures defining an input batch from dataclasses import dataclass -from typing import (TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, - cast) +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast import numpy as np import torch @@ -482,14 +481,14 @@ def _make_sampling_metadata(self) -> SamplingMetadata: ) def get_sampling_metadata( - self, - req_id_to_spec_token_ids: Dict[str, - Sequence[int]]) -> SamplingMetadata: + self, + req_id_to_spec_token_ids: Dict[str, + List[int]]) -> SamplingMetadata: self.sampling_metadata.spec_token_ids.clear() if req_id_to_spec_token_ids: # Set the new spec token ids in the cached sampling metadata. for req_id in self.req_ids: - spec_token_ids = req_id_to_spec_token_ids.get(req_id, ()) + spec_token_ids = req_id_to_spec_token_ids.get(req_id, []) self.sampling_metadata.spec_token_ids.append(spec_token_ids) return self.sampling_metadata diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 01880c167efd..0ecc00acc790 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,7 +2,7 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -1011,14 +1011,14 @@ def execute_model( def generate_draft_token_ids( self, sampled_token_ids: List[List[int]], - ) -> List[Sequence[int]]: + ) -> List[List[int]]: # TODO(woosuk): Optimize. - draft_token_ids: List[Sequence[int]] = [] + draft_token_ids: List[List[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: # Skip speculative decoding. - draft_token_ids.append(()) + draft_token_ids.append([]) continue # Add sampled_token_ids to token_ids_cpu. @@ -1031,7 +1031,7 @@ def generate_draft_token_ids( self.speculative_config.num_speculative_tokens, ) if drafter_output is None or len(drafter_output) == 0: - draft_token_ids.append(()) + draft_token_ids.append([]) else: draft_token_ids.append(drafter_output.tolist()) return draft_token_ids From 28a17aeb17e9fca7ecb922dfef27e966681c9458 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 18 Feb 2025 07:51:13 -0800 Subject: [PATCH 11/13] Address review comments Signed-off-by: Nick Hill --- tests/v1/sample/test_sampler.py | 2 +- tests/v1/worker/test_gpu_input_batch.py | 2 +- tests/v1/worker/test_gpu_model_runner.py | 6 ++-- vllm/v1/sample/metadata.py | 4 +-- vllm/v1/sample/rejection_sampler.py | 2 ++ vllm/v1/worker/gpu_input_batch.py | 46 ++++++------------------ 6 files changed, 18 insertions(+), 44 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 0462566a00c2..3f6301c54267 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -85,7 +85,7 @@ def _create_default_sampling_metadata( prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, vocab_size, device), output_token_ids=output_token_ids, - spec_token_ids=[], + spec_token_ids=None, frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 836c35ca1b4f..bf17f9224155 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -115,7 +115,7 @@ def _construct_expected_sampling_metadata( dtype=torch.float, device=device), output_token_ids=output_token_ids, - spec_token_ids=[], + spec_token_ids=None, min_tokens=min_tokens, no_penalties=(all(x == 0 for x in presence_penalties) and all(x == 0 for x in frequency_penalties) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index c655b0fded6e..a3c339eba07a 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -88,8 +88,7 @@ def test_update_states_new_request(model_runner): # new req scheduler_output = _schedule_new_request(req_id) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + model_runner._update_states(scheduler_output) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) @@ -117,8 +116,7 @@ def test_update_states_request_finished(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + model_runner._update_states(scheduler_output) assert not _is_req_added(model_runner, req_id) assert not _is_req_scheduled(model_runner, req_id) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index dac6cb64aa71..2184a1866ff5 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -13,8 +13,8 @@ class SamplingMetadata: all_greedy: bool all_random: bool - # The list will empty if no requests have spec tokens. - spec_token_ids: List[List[int]] + # None when there are no speculated tokens. + spec_token_ids: Optional[List[List[int]]] top_p: Optional[torch.Tensor] top_k: Optional[torch.Tensor] diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index df1da8930211..580ad44297aa 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -68,6 +68,7 @@ def flashinfer_sample( # NOTE: The following input preparationg can be moved # to the model runner with a persistent manner for better # performance. + assert sampling_metadata.spec_token_ids is not None spec_token_ids = sampling_metadata.spec_token_ids max_spec_len = max(len(s) for s in spec_token_ids) batch_size = len(spec_token_ids) @@ -119,6 +120,7 @@ def forward_native( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: + assert sampling_metadata.spec_token_ids is not None spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] # Add 1 to include the 'bonus' token. sample_lens = [x + 1 for x in spec_lens] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 0900cd632332..ccafc325b53f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -332,32 +332,12 @@ def remove_request(self, req_id: str) -> Optional[int]: self.logit_bias[req_index] = None return req_index - def clear(self) -> None: - self._req_ids.clear() - self.req_output_token_ids.clear() - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.min_p_reqs.clear() - self.frequency_penalties_reqs.clear() - self.presence_penalties_reqs.clear() - self.repetition_penalties_reqs.clear() - self.min_tokens.clear() - self.generators.clear() - self.num_logprobs.clear() - self.num_prompt_logprobs.clear() - self.request_lora_mapping.fill(0) - self.lora_id_to_lora_request.clear() - self.lora_id_to_request_ids.clear() - self.logit_bias = [None] * self.max_num_reqs - def condense(self, empty_req_indices: List[int]) -> None: num_reqs = self.num_reqs if num_reqs == 0: # The batched states are empty. - self.clear() + self._req_ids.clear() + self.req_output_token_ids.clear() return # NOTE(woosuk): This function assumes that the empty_req_indices @@ -425,10 +405,6 @@ def condense(self, empty_req_indices: List[int]) -> None: del self._req_ids[self.num_reqs:] del self.req_output_token_ids[self.num_reqs:] - # num_reqs entries should be non-None - assert all(req_id is not None - for req_id in self._req_ids), "req_ids contains None" - def refresh_sampling_metadata(self): self.sampling_metadata = self._make_sampling_metadata() @@ -474,22 +450,20 @@ def _make_sampling_metadata(self) -> SamplingMetadata: presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], output_token_ids=cast(List[List[int]], self.req_output_token_ids), - spec_token_ids=[], + spec_token_ids=None, min_tokens=self.min_tokens, no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], ) def get_sampling_metadata( - self, - req_id_to_spec_token_ids: Dict[str, - List[int]]) -> SamplingMetadata: - self.sampling_metadata.spec_token_ids.clear() - if req_id_to_spec_token_ids: - # Set the new spec token ids in the cached sampling metadata. - for req_id in self.req_ids: - spec_token_ids = req_id_to_spec_token_ids.get(req_id, []) - self.sampling_metadata.spec_token_ids.append(spec_token_ids) + self, + req_id_to_spec_token_ids: Dict[str, List[int]], + ) -> SamplingMetadata: + # Set the new spec token ids in the cached sampling metadata. + self.sampling_metadata.spec_token_ids = [ + req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids + ] if req_id_to_spec_token_ids else None return self.sampling_metadata def _make_prompt_token_ids_tensor(self) -> torch.Tensor: From 9250721aca0c728a16355fe0461138b113f1bc9d Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 18 Feb 2025 08:41:03 -0800 Subject: [PATCH 12/13] Fix up gpu_model_runner tests Signed-off-by: Nick Hill --- tests/v1/worker/test_gpu_input_batch.py | 2 +- tests/v1/worker/test_gpu_model_runner.py | 27 ++++++++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index bf17f9224155..cb3b3d21fbb3 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -41,7 +41,7 @@ def _remove_requests( for index in req_indices_to_remove: input_batch.remove_request(reqs[index].req_id) req_ids_to_remove.add(reqs[index].req_id) - return (req_ids_to_remove, req_indices_to_remove_list) + return req_ids_to_remove, req_indices_to_remove_list def _construct_expected_sampling_metadata( diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index a3c339eba07a..973efcbf8e50 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -5,6 +5,7 @@ from vllm.sampling_params import SamplingParams from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, SchedulerOutput) +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -82,13 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool: return req_id in model_runner.requests +def _is_sampling_metadata_changed(model_runner, + sampling_metadata_before: SamplingMetadata): + return model_runner.input_batch.sampling_metadata is not ( + sampling_metadata_before) + + def test_update_states_new_request(model_runner): req_id = "req_0" # new req scheduler_output = _schedule_new_request(req_id) + metadata_before = model_runner.input_batch.sampling_metadata model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) @@ -116,7 +125,9 @@ def test_update_states_request_finished(model_runner): free_encoder_input_ids=[], ) + metadata_before = model_runner.input_batch.sampling_metadata model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert not _is_req_added(model_runner, req_id) assert not _is_req_scheduled(model_runner, req_id) @@ -140,7 +151,7 @@ def test_update_states_request_resumed(model_runner): scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, - finished_req_ids={}, + finished_req_ids=set(), free_encoder_input_ids=[], ) @@ -169,8 +180,9 @@ def test_update_states_request_resumed(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) @@ -198,8 +210,9 @@ def test_update_states_no_changes(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is False + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert not _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) @@ -231,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + metadata_before = model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_ids[0]) assert _is_req_scheduled(model_runner, req_ids[0]) From ce3c3f429ebca1bf4a3977011ff8dd41922beadf Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 18 Feb 2025 09:44:50 -0800 Subject: [PATCH 13/13] Add comment Signed-off-by: Nick Hill --- vllm/v1/core/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index d5627164e967..535aa644c53c 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -195,6 +195,7 @@ def schedule(self) -> "SchedulerOutput": request.num_computed_tokens - request.num_tokens) if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( request.spec_token_ids)