diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index b20382006ec..db206bf2145 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -278,6 +278,17 @@ class LlmResponse: def has_error(self): return self.error_msg is not None + def clear_context_logits(self): + """Clear context logits from the response result. + + This is used to drop context logits after prompt_logprobs have been computed + when the user didn't explicitly request them. + """ + if self.result and hasattr(self.result, '_py_result'): + py_result = self.result._py_result + if hasattr(py_result, '_context_logits'): + py_result._context_logits = None + class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): """LlmRequest wraps `bindings.internal.batch_manager.LlmRequest` @@ -377,10 +388,36 @@ def __init__( def is_generation_only_request(self): return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY - def create_response( - self, - use_fast_logits=False, - mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None: + def create_response(self, + use_fast_logits=False, + mpi_world_rank=0) -> LlmResponse | None: + """Create an LlmResponse from the current request state. + + This method generates a response containing the request's execution results, + including generated tokens, logits, and completion status. It wraps the + parent class's serialized result in a PyTorch-specific LlmResponse object. + + Args: + use_fast_logits (bool, optional, default=False): Only applicable for TRT-backend with speculative decoding enabled. When returning generation logits under speculative decoding, + `use_fast_logits=True` replaces tensor payloads with tiny metadata so the target pulls logits + directly (zero-copy/IPC), reducing overhead; ignored on PyTorch. + mpi_world_rank (int, optional, default=0): Only applicable for TRT-backend, with speculative decoding + enabled, and `use_fast_logits=True`. Contains the MPI world rank of the process containing the draft + model, that produces the generation logits. This helps transfer logits from the draft model to the + target model without going through the serialization/transport path. + + Returns: + LlmResponse | None: An LlmResponse object containing the request results + if there is valid output, otherwise None. + The response includes: + - request_id: The request identifier (parent ID for child requests) + - result: LlmResult wrapping both serialized and PyTorch-specific results + - client_id: The client identifier for request routing + + Note: + Returns None if the serialized result is empty (len(result) == 0), + indicating no output was generated for this request iteration. + """ result, is_final = super().create_serialized_result( use_fast_logits, mpi_world_rank) return LlmResponse( diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 6e71fb2dc2f..ad4bab11818 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -11,6 +11,7 @@ from tensorrt_llm.logger import logger +from .._torch.pyexecutor.llm_request import LlmResponse from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, nvtx_range_debug) from ..bindings import executor as tllm @@ -701,24 +702,60 @@ def _get_params_for_first_rsp( return None, None +def _compute_pytorch_prompt_logprobs( + generation_result: GenerationResult, + response: LlmResponse) -> Optional[LogProbsResult]: + """Compute prompt logprobs for PyTorch backend (cached when streaming) """ + logprob_params = generation_result._logprob_params # should be present and non None + assert logprob_params is not None + if generation_result._streaming: + cached = getattr(generation_result, '_cached_prompt_logprobs', None) + if cached is not None: + return LogProbsResult( + prompt=cached, generation=None + ) # generation logprobs, if requested, is provided directly in response.result.log_probs from the sampler. + context_logits = response.result.context_logits + assert context_logits is not None, "context_logits cannot be None when prompt_logprobs is requested." + logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, None, + context_logits, None, None) + if generation_result._streaming: + generation_result._cached_prompt_logprobs = logprobs_result.prompt + + return logprobs_result + + def _get_logprobs(worker, - response: tllm.Response, + response: Union[tllm.Response, LlmResponse], is_pytorch_backend=False) -> Optional[LogProbsResult]: - """Compute logprob and prompt logprob and clear out logits if applicable. + """Compute logprobs from response logits when needed. + + Logprobs provenance varies by backend: + - PyTorch: Generation logprobs computed in sampler, only prompt logprobs computed here + - TRT: Both prompt and generation logprobs computed here from logits """ - if is_pytorch_backend: - # _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime. - # In the PyTorch backend, logprobs are already computed during runtime if requested. - return None logprobs_result = None generation_result = worker._results.get(response.client_id, None) if not generation_result: - return + return None logprob_params = getattr(generation_result, "_logprob_params", None) if logprob_params: + if is_pytorch_backend: + if not logprob_params.prompt_logprobs: + # PyTorch: generation logprobs computed in sampler, no post-processing needed + return None + else: + logprobs_result = _compute_pytorch_prompt_logprobs( + generation_result, response) + + if logprob_params.drop_context_logits: + response.clear_context_logits() + + return logprobs_result + + # TRT backend: compute both prompt and generation logprobs from logits logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, logprob_params.logprobs, response.result.context_logits, diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index ddbe2f636aa..8a266c06361 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -244,22 +244,36 @@ def _handle_sequence(self, if response_tensors.cum_log_probs is not None: output.cumulative_logprob = response_tensors.cum_log_probs[src_idx] - if logprobs_result: + # prompt logprobs handling + if logprobs_result and logprobs_result.prompt is not None: # both backends + output.prompt_logprobs = logprobs_result.prompt + # generation logprobs handling (provenance varies by backend) + if logprobs_result and logprobs_result.generation is not None: # TRT backend # update logprobs from ResponseWrapper (TRT top logprobs WAR) output._last_logprobs_len = len(output.logprobs) - output.prompt_logprobs = logprobs_result.prompt output.logprobs += logprobs_result.generation - elif response_tensors.log_probs is not None: - # handle logprobs directly from response tensors + elif response_tensors.log_probs is not None: # PyTorch backend + # handle logprobs directly from response tensors given by sampler output._last_logprobs_len = len(output.logprobs) - output.logprobs = response_tensors.log_probs[src_idx] + # In streaming mode, since out-of-order responses are not possible, + # each streamed response_tensors.log_probs[src_idx] + # contains a streamwise monotonically growing list of logprobs. + # so we need to accumulate only the new ones unique to that particular streamed response + assert output._last_logprobs_len <= len( + response_tensors.log_probs[src_idx] + ), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length (" + f"{len(response_tensors.log_probs[src_idx])})") + output.logprobs += response_tensors.log_probs[src_idx][ + output._last_logprobs_len:] # overcome some WAR in the cpp executor if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED: + # Check if logprobs is a list (not a dict or other structure) if len(output.logprobs) > output.length: # LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized. # Therefore, we treat extra logprobs/logits as expected and only consume what's needed. output.logprobs = output.logprobs[:output.length] assert len(output.logprobs) == output.length + if response_tensors.generation_logits is not None: output.generation_logits = response_tensors.generation_logits[ src_idx, :output.length] @@ -698,7 +712,12 @@ def compute_logprobs( output_token_ids: Optional[list[int]], ) -> LogProbsResult: """ - Compute top-K logprobs and ranks for each token position. + Compute top-K logprobs from logits when engine doesn't provide them directly. + + Used for post-processing logits into logprobs. + - Prompt logprobs (from context_logits): always used. + - Generation logprobs (from generation_logits, TRT backend): used when backend doesn't compute them in sampler (e.g., TRT). + - Generation logprobs (PyTorch backend): not used; computed in sampler, not here. Returns: LogProbsResult, a NamedTuple containing: diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index b2665b587ec..25632b02ecb 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -574,12 +574,6 @@ def _check_arguments(self, prompt_len: int, query_len: int, is_gen_only: bool) -> None: if self.args.backend in ["pytorch", "_autodeploy"]: - # TODO: remove these checks after PyTorch backend - # fully support TopK prompt and generation logprobs. - if sampling_params.prompt_logprobs: - raise ValueError( - f"`prompt_logprobs` in sampling_params is not supported in the PyTorch backend yet. Received `prompt_logprobs={sampling_params.prompt_logprobs}`. Please unset this field." - ) if sampling_params.logprobs and sampling_params.logprobs > 1: raise ValueError( f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead." diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 361c0fc0c0f..231fa82f380 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from tensorrt_llm.bindings import executor as tllme +from tensorrt_llm.logger import logger @dataclass(slots=True, kw_only=True) @@ -449,6 +450,20 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo if is_pytorch_backend: config_kwargs["return_log_probs"] = bool(self.logprobs) + if self.prompt_logprobs and not self.return_context_logits: + logger.info( + "Since prompt_logprobs is requested but return_context_logits is False, " + "internally enabling context logits for prompt logprobs computation. " + "context logits will be dropped after computation as the user didn't explicitly request them." + ) + # TODO(venky): Find a more elegant way to do this. + # NOTE: This is an internal hack, so we can entirely avoid introducing + # `prompt_logprobs` into the executor bindings and further into + # model engine / sampler. + # This is because, prompt_logprobs is a derived quantity from + # context logits, and the capability to post-compute it + # already exists in the worker. (see _get_logprobs in worker.py) + config_kwargs["return_context_logits"] = True else: config_kwargs["return_log_probs"] = self._return_log_probs diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index e99ac2b7b5a..c19ccfaa821 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1804,14 +1804,20 @@ def llm_return_logprobs_test_harness(prompt_logprobs: Optional[int], backend=None): LLM_CLASS = LLM llm_args_extra = {} + kv_cache_args_extra = {} if backend in ["pytorch", "autodeploy"]: LLM_CLASS = LLM_torch + if streaming: + # need this so that context_logits / prompt_logprobs are not dropped + # in the 2nd reuse of llm.generate() in streaming mode + kv_cache_args_extra["enable_block_reuse"] = False else: llm_args_extra["fast_build"] = True llm = LLM_CLASS( llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4, + **kv_cache_args_extra), build_config=BuildConfig(gather_context_logits=True), tensor_parallel_size=tp_size, gather_generation_logits=True, @@ -1864,7 +1870,7 @@ async def task(id: int, prompt: str): async for output in llm.generate_async(prompt, sampling_params, streaming=True): - logprobs_result_streaming += output.outputs[0].logprobs + logprobs_result_streaming += output.outputs[0].logprobs_diff # comparing streaming logprobs result to non-streaming assert logprobs_result_streaming == logprobs_result @@ -1877,21 +1883,28 @@ async def main(): asyncio.run(main()) -@pytest.mark.skip(reason="https://nvbugs/5516660") @force_ampere @pytest.mark.parametrize( - "prompt_logprobs, logprobs, return_context_logits, return_generation_logits", - [(2, None, True, False), (None, 2, False, False)]) + "prompt_logprobs, logprobs, return_context_logits, return_generation_logits, backend", + [ + # TRT backend test cases + (2, None, True, False, "trt"), # prompt_logprobs with context_logits + (None, 2, False, False, "trt"), # generation logprobs only (top-2) + (2, None, False, False, + "trt"), # prompt_logprobs without context_logits + (None, None, False, False, "trt"), # no logprobs at all + ]) def test_llm_return_logprobs(prompt_logprobs: Optional[int], logprobs: Optional[int], return_context_logits: bool, - return_generation_logits: bool): - llm_return_logprobs_test_harness(prompt_logprobs, logprobs, + return_generation_logits: bool, backend: str): + llm_return_logprobs_test_harness(prompt_logprobs, + logprobs, return_context_logits, - return_generation_logits) + return_generation_logits, + backend=backend) -@pytest.mark.skip(reason="https://nvbugs/5516660") @force_ampere def test_llm_return_logprobs_streaming(): llm_return_logprobs_test_harness(2, 2, False, True, streaming=True) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index ef2de1350c1..3d03c97aba2 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -1,5 +1,6 @@ import random from contextlib import contextmanager, nullcontext +from typing import Optional import pytest @@ -19,8 +20,9 @@ from .test_llm import (_test_llm_capture_request_error, get_model_path, global_kvcache_config, llama_model_path, llm_get_stats_async_test_harness, - llm_get_stats_test_harness, llm_test_harness, prompts, - run_llm_abort_request, + llm_get_stats_test_harness, + llm_return_logprobs_test_harness, llm_test_harness, + prompts, run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler, tinyllama_logits_processor_test_harness) from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb, @@ -892,3 +894,45 @@ def test_min_tokens(use_speculative: bool): assert len(res.outputs) == 1 assert len(res.outputs[0].token_ids) == output_len + + +@pytest.mark.parametrize( + "prompt_logprobs, logprobs, return_context_logits, return_generation_logits, backend", + [ + (2, None, True, False, + "pytorch"), # prompt_logprobs with context_logits + (None, 1, False, False, + "pytorch"), # generation logprobs only (top-1, PyTorch limit) + (2, None, False, False, + "pytorch"), # prompt_logprobs without context_logits + (None, None, False, False, "pytorch"), # no logprobs at all + ]) +def test_llm_return_logprobs(prompt_logprobs: Optional[int], + logprobs: Optional[int], + return_context_logits: bool, + return_generation_logits: bool, backend: str): + llm_return_logprobs_test_harness(prompt_logprobs, + logprobs, + return_context_logits, + return_generation_logits, + backend=backend) + + +@pytest.mark.parametrize( + "prompt_logprobs, logprobs, return_context_logits, return_generation_logits", + [ + (None, 1, False, + False), # generation logprobs only (top-1, PyTorch limit) + (2, None, True, False), # prompt_logprobs with context_logits + (2, None, False, False), # prompt_logprobs only + (2, 1, False, False), # both prompt and generation logprobs + ]) +def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs, + return_context_logits, + return_generation_logits): + llm_return_logprobs_test_harness(prompt_logprobs, + logprobs, + return_context_logits, + return_generation_logits, + streaming=True, + backend="pytorch")