Skip to content

Commit 7987523

Browse files
committed
update from_sampling_metadata method
Signed-off-by: Artur Fierka <[email protected]>
1 parent c8b6433 commit 7987523

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/model_executor/sampling_metadata.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,8 @@ def from_sampling_metadata(
382382
vocab_size: int,
383383
device: torch.device,
384384
dtype: torch.dtype,
385-
) -> Tuple["SamplingTensors", bool, bool, bool]:
385+
) -> Tuple["SamplingTensors", bool, bool, bool, Optional[int],
386+
Optional[float]]:
386387
prompt_tokens: List[array] = []
387388
output_tokens: List[array] = []
388389
top_ks: List[int] = []
@@ -470,6 +471,11 @@ def from_sampling_metadata(
470471
prompt_tokens.append(seq_data.prompt_token_ids_array)
471472
output_tokens.append(seq_data.output_token_ids_array)
472473

474+
top_k_scalar = top_ks[0] if do_top_p_top_k and all(
475+
k == top_ks[0] for k in top_ks) else None
476+
top_p_scalar = top_ps[0] if do_top_p_top_k and all(
477+
p == top_ps[0] for p in top_ps) else None
478+
473479
sampling_tensors = SamplingTensors.from_lists(
474480
temperatures,
475481
top_ps,
@@ -484,7 +490,8 @@ def from_sampling_metadata(
484490
device,
485491
dtype,
486492
)
487-
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
493+
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
494+
top_k_scalar, top_p_scalar)
488495

489496
@classmethod
490497
def from_lists(

0 commit comments

Comments
 (0)