From 35e3f0ba153f7f87dca8c1c260861ce177e70d03 Mon Sep 17 00:00:00 2001 From: yifeizhang-c <219273404+yifeizhang-c@users.noreply.github.com> Date: Sat, 16 Aug 2025 00:52:06 +0800 Subject: [PATCH 1/2] [https://nvbugs/5394392][fix] Enlarge scheduler capacity under disagg bs == 1 (#6537) Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com> --- .../batch_manager/assignReqSeqSlots.cpp | 5 +++ tensorrt_llm/_torch/pyexecutor/_util.py | 8 +++- tensorrt_llm/_torch/pyexecutor/py_executor.py | 4 +- .../_torch/pyexecutor/seq_slot_manager.py | 5 +++ .../disagg_config_gen_only_bs1.yaml | 37 +++++++++++++++++++ .../defs/disaggregated/test_disaggregated.py | 25 +++++++++++++ .../test_lists/test-db/l0_dgx_h100.yml | 1 + 7 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_bs1.yaml diff --git a/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp b/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp index d25572ad5a6..514d100fe58 100644 --- a/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp +++ b/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp @@ -30,6 +30,11 @@ void tensorrt_llm::batch_manager::AssignReqSeqSlots::operator()(SequenceSlotMana { for (auto const& llmReq : requests) { + if (llmReq->isDisaggGenerationInitState()) + { + // Skip assigning sequence slot for DISAGG_GENERATION_INIT request + continue; + } auto const isReqNew = (llmReq->isContextInitState() && !llmReq->mSeqSlot) || (llmReq->isDisaggGenerationTransmissionComplete()); if (isReqNew && llmReq->getReturnPerfMetrics()) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 21fa9f91c1d..37f8e0410b0 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -524,8 +524,14 @@ def create_py_executor_instance( resource_manager.resource_managers.move_to_end( ResourceManagerType.KV_CACHE_MANAGER, last=True) + # When scheduler_capacity == 1, attention dp dummy request will prevent the scheduling of DISAGG_GENERATION_INIT. + # Enlarge scheduler capacity to avoid DISAGG_GENERATION_INIT stuck in the scheduler. + scheduler_capacity = max_num_sequences + if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager: + scheduler_capacity += 1 + capacity_scheduler = BindCapacityScheduler( - max_num_sequences, + scheduler_capacity, kv_cache_manager.impl if kv_cache_manager is not None else None, peft_cache_manager.impl if peft_cache_manager is not None else None, executor_config.scheduler_config.capacity_scheduler_policy, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index d87dbef4e7d..4b7df3d9491 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1287,7 +1287,6 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests): for resource_mgr_type in ( ResourceManagerType.KV_CACHE_MANAGER, - ResourceManagerType.SEQ_SLOT_MANAGER, ResourceManagerType.SPEC_RESOURCE_MANAGER, ResourceManagerType.DRAFT_KV_CACHE_MANAGER): if (resource_mgr_type in self.resource_manager.resource_managers @@ -1307,6 +1306,9 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch): if req.is_disagg_generation_transmission_complete: cache_trans_complete_requests.append(req) if len(cache_trans_complete_requests) > 0: + self.resource_manager.resource_managers[ + ResourceManagerType.SEQ_SLOT_MANAGER].prepare_resources( + cache_trans_complete_requests) self._setup_sampler_step(cache_trans_complete_requests) for req in scheduled_batch.generation_requests: diff --git a/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py b/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py index c43c9726412..a3f11e56423 100644 --- a/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py @@ -16,6 +16,11 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int: def prepare_resources(self, scheduled_batch: ScheduledRequests) -> None: for llm_req in scheduled_batch.all_requests(): + if llm_req.is_disagg_generation_init_state: + logger.info( + f"Skip assigning sequence slot for DISAGG_GENERATION_INIT request." + ) + continue if llm_req.seq_slot is None or llm_req.is_disagg_generation_transmission_complete: llm_req.seq_slot = self.slot_manager.add_slot( llm_req.request_id) diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_bs1.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_bs1.yaml new file mode 100644 index 00000000000..4efbc9a9493 --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_bs1.yaml @@ -0,0 +1,37 @@ +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +hostname: localhost +port: 8000 +backend: "pytorch" +cuda_graph_config: null +free_gpu_memory_fraction: 0.2 +context_servers: + num_instances: 1 + max_batch_size: 1 + max_num_tokens: 3000 + max_seq_len: 4096 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + enable_attention_dp: true + kv_cache_config: + free_gpu_memory_fraction: 0.2 + enable_partial_reuse: False + disable_overlap_scheduler: True + cache_transceiver_config: + backend: default + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + enable_attention_dp: true + max_batch_size: 1 + max_num_tokens: 4096 + max_seq_len: 4096 + kv_cache_config: + free_gpu_memory_fraction: 0.2 + enable_partial_reuse: False + cache_transceiver_config: + backend: default + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index a02d5a1a16c..c72152bc357 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -47,6 +47,8 @@ def get_test_config(test_desc, example_dir, test_root): "gen_only": (2, f"{test_configs_root}/disagg_config_gen_only.yaml"), "gen_only_trt_backend": (2, f"{test_configs_root}/disagg_config_gen_only_trt_backend.yaml"), + "gen_only_bs1": + (4, f"{test_configs_root}/disagg_config_gen_only_bs1.yaml"), "4_ranks": (4, f"{test_configs_root}/disagg_config_ctxtp2_gentp1.yaml"), "4_ranks_trt_backend": (4, @@ -387,6 +389,29 @@ def test_disaggregated_benchmark_gen_only_trt_backend( cwd=llm_venv.get_working_directory()) +@pytest.mark.skip_less_device(4) +@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], + indirect=True) +def test_disaggregated_genbs1(disaggregated_test_root, + disaggregated_example_root, llm_venv, + llama_model_root): + src_dst_dict = { + llama_model_root: + f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + env = llm_venv._new_env.copy() + env['TRTLLM_DISAGG_BENCHMARK_GEN_ONLY'] = '1' + run_disaggregated_test(disaggregated_example_root, + "gen_only_bs1", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory()) + + @pytest.mark.skip_less_device(2) @pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], indirect=True) diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index e7ec5539776..f65ef799622 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -33,6 +33,7 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2_genpp2[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0] + - disaggregated/test_disaggregated.py::test_disaggregated_genbs1[TinyLlama-1.1B-Chat-v1.0] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] From 8f1e34094c26b580d34988fabfbc64b4d846a279 Mon Sep 17 00:00:00 2001 From: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com> Date: Mon, 18 Aug 2025 01:18:25 -0700 Subject: [PATCH 2/2] fix Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 4b7df3d9491..7d9b247c945 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1306,9 +1306,11 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch): if req.is_disagg_generation_transmission_complete: cache_trans_complete_requests.append(req) if len(cache_trans_complete_requests) > 0: + trans_complete_to_prepare = ScheduledRequests() + trans_complete_to_prepare.context_requests = cache_trans_complete_requests self.resource_manager.resource_managers[ ResourceManagerType.SEQ_SLOT_MANAGER].prepare_resources( - cache_trans_complete_requests) + trans_complete_to_prepare) self._setup_sampler_step(cache_trans_complete_requests) for req in scheduled_batch.generation_requests: