Skip to content

Commit c8b6433

Browse files
committed
Optimize topp/topk calculation in sampler
Signed-off-by: Artur Fierka <[email protected]>
1 parent 69d765f commit c8b6433

File tree

2 files changed

+181
-4
lines changed

2 files changed

+181
-4
lines changed

tests/samplers/test_sampler.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from transformers import GenerationConfig, GenerationMixin
1010

1111
import vllm.envs as envs
12-
from vllm.model_executor.layers.sampler import Sampler
12+
from vllm.model_executor.layers.sampler import ApplyToppTopkScalar, Sampler
1313
from vllm.model_executor.sampling_metadata import SamplingMetadata
1414
from vllm.model_executor.utils import set_random_seed
1515
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
@@ -754,3 +754,63 @@ def test_sampler_include_gpu_probs_tensor(device: str):
754754
assert sampler_output.sampled_token_probs is not None
755755
assert sampler_output.logprobs is not None
756756
assert sampler_output.sampled_token_ids is not None
757+
758+
759+
@pytest.mark.parametrize("device", CUDA_DEVICES)
760+
def test_topk_topk_scalar():
761+
obj1 = ApplyToppTopkScalar(2)
762+
assert ApplyToppTopkScalar._padded_k == 0
763+
x = torch.tensor([[9, 9, 8, 8, 8, 8, 7, 7, 7.0],
764+
[10, 10, 9, 9, 9, 8, 5, 5, 5]])
765+
766+
retval1 = obj1(x, p=0.9, k=5)
767+
ninf = -float("inf")
768+
expected1 = torch.tensor([[9., 9., 8., 8., 8., 8., ninf, ninf, ninf],
769+
[10., 10., 9., 9., 9., ninf, ninf, ninf, ninf]])
770+
assert torch.all(retval1 == expected1).item()
771+
assert ApplyToppTopkScalar._padded_k == 9
772+
773+
obj2 = ApplyToppTopkScalar(2)
774+
assert obj2._padded_k == 9
775+
776+
x = torch.tensor([[2, 2, 9, 9, 2, 2, 1, 1, 1.0],
777+
[10, 9, 9, 5, 9, 9, 5, 9, 10]])
778+
retval2 = obj2(x, p=0.9, k=5)
779+
expected2 = torch.tensor(
780+
[[ninf, ninf, 9., 9., ninf, ninf, ninf, ninf, ninf],
781+
[10., ninf, 9., ninf, 9., 9., ninf, 9., 10.]])
782+
assert torch.all(retval2 == expected2).item()
783+
assert obj2._padded_k == 9
784+
785+
retval3 = obj2(x, p=1.0, k=5)
786+
expected3 = torch.tensor([[2., 2., 9., 9., 2., 2., ninf, ninf, ninf],
787+
[10., 9., 9., ninf, 9., 9., ninf, 9., 10.]])
788+
789+
assert torch.all(retval3 == expected3).item()
790+
791+
# this should not be done in general, doing it here for testing purposes
792+
ApplyToppTopkScalar._padded_k = 0
793+
x = torch.tensor([[1, 1, 1, 9, 8, 1, 1, 1, 1.0],
794+
[2, 1, 2, 2, 1, 1, 1, 1, 1]])
795+
obj3 = ApplyToppTopkScalar(2)
796+
retval4 = obj3(x, p=0.9, k=2)
797+
expected4 = torch.tensor(
798+
[[ninf, ninf, ninf, 9., 8., ninf, ninf, ninf, ninf],
799+
[2., ninf, 2., 2., ninf, ninf, ninf, ninf, ninf]])
800+
assert torch.all(retval4 == expected4).item()
801+
assert obj3._padded_k == 4
802+
y = torch.tensor([[8, 8, 8, 9, 8, 1, 1, 1, 1.0],
803+
[2, 1, 2, 2, 1, 1, 1, 1, 1]])
804+
retval5 = obj3(y, p=0.9, k=2)
805+
assert obj3._padded_k == 8
806+
expected5 = torch.tensor([[8., 8., 8., 9., 8., ninf, ninf, ninf, ninf],
807+
[2., ninf, 2., 2., ninf, ninf, ninf, ninf,
808+
ninf]])
809+
assert torch.all(retval5 == expected5).item()
810+
y = torch.tensor([[8, 8, 8, 9, 8, 8, 1, 1, 1.0],
811+
[2, 1, 2, 2, 3, 1, 1, 1, 1]])
812+
retval6 = obj3(y, p=0.9, k=2)
813+
expected6 = torch.tensor([[8., 8., 8., 9., 8., 8., ninf, ninf, ninf],
814+
[2., ninf, 2., 2., 3., ninf, ninf, ninf, ninf]])
815+
assert torch.all(retval6 == expected6).item()
816+
assert obj3._padded_k == 8

vllm/model_executor/layers/sampler.py

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""A layer that samples the next tokens from the model's outputs."""
22
import itertools
3+
import os
34
import warnings
45
from dataclasses import dataclass
56
from importlib.util import find_spec
6-
from math import inf
7+
from math import ceil, inf
78
from typing import Dict, Iterator, List, Optional, Tuple, Union
89

910
import msgspec
@@ -204,14 +205,18 @@ def _init_sampling_tensors(
204205
self._sampling_tensors = None
205206

206207
# Initialize new sampling tensors
207-
(sampling_tensors, do_penalties, do_top_p_top_k,
208-
do_min_p) = SamplingTensors.from_sampling_metadata(
208+
(sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
209+
top_k_scalar, top_p_scalar) = SamplingTensors.from_sampling_metadata(
209210
sampling_metadata, vocab_size, logits.device, logits.dtype)
210211

211212
self._sampling_tensors = sampling_tensors
212213
self._do_penalties = do_penalties
213214
self._do_top_p_top_k = do_top_p_top_k
214215
self._do_min_p = do_min_p
216+
self._top_k_scalar = top_k_scalar
217+
self._top_p_scalar = top_p_scalar
218+
219+
self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)
215220

216221
def forward(
217222
self,
@@ -337,6 +342,118 @@ def _should_modify_greedy_probs_inplace(self) -> bool:
337342
return self.should_modify_greedy_probs_inplace
338343

339344

345+
class ApplyToppTopkScalar:
346+
"""
347+
The original implementation of _apply_top_k_top_p is more general
348+
as it uses vector topp, topk
349+
However in a lot of cases, topp and topk is same for all batch elements
350+
For such "scalar" topp, topk cases, we can use this class
351+
352+
The main optimizations in this class is:
353+
Use topk instead of sort, which is much faster especially for small k.
354+
However just using topk might not suffice in cases as shown below
355+
Consider a tensor: 9 9 8 8 8 8 7 7 7
356+
Topk, with k=5, on this yields 9 9 8 8 8
357+
The value "8" is on the boundary, hence the last "8" gets snipped off
358+
However the original implementation accepts all the "8"s,
359+
so it should output:
360+
9 9 8 8 8 8 (6 values, even though k=5)
361+
To ensure these semantics, we perform topk with _padded_k elements
362+
If we find more boundary elements left over,
363+
then we keep incrementing _padded_k
364+
and in future calls use the expanded value of __padded_k
365+
366+
The increments to _padded_k should be done
367+
with value > 1 to prevent excessive recompilations
368+
due to dynamic shapes (the output shape of the topk)
369+
370+
The main logic of this is in __call__
371+
This is a class instead of a function, just to keep track of
372+
the monotonic non-decreasing state _padded_k
373+
374+
To enable the duplicates that are outside of kth border,
375+
set VLLM_HANDLE_TOPK_DUPLICATES to 1 or true.
376+
"""
377+
_padded_k = 0
378+
_handle_duplicates = os.getenv('VLLM_HANDLE_TOPK_DUPLICATES',
379+
'0').lower() in ['1', 'true']
380+
381+
def __init__(self, increment: int):
382+
self._increment = increment
383+
384+
def __call__(self, logits: torch.Tensor, p: float, k: int):
385+
if k == 1 and not ApplyToppTopkScalar._handle_duplicates:
386+
new_logits = torch.full(logits.shape,
387+
-float("inf"),
388+
device=logits.device)
389+
vals, idx = torch.max(logits, keepdim=True, dim=1)
390+
new_logits.scatter_(1, idx, vals.to(new_logits.dtype))
391+
return new_logits
392+
393+
if k > ApplyToppTopkScalar._padded_k:
394+
ApplyToppTopkScalar._padded_k = min(k + self._increment,
395+
logits.shape[1])
396+
397+
vals, idx = torch.topk(logits,
398+
k=ApplyToppTopkScalar._padded_k,
399+
dim=1,
400+
sorted=True)
401+
402+
# this "if" checks if we have bucketed so much that
403+
# we have padded k upto shape of logits
404+
if self._handle_duplicates and \
405+
ApplyToppTopkScalar._padded_k != logits.shape[1]:
406+
smallest_of_top_k = vals[:, k - 1]
407+
num_duplicates_of_smallest_of_topk = torch.sum(
408+
logits == smallest_of_top_k.unsqueeze(1), 1)
409+
max_num_duplicates_of_smallest_of_topk = torch.max(
410+
num_duplicates_of_smallest_of_topk).item()
411+
412+
# there are n repeats for a border
413+
# (border meaning the smallest value of the top k).
414+
# we do not know if only 1 or 2 or (n-1)
415+
# of them lie outside the kth border,
416+
# so we choose to conservatively increase by n-1
417+
# when num_duplicates > _padded_k - k
418+
if max_num_duplicates_of_smallest_of_topk - 1 > (
419+
ApplyToppTopkScalar._padded_k - k):
420+
incr = int(
421+
ceil((max_num_duplicates_of_smallest_of_topk - 1) /
422+
self._increment) * self._increment)
423+
# this while loop should be traversed at most twice,
424+
# because we dont increment by self._increment and retry
425+
# instead we compute incr in one go
426+
ApplyToppTopkScalar._padded_k = min(
427+
ApplyToppTopkScalar._padded_k + incr, logits.shape[1])
428+
429+
# recompute topk with expanded padded_k
430+
vals, idx = torch.topk(logits,
431+
k=ApplyToppTopkScalar._padded_k,
432+
dim=1,
433+
sorted=True)
434+
435+
idx = torch.fliplr(idx)
436+
vals = torch.fliplr(vals)
437+
438+
top_k_smallest_val_idx = vals.size(1) - k
439+
top_k_mask = vals[:, top_k_smallest_val_idx].unsqueeze(1)
440+
top_k_mask = vals < top_k_mask
441+
vals.masked_fill_(top_k_mask, -float("inf"))
442+
443+
probs_sort = vals.softmax(dim=-1)
444+
probs_sum = probs_sort.cumsum(dim=-1)
445+
top_p_mask = probs_sum <= (1 - p)
446+
top_p_mask[:, -1] = False
447+
vals.masked_fill_(top_p_mask, -float("inf"))
448+
449+
new_logits = torch.full(logits.shape,
450+
-float("inf"),
451+
device=logits.device)
452+
new_logits.scatter_(1, idx, vals.to(new_logits.dtype))
453+
454+
return new_logits
455+
456+
340457
def _apply_min_tokens_penalty(
341458
logits: torch.Tensor,
342459
sampling_metadata: SamplingMetadata,

0 commit comments

Comments
 (0)