-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[TRTLLM-7015] [feat] Enable prompt_logprobs
in pytorch backend
#7580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a5fad18
7762c23
373fb19
7a69b7a
89a36be
e589e69
3b3c434
ed6df8a
142d29d
f3dcad0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -278,6 +278,17 @@ class LlmResponse: | |
def has_error(self): | ||
return self.error_msg is not None | ||
|
||
def clear_context_logits(self): | ||
"""Clear context logits from the response result. | ||
|
||
This is used to drop context logits after prompt_logprobs have been computed | ||
when the user didn't explicitly request them. | ||
""" | ||
if self.result and hasattr(self.result, '_py_result'): | ||
py_result = self.result._py_result | ||
if hasattr(py_result, '_context_logits'): | ||
py_result._context_logits = None | ||
|
||
|
||
class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): | ||
"""LlmRequest wraps `bindings.internal.batch_manager.LlmRequest` | ||
|
@@ -376,10 +387,33 @@ def __init__( | |
def is_generation_only_request(self): | ||
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY | ||
|
||
def create_response( | ||
self, | ||
use_fast_logits=False, | ||
mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None: | ||
def create_response(self, | ||
use_fast_logits=False, | ||
mpi_world_rank=0) -> LlmResponse | None: | ||
"""Create an LlmResponse from the current request state. | ||
|
||
This method generates a response containing the request's execution results, | ||
including generated tokens, logits, and completion status. It wraps the | ||
parent class's serialized result in a PyTorch-specific LlmResponse object. | ||
|
||
Args: | ||
use_fast_logits (bool, optional): Whether to use fast logits computation | ||
for improved performance. Defaults to False. | ||
mpi_world_rank (int, optional): The MPI world rank for distributed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing: this is the leader rank of the draft model when using direct logits transfer. |
||
inference. Defaults to 0. | ||
|
||
Returns: | ||
LlmResponse | None: An LlmResponse object containing the request results | ||
if there is valid output, otherwise None. | ||
The response includes: | ||
- request_id: The request identifier (parent ID for child requests) | ||
- result: LlmResult wrapping both serialized and PyTorch-specific results | ||
- client_id: The client identifier for request routing | ||
|
||
Note: | ||
Returns None if the serialized result is empty (len(result) == 0), | ||
indicating no output was generated for this request iteration. | ||
""" | ||
result, is_final = super().create_serialized_result( | ||
use_fast_logits, mpi_world_rank) | ||
return LlmResponse( | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -244,22 +244,32 @@ def _handle_sequence(self, | |||||||||||
if response_tensors.cum_log_probs is not None: | ||||||||||||
output.cumulative_logprob = response_tensors.cum_log_probs[src_idx] | ||||||||||||
|
||||||||||||
if logprobs_result: | ||||||||||||
# prompt logprobs handling | ||||||||||||
if logprobs_result and logprobs_result.prompt is not None: # both backends | ||||||||||||
output.prompt_logprobs = logprobs_result.prompt | ||||||||||||
# generation logprobs handling (provenance varies by backend) | ||||||||||||
if logprobs_result and logprobs_result.generation is not None: # TRT backend | ||||||||||||
# update logprobs from ResponseWrapper (TRT top logprobs WAR) | ||||||||||||
output._last_logprobs_len = len(output.logprobs) | ||||||||||||
output.prompt_logprobs = logprobs_result.prompt | ||||||||||||
output.logprobs += logprobs_result.generation | ||||||||||||
elif response_tensors.log_probs is not None: | ||||||||||||
# handle logprobs directly from response tensors | ||||||||||||
elif response_tensors.log_probs is not None: # PyTorch backend | ||||||||||||
# handle logprobs directly from response tensors given by sampler | ||||||||||||
output._last_logprobs_len = len(output.logprobs) | ||||||||||||
output.logprobs = response_tensors.log_probs[src_idx] | ||||||||||||
# In streaming mode, since out-of-order responses are not possible, | ||||||||||||
# each streamed response_tensors.log_probs[src_idx] | ||||||||||||
# contains a streamwise monotonically growing list of logprobs. | ||||||||||||
# so we need to accumulate only the new ones unique to that particular streamed response | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The slicing logic
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||
output.logprobs += response_tensors.log_probs[src_idx][ | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hchings @LinPoly this lack of accumulation was causing https://nvbugs/5516660 , fixed here |
||||||||||||
output._last_logprobs_len:] | ||||||||||||
# overcome some WAR in the cpp executor | ||||||||||||
if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED: | ||||||||||||
# Check if logprobs is a list (not a dict or other structure) | ||||||||||||
if len(output.logprobs) > output.length: | ||||||||||||
# LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized. | ||||||||||||
# Therefore, we treat extra logprobs/logits as expected and only consume what's needed. | ||||||||||||
output.logprobs = output.logprobs[:output.length] | ||||||||||||
assert len(output.logprobs) == output.length | ||||||||||||
|
||||||||||||
if response_tensors.generation_logits is not None: | ||||||||||||
output.generation_logits = response_tensors.generation_logits[ | ||||||||||||
src_idx, :output.length] | ||||||||||||
|
@@ -698,7 +708,12 @@ def compute_logprobs( | |||||||||||
output_token_ids: Optional[list[int]], | ||||||||||||
) -> LogProbsResult: | ||||||||||||
""" | ||||||||||||
Compute top-K logprobs and ranks for each token position. | ||||||||||||
Compute top-K logprobs from logits when engine doesn't provide them directly. | ||||||||||||
|
||||||||||||
Used for post-processing logits into logprobs. | ||||||||||||
- Prompt logprobs (from context_logits): always used. | ||||||||||||
- Generation logprobs (from generation_logits, TRT backend): used when backend doesn't compute them in sampler (e.g., TRT). | ||||||||||||
- Generation logprobs (PyTorch backend): not used; computed in sampler, not here. | ||||||||||||
venkywonka marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
|
||||||||||||
Returns: | ||||||||||||
LogProbsResult, a NamedTuple containing: | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
|
||
from tensorrt_llm.logger import logger | ||
|
||
from .._torch.pyexecutor.llm_request import LlmResponse | ||
from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size, | ||
mpi_comm, mpi_rank, nvtx_range_debug) | ||
from ..bindings import executor as tllm | ||
|
@@ -1057,15 +1058,37 @@ def _get_params_for_first_rsp( | |
return None, None | ||
|
||
|
||
def _compute_pytorch_prompt_logprobs( | ||
generation_result: GenerationResult, | ||
response: LlmResponse) -> Optional[LogProbsResult]: | ||
"""Compute prompt logprobs for PyTorch backend (cached when streaming) """ | ||
logprob_params = generation_result._logprob_params # should be present and non None | ||
assert logprob_params is not None | ||
if generation_result._streaming: | ||
cached = getattr(generation_result, '_cached_prompt_logprobs', None) | ||
if cached is not None: | ||
return LogProbsResult( | ||
prompt=cached, generation=None | ||
) # generation logprobs, if requested, is provided directly in response.result.log_probs from the sampler. | ||
context_logits = response.result.context_logits | ||
assert context_logits is not None, "context_logits cannot be None when prompt_logprobs is requested." | ||
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, None, | ||
context_logits, None, None) | ||
if generation_result._streaming: | ||
generation_result._cached_prompt_logprobs = logprobs_result.prompt | ||
|
||
return logprobs_result | ||
|
||
|
||
def _get_logprobs(worker, | ||
response: tllm.Response, | ||
response: Union[tllm.Response, LlmResponse], | ||
is_pytorch_backend=False) -> Optional[LogProbsResult]: | ||
"""Compute logprob and prompt logprob and clear out logits if applicable. | ||
"""Compute logprobs from response logits when needed. | ||
|
||
Logprobs provenance varies by backend: | ||
- PyTorch: Generation logprobs computed in sampler, only prompt logprobs computed here | ||
- TRT: Both prompt and generation logprobs computed here from logits | ||
""" | ||
if is_pytorch_backend: | ||
# _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime. | ||
# In the PyTorch backend, logprobs are already computed during runtime if requested. | ||
return None | ||
|
||
logprobs_result = None | ||
generation_result = worker._results.get(response.client_id, None) | ||
|
@@ -1075,6 +1098,20 @@ def _get_logprobs(worker, | |
|
||
logprob_params = getattr(generation_result, "_logprob_params", None) | ||
if logprob_params: | ||
if is_pytorch_backend: | ||
if not logprob_params.prompt_logprobs: | ||
# PyTorch: generation logprobs computed in sampler, no post-processing needed | ||
return | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
else: | ||
logprobs_result = _compute_pytorch_prompt_logprobs( | ||
generation_result, response) | ||
|
||
if logprob_params.drop_context_logits: | ||
response.clear_context_logits() | ||
|
||
return logprobs_result | ||
|
||
# TRT backend: compute both prompt and generation logprobs from logits | ||
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, | ||
logprob_params.logprobs, | ||
response.result.context_logits, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from pydantic import BaseModel | ||
|
||
from tensorrt_llm.bindings import executor as tllme | ||
from tensorrt_llm.logger import logger | ||
|
||
|
||
@dataclass(slots=True, kw_only=True) | ||
|
@@ -449,6 +450,20 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo | |
|
||
if is_pytorch_backend: | ||
config_kwargs["return_log_probs"] = bool(self.logprobs) | ||
if self.prompt_logprobs and not self.return_context_logits: | ||
logger.info( | ||
"Since prompt_logprobs is requested but return_context_logits is False, " | ||
"internally enabling context logits for prompt logprobs computation. " | ||
"context logits will be dropped after computation as the user didn't explicitly request them." | ||
) | ||
# TODO(venky): Find a more elegant way to do this. | ||
# NOTE: This is an internal hack, so we can entirely avoid introducing | ||
# `prompt_logprobs` into the executor bindings and further into | ||
# model engine / sampler. | ||
# This is because, prompt_logprobs is a derived quantity from | ||
# context logits, and the capability to post-compute it | ||
# already exists in the worker. (see _get_logprobs in worker.py) | ||
Comment on lines
+459
to
+465
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This TODO comment indicates a design concern about the current implementation being a 'hack'. Consider creating a GitHub issue to track the technical debt for refactoring this approach to a more elegant solution. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
config_kwargs["return_context_logits"] = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is very elegant. Can we pass the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. by "into the framework" you mean touch the executor bindings (OutputConfig) and llm request data structures right? I actively avoided it due to the unnecessary plumbing complexity it would incur for this - with almost no benefit to functionality. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a TODO for this for now |
||
else: | ||
config_kwargs["return_log_probs"] = self._return_log_probs | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is a bit vague and probably incorrect. This boolean controls the direct transfer from draft model to target model for speculative decoding. I actually don't know if it applies to the pytorch backend as it is implemented.