-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[V1] Logits processor docs #22919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] Logits processor docs #22919
Conversation
Signed-off-by: Andrew Feldman <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces documentation for the custom logits processor extensibility feature. The new markdown file explains how to create and use custom logits processors, including a code example. My review focuses on fixing several issues in this code example to make it functional and clear for developers.
docs/features/custom_logitsprocs.md
Outdated
The contrived example below implements a | ||
|
||
??? code "Example custom logits processor definition" | ||
|
||
```python | ||
from typing import Optional | ||
import torch | ||
from vllm.config import VllmConfig | ||
from vllm.sampling_params import SamplingParams | ||
from vllm.v1.sample.logits_processor import (BatchUpdate, | ||
LogitsProcessor, | ||
MoveDirectionality) | ||
|
||
class DummyLogitsProcessor(LogitsProcessor): | ||
"""Fake logit processor to support unit testing and examples""" | ||
|
||
def __init__(self, vllm_config: "VllmConfig", device: torch.device, | ||
is_pin_memory: bool): | ||
self.req_info: dict[int, SamplingParams] = {} | ||
|
||
def is_argmax_invariant(self) -> bool: | ||
"""Never impacts greedy sampling""" | ||
return False | ||
|
||
def update_state(self, batch_update: Optional[BatchUpdate]): | ||
if not batch_update: | ||
return | ||
|
||
# Process added requests. | ||
for index, params, _, _ in batch_update.added: | ||
assert params is not None | ||
if params.extra_args and (target_token := | ||
params.extra_args.get("target_token")): | ||
self.req_info[index] = target_token | ||
|
||
if self.req_info: | ||
# Process removed requests. | ||
for index in batch_update.removed: | ||
self.req_info.pop(index, None) | ||
|
||
# Process moved requests, unidirectional move (a->b) and swap | ||
# (a<->b) | ||
for adx, bdx, direct in batch_update.moved: | ||
a_val = self.req_info.pop(adx, None) | ||
b_val = self.req_info.pop(bdx, None) | ||
if a_val is not None: | ||
self.req_info[bdx] = a_val | ||
if direct == MoveDirectionality.SWAP and b_val is not None: | ||
self.req_info[adx] = b_val | ||
|
||
def apply(self, logits: torch.Tensor) -> torch.Tensor: | ||
if not self.req_info: | ||
return logits | ||
|
||
# Save target values before modification | ||
rows_list = list(self.req_info.keys()) | ||
cols = torch.tensor([self.req_info[i] for i in rows_list], | ||
dtype=torch.long, | ||
device=logits.device) | ||
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) | ||
values_to_keep = logits[rows, cols].clone() | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example DummyLogitsProcessor
has a few issues that prevent it from working correctly and could confuse users:
- The introductory sentence on line 16 is incomplete.
- The type hint for
self.req_info
in__init__
isdict[int, SamplingParams]
, but it is used to store integertarget_token
values. This should bedict[int, int]
. - The
apply
method is incomplete. Whenself.req_info
is populated, it doesn't return thelogits
tensor, which will lead to a runtime error. The logic is also unfinished.
I've provided a corrected version of the example that addresses these points, making it a complete and functional illustration of a custom logits processor.
The contrived example below implements a logits processor that forces the model to select a specific `target_token` for requests that provide it.
??? code "Example custom logits processor definition"
```python
from typing import Optional
import torch
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import (BatchUpdate,
LogitsProcessor,
MoveDirectionality)
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
self.req_info: dict[int, int] = {}
def is_argmax_invariant(self) -> bool:
"""Never impacts greedy sampling"""
return False
def update_state(self, batch_update: Optional[BatchUpdate]):
if not batch_update:
return
# Process added requests.
for index, params, _, _ in batch_update.added:
assert params is not None
if params.extra_args and (target_token :=
params.extra_args.get("target_token")):
self.req_info[index] = target_token
if self.req_info:
# Process removed requests.
for index in batch_update.removed:
self.req_info.pop(index, None)
# Process moved requests, unidirectional move (a->b) and swap
# (a<->b)
for adx, bdx, direct in batch_update.moved:
a_val = self.req_info.pop(adx, None)
b_val = self.req_info.pop(bdx, None)
if a_val is not None:
self.req_info[bdx] = a_val
if direct == MoveDirectionality.SWAP and b_val is not None:
self.req_info[adx] = b_val
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.req_info:
return logits
rows_list = list(self.req_info.keys())
cols = torch.tensor([self.req_info[i] for i in rows_list],
dtype=torch.long,
device=logits.device)
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
# Get the original logits for the target tokens.
values_to_keep = logits[rows, cols].clone()
# For requests with a target token, set all other logits to -inf.
# This is a contrived example to force the model to select the
# target token.
for row_idx in rows_list:
logits[row_idx, :] = -float("inf")
logits[rows, cols] = values_to_keep
return logits
```
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Thank you @JosephMarinier for your review, I believe I addressed everything you mentioned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the cool feature! 🙏
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @afeldman-nm, looks great apart from the one comment, we could merge this since that will likely need to change soon anyhow.
Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @afeldman-nm
Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: afeldman-nm <[email protected]> Co-authored-by: Joseph Marinier <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: afeldman-nm <[email protected]> Co-authored-by: Joseph Marinier <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: afeldman-nm <[email protected]> Co-authored-by: Joseph Marinier <[email protected]> Signed-off-by: charlifu <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: afeldman-nm <[email protected]> Co-authored-by: Joseph Marinier <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: afeldman-nm <[email protected]> Co-authored-by: Joseph Marinier <[email protected]>
Purpose
Test Plan
N/A
Test Result
N/A
(Optional) Documentation Update
See Purpose
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.