Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ steps:

- label: Engine Test
mirror_hardwares: [amd]
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenization

- label: Entrypoints Test
mirror_hardwares: [amd]
Expand Down
109 changes: 87 additions & 22 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Any, Dict, List, Optional

import pytest
from transformers import AutoTokenizer
Expand Down Expand Up @@ -139,6 +139,15 @@ def create_dummy_logprobs(
} for token_id in complete_sequence_token_ids]


def create_dummy_prompt_logprobs(
complete_sequence_token_ids: List[int]
) -> List[Optional[Dict[int, Any]]]:
# logprob for the first prompt token is None.
logprobs: List[Optional[Dict[int, Any]]] = [None]
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
return logprobs


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False])
Expand Down Expand Up @@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,

@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True])
def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
detokenizer: Detokenizer):
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
sampling_params = SamplingParams(skip_special_tokens=True,
prompt_logprobs=1)

# Run sequentially.
Expand All @@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
seqs=[seq],
sampling_params=sampling_params,
arrival_time=0.0)
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
decoded_prompt_logprobs = dummy_logprobs
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group,
dummy_logprobs,
position_offset=0)
# First logprob is None.
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
1:] # type: ignore

if skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that this will only be true if we skip
# special tokens.
assert complete_sequence == "".join([
logprobs[token_id].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
])
assert complete_sequence != "".join([
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
])
# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
text = text_full[len(text_first):]

# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that the first logprob is None.
assert text == "".join([
logprobs[token_id].decoded_token
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])
assert text != "".join([
logprobs[token_id + 1].decoded_token
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
def test_decode_prompt_logprobs_chunked_prefill(
vllm_runner,
model,
chunked_prefill_token_size: int,
example_prompts,
):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size

with vllm_runner(model,
dtype="half",
max_logprobs=5,
gpu_memory_utilization=0.5,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:

vllm_sampling_params = SamplingParams(max_tokens=10,
logprobs=5,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)

for idx, result in enumerate(vllm_results):
assert result.prompt_logprobs is not None
assert result.prompt_logprobs[0] is None

# Compared detokenized prompts ids to original prompt.
generated_string = ""
for (prompt_token,
prompt_logprobs) in zip(result.prompt_token_ids[1:],
result.prompt_logprobs[1:]):
# prompt_logprobs is a dict of the token_id: logprob
# We select the token_id corresponding to the actual prompt
# Decoded token in the detokenized string corresponding to this
# prompt token.
generated_string += prompt_logprobs[prompt_token].decoded_token

assert generated_string == example_prompts[idx], (
"Detokenized prompt logprobs do not match original prompt")
19 changes: 14 additions & 5 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,23 @@ def process_prompt_logprob(self, seq_group: SequenceGroup,
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
prompt_logprobs = output.prompt_logprobs

# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []

if seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
if not seq_group.prompt_logprobs:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
seq_group.prompt_logprobs = [None]
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))

seq_group.prompt_logprobs.extend(prompt_logprobs)

def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
Expand Down
16 changes: 12 additions & 4 deletions vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ def get_tokenizer_for_seq(self,
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)

def decode_prompt_logprobs_inplace(
self, seq_group: SequenceGroup,
prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: List[Optional[Dict[
int, Logprob]]],
position_offset: int) -> None:
"""Decodes the logprobs for the prompt of a sequence group.

Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill).

Returns:
The prompt logprobs with the decoded tokens.
Expand All @@ -47,8 +50,13 @@ def decode_prompt_logprobs_inplace(
next_iter_tokens: List[str] = []
prev_tokens = None

for token_position, prompt_logprobs_for_token in enumerate(
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
prompt_logprobs):

# Absolute token position equals the index in the logprobs
# list plus the offset of the entire logprobs list relative
# to the start of the sequence.
token_position = token_position_in_logprob + position_offset
if not prompt_logprobs_for_token:
continue
for token_id, sample_logprob in prompt_logprobs_for_token.items():
Expand Down