Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe
SizeType32 batchIdx{0};
for (auto const& llmReq : contextRequests)
{
auto const currentSequenceLen = llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens();
auto const disaggFirstGenTokenSize
= llmReq->getContextPhaseParams() ? llmReq->getContextPhaseParams().value().getFirstGenTokens().size() : 0;
auto const currentSequenceLen
= llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens() + disaggFirstGenTokenSize;
// Get position of the current sequence in the decoder
auto const seqSlot = llmReq->mSeqSlot.value();
batchSlotsRange[batchIdx] = seqSlot;
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ def create_py_executor_instance(
cache_transceiver_config = executor_config.cache_transceiver_config
kv_cache_transceiver = create_kv_cache_transceiver(
mapping, kv_cache_manager, attention_type, cache_transceiver_config)

return PyExecutor(
resource_manager,
scheduler,
Expand Down
4 changes: 1 addition & 3 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,7 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):

reqs_with_new_tokens = [
r for r in reqs
if (sequence_lengths_host_data[r.py_seq_slot] > r.get_num_tokens(0)
or self.is_trt_overlap)
if (sequence_lengths_host_data[r.py_seq_slot] > r.get_num_tokens(0))
]

# Add new tokens
Expand Down Expand Up @@ -821,7 +820,6 @@ def update_requests_multiple_beams_or_drafting(self,
for beam in range(beam_width):
seq_len = sequence_lengths_host_data[seq_slot * beam_width +
beam]
seq_len = seq_len + 1 if self.is_trt_overlap else seq_len
num_new_tokens[beam] = min(
num_generated_tokens,
seq_len - request.get_num_tokens(beam))
Expand Down