-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling #11394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+355
−190
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
ccf53d1
[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling
WoosukKwon 9df8ccf
Merge branch 'main' into v1-topk-top
WoosukKwon 83d9aa4
update
WoosukKwon 0c6d409
Add warning
WoosukKwon 121cea5
fix
WoosukKwon cf097f4
minor
WoosukKwon 98374e0
comment
WoosukKwon e068d68
Minor
WoosukKwon 6e97c5f
fix
WoosukKwon 15fda81
minor
WoosukKwon 3dcac1c
Fix tests
WoosukKwon 5cac3e1
Minor
WoosukKwon 8061a16
comment
WoosukKwon e968e18
Minor
WoosukKwon 0f784a5
Minor
WoosukKwon 6bea166
Consider VLLM_USE_FLASHINFER_SAMPLER
WoosukKwon 68ffc96
Minor
WoosukKwon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import List, Set, Tuple | ||
|
||
import torch | ||
|
||
from vllm.model_executor.layers.utils import ( | ||
apply_penalties as _apply_penalties) | ||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad | ||
|
||
|
||
def apply_min_token_penalties(logits: torch.Tensor, | ||
output_token_ids: List[List[int]], | ||
stop_token_ids: List[Set[int]], | ||
min_tokens: List[int]) -> None: | ||
""" | ||
Applies minimum token penalty by setting the logits of the stop tokens | ||
to -inf. | ||
""" | ||
min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] | ||
for index, min_token in enumerate(min_tokens): | ||
if (len(output_token_ids[index]) < min_token): | ||
for stop_token_id in stop_token_ids[index]: | ||
min_tokens_logits_to_penalize.append((index, stop_token_id)) | ||
if min_tokens_logits_to_penalize: | ||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") | ||
|
||
|
||
def apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor, | ||
presence_penalties: torch.Tensor, | ||
frequency_penalties: torch.Tensor, | ||
repetition_penalties: torch.Tensor, | ||
output_token_ids: List[List[int]]) -> torch.Tensor: | ||
""" | ||
Applies presence, frequency and repetition penalties to the logits. | ||
""" | ||
_, vocab_size = logits.shape | ||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, | ||
logits.device) | ||
return _apply_penalties(logits, prompt_token_ids, output_tokens_t, | ||
presence_penalties, frequency_penalties, | ||
repetition_penalties) | ||
|
||
|
||
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, | ||
device: torch.device) -> torch.Tensor: | ||
""" | ||
Convert the different list data structures to tensors. | ||
""" | ||
output_tokens_tensor = make_tensor_with_pad( | ||
output_token_ids, | ||
# Use the value of vocab_size as a pad since we don't have a | ||
# token_id of this value. | ||
pad=vocab_size, | ||
device="cpu", | ||
dtype=torch.int64, | ||
pin_memory=is_pin_memory_available(), | ||
) | ||
return output_tokens_tensor.to(device, non_blocking=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
from typing import Dict | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm import envs | ||
from vllm.logger import init_logger | ||
from vllm.platforms import current_platform | ||
|
||
logger = init_logger(__name__) | ||
|
||
try: | ||
import flashinfer.sampling | ||
is_flashinfer_available = True | ||
except ImportError: | ||
is_flashinfer_available = False | ||
|
||
|
||
class TopKTopPSampler(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
if current_platform.is_cuda: | ||
if is_flashinfer_available: | ||
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: | ||
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for | ||
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by | ||
# default it is unused). For backward compatibility, we set | ||
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and | ||
# interpret it differently in V0 and V1 samplers: In V0, | ||
# None means False, while in V1, None means True. This is | ||
# why we use the condition | ||
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. | ||
logger.info("Using FlashInfer for top-p & top-k sampling.") | ||
self.forward = self.forward_cuda | ||
else: | ||
logger.warning( | ||
"FlashInfer is available, but it is not enabled. " | ||
"Falling back to the PyTorch-native implementation of " | ||
"top-p & top-k sampling. For the best performance, " | ||
"please set VLLM_USE_FLASHINFER_SAMPLER=1.") | ||
self.forward = self.forward_native | ||
else: | ||
logger.warning( | ||
"FlashInfer is not available. Falling back to the PyTorch-" | ||
"native implementation of top-p & top-k sampling. For the " | ||
"best performance, please install FalshInfer.") | ||
self.forward = self.forward_native | ||
else: | ||
self.forward = self.forward_native | ||
|
||
def forward_native( | ||
self, | ||
logits: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""PyTorch-native implementation of top-k and top-p sampling.""" | ||
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) | ||
probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
return random_sample(probs, generators) | ||
|
||
def forward_cuda( | ||
self, | ||
logits: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""More optimized implementation for top-k and top-p sampling.""" | ||
probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
if no_top_k and no_top_p: | ||
# We prefer `random_sample` over `flashinfer_sample` when sorting is | ||
# not needed. This is because `random_sample` does not require | ||
# CPU-GPU synchronization while `flashinfer_sample` does. | ||
return random_sample(probs, generators) | ||
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) | ||
|
||
|
||
def apply_top_k_top_p( | ||
logits: torch.Tensor, | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""Apply top-k and top-p masks to the logits. | ||
|
||
This function sorts the logits tensor, which can be slow for large batches. | ||
""" | ||
if no_top_k and no_top_p: | ||
return logits | ||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False) | ||
|
||
if not no_top_k: | ||
# Apply top-k. | ||
top_k_mask = logits_sort.size(1) - k.to(torch.long) | ||
# Get all the top_k values. | ||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) | ||
top_k_mask = logits_sort < top_k_mask | ||
logits_sort.masked_fill_(top_k_mask, -float("inf")) | ||
|
||
if not no_top_p: | ||
# Apply top-p. | ||
probs_sort = logits_sort.softmax(dim=-1) | ||
probs_sum = probs_sort.cumsum(dim=-1) | ||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) | ||
# at least one | ||
top_p_mask[:, -1] = False | ||
logits_sort.masked_fill_(top_p_mask, -float("inf")) | ||
|
||
# Re-sort the probabilities. | ||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) | ||
return logits | ||
|
||
|
||
def random_sample( | ||
probs: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
) -> torch.Tensor: | ||
"""Randomly sample from the probabilities. | ||
|
||
We use this function instead of torch.multinomial because torch.multinomial | ||
causes CPU-GPU synchronization. | ||
""" | ||
q = torch.empty_like(probs) | ||
# NOTE(woosuk): To batch-process the requests without their own seeds, | ||
# which is the common case, we first assume that every request does | ||
# not have its own seed. Then, we overwrite the values for the requests | ||
# that have their own seeds. | ||
if len(generators) != probs.shape[0]: | ||
q.exponential_() | ||
if generators: | ||
# TODO(woosuk): This can be slow because we handle each request | ||
# one by one. Optimize this. | ||
for i, generator in generators.items(): | ||
q[i].exponential_(generator=generator) | ||
return probs.div_(q).argmax(dim=-1).view(-1) | ||
|
||
|
||
def flashinfer_sample( | ||
probs: torch.Tensor, | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
) -> torch.Tensor: | ||
"""Sample from the probabilities using FlashInfer. | ||
|
||
Statistically, this function is equivalent to the `random_sample` function. | ||
However, this function is faster because it avoids sorting the logits tensor | ||
via rejection sampling. | ||
|
||
NOTE: The outputs of this function do not necessarily match the outputs of | ||
the `random_sample` function. It only guarantees that the outputs are | ||
statistically equivalent. | ||
|
||
NOTE: This function includes CPU-GPU synchronization, while `random_sample` | ||
does not. Call this function at the end of the forward pass to minimize | ||
the synchronization overhead. | ||
""" | ||
assert not (no_top_k and no_top_p) | ||
max_top_k_round = 32 | ||
batch_size = probs.shape[0] | ||
uniform_samples = torch.empty((max_top_k_round, batch_size), | ||
device=probs.device) | ||
if len(generators) != batch_size: | ||
uniform_samples.uniform_() | ||
if generators: | ||
for i, generator in generators.items(): | ||
uniform_samples[:, i].uniform_(generator=generator) | ||
|
||
if no_top_k: | ||
# Top-p only. | ||
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( | ||
probs, uniform_samples, p, deterministic=True) | ||
elif no_top_p: | ||
# Top-k only. | ||
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( | ||
probs, uniform_samples, k, deterministic=True) | ||
else: | ||
# Both top-k and top-p. | ||
next_token_ids, success = ( | ||
flashinfer.sampling.top_k_top_p_sampling_from_probs( | ||
probs, uniform_samples, k, p, deterministic=True)) | ||
|
||
# NOTE: CPU-GPU synchronization happens here. | ||
if not success.all(): | ||
if not no_top_k: | ||
probs = flashinfer.sampling.top_k_renorm_prob(probs, k) | ||
if not no_top_p: | ||
probs = flashinfer.sampling.top_p_renorm_prob(probs, p) | ||
next_token_ids = flashinfer.sampling.sampling_from_probs( | ||
probs, uniform_samples[0], deterministic=True) | ||
return next_token_ids.view(-1) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.