Skip to content

Commit 935c2c1

Browse files
authored
[None] [fix] Minor fixes to slurm and benchmark scripts (#7453)
Signed-off-by: Kaiyu Xie <[email protected]>
1 parent 14af1f0 commit 935c2c1

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

examples/disaggregated/slurm/benchmark/gen_worker_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def gen_config_file(work_dir: str,
4848
server_port: Server port
4949
"""
5050
ctx_config = {
51+
'build_config': {
52+
'max_batch_size': ctx_batch_size,
53+
'max_num_tokens': ctx_max_num_tokens,
54+
'max_seq_len': ctx_max_seq_len,
55+
},
5156
'max_batch_size': ctx_batch_size,
5257
'max_num_tokens': ctx_max_num_tokens,
5358
'max_seq_len': ctx_max_seq_len,
@@ -79,6 +84,11 @@ def gen_config_file(work_dir: str,
7984
gen_moe_backend = "TRTLLM"
8085

8186
gen_config = {
87+
'build_config': {
88+
'max_batch_size': gen_batch_size,
89+
'max_num_tokens': gen_max_num_tokens,
90+
'max_seq_len': gen_max_seq_len,
91+
},
8292
'tensor_parallel_size': gen_tp_size,
8393
'moe_expert_parallel_size': gen_tp_size,
8494
'enable_attention_dp': True if gen_enable_attention_dp else False,

tensorrt_llm/serve/scripts/benchmark_dataset.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -494,11 +494,14 @@ def sample(
494494

495495
# Filter out sequences that are too long or too short
496496
requests = []
497-
for prompt, initial_prompt_len, cached_token_ids in zip(
498-
dataset, prompt_lengths, prompt_token_ids):
499-
i = len(requests)
500-
if i == num_requests:
501-
break
497+
dataset_len = len(dataset)
498+
499+
for i in range(num_requests):
500+
# Use modulo to cycle through the dataset when num_requests > dataset_len
501+
dataset_idx = i % dataset_len
502+
prompt = dataset[dataset_idx]
503+
initial_prompt_len = prompt_lengths[dataset_idx]
504+
cached_token_ids = prompt_token_ids[dataset_idx]
502505

503506
# Skip empty prompt
504507
if initial_prompt_len == 0:
@@ -534,9 +537,6 @@ def sample(
534537
prompt_len=total_input_len,
535538
expected_output_len=int(output_lens[i]),
536539
))
537-
assert len(requests) == num_requests, (
538-
f"Only {len(requests)} requests sampled from sharegpt dataset, {num_requests} requests are needed"
539-
)
540540
else:
541541
for i in range(num_requests):
542542
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
@@ -1131,6 +1131,7 @@ def sample(
11311131
if parser_fn is None:
11321132
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
11331133

1134+
sampled_requests = []
11341135
for item in self.data:
11351136
if len(prompts) >= num_requests:
11361137
break

0 commit comments

Comments
 (0)