Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,25 +1096,30 @@ def _prepare_tp_inputs(
previous_pos_indices = []
for request in extend_requests:
if next_draft_tokens_device is None or request.py_batch_idx is None:
# get token ids, including input token ids and draft token ids
if next_draft_tokens_device is None:
# skip dummy generation requests
# if next_draft_tokens_device is not None, but the batch index is None,
# it means the generation requests are dummy requests. For such cases,
# we do not update the input_ids and draft_tokens and just use the tokens
# from the previous batch.
input_ids.append(request.get_last_tokens(0))
input_ids.extend(request.py_draft_tokens)
draft_tokens.extend(request.py_draft_tokens)
# get other ids and lengths
num_draft_tokens = len(request.py_draft_tokens)
input_ids.append(request.get_last_tokens(0))
gather_ids.append(len(input_ids) - 1)
sequence_lengths.append(1 + num_draft_tokens)
past_seen_token_num = request.max_beam_num_tokens - 1
position_ids.append(past_seen_token_num)
draft_lens.append(num_draft_tokens)
prompt_lengths.append(num_draft_tokens + 1)
# draft tokens
input_ids.extend(request.py_draft_tokens)
prompt_lengths.append(1 + num_draft_tokens)
sequence_lengths.append(1 + num_draft_tokens)
gather_ids.extend(
list(
range(
len(input_ids) - num_draft_tokens, len(input_ids))))
range(len(position_ids),
len(position_ids) + 1 + self.max_draft_len)))
position_ids.extend(
list(
range(past_seen_token_num + 1,
range(past_seen_token_num,
past_seen_token_num + 1 + num_draft_tokens)))
draft_tokens.extend(request.py_draft_tokens)
num_cached_tokens_per_seq.append(past_seen_token_num)
request.py_batch_idx = batch_idx
batch_idx += 1
Expand Down