From e799438964215b05bd0038ff6af6b521767c238a Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Thu, 15 May 2025 18:29:25 -0700 Subject: [PATCH 01/11] fix input ids when adding dummy requests. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 3 +-- tensorrt_llm/_torch/pyexecutor/resource_manager.py | 7 +++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 2ba11b64402..88a1f6b77fe 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1488,8 +1488,7 @@ def _pad_attention_dp_dummy_request(self): request_ids=[0], is_gen=not self.has_context_request, prepare_resource=not self.has_context_request, - max_num_draft_tokens=0 - if self.has_context_request else self.max_draft_tokens, + max_num_draft_tokens=self.max_draft_tokens, )[0] llm_request.is_attention_dp_dummy = True self.active_requests.append(llm_request) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 1a9da786dba..e21a0b9e244 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -347,8 +347,11 @@ def add_dummy_requests( req.state = LlmRequestState.GENERATION_IN_PROGRESS req.prompt_len = token_num - 1 req.py_prompt_len = req.prompt_len - if max_num_draft_tokens > 0: - req.py_draft_tokens = [0] * max_num_draft_tokens + req.py_draft_tokens = [1] * max_num_draft_tokens + for _ in range(max_num_draft_tokens): + req.add_new_token(1, 0) + if prepare_resource: + self.impl.add_token(req_id) requests.append(req) return requests From 06c2d038fe81df6584f65b6925213a93af39b8ce Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Fri, 16 May 2025 01:36:37 -0700 Subject: [PATCH 02/11] fix token_nums. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/resource_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index e21a0b9e244..d5259d9cf64 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -328,7 +328,8 @@ def add_dummy_requests( requests = [] for i, req_id in enumerate(request_ids): sampling_params = SamplingParams() - token_num = token_nums[i] if token_nums is not None else 1 + token_num = token_nums[ + i] if token_nums is not None else 1 + max_num_draft_tokens encoder_input_tokens = [ 1 ] * token_num if self.impl.cross_kv else None @@ -348,9 +349,8 @@ def add_dummy_requests( req.prompt_len = token_num - 1 req.py_prompt_len = req.prompt_len req.py_draft_tokens = [1] * max_num_draft_tokens - for _ in range(max_num_draft_tokens): - req.add_new_token(1, 0) - if prepare_resource: + if prepare_resource: + for _ in range(max_num_draft_tokens): self.impl.add_token(req_id) requests.append(req) return requests From f7f5c37f7dcc72f114c8655995ccf110ccf38985 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Wed, 14 May 2025 04:50:14 -0700 Subject: [PATCH 03/11] fix the dummy requests + overlap scheduler + spec decoding issue. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/pyexecutor/model_engine.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 18eac341cab..f6e32316d3f 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1109,25 +1109,30 @@ def _prepare_tp_inputs( previous_pos_indices = [] for request in extend_requests: if next_draft_tokens_device is None or request.py_batch_idx is None: + # get token ids, including input token ids and draft token ids + if next_draft_tokens_device is None: + # skip dummy generation requests + # if next_draft_tokens_device is not None, but the batch index is None, + # it means the generation requests are dummy requests. For such cases, + # we do not update the input_ids and draft_tokens and just use the tokens + # from the previous batch. + input_ids.append(request.get_last_tokens(0)) + input_ids.extend(request.py_draft_tokens) + draft_tokens.extend(request.py_draft_tokens) + # get other ids and lengths num_draft_tokens = len(request.py_draft_tokens) - input_ids.append(request.get_last_tokens(0)) - gather_ids.append(len(input_ids) - 1) - sequence_lengths.append(1 + num_draft_tokens) past_seen_token_num = request.max_beam_num_tokens - 1 - position_ids.append(past_seen_token_num) draft_lens.append(num_draft_tokens) - prompt_lengths.append(num_draft_tokens + 1) - # draft tokens - input_ids.extend(request.py_draft_tokens) + prompt_lengths.append(1 + num_draft_tokens) + sequence_lengths.append(1 + num_draft_tokens) gather_ids.extend( list( - range( - len(input_ids) - num_draft_tokens, len(input_ids)))) + range(len(position_ids), + len(position_ids) + 1 + self.max_draft_len))) position_ids.extend( list( - range(past_seen_token_num + 1, + range(past_seen_token_num, past_seen_token_num + 1 + num_draft_tokens))) - draft_tokens.extend(request.py_draft_tokens) num_cached_tokens_per_seq.append(past_seen_token_num) request.py_batch_idx = batch_idx batch_idx += 1 From bdd28b551f7a526c81f9c75bea2f74a61be8dc31 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Thu, 22 May 2025 01:21:49 -0700 Subject: [PATCH 04/11] fix prepare inputs for the dummy requests in dp+mtp. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/pyexecutor/model_engine.py | 78 +++++++++++-------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index f6e32316d3f..0e561ed4338 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1107,18 +1107,14 @@ def _prepare_tp_inputs( # will contain previous batch incices of generation requests previous_batch_indices = [] previous_pos_indices = [] + request_ids_with_previous_batch = [] + num_extend_reqs_wo_previous_batch = 0 for request in extend_requests: if next_draft_tokens_device is None or request.py_batch_idx is None: # get token ids, including input token ids and draft token ids - if next_draft_tokens_device is None: - # skip dummy generation requests - # if next_draft_tokens_device is not None, but the batch index is None, - # it means the generation requests are dummy requests. For such cases, - # we do not update the input_ids and draft_tokens and just use the tokens - # from the previous batch. - input_ids.append(request.get_last_tokens(0)) - input_ids.extend(request.py_draft_tokens) - draft_tokens.extend(request.py_draft_tokens) + input_ids.append(request.get_last_tokens(0)) + input_ids.extend(request.py_draft_tokens) + draft_tokens.extend(request.py_draft_tokens) # get other ids and lengths num_draft_tokens = len(request.py_draft_tokens) past_seen_token_num = request.max_beam_num_tokens - 1 @@ -1134,10 +1130,13 @@ def _prepare_tp_inputs( range(past_seen_token_num, past_seen_token_num + 1 + num_draft_tokens))) num_cached_tokens_per_seq.append(past_seen_token_num) + request_ids.append(request.py_request_id) + # update batch index request.py_batch_idx = batch_idx batch_idx += 1 + num_extend_reqs_wo_previous_batch += 1 else: - # batch index + # update batch index previous_batch_idx = request.py_batch_idx request.py_batch_idx = batch_idx batch_idx += 1 @@ -1162,8 +1161,10 @@ def _prepare_tp_inputs( num_cached_tokens_per_seq.append(past_seen_token_num + self.max_draft_len + 1) prompt_lengths.append(request.py_prompt_len) + request_ids_with_previous_batch.append(request.py_request_id) - request_ids.append(request.py_request_id) + # move requests with previous batch to the end of the list + request_ids.extend(request_ids_with_previous_batch) sequence_lengths.extend([1] * len(generation_requests)) gather_ids.extend( @@ -1196,12 +1197,20 @@ def _prepare_tp_inputs( batch_idx += 1 num_tokens = len(input_ids) + num_draft_tokens = len(draft_tokens) previous_batchs = len(previous_batch_indices) + # if exist requests that do not have previous batch, copy input_ids and draft_tokens if num_tokens > 0: input_ids = torch.tensor(input_ids, dtype=torch.int, pin_memory=True) self.input_ids_cuda[:num_tokens].copy_(input_ids, non_blocking=True) + if num_draft_tokens > 0: + draft_tokens = torch.tensor(draft_tokens, + dtype=torch.int, + pin_memory=True) + self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens, + non_blocking=True) if next_draft_tokens_device is not None: if len(previous_batch_indices) > 0: previous_batch_indices = torch.tensor(previous_batch_indices, @@ -1220,26 +1229,39 @@ def _prepare_tp_inputs( non_blocking=True) # previous draft tokens previous_batch_draft_tokens = previous_batchs * self.max_draft_len - self.draft_tokens_cuda[:previous_batch_draft_tokens].copy_( - next_draft_tokens_device[ + self.draft_tokens_cuda[ + num_draft_tokens:num_draft_tokens + + previous_batch_draft_tokens].copy_(next_draft_tokens_device[ self.previous_batch_indices_cuda[:previous_batchs], :]. - flatten(), - non_blocking=True) + flatten(), + non_blocking=True) # prepare data for the preprocess inputs kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1 + pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * ( + 1 + self.max_draft_len) + pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens + pre_batch_start_idx = num_extend_reqs_wo_previous_batch + pre_batch_end_idx = pre_batch_start_idx + previous_batchs previous_pos_indices = torch.tensor(previous_pos_indices, dtype=torch.int, pin_memory=True) - self.previous_pos_indices_cuda[:previous_batch_tokens].copy_( - previous_pos_indices, non_blocking=True) - self.previous_pos_id_offsets_cuda[:previous_batch_tokens].copy_( - new_tokens_lens_device[ - self.previous_pos_indices_cuda[:previous_batch_tokens]], - non_blocking=True) - self.previous_kv_lens_offsets_cuda[:previous_batchs].copy_( - kv_len_offsets_device[ - self.previous_batch_indices_cuda[:previous_batchs]], - non_blocking=True) + self.previous_pos_indices_cuda[ + pre_tokens_start_idx:pre_tokens_end_idx].copy_( + previous_pos_indices, non_blocking=True) + self.previous_pos_id_offsets_cuda[ + pre_tokens_start_idx:pre_tokens_end_idx].copy_( + new_tokens_lens_device[self.previous_pos_indices_cuda[ + pre_tokens_start_idx:pre_tokens_end_idx]], + non_blocking=True) + self.previous_kv_lens_offsets_cuda[ + pre_batch_start_idx:pre_batch_end_idx].copy_( + kv_len_offsets_device[ + self.previous_batch_indices_cuda[:previous_batchs]], + non_blocking=True) + # for the requests that do not have previous batch, set the previous_pos_id_offsets and + # previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs + self.previous_pos_id_offsets_cuda[:pre_tokens_start_idx] *= 0 + self.previous_kv_lens_offsets_cuda[:pre_batch_start_idx] *= 0 else: # change the data to zeros to skip the value changes in _preprocess_inputs self.previous_pos_id_offsets_cuda *= 0 @@ -1310,12 +1332,6 @@ def _prepare_tp_inputs( if spec_metadata is not None: total_draft_lens = sum(draft_lens) - if len(draft_tokens) > 0: - draft_tokens = torch.tensor(draft_tokens, - dtype=torch.int, - pin_memory=True) - self.draft_tokens_cuda[:len(draft_tokens)].copy_( - draft_tokens, non_blocking=True) spec_metadata.draft_tokens = self.draft_tokens_cuda[: total_draft_lens] spec_metadata.request_ids = request_ids From 2d4ace6fca1e2e511947f4859669fad944873518 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Thu, 22 May 2025 03:20:40 -0700 Subject: [PATCH 05/11] fix prompt_len. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 0e561ed4338..cb329329e03 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1119,7 +1119,7 @@ def _prepare_tp_inputs( num_draft_tokens = len(request.py_draft_tokens) past_seen_token_num = request.max_beam_num_tokens - 1 draft_lens.append(num_draft_tokens) - prompt_lengths.append(1 + num_draft_tokens) + prompt_lengths.append(request.py_prompt_len) sequence_lengths.append(1 + num_draft_tokens) gather_ids.extend( list( From 49cd27273596c5565830ab9c48a7b8235aaa61d8 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Fri, 23 May 2025 02:56:05 -0700 Subject: [PATCH 06/11] add extra tokens. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/resource_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index d5259d9cf64..3f19abce5ff 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -344,6 +344,8 @@ def add_dummy_requests( req.paged_kv_block_ids = [] if prepare_resource: self.impl.add_sequence(req_id, token_num, beam_width, req) + for _ in range(self.num_extra_kv_tokens): + self.impl.add_token(req_id) if is_gen: req.state = LlmRequestState.GENERATION_IN_PROGRESS req.prompt_len = token_num - 1 From a7ef0904b2cf5000c974c97e98f674222d62f7df Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Fri, 23 May 2025 05:14:02 -0700 Subject: [PATCH 07/11] add waived tests back. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 7898f0edeb2..02492fff569 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -365,12 +365,7 @@ accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache SKIP (https://n accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5231468) accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache SKIP (https://nvbugs/5231310) test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image] SKIP (https://nvbugs/5233423) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5294983) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5234002) examples/test_gemma.py::test_llm_hf_gemma_quantization_1gpu[gemma-2-27b-it-fp8-bfloat16-8] SKIP (https://nvbugs/5234164) full::GH200/examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary[disable_weight_only] SKIP (https://nvbugs/5250460) full::GH200/examples/test_gemma.py::test_llm_gemma_1gpu_summary[gemma-2-27b-it-other-bfloat16-8] SKIP (https://nvbugs/5250460) @@ -399,24 +394,6 @@ triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-deco triton_server/test_triton.py::test_qwen2_vl[qwen2_vl] SKIP triton_server/test_triton.py::test_gpt_ib_speculative_decoding_bls[gpt-ib-speculative-decoding-bls] SKIP triton_server/test_triton_llm.py::test_mistral_v1_multi_models[False-1---False-True-False-0-128-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--max_utilization-4096--1-1-1-False-ensemble] SKIP -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5286795) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5286795) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5294983) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5285965) examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int4-float16] SKIP (https://nvbugs/5289523) examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int8-float16] SKIP (https://nvbugs/5289523) examples/test_qwen.py::test_llm_qwen_7b_int8_kv_1node_1gpus[qwen2_vl_7b_instruct-enable_gemm_plugin-enable_weight_only] SKIP (https://nvbugs/5289904) From cc11e4417869ced10040a7e401d6a451dcc18927 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Sun, 25 May 2025 23:14:32 -0700 Subject: [PATCH 08/11] update the py_draft_tokens default to an empty list and add some explanatory comments. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 2 +- tensorrt_llm/_torch/pyexecutor/model_engine.py | 10 +++++++++- tensorrt_llm/_torch/pyexecutor/resource_manager.py | 13 +++++++------ tensorrt_llm/_torch/pyexecutor/sampler.py | 6 ++++-- tensorrt_llm/_torch/pyexecutor/scheduler.py | 2 +- 5 files changed, 22 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 799a8867e55..ba865c3c8e9 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -239,7 +239,7 @@ def __init__( self.py_max_new_tokens = self.max_new_tokens self.py_batch_idx = None self.py_rewind_len = 0 - self.py_draft_tokens = self.draft_tokens + self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens self.py_last_draft_tokens = None self.py_decoding_iter = 0 self.is_attention_dp_dummy = False diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index cb329329e03..e9d893e8334 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1087,7 +1087,8 @@ def _prepare_tp_inputs( extend_requests = [] generation_requests = [] for request in scheduled_requests.generation_requests: - if request.py_draft_tokens is not None or next_draft_tokens_device is not None: + if len(request.py_draft_tokens + ) > 0 or next_draft_tokens_device is not None: extend_requests.append(request) else: generation_requests.append(request) @@ -1111,6 +1112,13 @@ def _prepare_tp_inputs( num_extend_reqs_wo_previous_batch = 0 for request in extend_requests: if next_draft_tokens_device is None or request.py_batch_idx is None: + # the request has no previous device tensors: + # (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or + # (2) request.py_batch_idx is None, which means the request has no previous batch. + # the second condition includes dummy generation requests created for CUDA graph padding or + # attention DP. These dummy generation requests should be at the head of generation_requests. + # TODO: move the dummy generation requests to the end of generation_requests to align with + # the logic for those requests in generation_requests. # get token ids, including input token ids and draft token ids input_ids.append(request.get_last_tokens(0)) input_ids.extend(request.py_draft_tokens) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 3f19abce5ff..61d1a0b3786 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -300,15 +300,13 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req_beam_width, req) for _ in range(self.num_extra_kv_tokens): self.impl.add_token(req.py_request_id) - if req.py_draft_tokens is not None: - for _ in range(len(req.py_draft_tokens)): - self.impl.add_token(req.py_request_id) + for _ in range(len(req.py_draft_tokens)): + self.impl.add_token(req.py_request_id) for req in generation_batch: self.impl.add_token(req.py_request_id) - if req.py_draft_tokens is not None: - for _ in range(len(req.py_draft_tokens)): - self.impl.add_token(req.py_request_id) + for _ in range(len(req.py_draft_tokens)): + self.impl.add_token(req.py_request_id) def add_dummy_requests( self, @@ -328,6 +326,9 @@ def add_dummy_requests( requests = [] for i, req_id in enumerate(request_ids): sampling_params = SamplingParams() + # Here 1+max_num_draft_tokens is used to extend the prompt length to + # a non-zero number to skip illegal memory access issue in MLA kernel + # during warmup. token_num = token_nums[ i] if token_nums is not None else 1 + max_num_draft_tokens encoder_input_tokens = [ diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 69f7dfd6451..3483e75c8b0 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -297,7 +297,7 @@ def handle_logits(request: LlmRequest, tokens: list[int], count=1): extend_requests = [] generation_requests = [] for request in scheduled_requests.generation_requests: - if request.py_draft_tokens is not None: + if len(request.py_draft_tokens) > 0: extend_requests.append(request) else: generation_requests.append(request) @@ -361,7 +361,9 @@ def _mixed_sample(self, scheduled_requests: ScheduledRequests, for request in scheduled_requests.generation_requests: if request.state == LlmRequestState.GENERATION_COMPLETE: continue - assert request.py_draft_tokens is None, "Speculative decoding not supported in SeparateDecoder." + assert len( + request.py_draft_tokens + ) == 0, "Speculative decoding not supported in SeparateDecoder." token_logits = logits[idx:idx + 1, :] new_token, probs = decode_single_request(request, token_logits) new_tokens_device_array.append(new_token) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 80f36c4dfe9..beb25f1afdd 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -184,7 +184,7 @@ def schedule( self, active_requests: RequestList, inflight_request_ids: set[int] ) -> tuple[list[LlmRequest], list[LlmRequest]]: for request in active_requests: - if request.py_draft_tokens is not None: + if len(request.py_draft_tokens) > 0: request.draft_tokens = request.py_draft_tokens return self.impl(active_requests, inflight_request_ids, self.max_batch_size, self.max_num_tokens) From 7b974aaddd6d94f635dff8c15c59a0d366e6c98c Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 26 May 2025 06:03:50 -0700 Subject: [PATCH 09/11] fix py_draft_tokens in dis-agg. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 88a1f6b77fe..baaac5140de 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1524,7 +1524,8 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch): req.decoding_iter = 1 req.py_decoding_iter = 1 first_gen_tokens = req.context_phase_params.first_gen_tokens - req.py_draft_tokens = req.context_phase_params.draft_tokens + ctx_draft_tokens = req.context_phase_params.draft_tokens + req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens beam_width = req.sampling_config.beam_width for beam in range(0, beam_width): req.add_new_token(first_gen_tokens[beam], beam) From 2511aec80340494dfb9cf3a2796edf9010be3b62 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Tue, 27 May 2025 17:47:28 -0700 Subject: [PATCH 10/11] waive test_fp8_block_scales tests due to unstable errors. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 02492fff569..5bb87137317 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -394,6 +394,12 @@ triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-deco triton_server/test_triton.py::test_qwen2_vl[qwen2_vl] SKIP triton_server/test_triton.py::test_gpt_ib_speculative_decoding_bls[gpt-ib-speculative-decoding-bls] SKIP triton_server/test_triton_llm.py::test_mistral_v1_multi_models[False-1---False-True-False-0-128-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--max_utilization-4096--1-1-1-False-ensemble] SKIP +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5286795) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5286795) examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int4-float16] SKIP (https://nvbugs/5289523) examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int8-float16] SKIP (https://nvbugs/5289523) examples/test_qwen.py::test_llm_qwen_7b_int8_kv_1node_1gpus[qwen2_vl_7b_instruct-enable_gemm_plugin-enable_weight_only] SKIP (https://nvbugs/5289904) From bead11cdcbfd0cab368bb30102670bf872288cd1 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Wed, 28 May 2025 20:10:20 -0700 Subject: [PATCH 11/11] waive test_bfloat16 tests due to unstable errors in post-merge. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 5bb87137317..825f2b0d2e6 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -400,6 +400,17 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5286795) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5286795) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965) +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5285965) examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int4-float16] SKIP (https://nvbugs/5289523) examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int8-float16] SKIP (https://nvbugs/5289523) examples/test_qwen.py::test_llm_qwen_7b_int8_kv_1node_1gpus[qwen2_vl_7b_instruct-enable_gemm_plugin-enable_weight_only] SKIP (https://nvbugs/5289904)