Skip to content
Merged
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(
self.py_max_new_tokens = self.max_new_tokens
self.py_batch_idx = None
self.py_rewind_len = 0
self.py_draft_tokens = self.draft_tokens
self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens
self.py_last_draft_tokens = None
self.py_decoding_iter = 0
self.is_attention_dp_dummy = False
Expand Down
97 changes: 63 additions & 34 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,8 @@ def _prepare_tp_inputs(
extend_requests = []
generation_requests = []
for request in scheduled_requests.generation_requests:
if request.py_draft_tokens is not None or next_draft_tokens_device is not None:
if len(request.py_draft_tokens
) > 0 or next_draft_tokens_device is not None:
extend_requests.append(request)
else:
generation_requests.append(request)
Expand All @@ -1107,32 +1108,43 @@ def _prepare_tp_inputs(
# will contain previous batch incices of generation requests
previous_batch_indices = []
previous_pos_indices = []
request_ids_with_previous_batch = []
num_extend_reqs_wo_previous_batch = 0
for request in extend_requests:
if next_draft_tokens_device is None or request.py_batch_idx is None:
num_draft_tokens = len(request.py_draft_tokens)
# the request has no previous device tensors:
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
# (2) request.py_batch_idx is None, which means the request has no previous batch.
# the second condition includes dummy generation requests created for CUDA graph padding or
# attention DP. These dummy generation requests should be at the head of generation_requests.
# TODO: move the dummy generation requests to the end of generation_requests to align with
# the logic for those requests in generation_requests.
# get token ids, including input token ids and draft token ids
input_ids.append(request.get_last_tokens(0))
gather_ids.append(len(input_ids) - 1)
sequence_lengths.append(1 + num_draft_tokens)
input_ids.extend(request.py_draft_tokens)
draft_tokens.extend(request.py_draft_tokens)
# get other ids and lengths
num_draft_tokens = len(request.py_draft_tokens)
past_seen_token_num = request.max_beam_num_tokens - 1
position_ids.append(past_seen_token_num)
draft_lens.append(num_draft_tokens)
prompt_lengths.append(num_draft_tokens + 1)
# draft tokens
input_ids.extend(request.py_draft_tokens)
prompt_lengths.append(request.py_prompt_len)
sequence_lengths.append(1 + num_draft_tokens)
gather_ids.extend(
list(
range(
len(input_ids) - num_draft_tokens, len(input_ids))))
range(len(position_ids),
len(position_ids) + 1 + self.max_draft_len)))
position_ids.extend(
list(
range(past_seen_token_num + 1,
range(past_seen_token_num,
past_seen_token_num + 1 + num_draft_tokens)))
draft_tokens.extend(request.py_draft_tokens)
num_cached_tokens_per_seq.append(past_seen_token_num)
request_ids.append(request.py_request_id)
# update batch index
request.py_batch_idx = batch_idx
batch_idx += 1
num_extend_reqs_wo_previous_batch += 1
else:
# batch index
# update batch index
previous_batch_idx = request.py_batch_idx
request.py_batch_idx = batch_idx
batch_idx += 1
Expand All @@ -1157,8 +1169,10 @@ def _prepare_tp_inputs(
num_cached_tokens_per_seq.append(past_seen_token_num +
self.max_draft_len + 1)
prompt_lengths.append(request.py_prompt_len)
request_ids_with_previous_batch.append(request.py_request_id)

request_ids.append(request.py_request_id)
# move requests with previous batch to the end of the list
request_ids.extend(request_ids_with_previous_batch)

sequence_lengths.extend([1] * len(generation_requests))
gather_ids.extend(
Expand Down Expand Up @@ -1191,12 +1205,20 @@ def _prepare_tp_inputs(
batch_idx += 1

num_tokens = len(input_ids)
num_draft_tokens = len(draft_tokens)
previous_batchs = len(previous_batch_indices)
# if exist requests that do not have previous batch, copy input_ids and draft_tokens
if num_tokens > 0:
input_ids = torch.tensor(input_ids,
dtype=torch.int,
pin_memory=True)
self.input_ids_cuda[:num_tokens].copy_(input_ids, non_blocking=True)
if num_draft_tokens > 0:
draft_tokens = torch.tensor(draft_tokens,
dtype=torch.int,
pin_memory=True)
self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens,
non_blocking=True)
if next_draft_tokens_device is not None:
if len(previous_batch_indices) > 0:
previous_batch_indices = torch.tensor(previous_batch_indices,
Expand All @@ -1215,26 +1237,39 @@ def _prepare_tp_inputs(
non_blocking=True)
# previous draft tokens
previous_batch_draft_tokens = previous_batchs * self.max_draft_len
self.draft_tokens_cuda[:previous_batch_draft_tokens].copy_(
next_draft_tokens_device[
self.draft_tokens_cuda[
num_draft_tokens:num_draft_tokens +
previous_batch_draft_tokens].copy_(next_draft_tokens_device[
self.previous_batch_indices_cuda[:previous_batchs], :].
flatten(),
non_blocking=True)
flatten(),
non_blocking=True)
# prepare data for the preprocess inputs
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * (
1 + self.max_draft_len)
pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens
pre_batch_start_idx = num_extend_reqs_wo_previous_batch
pre_batch_end_idx = pre_batch_start_idx + previous_batchs
previous_pos_indices = torch.tensor(previous_pos_indices,
dtype=torch.int,
pin_memory=True)
self.previous_pos_indices_cuda[:previous_batch_tokens].copy_(
previous_pos_indices, non_blocking=True)
self.previous_pos_id_offsets_cuda[:previous_batch_tokens].copy_(
new_tokens_lens_device[
self.previous_pos_indices_cuda[:previous_batch_tokens]],
non_blocking=True)
self.previous_kv_lens_offsets_cuda[:previous_batchs].copy_(
kv_len_offsets_device[
self.previous_batch_indices_cuda[:previous_batchs]],
non_blocking=True)
self.previous_pos_indices_cuda[
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
previous_pos_indices, non_blocking=True)
self.previous_pos_id_offsets_cuda[
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
new_tokens_lens_device[self.previous_pos_indices_cuda[
pre_tokens_start_idx:pre_tokens_end_idx]],
non_blocking=True)
self.previous_kv_lens_offsets_cuda[
pre_batch_start_idx:pre_batch_end_idx].copy_(
kv_len_offsets_device[
self.previous_batch_indices_cuda[:previous_batchs]],
non_blocking=True)
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
self.previous_pos_id_offsets_cuda[:pre_tokens_start_idx] *= 0
self.previous_kv_lens_offsets_cuda[:pre_batch_start_idx] *= 0
else:
# change the data to zeros to skip the value changes in _preprocess_inputs
self.previous_pos_id_offsets_cuda *= 0
Expand Down Expand Up @@ -1305,12 +1340,6 @@ def _prepare_tp_inputs(

if spec_metadata is not None:
total_draft_lens = sum(draft_lens)
if len(draft_tokens) > 0:
draft_tokens = torch.tensor(draft_tokens,
dtype=torch.int,
pin_memory=True)
self.draft_tokens_cuda[:len(draft_tokens)].copy_(
draft_tokens, non_blocking=True)
spec_metadata.draft_tokens = self.draft_tokens_cuda[:
total_draft_lens]
spec_metadata.request_ids = request_ids
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,8 +1488,7 @@ def _pad_attention_dp_dummy_request(self):
request_ids=[0],
is_gen=not self.has_context_request,
prepare_resource=not self.has_context_request,
max_num_draft_tokens=0
if self.has_context_request else self.max_draft_tokens,
max_num_draft_tokens=self.max_draft_tokens,
)[0]
llm_request.is_attention_dp_dummy = True
self.active_requests.append(llm_request)
Expand Down Expand Up @@ -1525,7 +1524,8 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
req.decoding_iter = 1
req.py_decoding_iter = 1
first_gen_tokens = req.context_phase_params.first_gen_tokens
req.py_draft_tokens = req.context_phase_params.draft_tokens
ctx_draft_tokens = req.context_phase_params.draft_tokens
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
beam_width = req.sampling_config.beam_width
for beam in range(0, beam_width):
req.add_new_token(first_gen_tokens[beam], beam)
Expand Down
24 changes: 15 additions & 9 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,13 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
req_beam_width, req)
for _ in range(self.num_extra_kv_tokens):
self.impl.add_token(req.py_request_id)
if req.py_draft_tokens is not None:
for _ in range(len(req.py_draft_tokens)):
self.impl.add_token(req.py_request_id)
for _ in range(len(req.py_draft_tokens)):
self.impl.add_token(req.py_request_id)

for req in generation_batch:
self.impl.add_token(req.py_request_id)
if req.py_draft_tokens is not None:
for _ in range(len(req.py_draft_tokens)):
self.impl.add_token(req.py_request_id)
for _ in range(len(req.py_draft_tokens)):
self.impl.add_token(req.py_request_id)

def add_dummy_requests(
self,
Expand All @@ -328,7 +326,11 @@ def add_dummy_requests(
requests = []
for i, req_id in enumerate(request_ids):
sampling_params = SamplingParams()
token_num = token_nums[i] if token_nums is not None else 1
# Here 1+max_num_draft_tokens is used to extend the prompt length to
# a non-zero number to skip illegal memory access issue in MLA kernel
# during warmup.
token_num = token_nums[
i] if token_nums is not None else 1 + max_num_draft_tokens
encoder_input_tokens = [
1
] * token_num if self.impl.cross_kv else None
Expand All @@ -343,12 +345,16 @@ def add_dummy_requests(
req.paged_kv_block_ids = []
if prepare_resource:
self.impl.add_sequence(req_id, token_num, beam_width, req)
for _ in range(self.num_extra_kv_tokens):
self.impl.add_token(req_id)
if is_gen:
req.state = LlmRequestState.GENERATION_IN_PROGRESS
req.prompt_len = token_num - 1
req.py_prompt_len = req.prompt_len
if max_num_draft_tokens > 0:
req.py_draft_tokens = [0] * max_num_draft_tokens
req.py_draft_tokens = [1] * max_num_draft_tokens
if prepare_resource:
for _ in range(max_num_draft_tokens):
self.impl.add_token(req_id)
requests.append(req)
return requests

Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def handle_logits(request: LlmRequest, tokens: list[int], count=1):
extend_requests = []
generation_requests = []
for request in scheduled_requests.generation_requests:
if request.py_draft_tokens is not None:
if len(request.py_draft_tokens) > 0:
extend_requests.append(request)
else:
generation_requests.append(request)
Expand Down Expand Up @@ -361,7 +361,9 @@ def _mixed_sample(self, scheduled_requests: ScheduledRequests,
for request in scheduled_requests.generation_requests:
if request.state == LlmRequestState.GENERATION_COMPLETE:
continue
assert request.py_draft_tokens is None, "Speculative decoding not supported in SeparateDecoder."
assert len(
request.py_draft_tokens
) == 0, "Speculative decoding not supported in SeparateDecoder."
token_logits = logits[idx:idx + 1, :]
new_token, probs = decode_single_request(request, token_logits)
new_tokens_device_array.append(new_token)
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def schedule(
self, active_requests: RequestList, inflight_request_ids: set[int]
) -> tuple[list[LlmRequest], list[LlmRequest]]:
for request in active_requests:
if request.py_draft_tokens is not None:
if len(request.py_draft_tokens) > 0:
request.draft_tokens = request.py_draft_tokens
return self.impl(active_requests, inflight_request_ids,
self.max_batch_size, self.max_num_tokens)
Expand Down
6 changes: 0 additions & 6 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,7 @@ accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache SKIP (https://n
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5231468)
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache SKIP (https://nvbugs/5231310)
test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image] SKIP (https://nvbugs/5233423)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5294983)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5234002)
examples/test_gemma.py::test_llm_hf_gemma_quantization_1gpu[gemma-2-27b-it-fp8-bfloat16-8] SKIP (https://nvbugs/5234164)
full::GH200/examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary[disable_weight_only] SKIP (https://nvbugs/5250460)
full::GH200/examples/test_gemma.py::test_llm_gemma_1gpu_summary[gemma-2-27b-it-other-bfloat16-8] SKIP (https://nvbugs/5250460)
Expand Down Expand Up @@ -415,7 +410,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5294983)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5285965)
examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int4-float16] SKIP (https://nvbugs/5289523)
examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int8-float16] SKIP (https://nvbugs/5289523)
Expand Down