-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Core] Optimize topp/topk calculation in sampler #12156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""A layer that samples the next tokens from the model's outputs.""" | ||
import itertools | ||
import os | ||
import warnings | ||
from dataclasses import dataclass | ||
from importlib.util import find_spec | ||
from math import inf | ||
from math import ceil, inf | ||
from typing import Dict, Iterator, List, Optional, Tuple, Union | ||
|
||
import msgspec | ||
|
@@ -204,14 +205,18 @@ def _init_sampling_tensors( | |
self._sampling_tensors = None | ||
|
||
# Initialize new sampling tensors | ||
(sampling_tensors, do_penalties, do_top_p_top_k, | ||
do_min_p) = SamplingTensors.from_sampling_metadata( | ||
(sampling_tensors, do_penalties, do_top_p_top_k, do_min_p, | ||
top_k_scalar, top_p_scalar) = SamplingTensors.from_sampling_metadata( | ||
sampling_metadata, vocab_size, logits.device, logits.dtype) | ||
|
||
self._sampling_tensors = sampling_tensors | ||
self._do_penalties = do_penalties | ||
self._do_top_p_top_k = do_top_p_top_k | ||
self._do_min_p = do_min_p | ||
self._top_k_scalar = top_k_scalar | ||
self._top_p_scalar = top_p_scalar | ||
|
||
self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not yet used right |
||
|
||
def forward( | ||
self, | ||
|
@@ -337,6 +342,118 @@ def _should_modify_greedy_probs_inplace(self) -> bool: | |
return self.should_modify_greedy_probs_inplace | ||
|
||
|
||
class ApplyToppTopkScalar: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this be a |
||
""" | ||
The original implementation of _apply_top_k_top_p is more general | ||
as it uses vector topp, topk | ||
However in a lot of cases, topp and topk is same for all batch elements | ||
For such "scalar" topp, topk cases, we can use this class | ||
|
||
The main optimizations in this class is: | ||
Use topk instead of sort, which is much faster especially for small k. | ||
However just using topk might not suffice in cases as shown below | ||
Consider a tensor: 9 9 8 8 8 8 7 7 7 | ||
Topk, with k=5, on this yields 9 9 8 8 8 | ||
The value "8" is on the boundary, hence the last "8" gets snipped off | ||
However the original implementation accepts all the "8"s, | ||
so it should output: | ||
9 9 8 8 8 8 (6 values, even though k=5) | ||
To ensure these semantics, we perform topk with _padded_k elements | ||
If we find more boundary elements left over, | ||
then we keep incrementing _padded_k | ||
and in future calls use the expanded value of __padded_k | ||
|
||
The increments to _padded_k should be done | ||
with value > 1 to prevent excessive recompilations | ||
due to dynamic shapes (the output shape of the topk) | ||
|
||
The main logic of this is in __call__ | ||
This is a class instead of a function, just to keep track of | ||
the monotonic non-decreasing state _padded_k | ||
|
||
To enable the duplicates that are outside of kth border, | ||
set VLLM_HANDLE_TOPK_DUPLICATES to 1 or true. | ||
""" | ||
_padded_k = 0 | ||
_handle_duplicates = os.getenv('VLLM_HANDLE_TOPK_DUPLICATES', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we use |
||
'0').lower() in ['1', 'true'] | ||
|
||
def __init__(self, increment: int): | ||
self._increment = increment | ||
|
||
def __call__(self, logits: torch.Tensor, p: float, k: int): | ||
if k == 1 and not ApplyToppTopkScalar._handle_duplicates: | ||
new_logits = torch.full(logits.shape, | ||
-float("inf"), | ||
device=logits.device) | ||
vals, idx = torch.max(logits, keepdim=True, dim=1) | ||
new_logits.scatter_(1, idx, vals.to(new_logits.dtype)) | ||
return new_logits | ||
|
||
if k > ApplyToppTopkScalar._padded_k: | ||
ApplyToppTopkScalar._padded_k = min(k + self._increment, | ||
logits.shape[1]) | ||
|
||
vals, idx = torch.topk(logits, | ||
k=ApplyToppTopkScalar._padded_k, | ||
dim=1, | ||
sorted=True) | ||
|
||
# this "if" checks if we have bucketed so much that | ||
# we have padded k upto shape of logits | ||
if self._handle_duplicates and \ | ||
ApplyToppTopkScalar._padded_k != logits.shape[1]: | ||
smallest_of_top_k = vals[:, k - 1] | ||
num_duplicates_of_smallest_of_topk = torch.sum( | ||
logits == smallest_of_top_k.unsqueeze(1), 1) | ||
max_num_duplicates_of_smallest_of_topk = torch.max( | ||
num_duplicates_of_smallest_of_topk).item() | ||
|
||
# there are n repeats for a border | ||
# (border meaning the smallest value of the top k). | ||
# we do not know if only 1 or 2 or (n-1) | ||
# of them lie outside the kth border, | ||
# so we choose to conservatively increase by n-1 | ||
# when num_duplicates > _padded_k - k | ||
if max_num_duplicates_of_smallest_of_topk - 1 > ( | ||
ApplyToppTopkScalar._padded_k - k): | ||
incr = int( | ||
ceil((max_num_duplicates_of_smallest_of_topk - 1) / | ||
self._increment) * self._increment) | ||
# this while loop should be traversed at most twice, | ||
# because we dont increment by self._increment and retry | ||
# instead we compute incr in one go | ||
ApplyToppTopkScalar._padded_k = min( | ||
ApplyToppTopkScalar._padded_k + incr, logits.shape[1]) | ||
|
||
# recompute topk with expanded padded_k | ||
vals, idx = torch.topk(logits, | ||
k=ApplyToppTopkScalar._padded_k, | ||
dim=1, | ||
sorted=True) | ||
|
||
idx = torch.fliplr(idx) | ||
vals = torch.fliplr(vals) | ||
|
||
top_k_smallest_val_idx = vals.size(1) - k | ||
top_k_mask = vals[:, top_k_smallest_val_idx].unsqueeze(1) | ||
top_k_mask = vals < top_k_mask | ||
vals.masked_fill_(top_k_mask, -float("inf")) | ||
|
||
probs_sort = vals.softmax(dim=-1) | ||
probs_sum = probs_sort.cumsum(dim=-1) | ||
top_p_mask = probs_sum <= (1 - p) | ||
top_p_mask[:, -1] = False | ||
vals.masked_fill_(top_p_mask, -float("inf")) | ||
|
||
new_logits = torch.full(logits.shape, | ||
-float("inf"), | ||
device=logits.device) | ||
new_logits.scatter_(1, idx, vals.to(new_logits.dtype)) | ||
|
||
return new_logits | ||
|
||
|
||
def _apply_min_tokens_penalty( | ||
logits: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the increment arbitrary?