Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
56e4d74
wrapper
afeldman-nm Aug 28, 2025
ff1d90b
Merge branch 'main' into lp_v0_wrap
afeldman-nm Aug 28, 2025
e66e2e0
skeleton of example
afeldman-nm Aug 28, 2025
bdcb7af
Merge branch 'main' into lp_v0_wrap
afeldman-nm Aug 28, 2025
2dc7b61
feedback
afeldman-nm Aug 28, 2025
52fb9c8
feedback
afeldman-nm Aug 29, 2025
ba7805b
Merge branch 'main' into lp_v0_wrap
afeldman-nm Aug 29, 2025
8662cc9
feedback
afeldman-nm Aug 29, 2025
484780d
example
afeldman-nm Aug 29, 2025
e225a80
test
afeldman-nm Aug 29, 2025
df9a5d7
annotation fix
afeldman-nm Aug 29, 2025
5e9c6af
Merge branch 'main' into lp_v0_wrap
afeldman-nm Aug 29, 2025
245971d
test passing
afeldman-nm Aug 29, 2025
5be18e6
Merge branch 'main' into lp_v0_wrap
afeldman-nm Aug 29, 2025
633e755
refactor
afeldman-nm Aug 29, 2025
02a6b06
rename
afeldman-nm Sep 1, 2025
c559c8e
Merge branch 'main' into lp_v0_wrap
afeldman-nm Sep 1, 2025
9347520
wip:
afeldman-nm Sep 1, 2025
2dfcd00
refactor
afeldman-nm Sep 1, 2025
17c7af5
Merge branch 'main' into lp_v0_wrap
afeldman-nm Sep 1, 2025
df51606
refactor
afeldman-nm Sep 1, 2025
14d246d
rename
afeldman-nm Sep 1, 2025
97143c4
refactor
afeldman-nm Sep 1, 2025
68ec593
refactor
afeldman-nm Sep 1, 2025
e61a521
Merge branch 'main' into lp_v0_plumb
afeldman-nm Sep 2, 2025
c23fec8
small fix; refactor
afeldman-nm Sep 2, 2025
8a8bbd5
Merge branch 'main' into lp_v0_wrap
afeldman-nm Sep 2, 2025
67d1860
lint
afeldman-nm Sep 2, 2025
7256952
Merge branch 'main' into lp_v0_wrap
afeldman-nm Sep 2, 2025
cd81b39
Merge branch 'lp_v0_plumb' of https://github.com/neuralmagic/vllm int…
afeldman-nm Sep 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions examples/offline_inference/logits_processor/custom_req.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""This example demonstrates wrapping a request-level logits processor to be
compatible with vLLM's batch-level logits processing

For demo purposes, a dummy logits processor is employed which, if
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
will mask out all tokens except `target_token`. This logits processor can be
applied to a vector of logits associated with a single decode step for a single
request. The logits processor cannot be applied to a request which does not
pass in a `target_token` custom argument.

The request-level dummy logits processor is wrapped to create a batch-level
logits processor, which can apply the logits processor to output logits from
all requests in the persistent batch in a given decode step. For requests which
do not provide a `target_token` argument, the corresponding row of `logits`
will not be modified.

A batch is constructed with `temperature=0.0` and 50% of requests specifying
`target_token`, and for these requests - and *only* these requests - we
expect the `target_token` to be decoded in each step, yielding an output
similar to that shown below:

Generated Outputs:
------------------------------------------------------------
Prompt: 'Hello, my name is'
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
------------------------------------------------------------
Prompt: 'The president of the United States is'
Output: " not a racist. He is a racist.\nHe's a racist because he"
------------------------------------------------------------
Prompt: 'The capital of France is'
Output: ' also also also also also also also also also also also also also
also also also'
------------------------------------------------------------
Prompt: 'The future of AI is'
Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------
"""

from typing import Any, Optional

import torch

from vllm import LLM, SamplingParams
from vllm.logger import init_logger
from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor,
RequestLogitsProcessor,
)

logger = init_logger(__name__)


class DummyPerReqLogitsProcessor:
"""The request-level logits processor masks out all logits except the
token id identified by `target_token`"""

def __init__(self, target_token: int) -> None:
"""Specify `target_token`"""
self.target_token = target_token

def __call__(
self,
output_ids: list[int],
logits: torch.Tensor,
) -> torch.Tensor:
val_to_keep = logits[self.target_token].item()
logits[:] = float("-inf")
logits[self.target_token] = val_to_keep
return logits


class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of wrapping a fake request-level logit processor to create a
batch-level logits processor"""

def is_argmax_invariant(self) -> bool:
return False

def new_req_logits_processor(
self,
params: SamplingParams,
) -> Optional[RequestLogitsProcessor]:
"""This method returns a new request-level logits processor, customized
to the `target_token` value associated with a particular request.

Returns None if the logits processor should not be applied to the
particular request. To use the logits processor the request must have
a "target_token" custom argument with an integer value.

Args:
params: per-request sampling params

Returns:
`Callable` request logits processor, or None
"""
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is None:
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)


# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a mixture of requests which do and don't utilize the dummy logitproc
sampling_params_list = [
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
SamplingParams(temperature=0.0),
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
SamplingParams(temperature=0.0),
]


def main():
# Create an LLM.
llm = LLM(
model="facebook/opt-125m",
logits_processors=[WrappedPerReqLogitsProcessor],
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params_list)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)


if __name__ == "__main__":
main()
165 changes: 165 additions & 0 deletions examples/offline_inference/logits_processor/custom_req_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""This example demonstrates a special case of wrapping a request-level logits
processor, namely the case where it is necessary to utilize engine config or
environment info passed to the constructor. The subclass must override the
wrapper base class `__init__()` method to access the engine config, the device
identifier, or the flag which indicates whether pinned memory is available.

For demo purposes, a request-level dummy logits processor is employed which
causes the same token (`target_token`) to be decoded in each step. The
request-level dummy logits processor is wrapped to create a batch-level logits
processor, which can apply the logits processor to output logits from all
requests in the persistent batch in a given decode step.

The wrapped dummy logits processor below models a scenario where we must
disable the logits processor on non-"cuda" platforms. The wrapper base class
`__init__()` is overridden in order to check this condition and set a flag.

A batch is constructed with `temperature=0.0` and 50% of requests specifying
`target_token`, and for these requests - and *only* these requests - we
expect that on a "cuda" device the output will look something like:

Generated Outputs:
------------------------------------------------------------
Prompt: 'Hello, my name is'
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
------------------------------------------------------------
Prompt: 'The president of the United States is'
Output: " not a racist. He is a racist.\nHe's a racist because he"
------------------------------------------------------------
Prompt: 'The capital of France is'
Output: ' also also also also also also also also also also also also also
also also also'
------------------------------------------------------------
Prompt: 'The future of AI is'
Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------

which indicates that the logits processor is running. However, on a non-"cuda"
device, the first and third requests would not repeat the same token.
"""

from typing import Optional

import torch

from vllm import LLM, SamplingParams
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor,
RequestLogitsProcessor,
)

logger = init_logger(__name__)


class DummyPerReqLogitsProcessor:
"""The request-level logits processor masks out all logits except the
token id identified by `target_token`"""

def __init__(self, target_token: int) -> None:
"""Specify `target_token`"""
self.target_token = target_token

def __call__(
self,
output_ids: list[int],
logits: torch.Tensor,
) -> torch.Tensor:
val_to_keep = logits[self.target_token].item()
logits[:] = float("-inf")
logits[self.target_token] = val_to_keep
return logits


class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type"""

def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
super().__init__(vllm_config, device, is_pin_memory)
self.is_cuda = device.type == "cuda"

def is_argmax_invariant(self) -> bool:
return False

def new_req_logits_processor(
self,
params: SamplingParams,
) -> Optional[RequestLogitsProcessor]:
"""This method returns a new request-level logits processor, customized
to the `target_token` value associated with a particular request.

Returns None if the logits processor should not be applied to the
particular request. To use the logits processor the request must have
a "target_token" custom argument with an integer value, and the device
must be "cuda"-type

Args:
params: per-request sampling params

Returns:
`Callable` request logits processor, or None
"""
if (
not self.is_cuda
or (
target_token := params.extra_args
and params.extra_args.get("target_token")
)
is None
):
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)


# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a mixture of requests which do and don't utilize the dummy logitproc
sampling_params_list = [
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
SamplingParams(temperature=0.0),
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
SamplingParams(temperature=0.0),
]


def main():
# Create an LLM.
llm = LLM(
model="facebook/opt-125m",
logits_processors=[WrappedPerReqLogitsProcessor],
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params_list)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)


if __name__ == "__main__":
main()
33 changes: 33 additions & 0 deletions tests/v1/logits_processors/test_custom_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
POOLING_MODEL_NAME, TEMP_GREEDY,
CustomLogitprocSource,
DummyLogitsProcessor,
WrappedPerReqLogitsProcessor,
dummy_module)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
from tests.v1.logits_processors.utils import prompts
Expand Down Expand Up @@ -161,6 +162,38 @@ def test_custom_logitsprocs(monkeypatch,
_run_test(kwargs, logitproc_loaded=True)


@create_new_process_for_each_test()
def test_custom_logitsprocs_req(monkeypatch):
"""Test passing request-level logits processor to offline Python interface

Wrap a request-level logits processor to create a batch level logits
processor that has a well-defined behavior (mask out all tokens except one
`target_token`)

Construct an `LLM` instance which loads the wrapped logits processor. Pass
the custom logitproc as a class object.

Construct a reference `LLM` instance with no custom logitproc

Pass in a batch of requests, 50% of which pass a `target_token` value
in through `SamplingParams.extra_args`, 50% of which do not.

Validate that
* Requests which do not activate the custom logitproc, yield the same
results for both `LLM` instances
* Requests which activate the custom logitproc, only output `target_token`

Args:
monkeypatch: for setting env vars
"""

# Test that logitproc info is passed to workers
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
random.seed(40)
_run_test({"logits_processors": [WrappedPerReqLogitsProcessor]},
logitproc_loaded=True)


@create_new_process_for_each_test()
@pytest.mark.parametrize("logitproc_source", [
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
Expand Down
Loading