Skip to content

Commit 8a619be

Browse files
authored
[None] [chore] Make disagg example compatible with recommended usage (#7121)
Signed-off-by: Kaiyu Xie <[email protected]>
1 parent 7cfa475 commit 8a619be

File tree

8 files changed

+531
-516
lines changed

8 files changed

+531
-516
lines changed

examples/disaggregated/slurm/benchmark/README.md

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ This directory contains scripts to run disaggregated inference benchmarks using
44

55
## Overview
66

7-
The benchmarking process is orchestrated through a set of shell scripts and a Python script that work together:
7+
The benchmarking process is orchestrated through a set of shell scripts and Python scripts that work together:
88

99
1. `submit.sh`: The main entry point for submitting benchmark jobs to SLURM. It runs a parameter sweep by calling `sbatch` with different configurations.
10-
2. `disaggr_torch.slurm`: The SLURM script that sets up and runs a single benchmark experiment. It launches a container, generates a configuration file, starts the server and workers, and runs the benchmark client.
11-
3. `gen_yaml.py`: A Python script that generates the `config.yaml` file needed by `trtllm-serve`. It determines the server and worker configuration based on SLURM environment variables and script arguments.
12-
4. `start_worker.sh`: A shell script responsible for starting a `trtllm-serve disaggregated_mpi_worker` on each allocated machine.
13-
5. `run_benchmark.sh`: A shell script that waits for the server to be healthy and then runs the actual benchmark client (`run_benchmark.py`, not included in this directory).
10+
2. `disaggr_torch.slurm`: The SLURM script that sets up and runs a single benchmark experiment. It launches a container, generates configuration files, starts the server and workers, and runs the benchmark client.
11+
3. `gen_worker_config.py`: A Python script that generates the worker configuration YAML file needed by `trtllm-serve`. It determines the worker configuration based on SLURM environment variables and script arguments.
12+
4. `gen_server_config.py`: A Python script that generates the server configuration YAML file needed by `trtllm-serve`. It determines the server configuration based on the number of context and generation servers.
13+
5. `start_worker.sh`: A shell script responsible for starting disaggregated workers using `trtllm-serve` on each allocated machine.
14+
6. `start_server.sh`: A shell script responsible for starting disaggregated server using `trtllm-serve` on each allocated machine.
15+
7. `run_benchmark.sh`: A shell script that waits for the server to be healthy and then runs the actual benchmark client (`run_benchmark.py`, not included in this directory).
1416

1517
## File Descriptions
1618

@@ -58,28 +60,52 @@ It takes the following arguments in order:
5860
24. `model_dir`: Model directory path.
5961
25. `trtllm_repo`: TensorRT-LLM repository path.
6062

61-
### `gen_yaml.py`
63+
### `gen_worker_config.py`
6264

63-
This Python script generates the `config.yaml` file that configures the `trtllm-serve` application. It reads SLURM environment variables (`SLURM_JOB_NODELIST`, `SLURM_TASKS_PER_NODE`) to distribute workers across nodes.
65+
This Python script generates the worker configuration YAML file that configures the `trtllm-serve` workers. It creates separate configurations for context and generation workers with different tensor parallelism, batch sizes, and other parameters.
6466

6567
**Usage:**
6668

67-
The script is called from within `disaggr_torch.slurm`. It takes numerous arguments to define the model, parallelism, and server configurations.
69+
The script is called from within `disaggr_torch.slurm`. It takes numerous arguments to define the model, parallelism, and worker configurations for both context and generation phases.
70+
71+
### `gen_server_config.py`
72+
73+
This Python script generates the server configuration YAML file that configures the `trtllm-serve` disaggregated server. It reads hostname information from the work directory and creates a configuration that specifies the URLs for context and generation servers.
74+
75+
**Usage:**
76+
77+
The script is called from within `start_server.sh`. It takes arguments for the number of context and generation servers and the work directory.
6878

6979
### `start_worker.sh`
7080

7181
This script starts a `trtllm-serve disaggregated_mpi_worker`. It is launched by `srun` from the `disaggr_torch.slurm` script on all allocated nodes.
7282

7383
**Arguments:**
7484

75-
1. `config_file`: Path to the `config.yaml` file.
76-
2. `enable_pdl`: `true` or `false`.
77-
3. `ctx_gpus`: Number of GPUs used for the context phase.
78-
4. `work_dir`: (Optional) Directory to store nsys profiling output.
85+
1. `worker_type`: Either "CTX" or "GEN" to specify the worker type.
86+
2. `worker_index`: Index of the worker instance.
87+
3. `model_dir`: Path to the model directory.
88+
4. `worker_port`: Port for the worker to listen on.
89+
5. `benchmark_mode`: Benchmark mode setting.
90+
6. `concurrency`: Concurrency level.
91+
7. `enable_pdl`: `true` or `false`.
92+
8. `work_dir`: Work directory for logs and configuration.
93+
9. `nsys_on`: Whether to enable nsys profiling.
94+
95+
### `start_server.sh`
96+
97+
This script starts the `trtllm-serve disaggregated` server. It first generates the server configuration using `gen_server_config.py`, then starts the server process.
98+
99+
**Arguments:**
100+
101+
1. `num_ctx_servers`: Number of context servers.
102+
2. `num_gen_servers`: Number of generation servers.
103+
3. `work_dir`: Work directory for logs and configuration.
104+
4. `script_dir`: Directory containing the scripts.
79105

80106
### `run_benchmark.sh`
81107

82-
This script orchestrates the execution of the benchmark client. It waits for the `config.yaml` to be created and for the server's `/health` endpoint to respond, then it runs the benchmark.
108+
This script orchestrates the execution of the benchmark client. It waits for the configuration files to be created and for the server's `/health` endpoint to respond, then it runs the benchmark.
83109

84110
**Arguments:**
85111

@@ -97,9 +123,9 @@ This script orchestrates the execution of the benchmark client. It waits for the
97123
2. The user runs `./submit.sh`.
98124
3. `submit.sh` submits one or more jobs to SLURM by calling `sbatch disaggr_torch.slurm` with different parameters.
99125
4. For each job, SLURM allocates resources and runs `disaggr_torch.slurm`.
100-
5. `disaggr_torch.slurm` runs `gen_yaml.py` to create a `config.yaml`.
101-
6. `disaggr_torch.slurm` uses `srun` to launch `start_worker.sh` on all nodes, starting the MPI workers.
102-
7. `disaggr_torch.slurm` starts the main `trtllm-serve` process.
126+
5. `disaggr_torch.slurm` runs `gen_worker_config.py` to create worker configuration files.
127+
6. `disaggr_torch.slurm` uses `srun` to launch `start_worker.sh` on all nodes, starting the MPI workers for both context and generation phases.
128+
7. `disaggr_torch.slurm` starts the main `trtllm-serve` process using `start_server.sh`, which generates the server configuration using `gen_server_config.py`.
103129
8. `disaggr_torch.slurm` runs `run_benchmark.sh` which waits for the server to be ready.
104130
9. `run_benchmark.sh` executes the benchmark for each concurrency level specified.
105131
10. After the benchmark, `run_benchmark.sh` and `disaggr_torch.slurm` attempt to kill the server and worker processes.

examples/disaggregated/slurm/benchmark/disaggr_torch.slurm

Lines changed: 87 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#SBATCH --job-name=${job_name} # add your job name here or specify in the sbatch command
88
#SBATCH --time=02:00:00
99

10+
set -u
11+
set -e
12+
set -x
13+
1014
# Context servers arguments
1115
num_ctx_servers=${1}
1216
ctx_tp_size=${2}
@@ -42,7 +46,10 @@ mounts=${23}
4246
workdir=${24}
4347
model_dir=${25}
4448
benchmark_mode=${26}
45-
trtllm_repo=${27}
49+
trtllm_repo=${27:-""}
50+
51+
# Get GPUs per node dynamically from SLURM
52+
ntasks_per_node=${SLURM_NTASKS_PER_NODE:-4} # Default to 4 for GB200
4653

4754
echo "================= parameters ================="
4855
echo "num_ctx_servers: ${num_ctx_servers}"
@@ -72,6 +79,7 @@ echo "workdir: ${workdir}"
7279
echo "model_dir: ${model_dir}"
7380
echo "benchmark_mode: ${benchmark_mode}"
7481
echo "trtllm_repo: ${trtllm_repo}"
82+
echo "ntasks_per_node: ${ntasks_per_node}"
7583
echo "==========================================="
7684

7785

@@ -80,8 +88,8 @@ gen_max_seq_len=$((isl + osl))
8088
ctx_gpu_frac=${ctx_gpu_memory_fraction}
8189
cache_transceiver_max_num_tokens=8448
8290

83-
container_name=disaggr
84-
logdir=${workdir}/benchmark-${isl}-${osl}
91+
container_name=disaggregated_serving
92+
logdir=${workdir}/slurm-${SLURM_JOB_ID}/benchmark-${isl}-${osl}
8593
mkdir -p ${logdir}
8694
full_logdir=${logdir}/ctx${num_ctx_servers}_gen${num_gen_servers}_dep${gen_tp_size}_batch${gen_batch_size}_eplb${eplb_num_slots}_mtp${mtp_size}
8795

@@ -107,13 +115,14 @@ if [ "${benchmark_mode}" != "gen_only" ] && [ "${benchmark_mode}" != "e2e" ]; th
107115
benchmark_mode="e2e"
108116
fi
109117

110-
if [ -z "${TRT_LLM_GIT_COMMIT}" ]; then
118+
if [ -z "${TRT_LLM_GIT_COMMIT:-}" ]; then
111119
export TRT_LLM_GIT_COMMIT=$(git -C ${trtllm_repo} rev-parse --short HEAD 2>/dev/null || echo "unknown")
112120
echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}"
113121
fi
114122

115123
nsys_on=""
116124
# nsys_on=${full_logdir} # Uncomment this line to enable Nsys profiling
125+
117126
# start the container
118127
srun -l --container-image=${container_image} \
119128
--container-name=${container_name} \
@@ -128,60 +137,92 @@ if [ -n "${trtllm_repo}" ]; then
128137
bash -c "cd ${trtllm_repo} && echo 'Running install operation...' && pip install -e . " 2>&1 | tee ${full_logdir}/install.log
129138
fi
130139

131-
# generate the yaml file
132-
srun -l --container-name=${container_name} \
140+
echo "Generating YAML file for workers."
141+
srun -l -N 1 -n 1 \
142+
--container-name=${container_name} \
133143
--container-mounts=${mounts} \
134144
--mpi=pmix --overlap \
135-
python3 ${workdir}/gen_yaml.py --config ${full_logdir}/config.yaml \
136-
--model ${model_dir} \
137-
--num_ctx_servers ${num_ctx_servers} \
138-
--ctx_tp_size ${ctx_tp_size} \
139-
--ctx_pp_size ${ctx_pp_size} \
140-
--ctx_batch_size ${ctx_batch_size} \
141-
--ctx_max_num_tokens ${ctx_max_num_tokens} \
142-
--ctx_max_seq_len ${ctx_max_seq_len} \
143-
--ctx_free_gpu_memory_fraction ${ctx_gpu_frac} \
144-
--cache_transceiver_max_num_tokens ${cache_transceiver_max_num_tokens} \
145-
--num_gen_servers ${num_gen_servers} \
146-
--gen_tp_size ${gen_tp_size} \
147-
--gen_pp_size ${gen_pp_size} \
148-
--gen_batch_size ${gen_batch_size} \
149-
--gen_max_num_tokens ${gen_max_num_tokens} \
150-
--gen_max_seq_len ${gen_max_seq_len} \
151-
--gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \
152-
--eplb_num_slots ${eplb_num_slots} \
153-
$(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \
154-
$(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi) \
155-
$(if [ "${mtp_size}" -gt 0 ]; then echo "--mtp_size ${mtp_size}"; fi)
145+
python3 ${workdir}/gen_worker_config.py \
146+
--work_dir ${full_logdir} \
147+
--ctx_tp_size ${ctx_tp_size} \
148+
--ctx_pp_size ${ctx_pp_size} \
149+
--ctx_batch_size ${ctx_batch_size} \
150+
--ctx_max_num_tokens ${ctx_max_num_tokens} \
151+
--ctx_max_seq_len ${ctx_max_seq_len} \
152+
--ctx_free_gpu_memory_fraction ${ctx_gpu_frac} \
153+
--gen_tp_size ${gen_tp_size} \
154+
--gen_pp_size ${gen_pp_size} \
155+
--gen_batch_size ${gen_batch_size} \
156+
--gen_max_num_tokens ${gen_max_num_tokens} \
157+
--gen_max_seq_len ${gen_max_seq_len} \
158+
--gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \
159+
--eplb_num_slots ${eplb_num_slots} \
160+
--mtp_size ${mtp_size} \
161+
--cache_transceiver_max_num_tokens ${cache_transceiver_max_num_tokens} \
162+
$(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi) \
163+
$(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \
164+
2>&1 | tee ${full_logdir}/gen_worker_config.log
156165

157166
echo "YAML file generated."
158167

159-
hostname_value=$(grep '^hostname:' ${full_logdir}/config.yaml | awk -F': ' '{print $2}')
160-
echo "server host name: $hostname_value"
168+
ctx_nodes_num=$(((ctx_tp_size + ntasks_per_node - 1) / ntasks_per_node))
169+
gen_nodes_num=$(((gen_tp_size + ntasks_per_node - 1) / ntasks_per_node))
161170

171+
all_nodes=($(scontrol show hostname $SLURM_NODELIST | sort))
172+
total_nodes_num=${#all_nodes[@]}
173+
echo "all_nodes: ${all_nodes[@]}, total_nodes_num: ${total_nodes_num}"
162174

163-
# start the workers
164-
srun -l --container-name=${container_name} \
175+
# get the node list for the gen workers
176+
total_gen_nodes_num=$((gen_nodes_num * num_gen_servers))
177+
gen_nodes=(${all_nodes[@]:0:${total_gen_nodes_num}})
178+
echo "gen_nodes: ${gen_nodes[@]}, total_gen_nodes_num: ${total_gen_nodes_num}"
179+
180+
# get the node list for the ctx workers
181+
total_ctx_nodes_num=$((ctx_nodes_num * num_ctx_servers))
182+
ctx_nodes=(${all_nodes[@]:${total_gen_nodes_num}:${total_nodes_num}})
183+
echo "ctx_nodes: ${ctx_nodes[@]}, total_ctx_nodes_num: ${total_ctx_nodes_num}"
184+
185+
rm -rf ${full_logdir}/hostnames
186+
187+
# start the gen workers
188+
for i in $(seq 0 $((num_gen_servers - 1))); do
189+
srun -l -N ${gen_nodes_num} \
190+
--ntasks=${gen_tp_size} \
191+
--ntasks-per-node=${ntasks_per_node} \
192+
--container-image=${container_image} \
193+
--container-name=${container_name} \
165194
--container-mounts=${mounts} \
166-
--mpi=pmix --overlap \
167-
bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${benchmark_mode} ${concurrency} ${nsys_on} &> ${full_logdir}/output_workers.log &
195+
--mpi=pmix \
196+
bash ${workdir}/start_worker.sh "GEN" ${i} ${model_dir} "8336" ${benchmark_mode} ${concurrency} ${enable_pdl} ${full_logdir} ${nsys_on} \
197+
&> ${full_logdir}/output_gen_${i}.log &
198+
done
199+
200+
# start the ctx workers
201+
for i in $(seq 0 $((num_ctx_servers - 1))); do
202+
srun -l -N ${ctx_nodes_num} \
203+
--ntasks=${ctx_tp_size} \
204+
--ntasks-per-node=${ntasks_per_node} \
205+
--container-image=${container_image} \
206+
--container-name=${container_name} \
207+
--container-mounts=${mounts} \
208+
--mpi=pmix \
209+
bash ${workdir}/start_worker.sh "CTX" ${i} ${model_dir} "8336" ${benchmark_mode} ${concurrency} ${enable_pdl} ${full_logdir} ${nsys_on} \
210+
&> ${full_logdir}/output_ctx_${i}.log &
211+
done
168212

169213
# start the server
170214
srun -l --container-name=${container_name} \
171-
--container-mounts=${mounts} \
172-
--mpi=pmix --overlap -N 1 -n 1 \
173-
-w ${hostname_value} \
174-
bash ${workdir}/start_server.sh ${full_logdir}/config.yaml &> ${full_logdir}/output_server.log &
215+
--container-image=${container_image} \
216+
--container-mounts=${mounts} \
217+
--mpi=pmix --overlap -N 1 -n 1 \
218+
bash ${workdir}/start_server.sh ${num_ctx_servers} ${num_gen_servers} ${full_logdir} ${workdir} \
219+
&> ${full_logdir}/output_server.log &
175220

176221
# start benchmarking
177222
srun -l --container-name=${container_name} \
178-
--container-mounts=${mounts} \
179-
--mpi=pmix --overlap -N 1 -n 1 \
180-
bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir} > ${full_logdir}/benchmark.log 2>&1
223+
--container-mounts=${mounts} \
224+
--mpi=pmix --overlap -N 1 -n 1 \
225+
bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir} \
226+
&> ${full_logdir}/benchmark.log 2>&1
181227

182-
# try to kill the server and workers
183-
srun -l --container-name=${container_name} \
184-
--container-mounts=${mounts} \
185-
--mpi=pmix --overlap \
186-
kill -9 $(ps aux | grep '[t]rtllm-serve' | awk '{print $2}') >/dev/null 2>&1 || true
187-
wait
228+
scancel ${SLURM_JOB_ID}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import argparse
2+
import os
3+
import socket
4+
import time
5+
6+
import yaml
7+
8+
if __name__ == "__main__":
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument("--num_ctx_servers",
11+
type=int,
12+
required=True,
13+
help="Number of context servers")
14+
parser.add_argument("--num_gen_servers",
15+
type=int,
16+
required=True,
17+
help="Number of generation servers")
18+
parser.add_argument("--work_dir",
19+
type=str,
20+
default="logs",
21+
help="Work directory")
22+
parser.add_argument("--worker_port",
23+
type=int,
24+
default=8336,
25+
help="Worker port")
26+
parser.add_argument("--server_port",
27+
type=int,
28+
default=8333,
29+
help="Server port")
30+
args = parser.parse_args()
31+
32+
# check if the work_dir exists
33+
if not os.path.exists(args.work_dir):
34+
raise ValueError(f"Work directory {args.work_dir} not found")
35+
36+
#check all of the hostnames in the hostnames folder exists, if not, sleep 10 seconds and check again
37+
hostnames_folder = os.path.join(args.work_dir, "hostnames")
38+
while not os.path.exists(hostnames_folder):
39+
time.sleep(10)
40+
print(f"Waiting for hostnames folder {hostnames_folder} to be found")
41+
hostnames = os.listdir(hostnames_folder)
42+
# check length of hostnames is equal to num_ctx_servers + num_gen_servers, if not, sleep 10 seconds and check again
43+
while len(hostnames) != args.num_ctx_servers + args.num_gen_servers:
44+
time.sleep(10)
45+
hostnames = os.listdir(hostnames_folder)
46+
print(
47+
f"Waiting for hostnames to be found in {hostnames_folder}, current length: {len(hostnames)}, expected length: {args.num_ctx_servers + args.num_gen_servers}"
48+
)
49+
print(f"All hostnames found in {hostnames_folder}")
50+
51+
# get the ctx and gen hostnames from the hostnames file
52+
ctx_hostnames = []
53+
gen_hostnames = []
54+
for hostname_file in hostnames:
55+
hostname_file_path = os.path.join(hostnames_folder, hostname_file)
56+
with open(hostname_file_path, 'r') as f:
57+
actual_hostname = f.read().strip()
58+
print(f"Hostname: {actual_hostname} in {hostname_file}")
59+
60+
if hostname_file.startswith("CTX"):
61+
ctx_hostnames.append(actual_hostname)
62+
elif hostname_file.startswith("GEN"):
63+
gen_hostnames.append(actual_hostname)
64+
65+
print(f"ctx_hostnames: {ctx_hostnames}")
66+
print(f"gen_hostnames: {gen_hostnames}")
67+
68+
# get current hostname from env
69+
hostname = socket.gethostname()
70+
print(f"Current hostname: {hostname}")
71+
72+
server_config = {
73+
'hostname': hostname,
74+
'port': args.server_port,
75+
'backend': 'pytorch',
76+
'context_servers': {
77+
'num_instances': args.num_ctx_servers,
78+
'urls': [f'{host}:{args.worker_port}' for host in ctx_hostnames]
79+
},
80+
'generation_servers': {
81+
'num_instances': args.num_gen_servers,
82+
'urls': [f'{host}:{args.worker_port}' for host in gen_hostnames]
83+
}
84+
}
85+
86+
with open(os.path.join(args.work_dir, "server_config.yaml"), "w") as f:
87+
yaml.dump(server_config, f)
88+
print(
89+
f"Server config file {os.path.join(args.work_dir, 'server_config.yaml')} generated"
90+
)

0 commit comments

Comments
 (0)