From 488240d3ca7ea29dc78360f0819bbe6172137d31 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 27 Aug 2024 15:58:23 -0600 Subject: [PATCH 1/9] fix: use INVALID_TOKEN_ID instead of uninitialized tensor data Signed-off-by: Travis Johnson --- vllm/model_executor/layers/sampler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 487f5a3d2a44..c7a2c9343cbf 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -18,6 +18,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics +from vllm.transformers_utils.detokenizer import INVALID_TOKEN_ID if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -762,10 +763,10 @@ def _sample_with_torch( # Create output tensor for sampled token ids. if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.empty(logprobs.shape[0], - 1, - dtype=torch.long, - device=logprobs.device) + sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), + INVALID_TOKEN_ID, + dtype=torch.long, + device=logprobs.device) else: sampled_token_ids_tensor = None From fb58cd4103f62d73d79598c701a8509c98f9ffd2 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 27 Aug 2024 15:57:31 -0600 Subject: [PATCH 2/9] refactor: use INVALID_TOKEN_ID instead of magic number Signed-off-by: Travis Johnson --- vllm/engine/output_processor/multi_step.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index c73db765fc3b..9757f3eb5e07 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -11,7 +11,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) -from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.detokenizer import Detokenizer, INVALID_TOKEN_ID from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -110,7 +110,7 @@ def process_outputs(self, # we can take the first sample. samples = [output.samples[0] for output in outputs] - # -1 means the output token is not valid (eg. due to spec decode + # entries in sample tokens may be invalid (eg. due to spec decode # rejecting tokens). valid_samples = [ sample for sample in samples if sample.output_token != -1 From 9bcf56b3a56d7ba3786fb1506643d9794e600f7d Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 27 Aug 2024 15:59:53 -0600 Subject: [PATCH 3/9] fix: skip entries for prompt tokens in spec decoding Signed-off-by: Travis Johnson --- vllm/spec_decode/spec_decode_worker.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 91f0a98c7bc3..8ff37a59b8ec 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -32,6 +32,7 @@ get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) +from vllm.transformers_utils.detokenizer import INVALID_TOKEN_ID from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase @@ -455,7 +456,12 @@ def _serialize_sampler_output_no_logprobs( IDs populated. """ seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list) - sampled_token_ids_list = sampler_output.sampled_token_ids.tolist() + # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID + sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( + sampler_output.sampled_token_ids - INVALID_TOKEN_ID)[0]] \ + if any(seq.is_prompt + for seq in execute_model_req.seq_group_metadata_list) else \ + sampler_output.sampled_token_ids).tolist() completion_seq_group_output_list: List[ CompletionSequenceGroupOutput] = [] for index, seq_id in enumerate(seq_ids): @@ -487,6 +493,11 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, # Store hidden states from target model execution. hidden_states = sampler_output.hidden_states if hidden_states is not None: + # remove hidden_states for prompt tokens + if any(seq.is_prompt + for seq in execute_model_req.seq_group_metadata_list): + hidden_states = hidden_states[torch.where( + sampler_output.sampled_token_ids - INVALID_TOKEN_ID)[0]] if self.previous_hidden_states is None: self.previous_hidden_states = HiddenStates( hidden_states, execute_model_req.seq_group_metadata_list) From 526f14c977e24c4f6a2c0d97b62bdfd1913dc624 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Fri, 30 Aug 2024 10:48:57 -0600 Subject: [PATCH 4/9] refactor: move INVALID_TOKEN_ID to sequence.py and rename Signed-off-by: Travis Johnson --- vllm/engine/output_processor/multi_step.py | 9 +++++---- vllm/model_executor/layers/sampler.py | 6 +++--- vllm/sequence.py | 2 ++ vllm/spec_decode/batch_expansion.py | 10 +++++----- vllm/spec_decode/spec_decode_worker.py | 11 ++++++----- vllm/transformers_utils/detokenizer.py | 10 ++++------ 6 files changed, 25 insertions(+), 23 deletions(-) diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 9757f3eb5e07..31c2bbc8e712 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -9,9 +9,9 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, - SequenceOutput, SequenceStatus) -from vllm.transformers_utils.detokenizer import Detokenizer, INVALID_TOKEN_ID +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -113,7 +113,8 @@ def process_outputs(self, # entries in sample tokens may be invalid (eg. due to spec decode # rejecting tokens). valid_samples = [ - sample for sample in samples if sample.output_token != -1 + sample for sample in samples + if sample.output_token != VLLM_INVALID_TOKEN_ID ] assert valid_samples diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c7a2c9343cbf..31a26ef8103c 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -15,10 +15,10 @@ SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -from vllm.transformers_utils.detokenizer import INVALID_TOKEN_ID if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -764,7 +764,7 @@ def _sample_with_torch( # Create output tensor for sampled token ids. if include_gpu_probs_tensor: sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), - INVALID_TOKEN_ID, + VLLM_INVALID_TOKEN_ID, dtype=torch.long, device=logprobs.device) else: diff --git a/vllm/sequence.py b/vllm/sequence.py index 07ceccf12354..8b161c2e144a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -25,6 +25,8 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" +VLLM_INVALID_TOKEN_ID = -1 + # We use dataclass for now because it is used for # openai server output, and msgspec is not serializable. diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index b2204e8b27af..9eb8bbfc5407 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,9 +6,9 @@ from vllm import SamplingParams from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, - SequenceData, SequenceGroupMetadata, - get_all_seq_ids) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, + ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len @@ -69,10 +69,10 @@ def score_proposals( proposal_lens_list = proposals.proposal_lens.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist() - # Filter the list to ignore -1 proposals. + # Filter the list to ignore invalid proposals. proposal_token_ids_list_without_skips = [ proposals for proposals in proposal_token_ids_list - if -1 not in proposals + if VLLM_INVALID_TOKEN_ID not in proposals ] (spec_indices, non_spec_indices, target_seq_group_metadata_list, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 8ff37a59b8ec..21a25fafc4b6 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -13,7 +13,8 @@ SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) -from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, get_all_seq_ids, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer @@ -32,7 +33,6 @@ get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) -from vllm.transformers_utils.detokenizer import INVALID_TOKEN_ID from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase @@ -458,7 +458,7 @@ def _serialize_sampler_output_no_logprobs( seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list) # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( - sampler_output.sampled_token_ids - INVALID_TOKEN_ID)[0]] \ + sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \ if any(seq.is_prompt for seq in execute_model_req.seq_group_metadata_list) else \ sampler_output.sampled_token_ids).tolist() @@ -496,8 +496,9 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, # remove hidden_states for prompt tokens if any(seq.is_prompt for seq in execute_model_req.seq_group_metadata_list): - hidden_states = hidden_states[torch.where( - sampler_output.sampled_token_ids - INVALID_TOKEN_ID)[0]] + hidden_states = hidden_states[ + torch.where(sampler_output.sampled_token_ids - + VLLM_INVALID_TOKEN_ID)[0]] if self.previous_hidden_states is None: self.previous_hidden_states = HiddenStates( hidden_states, execute_model_req.seq_group_metadata_list) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index d27d7ba9e67b..111158f032c7 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,13 +1,11 @@ from typing import Dict, List, Optional, Tuple -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, + Sequence, SequenceGroup) from .tokenizer import AnyTokenizer from .tokenizer_group import BaseTokenizerGroup -# Used eg. for marking rejected tokens in spec decoding. -INVALID_TOKEN_ID = -1 - class Detokenizer: """Provides methods to decode the output of a model into text.""" @@ -61,7 +59,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, continue for token_id, sample_logprob in prompt_logprobs_for_token.items(): if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): + and token_id != VLLM_INVALID_TOKEN_ID): prompt_token_ids_with_token = ( prompt_token_ids[:token_position] + [token_id]) (new_tokens, new_text, new_prefix_offset, @@ -143,7 +141,7 @@ def decode_sequence_inplace(self, seq: Sequence, continue if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): + and token_id != VLLM_INVALID_TOKEN_ID): all_input_ids_with_logprob = previous_tokens + [token_id] (_, new_text, _, _) = detokenize_incrementally( tokenizer=tokenizer, From 041e5fafd30fea8df332264c9f78b73b1efadb66 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Thu, 5 Sep 2024 09:49:57 -0600 Subject: [PATCH 5/9] fix: detokenize negative id to empty string Signed-off-by: Travis Johnson --- vllm/transformers_utils/detokenizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 111158f032c7..9b9248936422 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -280,14 +280,14 @@ def detokenize_incrementally( assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. - if new_token_id >= len(tokenizer): - new_tokens = [""] - else: + if 0 < new_token_id < len(tokenizer): # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) if isinstance(new_tokens, str): new_tokens = [new_tokens] + else: + new_tokens = [""] output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens. From ad367ff63d020d3bbc79329a65dfdc9bc131cb6b Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Thu, 5 Sep 2024 11:40:26 -0600 Subject: [PATCH 6/9] feat: support prompt_logprobs output with spec decoding Signed-off-by: Travis Johnson --- vllm/sequence.py | 11 ++++++ vllm/spec_decode/spec_decode_worker.py | 48 ++++++++++++++++++++------ vllm/spec_decode/util.py | 45 +++++++++++++++++++----- 3 files changed, 85 insertions(+), 19 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 8b161c2e144a..31c4cd26f8c7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1127,6 +1127,17 @@ def __eq__(self, other: object): self.__class__) and self.outputs == other.outputs +def get_all_seq_data_entries( + seq_group_metadata_list: List[SequenceGroupMetadata] +) -> List[Tuple[int, SequenceData]]: + """Given a list of SequenceGroupMetadata, create a dict of + sequence ids to SequenceData + """ + return [(seq_id, seq_data) for sg in seq_group_metadata_list \ + for seq_id, seq_data in sg.seq_data.items() + ] + + def get_all_seq_ids( seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: """Given a list of SequenceGroupMetadata, create a list of all diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 21a25fafc4b6..64d92fd27870 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -16,7 +16,8 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, - get_all_seq_ids, get_all_seq_ids_and_request_ids) + get_all_seq_data_entries, + get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -29,7 +30,8 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner -from vllm.spec_decode.util import (Timer, create_sequence_group_output, +from vllm.spec_decode.util import (Timer, create_logprobs_output, + create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) @@ -439,8 +441,8 @@ def _serialize_sampler_output_no_logprobs( self, execute_model_req: ExecuteModelRequest, sampler_output: SamplerOutput) -> SamplerOutput: """ - Creates and returns a `SamplerOutput` with only the sampled token IDs - being serialized to CPU & populated in `CompletionSequenceGroupOutput`. + Creates and returns a `SamplerOutput` with only the token IDs being + serialized to CPU and populated in `CompletionSequenceGroupOutput`. All other parameters in `CompletionSequenceGroupOutput` related to log probabilities are skipped. @@ -452,19 +454,43 @@ def _serialize_sampler_output_no_logprobs( Returns: SamplerOutput: A new `SamplerOutput` instance containing a list of - `CompletionSequenceGroupOutput` objects with only sampled token - IDs populated. + `CompletionSequenceGroupOutput` objects with only token IDs + populated. """ - seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list) + seq_output_prompt_logprobs = [ + seq.is_prompt and seq.sampling_params.prompt_logprobs is not None + and seq.sampling_params.prompt_logprobs > 0 + for seq in execute_model_req.seq_group_metadata_list + ] # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( + # subtracting is faster than testing for equality sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \ - if any(seq.is_prompt - for seq in execute_model_req.seq_group_metadata_list) else \ + if any(seq_output_prompt_logprobs) else \ sampler_output.sampled_token_ids).tolist() + + seq_data_entries = get_all_seq_data_entries( + execute_model_req.seq_group_metadata_list) completion_seq_group_output_list: List[ CompletionSequenceGroupOutput] = [] - for index, seq_id in enumerate(seq_ids): + for index, ((seq_id, seq_data), needs_prompt_logprobs) in \ + enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)): + if needs_prompt_logprobs: + prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_logprobs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) + # no prompt logprobs for the first token + for p_token_id in prompt_token_ids[1:] + ] + else: + prompt_logprobs = None + completion_seq_group_output_list.append( create_sequence_group_output( token_id=sampled_token_ids_list[index][0], @@ -473,7 +499,7 @@ def _serialize_sampler_output_no_logprobs( seq_id=seq_id, topk_token_ids=[], topk_logprobs=[], - )) + prompt_logprobs=prompt_logprobs)) return SamplerOutput(outputs=completion_seq_group_output_list) @nvtx_range("spec_decode_worker._run_no_spec") diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 54e718bc4901..193ef870dfce 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -6,7 +6,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceGroupMetadata, SequenceOutput) + PromptLogprobs, SequenceGroupMetadata, + SequenceOutput) SeqId = int @@ -49,21 +50,19 @@ def get_sampled_token_logprobs( return sampled_token_ids_ranks, selected_logprobs -def create_sequence_group_output( +def create_logprobs_output( token_id: int, token_id_logprob_rank: int, token_id_logprob: float, - seq_id: SeqId, topk_token_ids: List[Optional[int]], topk_logprobs: List[Optional[float]], -) -> CompletionSequenceGroupOutput: - """Create a SequenceGroupOutput given the sampling results. +) -> Dict[int, Logprob]: + """Create a Logprob Dict for a token given the sampling results. Args: token_id (int): The sampled token for the sequence. token_id_logprob_rank (int): The logprob rank of the sampled token. token_id_logprob (float): The logprob value of the sampled token. - seq_id (int): The sequence id. topk_token_ids (List[Optional[int]]): The list of top-k token ids. topk_logprobs (List[Optional[float]]): The list of top-k logprobs. """ @@ -85,14 +84,44 @@ def create_sequence_group_output( if topk_token_id is not None }) + return logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], + prompt_logprobs: Optional[PromptLogprobs] = None, +) -> CompletionSequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + """ + + logprobs = create_logprobs_output( + token_id, + token_id_logprob_rank, + token_id_logprob, + topk_token_ids, + topk_logprobs, + ) + return CompletionSequenceGroupOutput( samples=[ SequenceOutput(parent_seq_id=seq_id, output_token=token_id, logprobs=logprobs) ], - # TODO add prompt logprobs support. - prompt_logprobs=None, + prompt_logprobs=prompt_logprobs, ) From 6f420db2c7200e1c5837710b9ba6bbc727f13137 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 16 Sep 2024 11:28:28 -0600 Subject: [PATCH 7/9] test: allow generate_w_logprobs even if no logprobs requested Signed-off-by: Travis Johnson --- tests/conftest.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c2616bcf7091..a03173e106a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -675,8 +675,6 @@ def generate_w_logprobs( videos: Optional[PromptVideoInput] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: - assert sampling_params.logprobs is not None - if images is not None: assert len(prompts) == len(images) @@ -754,7 +752,7 @@ def generate_greedy_logprobs( temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs, - prompt_logprobs=(num_prompt_logprobs), + prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids) return self.generate_w_logprobs(prompts, From a15fea64865ccfc74bcf5b2d9ff95db39e2371ca Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 16 Sep 2024 17:00:03 -0600 Subject: [PATCH 8/9] test: update spec_decode e2e tests Include logprobs cases and where disable_logprobs is True. Signed-off-by: Travis Johnson --- tests/spec_decode/e2e/conftest.py | 139 ++++++++++++------ .../spec_decode/e2e/test_eagle_correctness.py | 58 ++++++++ tests/spec_decode/e2e/test_logprobs.py | 95 ++++++------ .../e2e/test_medusa_correctness.py | 59 ++++++++ tests/spec_decode/e2e/test_mlp_correctness.py | 57 ++++++- .../spec_decode/e2e/test_ngram_correctness.py | 59 ++++++++ 6 files changed, 378 insertions(+), 89 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 3d93f4a23b68..b450ef97c89d 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,13 +1,16 @@ from itertools import cycle -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple, Union import pytest from vllm import LLM, SamplingParams from vllm.model_executor.utils import set_random_seed +from vllm.sequence import PromptLogprobs, SampleLogprobs from ...conftest import cleanup -from ...models.utils import check_logprobs_close, check_outputs_equal +from ...models.utils import (TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + check_logprobs_close, check_outputs_equal) from ...utils import RemoteOpenAIServer PROMPTS = [ @@ -81,45 +84,77 @@ def get_output_from_llm_generator( return tokens, token_ids, acceptance_rate -def run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size: int, - max_output_len: int, - seed: Optional[int] = 0, - temperature: float = 0.0, - logprobs: int = 1): - org_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **baseline_llm_kwargs, - } - - sd_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **test_llm_kwargs, - } - - prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] - - sampling_params = SamplingParams(temperature=temperature, - max_tokens=max_output_len, - seed=seed, - logprobs=logprobs) - - with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - with vllm_runner(**sd_args) as vllm_model: - sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - check_logprobs_close(outputs_0_lst=org_outputs, - outputs_1_lst=sd_outputs, - name_0="org", - name_1="sd") +def check_logprobs_correctness( + spec_outputs: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs]], + baseline_outputs: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs]], + disable_logprobs: bool = False, +): + """Compare sampled and prompt logprobs between baseline and spec decoding + """ + if not disable_logprobs: + return check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=spec_outputs, + name_0="org", + name_1="sd", + ) + + # Check correctness when disable_logprobs == True + for spec_output, baseline_output in zip(spec_outputs, baseline_outputs): + # Check generated token logprobs. + spec_logprobs = spec_output[2] + baseline_logprobs = baseline_output[2] + _check_logprobs_when_output_disabled(spec_logprobs, + baseline_logprobs, + is_prompt_logprobs=False) + + # Check prompt logprobs too, if they exist + if len(baseline_output) == 4: + assert len(spec_output) == 4 + spec_prompt_logprobs = spec_output[3] + baseline_prompt_logprobs = baseline_output[3] + _check_logprobs_when_output_disabled(spec_prompt_logprobs, + baseline_prompt_logprobs, + is_prompt_logprobs=True) + + +def _check_logprobs_when_output_disabled( + spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], + baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], + is_prompt_logprobs: bool = False, +): + # Prompt logprobs are optional + if is_prompt_logprobs and baseline_logprobs is None: + assert spec_logprobs is None + return + + assert spec_logprobs is not None + assert baseline_logprobs is not None + assert len(spec_logprobs) == len(baseline_logprobs) + + # For each generated position of the sequence. + for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( + zip(spec_logprobs, baseline_logprobs)): + + # First prompt logprob is expected to be None + if is_prompt_logprobs and baseline_pos_logprobs is None: + assert spec_pos_logprobs is None + assert pos == 0 + continue + + assert spec_pos_logprobs is not None + assert baseline_pos_logprobs is not None + + # When disabled, the 1 logprob is returned with dummy values for the + # score and rank, but the token id should match the baseline model + assert len(spec_pos_logprobs) == 1 + (spec_pos_logprob_token_id, + spec_pos_logprob) = next(iter(spec_pos_logprobs.items())) + assert spec_pos_logprob.rank == -1 + assert spec_pos_logprob.logprob == 0.0 + assert spec_pos_logprob_token_id in baseline_pos_logprobs def run_equality_correctness_test( @@ -135,7 +170,10 @@ def run_equality_correctness_test( disable_seed: bool = False, ignore_eos: bool = True, ensure_all_accepted: bool = False, - expected_acceptance_rate: Optional[float] = None): + expected_acceptance_rate: Optional[float] = None, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + disable_logprobs: bool = False): org_args = { **common_llm_kwargs, @@ -157,10 +195,12 @@ def run_equality_correctness_test( sampling_params = SamplingParams(temperature=temperature, max_tokens=max_output_len, seed=seed, - ignore_eos=ignore_eos) + ignore_eos=ignore_eos, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs) with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate(prompts, sampling_params) + org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) with vllm_runner(**sd_args) as vllm_model: if ensure_all_accepted or expected_acceptance_rate is not None: @@ -169,7 +209,7 @@ def run_equality_correctness_test( 'prometheus'] stat_logger.local_interval = -100 - sd_outputs = vllm_model.generate(prompts, sampling_params) + sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) if ensure_all_accepted or expected_acceptance_rate is not None: acceptance_rate = (stat_logger.metrics. @@ -185,11 +225,16 @@ def run_equality_correctness_test( if expected_acceptance_rate is not None: assert acceptance_rate >= expected_acceptance_rate - 1e-2 - check_outputs_equal(outputs_0_lst=org_outputs, - outputs_1_lst=sd_outputs, + # Only pass token entries, not the logprobs + check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs], + outputs_1_lst=[out[0:2] for out in sd_outputs], name_0="org", name_1="sd") + # Check logprobs if requested + if logprobs is not None or prompt_logprobs is not None: + check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs) + def run_equality_correctness_test_tp(model, common_llm_kwargs, diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index f2af2c2bedb1..d7ca8815ec25 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, batch_size, output_len, seed) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 03c1733f104f..b7d54991e053 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -4,7 +4,7 @@ from vllm import SamplingParams -from .conftest import run_logprob_correctness_test +from .conftest import run_equality_correctness_test @pytest.mark.parametrize( @@ -25,6 +25,10 @@ "speculative_model": "JackFram/llama-160m", "num_speculative_tokens": 3, "disable_logprobs_during_spec_decoding": False, + }, { + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": True, }]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( @@ -41,16 +45,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify output logprobs are equal with and without speculative decoding. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -91,16 +98,18 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, output_len: int, seed: int, logprobs: int): """Veriy logprob greedy equality with different speculation lens. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -143,16 +152,18 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify logprobs greedy equality when some sequences skip speculation. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -267,13 +278,15 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs, """Check the behavior when logprobs are disabled. Token choices should match with the base model. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 568c2d65fca5..38eae22cbb8a 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -87,6 +87,65 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 8, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int, logprobs: int): + """Verify greedy equality with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 2d0d6fb923ad..7f3180befaff 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -16,7 +16,7 @@ * Test greedy equality under various number of speculative tokens. With those tests, we can say at least, MLPSpeculator would not break the -correctess for the target model outputs. +correctness for the target model outputs. """ from unittest.mock import patch @@ -88,6 +88,61 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [8]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + """Verify greedy equality with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 89301f24e115..850114eb7f5a 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -76,6 +76,65 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model_name": "JackFram/llama-68m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 8, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + """Verify greedy equality on a tiny model with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ From 6e8e54d334a36c83c80c8fc54462ac1e99d72597 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 23 Sep 2024 15:39:30 -0600 Subject: [PATCH 9/9] changes from code review Signed-off-by: Travis Johnson --- vllm/sequence.py | 11 ----------- vllm/spec_decode/spec_decode_worker.py | 8 +++++--- vllm/transformers_utils/detokenizer.py | 2 +- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 31c4cd26f8c7..8b161c2e144a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1127,17 +1127,6 @@ def __eq__(self, other: object): self.__class__) and self.outputs == other.outputs -def get_all_seq_data_entries( - seq_group_metadata_list: List[SequenceGroupMetadata] -) -> List[Tuple[int, SequenceData]]: - """Given a list of SequenceGroupMetadata, create a dict of - sequence ids to SequenceData - """ - return [(seq_id, seq_data) for sg in seq_group_metadata_list \ - for seq_id, seq_data in sg.seq_data.items() - ] - - def get_all_seq_ids( seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: """Given a list of SequenceGroupMetadata, create a list of all diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 64d92fd27870..4c18a001fae9 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -16,7 +16,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, - get_all_seq_data_entries, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner @@ -469,8 +468,11 @@ def _serialize_sampler_output_no_logprobs( if any(seq_output_prompt_logprobs) else \ sampler_output.sampled_token_ids).tolist() - seq_data_entries = get_all_seq_data_entries( - execute_model_req.seq_group_metadata_list) + seq_data_entries = ( + (seq_id, seq_data) for sg in \ + execute_model_req.seq_group_metadata_list \ + for seq_id, seq_data in sg.seq_data.items() + ) completion_seq_group_output_list: List[ CompletionSequenceGroupOutput] = [] for index, ((seq_id, seq_data), needs_prompt_logprobs) in \ diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 9b9248936422..2b418f3603a0 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -280,7 +280,7 @@ def detokenize_incrementally( assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. - if 0 < new_token_id < len(tokenizer): + if 0 <= new_token_id < len(tokenizer): # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens)