diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index 1d06ac0e860..baa51f47e73 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -63,7 +63,10 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe SizeType32 batchIdx{0}; for (auto const& llmReq : contextRequests) { - auto const currentSequenceLen = llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens(); + auto const disaggFirstGenTokenSize + = llmReq->getContextPhaseParams() ? llmReq->getContextPhaseParams().value().getFirstGenTokens().size() : 0; + auto const currentSequenceLen + = llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens() + disaggFirstGenTokenSize; // Get position of the current sequence in the decoder auto const seqSlot = llmReq->mSeqSlot.value(); batchSlotsRange[batchIdx] = seqSlot; diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 29f1c5d3ac8..0bfba50a9c9 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -520,7 +520,6 @@ def create_py_executor_instance( cache_transceiver_config = executor_config.cache_transceiver_config kv_cache_transceiver = create_kv_cache_transceiver( mapping, kv_cache_manager, attention_type, cache_transceiver_config) - return PyExecutor( resource_manager, scheduler, diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index b4dfdf25d45..31e17c1247d 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -751,8 +751,7 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM): reqs_with_new_tokens = [ r for r in reqs - if (sequence_lengths_host_data[r.py_seq_slot] > r.get_num_tokens(0) - or self.is_trt_overlap) + if (sequence_lengths_host_data[r.py_seq_slot] > r.get_num_tokens(0)) ] # Add new tokens @@ -821,7 +820,6 @@ def update_requests_multiple_beams_or_drafting(self, for beam in range(beam_width): seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam] - seq_len = seq_len + 1 if self.is_trt_overlap else seq_len num_new_tokens[beam] = min( num_generated_tokens, seq_len - request.get_num_tokens(beam))