Skip to content

Commit dc1f3f2

Browse files
committed
Also exclude prompt details in subsequent outputs in delta mode
1 parent ef2e59f commit dc1f3f2

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

vllm/engine/llm_engine.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,9 +1182,12 @@ def _process_model_outputs(
11821182
output_by_sequence_group = create_output_by_sequence_group(
11831183
output, num_seq_groups=len(scheduled_seq_groups))
11841184

1185-
# seq_id to (output token count, text len)
1185+
# seq_id to (output token count, text len),
11861186
# only for delta output seq groups
11871187
previous_output_lens: Dict[int, Tuple[int, int]] = {}
1188+
# Seq groups whose outputs should not have prompt details included,
1189+
# only applies to delta output seq groups
1190+
exclude_prompt_seq_group_ids = set()
11881191

11891192
# Update the scheduled sequence groups with the model outputs.
11901193
for scheduled_seq_group, outputs, seq_group_meta in zip(
@@ -1196,8 +1199,13 @@ def _process_model_outputs(
11961199
== RequestOutputKind.DELTA):
11971200
text_buffer_length = params.output_text_buffer_length
11981201
for seq in seq_group.seqs:
1202+
output_len = seq.get_output_len()
1203+
if output_len:
1204+
# Exclude the prompt if the seq group already has
1205+
# completion tokens
1206+
exclude_prompt_seq_group_ids.add(seq_group.request_id)
11991207
previous_output_lens[seq.seq_id] = (
1200-
seq.get_output_len(),
1208+
output_len,
12011209
seq.get_output_text_to_return_len(text_buffer_length))
12021210

12031211
seq_group.update_num_computed_tokens(
@@ -1236,23 +1244,27 @@ def _process_model_outputs(
12361244
for scheduled_seq_group in scheduled_seq_groups:
12371245
seq_group = scheduled_seq_group.seq_group
12381246
seq_group.maybe_set_first_token_time(now)
1247+
include_prompt = seq_group.request_id not in (
1248+
exclude_prompt_seq_group_ids)
12391249
request_output = RequestOutputFactory.create(
1240-
seq_group, previous_output_lens)
1250+
seq_group, previous_output_lens, include_prompt)
12411251
if request_output:
12421252
request_outputs.append(request_output)
12431253
for seq_group in ignored_seq_groups:
12441254
params = seq_group.sampling_params
1255+
include_prompt = True
12451256
if params is not None and params.output_kind == (
12461257
RequestOutputKind.DELTA):
12471258
if not seq_group.is_finished():
12481259
continue
12491260
# Ignored seq groups have no delta, but we must still return
12501261
# an "empty" RequestOutput when finished
1262+
include_prompt = False
12511263
for seq in seq_group.seqs:
12521264
previous_output_lens[seq.seq_id] = (seq.get_output_len(),
12531265
seq.output_text)
12541266
request_output = RequestOutputFactory.create(
1255-
seq_group, previous_output_lens)
1267+
seq_group, previous_output_lens, include_prompt)
12561268
if request_output:
12571269
request_outputs.append(request_output)
12581270
return request_outputs

vllm/outputs.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(
9393
self,
9494
request_id: str,
9595
prompt: Optional[str],
96-
prompt_token_ids: List[int],
96+
prompt_token_ids: Optional[List[int]],
9797
prompt_logprobs: Optional[PromptLogprobs],
9898
outputs: List[CompletionOutput],
9999
finished: bool,
@@ -118,6 +118,7 @@ def from_seq_group(
118118
cls,
119119
seq_group: SequenceGroup,
120120
prior_output_lens: Dict[int, Tuple[int, int]],
121+
include_prompt: bool = True,
121122
) -> Optional["RequestOutput"]:
122123
sampling_params = seq_group.sampling_params
123124
if sampling_params is None:
@@ -176,11 +177,18 @@ def from_seq_group(
176177
seq.stop_reason))
177178

178179
# Every sequence in the sequence group should have the same prompt.
179-
prompt = seq_group.prompt
180-
prompt_token_ids = seq_group.prompt_token_ids
181-
encoder_prompt = seq_group.encoder_prompt
182-
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
183-
prompt_logprobs = seq_group.prompt_logprobs
180+
if include_prompt:
181+
prompt = seq_group.prompt
182+
prompt_token_ids = seq_group.prompt_token_ids
183+
encoder_prompt = seq_group.encoder_prompt
184+
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
185+
prompt_logprobs = seq_group.prompt_logprobs
186+
else:
187+
prompt = None
188+
prompt_token_ids = None
189+
encoder_prompt = None
190+
encoder_prompt_token_ids = None
191+
prompt_logprobs = None
184192
finished_time = time.time() if finished else None
185193
seq_group.set_finished_time(finished_time)
186194
return cls(seq_group.request_id,
@@ -256,12 +264,15 @@ def __repr__(self):
256264
class RequestOutputFactory:
257265

258266
@staticmethod
259-
def create(seq_group,
260-
previous_output_lens: Dict[int, Tuple[int, int]] = {}): # noqa
267+
def create(
268+
seq_group,
269+
previous_output_lens: Dict[int, Tuple[int, int]] = {}, # noqa
270+
include_prompt: bool = True):
261271
# Determine the type based on a condition, for example:
262272
if hasattr(seq_group,
263273
'embeddings') and seq_group.embeddings is not None:
264274
return EmbeddingRequestOutput.from_seq_group(seq_group)
265275
else:
266276
return RequestOutput.from_seq_group(seq_group,
267-
previous_output_lens)
277+
previous_output_lens,
278+
include_prompt)

0 commit comments

Comments
 (0)