Skip to content

Commit 2f8dc6f

Browse files
authored
[None][feat] Return topk logprobs in torch backend (#7756)
Signed-off-by: Dong Cao <[email protected]>
1 parent 6256376 commit 2f8dc6f

File tree

6 files changed

+68
-68
lines changed

6 files changed

+68
-68
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def __init__(
311311
is_draft: bool = False,
312312
seq_slot: Optional[int] = None,
313313
target_seq_slot: Optional[int] = None,
314+
num_logprobs: int = 0,
314315
**kwargs):
315316

316317
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
@@ -354,6 +355,7 @@ def __init__(
354355
self.py_lora_task_layer_module_configs: list[
355356
tensorrt_llm.bindings.internal.runtime.
356357
TaskLayerModuleConfig] | None = None
358+
self.py_num_logprobs = num_logprobs
357359

358360
self.py_return_log_probs = return_log_probs
359361
self.py_return_context_logits = return_context_logits
@@ -562,6 +564,8 @@ def executor_request_to_llm_request(
562564
mrope_position_deltas=mrope_position_deltas,
563565
lookahead_config=None,
564566
return_log_probs=executor_request.output_config.return_log_probs,
567+
num_logprobs=executor_request.py_num_logprobs if hasattr(
568+
executor_request, "py_num_logprobs") else 0,
565569
return_context_logits=executor_request.output_config.
566570
return_context_logits,
567571
return_perf_metrics=executor_request.output_config.return_perf_metrics,

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 36 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, List, Literal, Optional, cast
1010

1111
import torch
12+
import torch.nn.functional as F
1213

1314
from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \
1415
MakeDecodingBatchInputOutput
@@ -852,15 +853,19 @@ def _handle_stop_criteria(self, request: LlmRequest,
852853

853854
def handle_logprobs(self, request: LlmRequest, state: SampleState, *,
854855
beam: int, count: int):
855-
current_slice = slice(0, count), request.py_seq_slot, beam
856856
if request.py_return_log_probs:
857-
assert state.host.log_probs is not None
858-
log_probs = state.host.log_probs[request.py_seq_slot][beam][:count]
859-
current_tokens = state.host.new_tokens[current_slice]
857+
topk_log_probs_vals = request.py_topk_logprobs_vals[:count]
858+
topk_log_probs_indices = request.py_topk_logprobs_indices[:count]
860859

861860
token_log_probs = [{
862-
int(token): Logprob(logprob=logprob, rank=1)
863-
} for token, logprob in zip(current_tokens, log_probs.tolist())]
861+
int(token):
862+
Logprob(logprob=logprob, rank=rank + 1)
863+
for rank, (token, logprob) in enumerate(
864+
zip(topk_token, topk_logprob.tolist()))
865+
}
866+
for topk_token, topk_logprob in zip(
867+
topk_log_probs_indices, topk_log_probs_vals)]
868+
864869
assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element"
865870
request.py_result.append_log_probs([token_log_probs])
866871

@@ -970,13 +975,8 @@ def log_probs_host(
970975
self,
971976
scheduled_requests: ScheduledRequests) -> Optional[torch.Tensor]:
972977
"""Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103"""
973-
if any(req.py_return_log_probs
974-
for req in scheduled_requests.all_requests()):
975-
return torch.empty(
976-
(self.max_num_sequences, self.MAX_BEAM_WIDTH, self.max_tokens),
977-
device="cpu",
978-
pin_memory=True)
979-
return None
978+
return any(req.py_return_log_probs
979+
for req in scheduled_requests.all_requests())
980980

981981
@override
982982
@torch.inference_mode()
@@ -1001,8 +1001,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
10011001
sampler_event.record()
10021002
return SampleState(scheduled_requests=scheduled_requests,
10031003
device=SampleStateTensors(new_tokens=new_tokens),
1004-
host=SampleStateTensors(new_tokens=new_tokens_host,
1005-
log_probs=log_probs_host),
1004+
host=SampleStateTensors(new_tokens=new_tokens_host),
10061005
sampler_event=sampler_event)
10071006

10081007
@staticmethod
@@ -1111,12 +1110,24 @@ def _sample_batched_by_strategy(
11111110
model_outputs: dict[str, torch.Tensor],
11121111
*,
11131112
cuda_device: torch.device,
1114-
log_probs_host: torch.Tensor | None = None,
1113+
log_probs_host: bool = False,
11151114
req_num_steps: torch.Tensor,
11161115
req_offsets: torch.Tensor,
11171116
steps_dim_size: int,
11181117
token_dtype: torch.dtype,
11191118
) -> _BatchedSamplingResult:
1119+
if log_probs_host:
1120+
assert logits_cuda.dim() == 2, "logits should be 2D"
1121+
logprobs = F.log_softmax(logits_cuda.to("cuda",
1122+
dtype=torch.float32),
1123+
dim=-1)
1124+
topk_vals, topk_indices = torch.topk(logprobs,
1125+
k=max(req.py_num_logprobs
1126+
for req in requests),
1127+
dim=-1)
1128+
topk_vals = topk_vals.to(device="cpu", non_blocking=True)
1129+
topk_indices = topk_indices.to(device="cpu", non_blocking=True)
1130+
11201131
requests_by_strategy = _group_requests_by_sampling_strategy(
11211132
requests, pin_memory=True)
11221133
generator_cuda = self.get_generator(cuda_device)
@@ -1160,12 +1171,18 @@ def _sample_batched_by_strategy(
11601171
# softmax_grp_indices: Indices of 'speculation_group_indices' entries requesting probs
11611172
# speculation_softmax_indices: Indices of 'softmax_grp_indices' entries corresponding
11621173
# to requests with draft logits.
1163-
if log_probs_host is not None:
1174+
if log_probs_host:
11641175
softmax_req_indices = group_req_indices
11651176
softmax_grp_indices = torch.arange(len(group_req_indices),
11661177
dtype=torch.int32)
11671178
speculation_softmax_indices = torch.tensor(
11681179
speculation_group_indices, dtype=torch.int32)
1180+
for req_id in group_req_indices:
1181+
req = requests[req_id]
1182+
req.py_topk_logprobs_vals = topk_vals[
1183+
logits_cuda_indexer[req_id], :req.py_num_logprobs]
1184+
req.py_topk_logprobs_indices = topk_indices[
1185+
logits_cuda_indexer[req_id], :req.py_num_logprobs]
11691186
else:
11701187
speculation_group_indices_tensor = torch.tensor(
11711188
speculation_group_indices, dtype=torch.int32)
@@ -1257,7 +1274,7 @@ def _unbatch_sampling_results(
12571274
new_tokens_cuda: torch.Tensor,
12581275
req_num_steps: torch.Tensor,
12591276
seq_slots: torch.Tensor,
1260-
log_probs_host: torch.Tensor | None = None,
1277+
log_probs_host: bool = False,
12611278
) -> torch.Tensor:
12621279
beam = self.BEAM
12631280
assert beam == 0, "beam_width != 1 not supported"
@@ -1274,17 +1291,6 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
12741291
# Assert destination tensor dimensions are canonically ordered ("row"-major); this
12751292
# matters for element ordering in the .view(...).scatter_(...) calls below.
12761293
assert _dims_canonically_ordered(new_tokens_cuda)
1277-
assert log_probs_host is None or _dims_canonically_ordered(
1278-
log_probs_host)
1279-
1280-
# new_tokens_cuda indexed by
1281-
# slice(0, steps), slot, beam
1282-
# log_probs_host indexed by
1283-
# slot, beam, slice(0, steps)
1284-
# batch_... tensors indexed by slice(batch_req_index, batch_req_index + steps)
1285-
#
1286-
if log_probs_host is not None:
1287-
assert new_tokens_cuda.size(0) == log_probs_host.size(-2)
12881294

12891295
# Construct index mapping from slice indices of computed tensors
12901296
# (packed request_idx and step dimensions) to linearized indices
@@ -1306,39 +1312,7 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
13061312
0, batch_dest_indices_1d_cuda,
13071313
batch_next_tokens_cuda_int)
13081314
new_tokens_host = new_tokens_cuda.to("cpu", non_blocking=True)
1309-
# NB: In order to avoid a scatter_ on the host and the necessary D2H copy + synchronization,
1310-
# the 'step' and 'seq_slot' dimensions are unpacked on GPU and later asynchronously
1311-
# copied into the destination buffer. Note that this overwrites all 'step' and token slots for the
1312-
# requests in 'requests' (passed to _process_requests). In fact, the current implementation
1313-
# even overwrites the destination tensors completely (including slices corresponding to request
1314-
# slots not present in 'requests', cf. 'FIXME' below).
1315-
if log_probs_host is not None:
1316-
# FIXME: If log_probs_host were indexed by request indices, rather than request slots, this
1317-
# tensor could be packed densely along the request axis.
1318-
log_probs_cuda = torch.empty_like(
1319-
log_probs_host, device=batch_dest_indices_1d_cuda.device)
1320-
# FIXME: Needs a separate indexer because tensor layout differs from new_tokens_cuda
1321-
batch_dest_probs_cuda_indexer = _UnpackedStepIndexer(
1322-
seq_slots=seq_slots[batch_req_indices],
1323-
num_steps=req_num_steps[batch_req_indices],
1324-
steps_dim_size=new_tokens_cuda.size(0),
1325-
slots_dim_size=new_tokens_cuda.size(1),
1326-
dim_order=_UnpackedStepIndexer.DimOrder.SLOT_MAJOR,
1327-
index_dtype=torch.int64, # enforced by Tensor.scatter_
1328-
)
1329-
batch_dest_probs_indices_cuda = batch_dest_probs_cuda_indexer[:].to(
1330-
batch_softmax_cuda.device, non_blocking=True)
1331-
# NB: torch.arange is needed to enable "advanced indexing",
1332-
# cf. https://numpy.org/devdocs/user/basics.indexing.html#integer-array-indexing
1333-
batch_token_probs = batch_softmax_cuda[
1334-
torch.arange(batch_softmax_cuda.size(0),
1335-
device=batch_softmax_cuda.device,
1336-
dtype=torch.int32), batch_next_tokens_cuda_int]
1337-
log_probs_cuda[:, beam,
1338-
...].view(-1, *log_probs_cuda.shape[3:]).scatter_(
1339-
0, batch_dest_probs_indices_cuda,
1340-
torch.log(batch_token_probs))
1341-
log_probs_host.copy_(log_probs_cuda, non_blocking=True)
1315+
13421316
# For requests with LlmRequest.py_draft_logits, return py_target_probs
13431317
for request, batch_softmax_index_cuda in py_draft_logits_indices:
13441318
request.py_target_probs = batch_softmax_cuda[
@@ -1481,7 +1455,6 @@ def _process_requests(
14811455

14821456
logits_cuda = self._apply_min_length_penalty(logits_cuda, requests,
14831457
req_num_steps_list)
1484-
14851458
# Perform sampling in batches
14861459
batched_sampling_result = self._sample_batched_by_strategy(
14871460
logits_cuda,

tensorrt_llm/executor/base_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ def _deduce_max_tokens(request: GenerationRequest,
480480
context_phase_params=context_phase_params,
481481
type=request_type,
482482
cache_salt_id=request.cache_salt_id)
483+
executor_request.py_num_logprobs = request.sampling_params.logprobs
483484
executor_request.py_lora_path = py_lora_path
484485

485486
if self._is_pytorch_backend and request.multimodal_params is not None:

tensorrt_llm/llmapi/llm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,6 @@ def _check_arguments(self, prompt_len: int, query_len: int,
598598
is_gen_only: bool) -> None:
599599

600600
if self.args.backend in ["pytorch", "_autodeploy"]:
601-
if sampling_params.logprobs and sampling_params.logprobs > 1:
602-
raise ValueError(
603-
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
604-
)
605601
# Check prompt length and query length against max_num_tokens to filter illegal requests.
606602
# Skip check for gen-only requests
607603
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:

tensorrt_llm/scaffolding/worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def convert_task_params(self, task: GenerationTask):
180180
temperature=task.temperature,
181181
top_p=task.top_p,
182182
top_k=task.top_k,
183-
return_context_logits=task.return_context_logits)
183+
return_context_logits=task.return_context_logits,
184+
logprobs=task.num_logprobs)
184185
return sampling_params
185186

186187
async def generation_handler(self, task: GenerationTask) -> TaskStatus:

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,31 @@ def test_llm_reward_model():
175175
assert not outputs[0].outputs[0].text
176176

177177

178+
def test_llm_topk_logprobs():
179+
topk_logprobs = 3
180+
max_tokens = 10
181+
llm = LLM(model=llama_model_path, kv_cache_config=global_kvcache_config)
182+
sampling_params = SamplingParams(max_tokens=max_tokens,
183+
logprobs=topk_logprobs)
184+
outputs = llm.generate(prompts, sampling_params)
185+
logprobs = outputs[0].outputs[0].logprobs
186+
187+
assert len(logprobs) == max_tokens
188+
for step_logprobs in logprobs:
189+
assert len(step_logprobs) == topk_logprobs
190+
191+
logprob_items = [(logprob_obj.logprob, logprob_obj.rank)
192+
for logprob_obj in step_logprobs.values()]
193+
sorted_by_rank = sorted(logprob_items, key=lambda x: x[1])
194+
195+
for i in range(len(sorted_by_rank) - 1):
196+
current_logprob, current_rank = sorted_by_rank[i]
197+
next_logprob, next_rank = sorted_by_rank[i + 1]
198+
assert current_logprob >= next_logprob
199+
assert current_rank == i + 1
200+
assert next_rank == current_rank + 1
201+
202+
178203
def test_llm_perf_metrics():
179204
llm = LLM(model=llama_model_path, kv_cache_config=global_kvcache_config)
180205
sampling_params = SamplingParams(max_tokens=10, return_perf_metrics=True)

0 commit comments

Comments
 (0)