Skip to content

Commit 6fd6704

Browse files
committed
Make disagg example compatible with recommended usage
Signed-off-by: Kaiyu Xie <[email protected]>
1 parent 2d40e87 commit 6fd6704

File tree

6 files changed

+300
-328
lines changed

6 files changed

+300
-328
lines changed

examples/disaggregated/slurm/benchmark/disaggr_torch.slurm

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ fi
110110

111111
nsys_on=""
112112
# nsys_on=${full_logdir} # Uncomment this line to enable Nsys profiling
113+
113114
# start the container
114115
srun -l --container-image=${container_image} \
115116
--container-name=${container_name} \
@@ -124,29 +125,28 @@ if [ -n "${trtllm_repo}" ]; then
124125
bash -c "cd ${trtllm_repo} && echo 'Running install operation...' && pip install -e . " 2>&1 | tee ${full_logdir}/install.log
125126
fi
126127

127-
# generate the yaml file
128-
srun -l --container-name=${container_name} \
128+
echo "Generating YAML file for workers."
129+
srun -l -N 1 -n 1 \
130+
--container-name=${container_name} \
129131
--container-mounts=${mounts} \
130132
--mpi=pmix --overlap \
131-
python3 ${workdir}/gen_yaml.py --config ${full_logdir}/config.yaml \
132-
--model ${model_dir} \
133-
--num_ctx_servers ${num_ctx_servers} \
134-
--ctx_tp_size ${ctx_tp_size} \
135-
--ctx_batch_size ${ctx_batch_size} \
136-
--ctx_max_num_tokens ${ctx_max_num_tokens} \
137-
--ctx_max_seq_len ${ctx_max_seq_len} \
138-
--ctx_free_gpu_memory_fraction ${ctx_gpu_frac} \
139-
--cache_transceiver_max_num_tokens ${cache_transceiver_max_num_tokens} \
140-
--num_gen_servers ${num_gen_servers} \
141-
--gen_tp_size ${gen_tp_size} \
142-
--gen_batch_size ${gen_batch_size} \
143-
--gen_max_num_tokens ${gen_max_num_tokens} \
144-
--gen_max_seq_len ${gen_max_seq_len} \
145-
--gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \
146-
--eplb_num_slots ${eplb_num_slots} \
147-
$(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \
148-
$(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi) \
149-
$(if [ "${mtp_size}" -gt 0 ]; then echo "--mtp_size ${mtp_size}"; fi)
133+
python3 ${workdir}/gen_yaml.py \
134+
--work_dir ${full_logdir} \
135+
--ctx_tp_size ${ctx_tp_size} \
136+
--ctx_batch_size ${ctx_batch_size} \
137+
--ctx_max_num_tokens ${ctx_max_num_tokens} \
138+
--ctx_max_seq_len ${ctx_max_seq_len} \
139+
--ctx_free_gpu_memory_fraction ${ctx_gpu_frac} \
140+
--gen_tp_size ${gen_tp_size} \
141+
--gen_batch_size ${gen_batch_size} \
142+
--gen_max_num_tokens ${gen_max_num_tokens} \
143+
--gen_max_seq_len ${gen_max_seq_len} \
144+
--gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \
145+
--eplb_num_slots ${eplb_num_slots} \
146+
--mtp_size ${mtp_size} \
147+
--cache_transceiver_max_num_tokens ${cache_transceiver_max_num_tokens} \
148+
$(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi) \
149+
$(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi)
150150

151151
echo "YAML file generated."
152152

@@ -155,17 +155,45 @@ echo "server host name: $hostname_value"
155155

156156

157157
# start the workers
158-
srun -l --container-name=${container_name} \
158+
pid_list=""
159+
160+
# start the ctx workers
161+
for i in $(seq 0 $((num_ctx_servers - 1))); do
162+
srun -l -N ${ctx_nodes_num} \
163+
--ntasks=${ctx_tp_size} \
164+
--ntasks-per-node=${gpus_per_node} \
165+
--segment=${ctx_nodes_num} \
166+
--container-image=${container_image} \
167+
--container-name=${container_name}_ctx_${i} \
159168
--container-mounts=${mounts} \
160-
--mpi=pmix --overlap \
161-
bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${benchmark_mode} ${concurrency} ${nsys_on} &> ${full_logdir}/output_workers.log &
169+
--mpi=pmix \
170+
bash ${work_dir}/start_worker.sh "CTX" ${i} ${model_path} "8336" ${benchmark_mode} ${concurrency} ${enable_pdl} ${full_logdir} ${nsys_folder} \
171+
&> ${full_logdir}/output_ctx_${i}.log &
172+
pid_list="${pid_list} $!"
173+
done
174+
175+
# start the gen workers
176+
for i in $(seq 0 $((num_gen_servers - 1))); do
177+
srun -l -N ${gen_nodes_num} \
178+
--ntasks=${gen_tp_size} \
179+
--ntasks-per-node=${gpus_per_node} \
180+
--segment=${gen_nodes_num} \
181+
--container-image=${container_image} \
182+
--container-name=${container_name}_gen_${i} \
183+
--container-mounts=${mounts} \
184+
--mpi=pmix \
185+
bash ${workdir}/start_worker.sh "GEN" ${i} ${model_path} "8336" ${benchmark_mode} ${concurrency} ${enable_pdl} ${full_logdir} ${nsys_folder} \
186+
&> ${full_logdir}/output_gen_${i}.log &
187+
pid_list="${pid_list} $!"
188+
done
162189

163190
# start the server
164-
srun -l --container-name=${container_name} \
165-
--container-mounts=${mounts} \
166-
--mpi=pmix --overlap -N 1 -n 1 \
167-
-w ${hostname_value} \
168-
bash ${workdir}/start_server.sh ${full_logdir}/config.yaml &> ${full_logdir}/output_server.log &
191+
srun -l --container-name=${container_name}_server \
192+
--container-image=${container_image} \
193+
--container-mounts=${mounts} \
194+
--mpi=pmix --overlap -N 1 -n 1 \
195+
bash ${workdir}/start_server.sh ${num_ctx_servers} ${num_gen_servers} ${full_logdir} ${work_dir} \
196+
&> ${full_logdir}/output_server.log &
169197

170198
# start benchmarking
171199
srun -l --container-name=${container_name} \

0 commit comments

Comments
 (0)