Skip to content

Commit 35e3f0b

Browse files
committed
[https://nvbugs/5394392][fix] Enlarge scheduler capacity under disagg bs == 1 (#6537)
Signed-off-by: Yifei Zhang <[email protected]>
1 parent 7f7a301 commit 35e3f0b

File tree

7 files changed

+83
-2
lines changed

7 files changed

+83
-2
lines changed

cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ void tensorrt_llm::batch_manager::AssignReqSeqSlots::operator()(SequenceSlotMana
3030
{
3131
for (auto const& llmReq : requests)
3232
{
33+
if (llmReq->isDisaggGenerationInitState())
34+
{
35+
// Skip assigning sequence slot for DISAGG_GENERATION_INIT request
36+
continue;
37+
}
3338
auto const isReqNew = (llmReq->isContextInitState() && !llmReq->mSeqSlot)
3439
|| (llmReq->isDisaggGenerationTransmissionComplete());
3540
if (isReqNew && llmReq->getReturnPerfMetrics())

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,14 @@ def create_py_executor_instance(
524524
resource_manager.resource_managers.move_to_end(
525525
ResourceManagerType.KV_CACHE_MANAGER, last=True)
526526

527+
# When scheduler_capacity == 1, attention dp dummy request will prevent the scheduling of DISAGG_GENERATION_INIT.
528+
# Enlarge scheduler capacity to avoid DISAGG_GENERATION_INIT stuck in the scheduler.
529+
scheduler_capacity = max_num_sequences
530+
if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager:
531+
scheduler_capacity += 1
532+
527533
capacity_scheduler = BindCapacityScheduler(
528-
max_num_sequences,
534+
scheduler_capacity,
529535
kv_cache_manager.impl if kv_cache_manager is not None else None,
530536
peft_cache_manager.impl if peft_cache_manager is not None else None,
531537
executor_config.scheduler_config.capacity_scheduler_policy,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,6 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
12871287

12881288
for resource_mgr_type in (
12891289
ResourceManagerType.KV_CACHE_MANAGER,
1290-
ResourceManagerType.SEQ_SLOT_MANAGER,
12911290
ResourceManagerType.SPEC_RESOURCE_MANAGER,
12921291
ResourceManagerType.DRAFT_KV_CACHE_MANAGER):
12931292
if (resource_mgr_type in self.resource_manager.resource_managers
@@ -1307,6 +1306,9 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
13071306
if req.is_disagg_generation_transmission_complete:
13081307
cache_trans_complete_requests.append(req)
13091308
if len(cache_trans_complete_requests) > 0:
1309+
self.resource_manager.resource_managers[
1310+
ResourceManagerType.SEQ_SLOT_MANAGER].prepare_resources(
1311+
cache_trans_complete_requests)
13101312
self._setup_sampler_step(cache_trans_complete_requests)
13111313

13121314
for req in scheduled_batch.generation_requests:

tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
1616

1717
def prepare_resources(self, scheduled_batch: ScheduledRequests) -> None:
1818
for llm_req in scheduled_batch.all_requests():
19+
if llm_req.is_disagg_generation_init_state:
20+
logger.info(
21+
f"Skip assigning sequence slot for DISAGG_GENERATION_INIT request."
22+
)
23+
continue
1924
if llm_req.seq_slot is None or llm_req.is_disagg_generation_transmission_complete:
2025
llm_req.seq_slot = self.slot_manager.add_slot(
2126
llm_req.request_id)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
2+
hostname: localhost
3+
port: 8000
4+
backend: "pytorch"
5+
cuda_graph_config: null
6+
free_gpu_memory_fraction: 0.2
7+
context_servers:
8+
num_instances: 1
9+
max_batch_size: 1
10+
max_num_tokens: 3000
11+
max_seq_len: 4096
12+
tensor_parallel_size: 2
13+
pipeline_parallel_size: 1
14+
enable_attention_dp: true
15+
kv_cache_config:
16+
free_gpu_memory_fraction: 0.2
17+
enable_partial_reuse: False
18+
disable_overlap_scheduler: True
19+
cache_transceiver_config:
20+
backend: default
21+
urls:
22+
- "localhost:8001"
23+
generation_servers:
24+
num_instances: 1
25+
tensor_parallel_size: 2
26+
pipeline_parallel_size: 1
27+
enable_attention_dp: true
28+
max_batch_size: 1
29+
max_num_tokens: 4096
30+
max_seq_len: 4096
31+
kv_cache_config:
32+
free_gpu_memory_fraction: 0.2
33+
enable_partial_reuse: False
34+
cache_transceiver_config:
35+
backend: default
36+
urls:
37+
- "localhost:8002"

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def get_test_config(test_desc, example_dir, test_root):
4747
"gen_only": (2, f"{test_configs_root}/disagg_config_gen_only.yaml"),
4848
"gen_only_trt_backend":
4949
(2, f"{test_configs_root}/disagg_config_gen_only_trt_backend.yaml"),
50+
"gen_only_bs1":
51+
(4, f"{test_configs_root}/disagg_config_gen_only_bs1.yaml"),
5052
"4_ranks": (4, f"{test_configs_root}/disagg_config_ctxtp2_gentp1.yaml"),
5153
"4_ranks_trt_backend":
5254
(4,
@@ -387,6 +389,29 @@ def test_disaggregated_benchmark_gen_only_trt_backend(
387389
cwd=llm_venv.get_working_directory())
388390

389391

392+
@pytest.mark.skip_less_device(4)
393+
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
394+
indirect=True)
395+
def test_disaggregated_genbs1(disaggregated_test_root,
396+
disaggregated_example_root, llm_venv,
397+
llama_model_root):
398+
src_dst_dict = {
399+
llama_model_root:
400+
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
401+
}
402+
for src, dst in src_dst_dict.items():
403+
if not os.path.islink(dst):
404+
os.makedirs(os.path.dirname(dst), exist_ok=True)
405+
os.symlink(src, dst, target_is_directory=True)
406+
407+
env = llm_venv._new_env.copy()
408+
env['TRTLLM_DISAGG_BENCHMARK_GEN_ONLY'] = '1'
409+
run_disaggregated_test(disaggregated_example_root,
410+
"gen_only_bs1",
411+
env=llm_venv._new_env,
412+
cwd=llm_venv.get_working_directory())
413+
414+
390415
@pytest.mark.skip_less_device(2)
391416
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
392417
indirect=True)

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ l0_dgx_h100:
3333
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
3434
- disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
3535
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0]
36+
- disaggregated/test_disaggregated.py::test_disaggregated_genbs1[TinyLlama-1.1B-Chat-v1.0]
3637
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
3738
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
3839
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]

0 commit comments

Comments
 (0)