Skip to content

Commit f925fa9

Browse files
committed
add helpful comments
Signed-off-by: Venky Ganesh <[email protected]>
1 parent f8d0be7 commit f925fa9

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed

tensorrt_llm/executor/result.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,16 @@ def _handle_sequence(self,
244244
if response_tensors.cum_log_probs is not None:
245245
output.cumulative_logprob = response_tensors.cum_log_probs[src_idx]
246246

247-
if logprobs_result:
247+
# prompt logprobs handling
248+
if logprobs_result and logprobs_result.prompt is not None: # both backends
249+
output.prompt_logprobs = logprobs_result.prompt
250+
# generation logprobs handling (provenance varies by backend)
251+
if logprobs_result and logprobs_result.generation is not None: # TRT backend
248252
# update logprobs from ResponseWrapper (TRT top logprobs WAR)
249253
output._last_logprobs_len = len(output.logprobs)
250-
output.prompt_logprobs = logprobs_result.prompt
251254
output.logprobs += logprobs_result.generation
252-
elif response_tensors.log_probs is not None:
253-
# handle logprobs directly from response tensors
255+
elif response_tensors.log_probs is not None: # PyTorch backend
256+
# handle logprobs directly from response tensors given by sampler
254257
output._last_logprobs_len = len(output.logprobs)
255258
output.logprobs = response_tensors.log_probs[src_idx]
256259
# overcome some WAR in the cpp executor
@@ -701,12 +704,12 @@ def compute_logprobs(
701704
output_token_ids: Optional[list[int]],
702705
) -> LogProbsResult:
703706
"""
704-
Compute top-K logprobs and ranks for each token position.
707+
Compute top-K logprobs from logits when engine doesn't provide them directly.
705708
706-
Returns:
707-
LogProbsResult, a NamedTuple containing:
708-
- prompt: Optional[List[Dict[token_id, Logprob]]] logprobs for prompt tokens.
709-
- generation: Optional[List[Dict[token_id, Logprob]]] logprobs for generated tokens.
709+
Used for post-processing logits into logprobs.
710+
- Prompt logprobs (from context_logits): always used.
711+
- Generation logprobs (from generation_logits, TRT backend): used when backend doesn't compute them in sampler (e.g., TRT).
712+
- Generation logprobs (PyTorch backend): not used; computed in sampler, not here.
710713
"""
711714

712715
def _topk_logprobs(logits: torch.Tensor, top_k: int,

tensorrt_llm/executor/worker.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,7 +1056,11 @@ def _get_params_for_first_rsp(
10561056
def _get_logprobs(worker,
10571057
response: Union[tllm.Response, LlmResponse],
10581058
is_pytorch_backend=False) -> Optional[LogProbsResult]:
1059-
"""Compute logprob and prompt logprob and clear out logits if applicable.
1059+
"""Compute logprobs from response logits when needed.
1060+
1061+
Logprobs provenance varies by backend:
1062+
- PyTorch: Generation logprobs computed in sampler, only prompt logprobs computed here
1063+
- TRT: Both prompt and generation logprobs computed here from logits
10601064
"""
10611065

10621066
logprobs_result = None
@@ -1069,10 +1073,14 @@ def _get_logprobs(worker,
10691073
if logprob_params:
10701074
if is_pytorch_backend:
10711075
if not logprob_params.prompt_logprobs:
1072-
# generation logprobs are already calculated in PyTorch backend sampler
1076+
# PyTorch: generation logprobs computed in sampler, no post-processing needed
10731077
return
10741078
else:
1075-
# Fallback: compute from context_logits if available
1079+
# PyTorch: compute only prompt logprobs from context logits
1080+
# This can be done as a postprocessing step instead of coupling it to the
1081+
# pytorch engine, because prompt_logprobs calculation is not complicated by
1082+
# generation sampling strategies. Therefore it is simpler to do it here than
1083+
# doing it in the pytorch engine and plumbing it through the response.
10761084
logprobs_result = compute_logprobs(
10771085
logprob_params.prompt_logprobs, None,
10781086
response.result.context_logits, None, None)
@@ -1082,7 +1090,7 @@ def _get_logprobs(worker,
10821090
response.clear_context_logits()
10831091
return logprobs_result
10841092

1085-
# trt backend
1093+
# TRT backend: compute both prompt and generation logprobs from logits
10861094
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
10871095
logprob_params.logprobs,
10881096
response.result.context_logits,

tensorrt_llm/llmapi/llm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -574,12 +574,6 @@ def _check_arguments(self, prompt_len: int, query_len: int,
574574
is_gen_only: bool) -> None:
575575

576576
if self.args.backend in ["pytorch", "_autodeploy"]:
577-
# TODO: remove these checks after PyTorch backend
578-
# fully support TopK prompt and generation logprobs.
579-
# if sampling_params.prompt_logprobs:
580-
# raise ValueError(
581-
# f"`prompt_logprobs` in sampling_params is not supported in the PyTorch backend yet. Received `prompt_logprobs={sampling_params.prompt_logprobs}`. Please unset this field."
582-
# )
583577
if sampling_params.logprobs and sampling_params.logprobs > 1:
584578
raise ValueError(
585579
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."

tensorrt_llm/sampling_params.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydantic import BaseModel
99

1010
from tensorrt_llm.bindings import executor as tllme
11+
from tensorrt_llm.logger import logger
1112

1213

1314
@dataclass(slots=True, kw_only=True)
@@ -453,6 +454,11 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo
453454
# we need to internally enable context logits for prompt logprobs computation
454455
# They will be dropped after computation if the user didn't explicitly request them
455456
if self.prompt_logprobs and not self.return_context_logits:
457+
logger.info(
458+
"Since prompt_logprobs is requested but return_context_logits is False, "
459+
"internally enabling context logits for prompt logprobs computation. "
460+
"context logits will be dropped after computation as the user didn't explicitly request them."
461+
)
456462
config_kwargs["return_context_logits"] = True
457463
else:
458464
config_kwargs["return_log_probs"] = self._return_log_probs

0 commit comments

Comments
 (0)