Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
245 commits
Select commit Hold shift + click to select a range
b504b73
[RFC][V1] LogitsProcessor interface
njhill Feb 16, 2025
55328d8
extra_args
afeldman-nm Apr 18, 2025
cc44096
Merge branch 'main' into extra_args
afeldman-nm Apr 21, 2025
876de25
Merge branch 'main' into extra_args
afeldman-nm Apr 21, 2025
191b9e1
rename
afeldman-nm Apr 22, 2025
1b658cd
rename
afeldman-nm Apr 22, 2025
6c892d8
Merge branch 'main' into extra_args
afeldman-nm Apr 22, 2025
6a0f87c
extra_body
afeldman-nm Apr 22, 2025
ac57a7f
completion custom arg unit test
afeldman-nm Apr 22, 2025
9753c75
Merge branch 'main' into extra_args
afeldman-nm Apr 22, 2025
c2f39bd
Merge branch 'main' into extra_args
afeldman-nm Apr 23, 2025
5c43609
tweak extra_args; test sampling params extra args via api
afeldman-nm Apr 23, 2025
1f8d6d1
Merge branch 'main' into extra_args
afeldman-nm Apr 23, 2025
368f907
remove unnecessary extra_body field/breakout
afeldman-nm Apr 23, 2025
a90311a
removed transcription scenario
afeldman-nm Apr 23, 2025
0e7809d
Merge branch 'main' into extra_args
afeldman-nm Apr 25, 2025
42b0d31
small changes
afeldman-nm May 1, 2025
f1ef8ef
spec decode min p
afeldman-nm May 2, 2025
b270ac4
spec decode min p
afeldman-nm May 2, 2025
49531cb
wip TPU fix
afeldman-nm May 5, 2025
066761d
merge
afeldman-nm May 5, 2025
6a3f618
Merge branch 'logitsprocs' into logitsprocs_tpu
afeldman-nm May 5, 2025
c18c558
merge
afeldman-nm May 5, 2025
ddf3255
merge
afeldman-nm May 5, 2025
bc5fd4f
test logits processors
afeldman-nm May 5, 2025
6e03ca0
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 6, 2025
f18610d
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 6, 2025
510623c
Merge branch 'main' into extra_args
afeldman-nm May 6, 2025
52988b8
revert sampling params
afeldman-nm May 7, 2025
730fb25
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 7, 2025
94e5855
Merge branch 'main' into extra_args
afeldman-nm May 7, 2025
a869a6d
impl based on rfc
afeldman-nm May 7, 2025
934de06
Merge branch 'main' into extra_args
afeldman-nm May 7, 2025
ea35594
merge
afeldman-nm May 9, 2025
c9a193f
merge
afeldman-nm May 9, 2025
cf6d7c5
upstream merge
afeldman-nm May 13, 2025
c8001b7
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 13, 2025
eadd6ea
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 16, 2025
a47c414
requires_nongreedy
afeldman-nm May 16, 2025
40c9ac3
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 16, 2025
94d93a6
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 19, 2025
c035aff
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 22, 2025
6f62142
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 22, 2025
50ee0b5
remove TPU hacks
afeldman-nm May 22, 2025
cfa0c86
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 22, 2025
000794d
is_tpu
afeldman-nm May 22, 2025
5a2c7f8
removed property
afeldman-nm May 27, 2025
4405c94
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 27, 2025
7226815
feature flag re-enables hard-coded min-p for TPU
afeldman-nm May 27, 2025
494b52d
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 27, 2025
b36ea72
_device_tensor
afeldman-nm May 27, 2025
47c7e41
class method
afeldman-nm May 27, 2025
77c0959
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 27, 2025
029f003
Merge branch 'logitsprocs' into logitsprocs_merge
afeldman-nm May 27, 2025
10744c2
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 27, 2025
fa93444
upstream merge swap logic
afeldman-nm May 27, 2025
647bbea
bugfixes
afeldman-nm May 27, 2025
af7fcbf
test type annotation
afeldman-nm May 27, 2025
f278fd2
Merge branch 'main' into logitsprocs_merge
afeldman-nm May 27, 2025
8318f1a
removed errant todo
afeldman-nm May 27, 2025
1474fcf
Merge branch 'logitsprocs' into logitsprocs_merge
afeldman-nm Jun 2, 2025
1c29a6f
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 2, 2025
cefb163
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 2, 2025
7ebec2d
fixed req removal bookkeeping
afeldman-nm Jun 2, 2025
fdb770d
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 2, 2025
481c985
proper order of add in logitsprocs
afeldman-nm Jun 2, 2025
f655cc2
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 3, 2025
0695f26
upstream merge
afeldman-nm Jun 3, 2025
c3047cc
Merge branch 'main' into extra_args_merge
afeldman-nm Jun 3, 2025
8065a51
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 3, 2025
0cf5721
upstream merge
afeldman-nm Jun 4, 2025
3eef20e
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 4, 2025
6627c6e
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 4, 2025
87a1835
revert backend changes
afeldman-nm Jun 4, 2025
ecf26ac
wip
afeldman-nm Jun 4, 2025
ea966c8
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 4, 2025
a1ab94c
Merge branch 'logitsprocs' into logitsprocs_reorg
afeldman-nm Jun 4, 2025
e51a1e4
wip
afeldman-nm Jun 4, 2025
dc2d57e
restructure
afeldman-nm Jun 4, 2025
f4f6980
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 4, 2025
fdc0c4f
Merge branch 'logitsprocs' into logitsprocs_reorg
afeldman-nm Jun 4, 2025
b7d0779
fixed some tests
afeldman-nm Jun 4, 2025
906105b
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 4, 2025
8539576
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 6, 2025
abf42cc
bugfix
afeldman-nm Jun 6, 2025
60e5016
bugfix
afeldman-nm Jun 6, 2025
bd1ffa3
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 6, 2025
279679b
Merge branch 'logitsprocs_reorg' into logitsprocs_reorg_bugfix
afeldman-nm Jun 6, 2025
8cf4817
merge
afeldman-nm Jun 7, 2025
17c10ca
refactor
afeldman-nm Jun 7, 2025
849d829
bugfix - redundant batch update
afeldman-nm Jun 7, 2025
03a836b
min tokens test bugfix
afeldman-nm Jun 7, 2025
5ab8af1
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 7, 2025
d92a3f3
remove prints
afeldman-nm Jun 7, 2025
1f87ec8
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 8, 2025
ef2294d
rejection sampling test bugfix
afeldman-nm Jun 8, 2025
198db48
sampler test bugfix
afeldman-nm Jun 8, 2025
2f2550b
removed logitsprocs where not needed in test
afeldman-nm Jun 8, 2025
1b1f8ca
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 8, 2025
5fc130b
refactor
afeldman-nm Jun 9, 2025
7b8f299
sampling_params min-p check
afeldman-nm Jun 9, 2025
4d5ea01
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 9, 2025
d8a6761
Merge branch 'logitsprocs' into logitsprocs_valid
afeldman-nm Jun 9, 2025
0515848
small test optimization
afeldman-nm Jun 9, 2025
17e7f62
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 9, 2025
c392898
refactor
afeldman-nm Jun 10, 2025
dc4b6b8
wip tests
afeldman-nm Jun 11, 2025
fd26581
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 11, 2025
5fb16a6
refactor
afeldman-nm Jun 11, 2025
b0658c2
passing mixed batch test for min_p and none
afeldman-nm Jun 11, 2025
ac608f1
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 11, 2025
7f44262
Merge branch 'logitsprocs' into logitsprocs_reorder
afeldman-nm Jun 11, 2025
0a20965
mix batch test passes without reorder
afeldman-nm Jun 12, 2025
bdea83c
refactor
afeldman-nm Jun 12, 2025
7f4d72e
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 12, 2025
ec25ab5
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 12, 2025
5a5e38f
move-only
afeldman-nm Jun 12, 2025
19d3882
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 12, 2025
588b845
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 13, 2025
38746ae
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 13, 2025
9703c4a
fake reordering logic
afeldman-nm Jun 13, 2025
6078602
fake logitsproc invocation against fake batch
afeldman-nm Jun 13, 2025
ae5b600
almost passing
afeldman-nm Jun 13, 2025
d5679bb
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 13, 2025
84cad20
Merge branch 'logitsprocs' into logitsprocs_reorder
afeldman-nm Jun 13, 2025
89ea6dd
wip refactor
afeldman-nm Jun 13, 2025
76438fb
test fix
afeldman-nm Jun 13, 2025
e03c561
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 13, 2025
045bc01
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 16, 2025
9ac6190
latest
afeldman-nm Jun 16, 2025
c1b8e69
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 16, 2025
e83f90b
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 17, 2025
360f2c4
removed tpu hack
afeldman-nm Jun 17, 2025
eac0c82
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 17, 2025
395d472
wip tpu backward compat
afeldman-nm Jun 17, 2025
5c53a8c
typing
afeldman-nm Jun 17, 2025
e08e4f4
Merge branch 'logitsprocs' into logitsprocs_tpu
afeldman-nm Jun 17, 2025
f7969c5
wip
afeldman-nm Jun 17, 2025
1117f51
first pass at tpu/gpu separation
afeldman-nm Jun 17, 2025
e0fd74b
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 17, 2025
2a4e09c
first pass at new TPU approach
afeldman-nm Jun 17, 2025
1f4cad3
docstrings
afeldman-nm Jun 17, 2025
a6be23c
Merge branch 'main' into tpu-isolate
afeldman-nm Jun 17, 2025
b28588c
merged in GPU/TPU decoupling PR
afeldman-nm Jun 17, 2025
ca87319
bugfix
afeldman-nm Jun 17, 2025
32e4275
type checking
afeldman-nm Jun 17, 2025
9aeb49d
Merge branch 'main' into tpu-isolate
afeldman-nm Jun 18, 2025
0383e73
InputBatch fix
afeldman-nm Jun 18, 2025
9564879
Merge branch 'tpu-isolate' into logitsprocs_merge
afeldman-nm Jun 18, 2025
c02ef1b
merge
afeldman-nm Jun 18, 2025
b804423
vllm_xargs/kv_transfer_params compatibility
afeldman-nm Jun 18, 2025
17f02ee
fix
afeldman-nm Jun 18, 2025
061ac67
remove unnecessary unit test
afeldman-nm Jun 18, 2025
421c278
precedence
afeldman-nm Jun 18, 2025
f315e0e
pre-commit fix
afeldman-nm Jun 18, 2025
3d92a07
Merge branch 'main' into extra_args_merge
afeldman-nm Jun 18, 2025
873b89f
merge
afeldman-nm Jun 18, 2025
f8609ff
Merge branch 'main' into extra_args_merge
afeldman-nm Jun 18, 2025
9c5f407
Documentation changes
afeldman-nm Jun 18, 2025
0857dc4
refactor
afeldman-nm Jun 18, 2025
f9c4e19
typing
afeldman-nm Jun 18, 2025
03c6010
typing
afeldman-nm Jun 18, 2025
95e1b0d
typing
afeldman-nm Jun 18, 2025
9daeaed
Update vllm/entrypoints/openai/protocol.py
afeldman-nm Jun 18, 2025
baf90c9
feedback
afeldman-nm Jun 18, 2025
4f04198
remove swap type
afeldman-nm Jun 18, 2025
da23801
refactor
afeldman-nm Jun 18, 2025
34c9866
move/swap refactoring
afeldman-nm Jun 18, 2025
e1f0455
refactoring
afeldman-nm Jun 18, 2025
fe088ea
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 18, 2025
f506dd7
small fixes
afeldman-nm Jun 18, 2025
3257deb
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 18, 2025
c0b2068
Merge branch 'main' into extra_args_merge
afeldman-nm Jun 18, 2025
5894110
Merge branch 'extra_args' into lp_ext
afeldman-nm Jun 18, 2025
3885bc5
merge
afeldman-nm Jun 20, 2025
e06f9e9
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 24, 2025
7d89720
batch update builder
afeldman-nm Jun 24, 2025
33e0f14
comments
afeldman-nm Jun 24, 2025
26c18d6
Merge branch 'logitsprocs' into lp_ext
afeldman-nm Jun 24, 2025
2e56aec
add custom logitsprocs arg
afeldman-nm Jun 24, 2025
2213b44
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 24, 2025
36d6f69
logitsprocs+pooling bugfix
afeldman-nm Jun 24, 2025
28b6606
Merge branch 'logitsprocs' into lp_ext_merge
afeldman-nm Jun 24, 2025
e422caa
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jun 24, 2025
3cca78f
small tweaks
afeldman-nm Jun 24, 2025
4177594
refactor
afeldman-nm Jun 24, 2025
40407b7
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 24, 2025
5209ffe
Fixed min tokens bug
afeldman-nm Jun 25, 2025
301db58
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
6f41503
fixed logit bias bug
afeldman-nm Jun 25, 2025
36f161d
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
a14d3a4
Merge branch 'logitsprocs' into lp_ext_merge
afeldman-nm Jun 25, 2025
1716f07
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jun 25, 2025
b429c10
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
fbdb595
comment Re: output tokens list ref
afeldman-nm Jun 25, 2025
e3dc71e
Merge branch 'logitsprocs' into logitsprocs_merge
afeldman-nm Jun 25, 2025
aa4c519
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
3ae8a6b
Merge branch 'logitsprocs' into lp_ext_merge
afeldman-nm Jun 25, 2025
d58bf24
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jun 25, 2025
77bba48
refactor
afeldman-nm Jun 25, 2025
890a9cd
refactor
afeldman-nm Jun 25, 2025
6b3ea9f
Update vllm/v1/sample/logits_processor.py
afeldman-nm Jun 25, 2025
070d71d
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
5384732
feedback
afeldman-nm Jun 25, 2025
9aebc9f
Update vllm/v1/sample/sampler.py
afeldman-nm Jun 25, 2025
8bb6bf0
revert some changes
afeldman-nm Jun 25, 2025
0a88e16
refactor
afeldman-nm Jun 25, 2025
18721da
Merge branch 'logitsprocs' of https://github.com/neuralmagic/vllm int…
afeldman-nm Jun 25, 2025
dc0b23a
refactor
afeldman-nm Jun 25, 2025
21ad212
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
2f0de77
argmax_invariant
afeldman-nm Jun 25, 2025
8d97a7c
batch update builder impl
afeldman-nm Jun 25, 2025
2abd24d
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
d1c6607
refactor
afeldman-nm Jun 25, 2025
9fe0bc3
wip dict removal
afeldman-nm Jun 25, 2025
aa18e8f
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
f7a162c
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 26, 2025
de81e42
updated unit tests
afeldman-nm Jun 26, 2025
20928f0
refactor
afeldman-nm Jun 26, 2025
a0e5398
iterators
afeldman-nm Jun 26, 2025
d4704d7
refactor
afeldman-nm Jun 26, 2025
729729d
reorg
afeldman-nm Jun 27, 2025
9948fd3
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 27, 2025
bc48f38
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 27, 2025
9eeea03
feedback
afeldman-nm Jun 28, 2025
1078a24
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 28, 2025
cd766a4
feedback
afeldman-nm Jun 28, 2025
2628f98
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 30, 2025
2ecb37d
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 30, 2025
64ac2cf
input batch tests
afeldman-nm Jul 1, 2025
4da82cc
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
bd62df4
refactor
afeldman-nm Jul 1, 2025
8455bb6
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
a6dc218
attempted fmt fix
afeldman-nm Jul 1, 2025
55fd6e7
fixed cancellation bug
afeldman-nm Jul 1, 2025
b55f88e
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
402d012
Update vllm/v1/worker/gpu_model_runner.py
afeldman-nm Jul 1, 2025
06fc926
pr feedback
afeldman-nm Jul 1, 2025
8d229ed
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
45dade4
mem util
afeldman-nm Jul 1, 2025
d377a6b
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
6ae7574
memory util
afeldman-nm Jul 1, 2025
5203324
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
68aab25
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
066736d
merge'
afeldman-nm Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
626 changes: 626 additions & 0 deletions tests/v1/sample/test_logits_processors.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions tests/v1/sample/test_logprobs_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

# FIXME(rob): enable prefix caching once supported.
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False,gpu_memory_utilization=0.8" # noqa: E501
SERVER_ARGS = [
"--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests"
"--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests",
"--gpu-memory-utilization=0.8"
]
NUM_CONCURRENT = 100

Expand Down
5 changes: 2 additions & 3 deletions tests/v1/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn.functional as F

from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
RejectionSampler)
Expand Down Expand Up @@ -58,7 +59,6 @@ def create_sampling_metadata(
all_random=not all_greedy,
top_p=top_p,
top_k=top_k,
min_p=torch.empty(1, ),
generators=generators,
max_num_logprobs=0,
no_penalties=False,
Expand All @@ -67,10 +67,9 @@ def create_sampling_metadata(
presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]),
output_token_ids=[],
min_tokens={},
logit_bias=[None],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
)


Expand Down
154 changes: 5 additions & 149 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
import torch

from vllm.platforms import current_platform
from vllm.utils import make_tensor_with_pad
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler

PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS = 256
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [
Expand Down Expand Up @@ -48,18 +51,6 @@ def _create_prompt_tokens_tensor(
)


def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> list[Optional[dict[int, float]]]:
res: list[Optional[dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
return res


def _create_allowed_token_ids(
batch_size: int,
vocab_size: int,
Expand Down Expand Up @@ -145,7 +136,6 @@ def _create_default_sampling_metadata(
all_random=False,
top_p=None,
top_k=None,
min_p=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
Expand All @@ -155,43 +145,13 @@ def _create_default_sampling_metadata(
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
)
return fake_sampling_metadata


def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: list[int]
) -> dict[int, tuple[int, set[int]]]:
"""
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
batch.

If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
min_tokens: dict[int, tuple[int, set[int]]] = {}
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens[index] = (
np.random.randint(num_output_tokens + 1,
2 * num_output_tokens),
set(
np.random.randint(0, vocab_size - 1)
for _ in range(np.random.randint(0, vocab_size))))
else:
min_tokens[index] = (np.random.randint(0,
num_output_tokens), set())
return min_tokens


def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int) -> tuple[list[list[int]], list[list[int]]]:
Expand Down Expand Up @@ -227,36 +187,6 @@ def _create_weighted_output_token_list(
return output_token_ids, sorted_token_ids_in_output


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
"""
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens = _generate_min_token_penalties_and_stop_tokens(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
batch_indices_for_min_token_penalty)
sampling_metadata.min_tokens = min_tokens
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
if token_id in stop_token_ids:
assert logits[batch_idx][token_id] == -float("inf")
else:
assert logits[batch_idx][token_id] != -float("inf")


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
Expand Down Expand Up @@ -401,80 +331,6 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
or non_penalized_token_id in output_tokens)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("min_p", [0.0, 0.1])
def test_sampler_min_p(device: str, batch_size: int, min_p: float):
"""
Tests that when min_p is applied, tokens with probability below
min_p * max_prob are masked with -inf.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)

# Create one dominant token per batch
for i in range(batch_size):
fake_logits[i, 0] = 10.0 # High logit for first token
fake_logits[i, 1:] = 1e-2 # Others remain low

sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))

# Configure min_p parameters
sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device)

sampler = Sampler()
logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p)
logits = logits.cpu()

for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
if token_id == 0:
# Dominant token should always be unmasked
assert logits[batch_idx][token_id] != -float("inf")
else:
if min_p > 0.0:
# Non-dominant tokens should be masked when min_p > 0
assert logits[batch_idx][token_id] == -float("inf")
else:
# No masking when min_p is 0
assert logits[batch_idx][token_id] != -float("inf")


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.logit_bias = _create_logit_bias(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
bias_value=bias_value,
)
sampler = Sampler()
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
biased_index = min(batch_idx, VOCAB_SIZE - 1)
for token_id in range(VOCAB_SIZE):
if biased_index == token_id:
assert logits_for_req[token_id] == pytest.approx(bias_value +
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
Expand Down
81 changes: 80 additions & 1 deletion tests/v1/sample/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterator
from enum import Enum
from typing import Optional
from typing import NamedTuple, Optional

import regex as re
import torch

from vllm import CompletionOutput
from vllm.utils import make_tensor_with_pad
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata


class BatchLogprobsComposition(Enum):
Expand Down Expand Up @@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob(
logprobs = completion_output.logprobs
assert logprobs is not None
return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)])


def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float)
return fake_logits


def create_penalty_tensor(batch_size: int, penalty_value: float,
device: torch.device) -> torch.Tensor:
return torch.full((batch_size, ),
fill_value=penalty_value,
dtype=torch.float,
device=device)


def create_prompt_tokens_tensor(
prompt_token_ids: list[list[int]],
vocab_size: int,
device: torch.device,
) -> torch.Tensor:
return make_tensor_with_pad(
prompt_token_ids,
pad=vocab_size,
device=device,
dtype=torch.int64,
pin_memory=False,
)


class LogitsprocsTestFakes(NamedTuple):
"""Wraps fake data structures to support testing"""
logits: torch.Tensor
sampling_metadata: SamplingMetadata

def get_logitsprocs_by_cls(
self,
cls: type[LogitsProcessor],
) -> Iterator[LogitsProcessor]:
"""Yield logits processors of a specific class.

Args:
cls: :class:`LogitsProcessor` subclass

Returns:
Iterator over logits processors
"""
return (lp for lp in self.sampling_metadata.logitsprocs.all
if isinstance(lp, cls))

def get_logitsprocs(self) -> Iterator[LogitsProcessor]:
"""Iterator over all logits processors."""
return self.sampling_metadata.logitsprocs.all


def fake_update_logitsprocs_state(
test_fakes: LogitsprocsTestFakes,
batch_update: BatchUpdate,
) -> None:
"""Imitate logits processors persistent batch state update
in engine core"""
for logitproc in test_fakes.get_logitsprocs():
logitproc.update_state(batch_update)


def fake_apply_logitsprocs(
test_fakes: LogitsprocsTestFakes,
slice_indices: list[int],
) -> torch.Tensor:
"""Imitate application of logits processors in engine core"""
logits = test_fakes.logits[torch.tensor(slice_indices,
dtype=torch.long)].clone()
for processor in test_fakes.get_logitsprocs():
logits = processor.apply(logits)
return logits
Loading