Skip to content

Commit ce63ba4

Browse files
committed
[RFC][V1] LogitsProcessor interface
Signed-off-by: Nick Hill <[email protected]>
1 parent 30172b4 commit ce63ba4

File tree

7 files changed

+304
-119
lines changed

7 files changed

+304
-119
lines changed

vllm/v1/sample/logits_processor.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import dataclasses
3+
from abc import ABC, abstractmethod
4+
from typing import Dict, List, Optional, Sequence, Set, Tuple
5+
6+
import torch
7+
8+
from vllm import SamplingParams
9+
10+
11+
@dataclasses.dataclass
12+
class BatchUpdate:
13+
# Batch indices of any removed requests.
14+
removed: List[int]
15+
# (from, to) batch indices of any requests
16+
# moved within the batch.
17+
moved: List[Tuple[int, int]]
18+
# (index, params, output_tok_ids) for new
19+
# requests added to the batch.
20+
added: List[Tuple[int, SamplingParams, List[int]]]
21+
# The current number of requests in the batch.
22+
batch_size: int
23+
24+
25+
class LogitsProcessor(ABC):
26+
27+
@abstractmethod
28+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
29+
raise NotImplementedError
30+
31+
@abstractmethod
32+
def update_states(
33+
self,
34+
batch_update: Optional[BatchUpdate] = None,
35+
) -> None:
36+
"""Called when there are new output tokens, prior
37+
to each forward pass.
38+
39+
Args:
40+
batch_update is non-None iff there have been
41+
changes to the batch makeup.
42+
"""
43+
raise NotImplementedError
44+
45+
46+
###### ----- LogitsProcessor impls below here
47+
48+
49+
class MinPLogitsProcessor(LogitsProcessor):
50+
51+
def __init__(self, max_num_reqs: int, pin_memory: bool,
52+
device: torch.device):
53+
self.min_p_count: int = 0
54+
55+
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
56+
dtype=torch.float32,
57+
device="cpu",
58+
pin_memory=pin_memory)
59+
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
60+
# Pre-allocated device tensor
61+
self.min_p_gpu: torch.Tensor = torch.empty((max_num_reqs, ),
62+
dtype=torch.float32,
63+
device=device)
64+
# Current slice of the device tensor
65+
self.min_p: torch.Tensor = self.min_p_gpu[:0]
66+
67+
def update_states(self, batch_update: Optional[BatchUpdate] = None):
68+
if not batch_update:
69+
return
70+
71+
needs_update = False
72+
if self.min_p_count:
73+
# Process removed and moved requests.
74+
for index in batch_update.removed:
75+
if self.min_p_cpu[index]:
76+
self.min_p_count -= 1
77+
needs_update = True
78+
79+
for from_index, to_index in batch_update.moved:
80+
min_p = self.min_p_cpu[from_index]
81+
self.min_p_cpu[to_index] = min_p
82+
if min_p:
83+
needs_update = True
84+
85+
# Process added requests.
86+
for index, sampling_params, _ in batch_update.added:
87+
min_p = sampling_params.min_p
88+
self.min_p_cpu[index] = min_p
89+
if min_p:
90+
self.min_p_count += 1
91+
needs_update = True
92+
93+
# Update tensors if needed.
94+
size = batch_update.batch_size
95+
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
96+
97+
self.min_p = self.min_p_gpu[:size]
98+
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
99+
self.min_p.unsqueeze_(1)
100+
101+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
102+
if not self.min_p_count:
103+
return logits
104+
105+
# Convert logits to probability distribution
106+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
107+
# Calculate maximum probabilities per sequence
108+
max_probabilities = torch.amax(probability_values,
109+
dim=-1,
110+
keepdim=True)
111+
# Adjust min_p
112+
adjusted_min_p = max_probabilities.mul_(self.min_p)
113+
# Identify valid tokens using threshold comparison
114+
invalid_token_mask = probability_values < adjusted_min_p
115+
# Apply mask using boolean indexing
116+
logits[invalid_token_mask] = -float('inf')
117+
return logits
118+
119+
120+
class LogitBiasLogitsProcessor(LogitsProcessor):
121+
122+
def __init__(self, pin_memory: bool, device: torch.device):
123+
self.biases: Dict[int, Dict[int, float]] = {}
124+
self.device = device
125+
self.pin_memory = pin_memory
126+
127+
self.bias_tensor: torch.Tensor = torch.tensor(())
128+
self.logits_slice: Tuple[torch.Tensor, torch.Tensor] = (torch.tensor(
129+
()), torch.tensor(()))
130+
131+
def update_states(self, batch_update: Optional[BatchUpdate] = None):
132+
if not batch_update:
133+
return
134+
135+
needs_update = False
136+
if self.biases:
137+
# Process removed and moved requests.
138+
for index in batch_update.removed:
139+
if self.biases.pop(index, None):
140+
needs_update = True
141+
142+
for from_index, to_index in batch_update.moved:
143+
if entry := self.biases.pop(from_index, None):
144+
self.biases[to_index] = entry
145+
needs_update = True
146+
147+
# Process added requests.
148+
for index, sampling_params, _ in batch_update.added:
149+
if lb := sampling_params.logit_bias:
150+
self.biases[index] = lb
151+
needs_update = True
152+
153+
# Update tensors if needed.
154+
if self.biases and needs_update:
155+
reqs, tok_ids, biases = [], [], []
156+
for req, lb in self.biases.items():
157+
reqs.extend([req] * len(lb))
158+
tok_ids.extend(lb.keys())
159+
biases.extend(lb.values())
160+
161+
self.bias_tensor = self._tensor(biases, torch.float32)
162+
self.logits_slice = (self._tensor(reqs, torch.int32),
163+
self._tensor(tok_ids, torch.int32))
164+
165+
def _tensor(self, data: List, dtype: torch.dtype) -> torch.Tensor:
166+
return (torch.tensor(data,
167+
device="cpu",
168+
dtype=dtype,
169+
pin_memory=self.pin_memory).to(device=self.device,
170+
non_blocking=True))
171+
172+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
173+
if self.biases:
174+
logits[self.logits_slice] += self.bias_tensor
175+
return logits
176+
177+
178+
class MinTokensLogitsProcessor(LogitsProcessor):
179+
180+
def __init__(self, pin_memory: bool, device: torch.device):
181+
# index -> (min_toks, output_token_ids, stop_token_ids)
182+
self.min_toks: Dict[int, Tuple[int, Sequence[int], Set[int]]] = {}
183+
self.device = device
184+
self.pin_memory = pin_memory
185+
186+
self.logits_slice: Tuple[torch.Tensor, torch.Tensor] = (torch.tensor(
187+
()), torch.tensor(()))
188+
189+
def update_states(self, batch_update: Optional[BatchUpdate] = None):
190+
needs_update = False
191+
if batch_update:
192+
if self.min_toks:
193+
# Process removed and moved requests.
194+
for index in batch_update.removed:
195+
if self.min_toks.pop(index, None):
196+
needs_update = True
197+
198+
for from_index, to_index in batch_update.moved:
199+
if entry := self.min_toks.pop(from_index, None):
200+
self.min_toks[to_index] = entry
201+
needs_update = True
202+
203+
# Process added requests.
204+
for index, sampling_params, output_tok_ids in batch_update.added:
205+
if ((min_tokens := sampling_params.min_tokens)
206+
and len(output_tok_ids) < min_tokens):
207+
self.min_toks[index] = (min_tokens, output_tok_ids,
208+
sampling_params.all_stop_token_ids)
209+
needs_update = True
210+
211+
if self.min_toks:
212+
# Check for any requests that have attained their min tokens.
213+
to_remove = tuple(index for index, (min_toks, out_tok_ids,
214+
_) in self.min_toks.items()
215+
if len(out_tok_ids) >= min_toks)
216+
if to_remove:
217+
needs_update = True
218+
for index in to_remove:
219+
del self.min_toks[index]
220+
221+
# Update tensors if needed.
222+
if needs_update and self.min_toks:
223+
reqs: List[int] = []
224+
tok_ids: List[int] = []
225+
for req, (_, _, stop_tok_ids) in self.min_toks.items():
226+
reqs.extend([req] * len(stop_tok_ids))
227+
tok_ids.extend(stop_tok_ids)
228+
229+
self.logits_slice = (self._tensor(reqs, torch.int32),
230+
self._tensor(tok_ids, torch.int32))
231+
232+
def _tensor(self, data: List, dtype: torch.dtype) -> torch.Tensor:
233+
return (torch.tensor(data,
234+
device="cpu",
235+
dtype=dtype,
236+
pin_memory=self.pin_memory).to(device=self.device,
237+
non_blocking=True))
238+
239+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
240+
if self.min_toks:
241+
logits[self.logits_slice] = -float("inf")
242+
return logits

vllm/v1/sample/metadata.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from dataclasses import dataclass
4-
from typing import Dict, List, Optional, Set, Tuple
4+
from typing import Dict, List, Optional
55

66
import torch
77

8+
from vllm.v1.sample.logits_processor import LogitsProcessor
9+
810

911
@dataclass
1012
class SamplingMetadata:
@@ -18,7 +20,6 @@ class SamplingMetadata:
1820

1921
top_p: Optional[torch.Tensor]
2022
top_k: Optional[torch.Tensor]
21-
min_p: Optional[torch.Tensor]
2223

2324
generators: Dict[int, torch.Generator]
2425

@@ -33,7 +34,5 @@ class SamplingMetadata:
3334

3435
output_token_ids: List[List[int]]
3536

36-
# req_index -> (min_tokens, stop_token_ids)
37-
min_tokens: Dict[int, Tuple[int, Set[int]]]
38-
39-
logit_bias: List[Optional[Dict[int, float]]]
37+
logits_procs: List[LogitsProcessor]
38+
nongreedy_logits_procs: List[LogitsProcessor]

vllm/v1/sample/ops/penalties.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Dict, List, Set, Tuple
3+
from typing import List
44

55
import torch
66

77
from vllm.model_executor.layers.utils import apply_penalties
88
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
99

1010

11-
def apply_min_token_penalties(
12-
logits: torch.Tensor, output_token_ids: List[List[int]],
13-
min_tokens: Dict[int, Tuple[int, Set[int]]]) -> None:
14-
"""
15-
Applies minimum token penalty by setting the logits of the stop tokens
16-
to -inf.
17-
"""
18-
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
19-
for index, (min_token, stop_token_ids) in min_tokens.items():
20-
if len(output_token_ids[index]) < min_token:
21-
for stop_token_id in stop_token_ids:
22-
min_tokens_logits_to_penalize.append((index, stop_token_id))
23-
if min_tokens_logits_to_penalize:
24-
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
25-
26-
2711
def apply_all_penalties(
2812
logits: torch.Tensor,
2913
prompt_token_ids: torch.Tensor,

vllm/v1/sample/sampler.py

Lines changed: 9 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
88
from vllm.v1.sample.metadata import SamplingMetadata
9-
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
10-
apply_min_token_penalties)
9+
from vllm.v1.sample.ops.penalties import apply_all_penalties
1110
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
1211
from vllm.v1.sample.rejection_sampler import RejectionSampler
1312

@@ -47,8 +46,11 @@ def forward(
4746

4847
# Use float32 for the logits.
4948
logits = logits.to(torch.float32)
50-
# Apply logits bias.
51-
logits = self.apply_logits_bias(logits, sampling_metadata)
49+
50+
# Apply logits processors.
51+
for processor in sampling_metadata.logits_procs:
52+
logits = processor.apply(logits)
53+
5254
# Apply penalties (e.g., min_tokens, freq_penalties).
5355
logits = self.apply_penalties(logits, sampling_metadata)
5456
# Sample the next token.
@@ -103,9 +105,9 @@ def sample(
103105
# Apply temperature.
104106
logits = self.apply_temperature(logits, sampling_metadata.temperature)
105107

106-
# Apply min_p.
107-
if sampling_metadata.min_p is not None:
108-
logits = self.apply_min_p(logits, sampling_metadata.min_p)
108+
# Apply logits processors.
109+
for processor in sampling_metadata.nongreedy_logits_procs:
110+
logits = processor.apply(logits)
109111

110112
# Apply top_k and/or top_p.
111113
random_sampled = self.topk_topp_sampler(
@@ -177,10 +179,6 @@ def apply_penalties(
177179
logits: torch.Tensor,
178180
sampling_metadata: SamplingMetadata,
179181
) -> torch.Tensor:
180-
if sampling_metadata.min_tokens:
181-
apply_min_token_penalties(logits,
182-
sampling_metadata.output_token_ids,
183-
sampling_metadata.min_tokens)
184182
if not sampling_metadata.no_penalties:
185183
assert sampling_metadata.prompt_token_ids is not None
186184
logits = apply_all_penalties(
@@ -190,39 +188,3 @@ def apply_penalties(
190188
sampling_metadata.repetition_penalties,
191189
sampling_metadata.output_token_ids)
192190
return logits
193-
194-
def apply_min_p(
195-
self,
196-
logits: torch.Tensor,
197-
min_p: torch.Tensor,
198-
) -> torch.Tensor:
199-
"""
200-
Filters logits using adaptive probability thresholding.
201-
"""
202-
# Convert logits to probability distribution
203-
probability_values = torch.nn.functional.softmax(logits, dim=-1)
204-
# Calculate maximum probabilities per sequence
205-
max_probabilities = torch.amax(probability_values,
206-
dim=-1,
207-
keepdim=True)
208-
# Reshape min_p for broadcasting
209-
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
210-
# Identify valid tokens using threshold comparison
211-
valid_token_mask = probability_values >= adjusted_min_p
212-
# Apply mask using boolean indexing
213-
logits[~valid_token_mask] = -float('inf')
214-
return logits
215-
216-
def apply_logits_bias(
217-
self,
218-
logits: torch.Tensor,
219-
sampling_metadata: SamplingMetadata,
220-
) -> torch.Tensor:
221-
# TODO(houseroad): this implementation is extremely inefficient.
222-
# One idea is implement this as a PyTorch C++ op, and we may
223-
# even optimize the logit_bias layout.
224-
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
225-
if logit_bias:
226-
for token_id, bias in logit_bias.items():
227-
logits[i, token_id] += bias
228-
return logits

0 commit comments

Comments
 (0)