Skip to content

Commit 603ad84

Browse files
authored
[Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309)
1 parent a88081b commit 603ad84

18 files changed

+862
-633
lines changed

tests/samplers/test_logprobs.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,44 @@
99

1010
@pytest.mark.parametrize("model", MODELS)
1111
@pytest.mark.parametrize("dtype", ["half"])
12+
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
13+
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
1214
def test_get_prompt_logprobs(
1315
hf_runner,
1416
vllm_runner,
1517
model,
1618
dtype,
19+
chunked_prefill_token_size: int,
20+
num_top_logprobs: int,
1721
example_prompts,
1822
):
23+
max_num_seqs = 256
24+
enable_chunked_prefill = False
25+
max_num_batched_tokens = None
26+
if chunked_prefill_token_size != -1:
27+
enable_chunked_prefill = True
28+
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
29+
max_num_batched_tokens = chunked_prefill_token_size
30+
1931
max_tokens = 5
20-
num_top_logprobs = 6
2132
hf_model = hf_runner(model, dtype=dtype)
2233
hf_logprobs = hf_model.generate_greedy_logprobs(
2334
example_prompts,
2435
max_tokens=max_tokens,
2536
)
2637
del hf_model
2738

28-
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
39+
vllm_model = vllm_runner(
40+
model,
41+
dtype=dtype,
42+
max_logprobs=num_top_logprobs,
43+
enable_chunked_prefill=enable_chunked_prefill,
44+
max_num_batched_tokens=max_num_batched_tokens,
45+
max_num_seqs=max_num_seqs,
46+
)
2947
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
3048
logprobs=num_top_logprobs,
31-
prompt_logprobs=5,
49+
prompt_logprobs=num_top_logprobs,
3250
temperature=0.0)
3351
vllm_results = vllm_model.model.generate(
3452
example_prompts, sampling_params=vllm_sampling_params)
@@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
5270
"The output text from the top logprob for each token position "
5371
"should be the same as the output text in the result.")
5472

73+
# The first prompt logprob is always None
74+
assert result.prompt_logprobs[0] is None
75+
for prompt_logprobs in result.prompt_logprobs[1:]:
76+
# If the prompt token is not included in the top X
77+
# logprob, it can return 1 more data
78+
assert (len(prompt_logprobs) == num_top_logprobs
79+
or len(prompt_logprobs) == num_top_logprobs + 1)
80+
5581
# Test whether prompt logprobs are consistent with HF
5682
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
5783
# Check prompt logprobs
84+
# The first prompt logprob is always None, so we compare it from 1:.
5885
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
5986
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
6087
for token_id, logprob in vllm_prompt_logprob_dict.items():
@@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
74101
"The token should be decoded by the time it is returned "
75102
" to the user.")
76103

104+
# Test if prompt logprobs are correctly set.
105+
for vllm_result in vllm_results:
106+
token_ids = vllm_result.prompt_token_ids
107+
prompt_logprobs = vllm_result.prompt_logprobs
108+
109+
# The first token doesn't have logprob.
110+
assert prompt_logprobs[0] is None
111+
112+
for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
113+
assert token_id in logprob_dict
114+
77115

78116
def test_max_logprobs():
79117
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)

tests/samplers/test_sampler.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers import GenerationConfig, GenerationMixin
99

1010
from vllm.model_executor.layers.sampler import Sampler
11+
from vllm.model_executor.sampling_metadata import SamplingMetadata
1112
from vllm.model_executor.utils import set_random_seed
1213
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
1314
from vllm.utils import Counter
@@ -54,6 +55,7 @@ def _do_sample(
5455
sampler: MockLogitsSampler,
5556
model_runner: ModelRunner,
5657
sampling_params: SamplingParams,
58+
device: str,
5759
):
5860
seq_group_metadata_list = []
5961
prompt_lens = []
@@ -68,9 +70,12 @@ def _do_sample(
6870
))
6971
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
7072

71-
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
72-
prompt_lens,
73-
subquery_lens=prompt_lens)
73+
sampling_metadata = SamplingMetadata.prepare(
74+
seq_group_metadata_list,
75+
prompt_lens,
76+
subquery_lens=prompt_lens,
77+
device=device,
78+
pin_memory=model_runner.pin_memory)
7479
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
7580

7681

@@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
8590

8691
sampling_params = SamplingParams(temperature=0)
8792
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
88-
sampling_params)
93+
sampling_params, device)
8994
expected = torch.argmax(fake_logits, dim=-1)
9095
for i, sequence_output in enumerate(sampler_output):
9196
for nth_output in sequence_output.samples:
@@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
111116
n=random.randint(1, 10),
112117
)
113118
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
114-
sampling_params)
119+
sampling_params, device)
115120

116121
for i, sequence_output in enumerate(sampler_output):
117122
for nth_output in sequence_output.samples:
@@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
137142
seed=random.randint(0, 10000),
138143
)
139144
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
140-
sampling_params)
145+
sampling_params, device)
141146

142147
for i, sequence_output in enumerate(sampler_output):
143148
for nth_output in sequence_output.samples:
@@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
160165
seed=random.randint(0, 10000),
161166
)
162167
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
163-
model_runner, sampling_params)
168+
model_runner, sampling_params, device)
164169

165170
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
166-
model_runner, sampling_params)
171+
model_runner, sampling_params, device)
167172

168173
assert first_sampler_output == second_sampler_output
169174

@@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
183188
best_of=2,
184189
use_beam_search=True,
185190
)
186-
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
191+
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
192+
device)
187193
# no assertion here as I am not sure how to determine whether
188194
# the outputs are expected - in other words, this just tests
189195
# whether there are no exceptions in the sampler
@@ -443,10 +449,12 @@ def run_test_case(*,
443449
"batch size")
444450

445451
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
446-
sampling_metadata = model_runner._prepare_sample(
452+
sampling_metadata = SamplingMetadata.prepare(
447453
seq_group_metadata_list,
448454
prompt_lens=prompt_lens if prompt_lens else None,
449-
subquery_lens=prompt_lens if prompt_lens else None)
455+
subquery_lens=prompt_lens if prompt_lens else None,
456+
device=device,
457+
pin_memory=model_runner.pin_memory)
450458
# the logits tensor is modified in-place by the sampler
451459
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
452460

@@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str):
530538
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
531539

532540
def test_sampling(model_runner: ModelRunner):
533-
sampling_metadata = model_runner._prepare_sample(
534-
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
541+
sampling_metadata = SamplingMetadata.prepare(
542+
seq_group_metadata_list,
543+
prompt_lens,
544+
subquery_lens=prompt_lens,
545+
device=device,
546+
pin_memory=model_runner.pin_memory)
535547
sampler_output = sampler(logits=fake_logits,
536548
sampling_metadata=sampling_metadata)
537549

@@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
627639
))
628640
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
629641

630-
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
631-
prompt_lens,
632-
subquery_lens=prompt_lens)
642+
sampling_metadata = SamplingMetadata.prepare(
643+
seq_group_metadata_list,
644+
prompt_lens,
645+
subquery_lens=prompt_lens,
646+
device=device,
647+
pin_memory=model_runner.pin_memory)
633648

634649
sample_probs = None
635650

tests/test_logits_processor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77

88
from vllm.model_executor.layers.logits_processor import LogitsProcessor
9+
from vllm.model_executor.sampling_metadata import SamplingMetadata
910
from vllm.model_executor.utils import set_random_seed
1011
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
1112
from vllm.worker.model_runner import ModelRunner
@@ -82,9 +83,12 @@ def pick_ith(token_ids, logits):
8283
))
8384
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
8485

85-
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
86-
prompt_lens,
87-
subquery_lens=prompt_lens)
86+
sampling_metadata = SamplingMetadata.prepare(
87+
seq_group_metadata_list,
88+
prompt_lens,
89+
subquery_lens=prompt_lens,
90+
device=model_runner.device,
91+
pin_memory=model_runner.pin_memory)
8892
logits_processor_output = logits_processor(
8993
embedding=None,
9094
hidden_states=input_tensor,

tests/worker/test_model_runner.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33

44
from vllm.config import ModelConfig, SchedulerConfig
5+
from vllm.model_executor.sampling_metadata import SamplingMetadata
56
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
67
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
78

@@ -97,9 +98,12 @@ def test_prepare_prompt(batch_size):
9798
assert len(input_positions) == sum(prompt_lens)
9899
torch.testing.assert_close(input_tokens, input_positions)
99100

100-
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
101-
prompt_lens,
102-
subquery_lens=prompt_lens)
101+
sampling_metadata = SamplingMetadata.prepare(
102+
seq_group_metadata_list,
103+
prompt_lens,
104+
subquery_lens=prompt_lens,
105+
device=model_runner.device,
106+
pin_memory=model_runner.pin_memory)
103107
assert len(input_tokens) == sum(prompt_lens)
104108
assert len(input_positions) == sum(prompt_lens)
105109
actual = sampling_metadata.selected_token_indices
@@ -195,9 +199,12 @@ def test_prepare_decode_cuda_graph(batch_size):
195199
for prompt_len in prompt_lens:
196200
expected_selected_token_indices.append(selected_token_start_idx)
197201
selected_token_start_idx += 1
198-
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
199-
prompt_lens,
200-
subquery_lens=prompt_lens)
202+
sampling_metadata = SamplingMetadata.prepare(
203+
seq_group_metadata_list,
204+
prompt_lens,
205+
subquery_lens=prompt_lens,
206+
device=model_runner.device,
207+
pin_memory=model_runner.pin_memory)
201208
actual = sampling_metadata.selected_token_indices
202209
expected = torch.tensor(expected_selected_token_indices,
203210
device=actual.device,

vllm/core/scheduler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,20 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
915915
self.block_manager.get_common_computed_block_ids(
916916
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
917917

918+
do_sample = True
919+
if seq_group.is_prefill():
920+
seqs = seq_group.get_seqs()
921+
# Prefill has only 1 sequence.
922+
assert len(seqs) == 1
923+
# In the next iteration, all prompt tokens are not computed.
924+
# It means the prefill is chunked, and we don't need sampling.
925+
# NOTE: We use get_len instead of get_prompt_len because when
926+
# a sequence is preempted, prefill includes previous generated
927+
# output tokens.
928+
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
929+
seqs[0].data.get_len()):
930+
do_sample = False
931+
918932
# It assumes the scheduled_seq_groups is ordered by
919933
# prefill < decoding.
920934
is_prompt = seq_group.is_prefill()
@@ -924,6 +938,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
924938
seq_data=seq_data,
925939
sampling_params=seq_group.sampling_params,
926940
block_tables=block_tables,
941+
do_sample=do_sample,
927942
token_chunk_size=token_chunk_size,
928943
lora_request=seq_group.lora_request,
929944
computed_block_nums=common_computed_block_nums,

vllm/engine/async_llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ async def step_async(self) -> List[RequestOutput]:
219219

220220
request_outputs = self._process_model_outputs(
221221
output, scheduler_outputs.scheduled_seq_groups,
222-
scheduler_outputs.ignored_seq_groups)
222+
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
223223

224224
# Log stats.
225225
if self.log_stats:

vllm/engine/llm_engine.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from vllm.outputs import RequestOutput
2323
from vllm.sampling_params import SamplingParams
2424
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
25-
SequenceGroup, SequenceStage)
25+
SequenceGroup, SequenceGroupMetadata)
2626
from vllm.transformers_utils.detokenizer import Detokenizer
2727
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
2828
get_tokenizer_group)
@@ -476,9 +476,12 @@ def has_unfinished_requests(self) -> bool:
476476
return self.scheduler.has_unfinished_seqs()
477477

478478
def _process_model_outputs(
479-
self, output: List[SamplerOutput],
480-
scheduled_seq_groups: List[SequenceGroup],
481-
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
479+
self,
480+
output: List[SamplerOutput],
481+
scheduled_seq_groups: List[SequenceGroup],
482+
ignored_seq_groups: List[SequenceGroup],
483+
seq_group_metadata_list: List[SequenceGroupMetadata],
484+
) -> List[RequestOutput]:
482485
"""Apply the model output to the sequences in the scheduled seq groups.
483486
484487
Returns RequestOutputs that can be returned to the client.
@@ -492,17 +495,15 @@ def _process_model_outputs(
492495
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
493496

494497
# Update the scheduled sequence groups with the model outputs.
495-
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
496-
output_by_sequence_group):
498+
for scheduled_seq_group, outputs, seq_group_meta in zip(
499+
scheduled_seq_groups, output_by_sequence_group,
500+
seq_group_metadata_list):
497501
seq_group = scheduled_seq_group.seq_group
498502
seq_group.update_num_computed_tokens(
499503
scheduled_seq_group.token_chunk_size)
500504

501-
# If all sequences in the sequence group are in DECODE, then we can
502-
# process the output tokens. Otherwise, they are (chunked) prefill
503-
# samples and should not be processed.
504-
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
505-
if all(stage == SequenceStage.DECODE for stage in stages):
505+
self.output_processor.process_prompt_logprob(seq_group, outputs)
506+
if seq_group_meta.do_sample:
506507
self.output_processor.process_outputs(seq_group, outputs)
507508

508509
# Free the finished sequence groups.
@@ -585,7 +586,7 @@ def step(self) -> List[RequestOutput]:
585586

586587
request_outputs = self._process_model_outputs(
587588
output, scheduler_outputs.scheduled_seq_groups,
588-
scheduler_outputs.ignored_seq_groups)
589+
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
589590

590591
# Log stats.
591592
if self.log_stats:

vllm/engine/output_processor/interfaces.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,9 @@ def process_outputs(self, sequence_group: SequenceGroup,
6868
scheduler.
6969
"""
7070
pass
71+
72+
@abstractmethod
73+
def process_prompt_logprob(self, seq_group: SequenceGroup,
74+
outputs: List[SequenceGroupOutput]) -> None:
75+
"""Update prompt logprobs received from outputs to seq_group."""
76+
pass

vllm/engine/output_processor/multi_step.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ def __init__(
4444
self.get_tokenizer_for_seq = get_tokenizer_for_seq
4545
self.stop_checker = stop_checker
4646

47+
def process_prompt_logprob(self, seq_group: SequenceGroup,
48+
outputs: List[SequenceGroupOutput]) -> None:
49+
# TODO(sang): Prompt logprob currently not implemented in multi step
50+
# workers.
51+
logger.warning(
52+
"Prompt logprob is not supported by multi step workers. "
53+
"(e.g., speculative decode uses multi step workers).")
54+
pass
55+
4756
def process_outputs(self, sequence_group: SequenceGroup,
4857
outputs: List[SequenceGroupOutput]) -> None:
4958
"""Append new tokens in the outputs to sequences in the sequence group.

0 commit comments

Comments
 (0)