Skip to content
42 changes: 38 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,17 @@ class LlmResponse:
def has_error(self):
return self.error_msg is not None

def clear_context_logits(self):
"""Clear context logits from the response result.

This is used to drop context logits after prompt_logprobs have been computed
when the user didn't explicitly request them.
"""
if self.result and hasattr(self.result, '_py_result'):
py_result = self.result._py_result
if hasattr(py_result, '_context_logits'):
py_result._context_logits = None


class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
"""LlmRequest wraps `bindings.internal.batch_manager.LlmRequest`
Expand Down Expand Up @@ -376,10 +387,33 @@ def __init__(
def is_generation_only_request(self):
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY

def create_response(
self,
use_fast_logits=False,
mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None:
def create_response(self,
use_fast_logits=False,
mpi_world_rank=0) -> LlmResponse | None:
"""Create an LlmResponse from the current request state.

This method generates a response containing the request's execution results,
including generated tokens, logits, and completion status. It wraps the
parent class's serialized result in a PyTorch-specific LlmResponse object.

Args:
use_fast_logits (bool, optional): Whether to use fast logits computation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is a bit vague and probably incorrect. This boolean controls the direct transfer from draft model to target model for speculative decoding. I actually don't know if it applies to the pytorch backend as it is implemented.

for improved performance. Defaults to False.
mpi_world_rank (int, optional): The MPI world rank for distributed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing: this is the leader rank of the draft model when using direct logits transfer.

inference. Defaults to 0.

Returns:
LlmResponse | None: An LlmResponse object containing the request results
if there is valid output, otherwise None.
The response includes:
- request_id: The request identifier (parent ID for child requests)
- result: LlmResult wrapping both serialized and PyTorch-specific results
- client_id: The client identifier for request routing

Note:
Returns None if the serialized result is empty (len(result) == 0),
indicating no output was generated for this request iteration.
"""
result, is_final = super().create_serialized_result(
use_fast_logits, mpi_world_rank)
return LlmResponse(
Expand Down
27 changes: 21 additions & 6 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,22 +244,32 @@ def _handle_sequence(self,
if response_tensors.cum_log_probs is not None:
output.cumulative_logprob = response_tensors.cum_log_probs[src_idx]

if logprobs_result:
# prompt logprobs handling
if logprobs_result and logprobs_result.prompt is not None: # both backends
output.prompt_logprobs = logprobs_result.prompt
# generation logprobs handling (provenance varies by backend)
if logprobs_result and logprobs_result.generation is not None: # TRT backend
# update logprobs from ResponseWrapper (TRT top logprobs WAR)
output._last_logprobs_len = len(output.logprobs)
output.prompt_logprobs = logprobs_result.prompt
output.logprobs += logprobs_result.generation
elif response_tensors.log_probs is not None:
# handle logprobs directly from response tensors
elif response_tensors.log_probs is not None: # PyTorch backend
# handle logprobs directly from response tensors given by sampler
output._last_logprobs_len = len(output.logprobs)
output.logprobs = response_tensors.log_probs[src_idx]
# In streaming mode, since out-of-order responses are not possible,
# each streamed response_tensors.log_probs[src_idx]
# contains a streamwise monotonically growing list of logprobs.
# so we need to accumulate only the new ones unique to that particular streamed response
Copy link
Preview

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The slicing logic [output._last_logprobs_len:] to get only new logprobs in streaming mode could be error-prone if _last_logprobs_len becomes inconsistent. Consider adding an assertion to verify that output._last_logprobs_len <= len(response_tensors.log_probs[src_idx]) before slicing.

Suggested change
# so we need to accumulate only the new ones unique to that particular streamed response
# so we need to accumulate only the new ones unique to that particular streamed response
assert output._last_logprobs_len <= len(response_tensors.log_probs[src_idx]), (
f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ({len(response_tensors.log_probs[src_idx])})"
)

Copilot uses AI. Check for mistakes.

output.logprobs += response_tensors.log_probs[src_idx][
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hchings @LinPoly this lack of accumulation was causing https://nvbugs/5516660 , fixed here

output._last_logprobs_len:]
# overcome some WAR in the cpp executor
if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED:
# Check if logprobs is a list (not a dict or other structure)
if len(output.logprobs) > output.length:
# LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized.
# Therefore, we treat extra logprobs/logits as expected and only consume what's needed.
output.logprobs = output.logprobs[:output.length]
assert len(output.logprobs) == output.length

if response_tensors.generation_logits is not None:
output.generation_logits = response_tensors.generation_logits[
src_idx, :output.length]
Expand Down Expand Up @@ -698,7 +708,12 @@ def compute_logprobs(
output_token_ids: Optional[list[int]],
) -> LogProbsResult:
"""
Compute top-K logprobs and ranks for each token position.
Compute top-K logprobs from logits when engine doesn't provide them directly.

Used for post-processing logits into logprobs.
- Prompt logprobs (from context_logits): always used.
- Generation logprobs (from generation_logits, TRT backend): used when backend doesn't compute them in sampler (e.g., TRT).
- Generation logprobs (PyTorch backend): not used; computed in sampler, not here.

Returns:
LogProbsResult, a NamedTuple containing:
Expand Down
49 changes: 43 additions & 6 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from tensorrt_llm.logger import logger

from .._torch.pyexecutor.llm_request import LlmResponse
from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size,
mpi_comm, mpi_rank, nvtx_range_debug)
from ..bindings import executor as tllm
Expand Down Expand Up @@ -1057,15 +1058,37 @@ def _get_params_for_first_rsp(
return None, None


def _compute_pytorch_prompt_logprobs(
generation_result: GenerationResult,
response: LlmResponse) -> Optional[LogProbsResult]:
"""Compute prompt logprobs for PyTorch backend (cached when streaming) """
logprob_params = generation_result._logprob_params # should be present and non None
assert logprob_params is not None
if generation_result._streaming:
cached = getattr(generation_result, '_cached_prompt_logprobs', None)
if cached is not None:
return LogProbsResult(
prompt=cached, generation=None
) # generation logprobs, if requested, is provided directly in response.result.log_probs from the sampler.
context_logits = response.result.context_logits
assert context_logits is not None, "context_logits cannot be None when prompt_logprobs is requested."
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, None,
context_logits, None, None)
if generation_result._streaming:
generation_result._cached_prompt_logprobs = logprobs_result.prompt

return logprobs_result


def _get_logprobs(worker,
response: tllm.Response,
response: Union[tllm.Response, LlmResponse],
is_pytorch_backend=False) -> Optional[LogProbsResult]:
"""Compute logprob and prompt logprob and clear out logits if applicable.
"""Compute logprobs from response logits when needed.

Logprobs provenance varies by backend:
- PyTorch: Generation logprobs computed in sampler, only prompt logprobs computed here
- TRT: Both prompt and generation logprobs computed here from logits
"""
if is_pytorch_backend:
# _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime.
# In the PyTorch backend, logprobs are already computed during runtime if requested.
return None

logprobs_result = None
generation_result = worker._results.get(response.client_id, None)
Expand All @@ -1075,6 +1098,20 @@ def _get_logprobs(worker,

logprob_params = getattr(generation_result, "_logprob_params", None)
if logprob_params:
if is_pytorch_backend:
if not logprob_params.prompt_logprobs:
# PyTorch: generation logprobs computed in sampler, no post-processing needed
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return None to be explicit?

else:
logprobs_result = _compute_pytorch_prompt_logprobs(
generation_result, response)

if logprob_params.drop_context_logits:
response.clear_context_logits()

return logprobs_result

# TRT backend: compute both prompt and generation logprobs from logits
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
logprob_params.logprobs,
response.result.context_logits,
Expand Down
12 changes: 4 additions & 8 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,17 +569,13 @@ def _prepare_sampling_params(
sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics
return sampling_params

def _check_arguments(self, prompt_len: int, query_len: int,
def _check_arguments(self,
prompt_len: int,
query_len: int,
sampling_params: SamplingParams,
is_gen_only: bool) -> None:
is_gen_only: bool = False) -> None:

if self.args.backend in ["pytorch", "_autodeploy"]:
# TODO: remove these checks after PyTorch backend
# fully support TopK prompt and generation logprobs.
if sampling_params.prompt_logprobs:
raise ValueError(
f"`prompt_logprobs` in sampling_params is not supported in the PyTorch backend yet. Received `prompt_logprobs={sampling_params.prompt_logprobs}`. Please unset this field."
)
if sampling_params.logprobs and sampling_params.logprobs > 1:
raise ValueError(
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
Expand Down
15 changes: 15 additions & 0 deletions tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel

from tensorrt_llm.bindings import executor as tllme
from tensorrt_llm.logger import logger


@dataclass(slots=True, kw_only=True)
Expand Down Expand Up @@ -449,6 +450,20 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo

if is_pytorch_backend:
config_kwargs["return_log_probs"] = bool(self.logprobs)
if self.prompt_logprobs and not self.return_context_logits:
logger.info(
"Since prompt_logprobs is requested but return_context_logits is False, "
"internally enabling context logits for prompt logprobs computation. "
"context logits will be dropped after computation as the user didn't explicitly request them."
)
# TODO(venky): Find a more elegant way to do this.
# NOTE: This is an internal hack, so we can entirely avoid introducing
# `prompt_logprobs` into the executor bindings and further into
# model engine / sampler.
# This is because, prompt_logprobs is a derived quantity from
# context logits, and the capability to post-compute it
# already exists in the worker. (see _get_logprobs in worker.py)
Comment on lines +459 to +465
Copy link
Preview

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This TODO comment indicates a design concern about the current implementation being a 'hack'. Consider creating a GitHub issue to track the technical debt for refactoring this approach to a more elegant solution.

Copilot uses AI. Check for mistakes.

config_kwargs["return_context_logits"] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is very elegant. Can we pass the prompt_logprobs option into the framework and use the logic prompt_logprobs or return_context_logits where relevant instead?

Copy link
Collaborator Author

@venkywonka venkywonka Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by "into the framework" you mean touch the executor bindings (OutputConfig) and llm request data structures right? I actively avoided it due to the unnecessary plumbing complexity it would incur for this - with almost no benefit to functionality.
But I do agree that this doesn't seem elegant, but its merely operating under the current architectural constraints. For the purpose of unblocking this PR, I shall mark it as a TODO.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a TODO for this for now

else:
config_kwargs["return_log_probs"] = self._return_log_probs

Expand Down
Loading