@@ -1182,9 +1182,12 @@ def _process_model_outputs(
1182
1182
output_by_sequence_group = create_output_by_sequence_group (
1183
1183
output , num_seq_groups = len (scheduled_seq_groups ))
1184
1184
1185
- # seq_id to (output token count, text len)
1185
+ # seq_id to (output token count, text len),
1186
1186
# only for delta output seq groups
1187
1187
previous_output_lens : Dict [int , Tuple [int , int ]] = {}
1188
+ # Seq groups whose outputs should not have prompt details included,
1189
+ # only applies to delta output seq groups
1190
+ exclude_prompt_seq_group_ids = set ()
1188
1191
1189
1192
# Update the scheduled sequence groups with the model outputs.
1190
1193
for scheduled_seq_group , outputs , seq_group_meta in zip (
@@ -1196,8 +1199,13 @@ def _process_model_outputs(
1196
1199
== RequestOutputKind .DELTA ):
1197
1200
text_buffer_length = params .output_text_buffer_length
1198
1201
for seq in seq_group .seqs :
1202
+ output_len = seq .get_output_len ()
1203
+ if output_len :
1204
+ # Exclude the prompt if the seq group already has
1205
+ # completion tokens
1206
+ exclude_prompt_seq_group_ids .add (seq_group .request_id )
1199
1207
previous_output_lens [seq .seq_id ] = (
1200
- seq . get_output_len () ,
1208
+ output_len ,
1201
1209
seq .get_output_text_to_return_len (text_buffer_length ))
1202
1210
1203
1211
seq_group .update_num_computed_tokens (
@@ -1236,23 +1244,27 @@ def _process_model_outputs(
1236
1244
for scheduled_seq_group in scheduled_seq_groups :
1237
1245
seq_group = scheduled_seq_group .seq_group
1238
1246
seq_group .maybe_set_first_token_time (now )
1247
+ include_prompt = seq_group .request_id not in (
1248
+ exclude_prompt_seq_group_ids )
1239
1249
request_output = RequestOutputFactory .create (
1240
- seq_group , previous_output_lens )
1250
+ seq_group , previous_output_lens , include_prompt )
1241
1251
if request_output :
1242
1252
request_outputs .append (request_output )
1243
1253
for seq_group in ignored_seq_groups :
1244
1254
params = seq_group .sampling_params
1255
+ include_prompt = True
1245
1256
if params is not None and params .output_kind == (
1246
1257
RequestOutputKind .DELTA ):
1247
1258
if not seq_group .is_finished ():
1248
1259
continue
1249
1260
# Ignored seq groups have no delta, but we must still return
1250
1261
# an "empty" RequestOutput when finished
1262
+ include_prompt = False
1251
1263
for seq in seq_group .seqs :
1252
1264
previous_output_lens [seq .seq_id ] = (seq .get_output_len (),
1253
1265
seq .output_text )
1254
1266
request_output = RequestOutputFactory .create (
1255
- seq_group , previous_output_lens )
1267
+ seq_group , previous_output_lens , include_prompt )
1256
1268
if request_output :
1257
1269
request_outputs .append (request_output )
1258
1270
return request_outputs
0 commit comments