Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 48 additions & 38 deletions examples/disaggregated/slurm/disaggr_torch.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,54 @@
#SBATCH --nodes=2
#SBATCH --ntasks=8
#SBATCH --ntasks-per-node=4
#SBATCH --partition=${partition} # add your partition here
#SBATCH --account=${account} # add your account here
#SBATCH --partition=${partition} # add your partition here or specify in the sbatch command
#SBATCH --account=${account} # add your account here or specify in the sbatch command
#SBATCH --job-name=${job_name} # add your job name here or specify in the sbatch command
#SBATCH --time=02:00:00
#SBATCH --job-name=${job_name} # add your job name here

isl=1024
osl=1024
multi_round=10
gen_yaml_file=gen_yaml.py
streaming=true
container_image=${container_image} # add your container image here
mount_dir=${mount_dir} # add your mount directory here
workdir=${workdir} # add your path to the slurm scripts here
model_dir=${model_dir} # add your model directory here

mounts=${mount_dir}:${mount_dir}
logdir=${workdir}/benchmark-${isl}-${osl}/
mkdir -p ${logdir}

container_name=disaggr-test

num_ctx_servers=$1
ctx_tp_size=$2
ctx_batch_size=$3
ctx_max_num_tokens=$4
ctx_enable_attention_dp=$5
num_gen_servers=$6
gen_tp_size=$7
gen_batch_size=$8
gen_max_num_tokens=$9
# Context servers arguments
num_ctx_servers=${1}
ctx_tp_size=${2}
ctx_batch_size=${3}
ctx_max_num_tokens=${4}
ctx_enable_attention_dp=${5}

# Generation servers arguments
num_gen_servers=${6}
gen_tp_size=${7}
gen_batch_size=${8}
gen_max_num_tokens=${9}
gen_enable_attention_dp=${10}
gen_gpu_memory_fraction=${11}

# Other arguments
eplb_num_slots=${12}
mtp_size=${13}

# Benchmarking arguments
concurrency=${14}
isl=${15}
osl=${16}
multi_round=${17}
streaming=${18}

# User specific arguments
container_image=${19}
mounts=${20}
workdir=${21}
model_dir=${22}

ctx_max_seq_len=$((isl + 1))
gen_max_seq_len=$((isl + osl))
ctx_gpu_frac=0.75
cache_transceiver_max_num_tokens=8448

container_name=disaggr
logdir=${workdir}/benchmark-${isl}-${osl}/
mkdir -p ${logdir}
full_logdir=${logdir}/ctx${num_ctx_servers}_gen${num_gen_servers}_dep${gen_tp_size}_batch${gen_batch_size}_eplb${eplb_num_slots}_mtp${mtp_size}

full_logdir=${logdir}/dep${gen_tp_size}_concurrency${concurrency}_eplb${eplb_num_slots}_mtp${mtp_size}
echo "concurrency: ${concurrency}"

ctx_gpus=$((num_ctx_servers * ctx_tp_size))
gen_gpus=$((num_gen_servers * gen_tp_size))
Expand All @@ -49,9 +60,10 @@ enable_pdl=false
if [ "${gen_enable_attention_dp}" = "false" ]; then
enable_pdl=true
echo "enable_pdl: ${enable_pdl}"
full_logdir=${logdir}/tep${gen_tp_size}_concurrency${concurrency}_eplb${eplb_num_slots}_mtp${mtp_size}
full_logdir=${logdir}/ctx${num_ctx_servers}_gen${num_gen_servers}_tep${gen_tp_size}_batch${gen_batch_size}_eplb${eplb_num_slots}_mtp${mtp_size}
fi
mkdir -p ${full_logdir}
echo "Log will be saved to: ${full_logdir}"

nsys_on=""
# nsys_on=${full_logdir} # Uncomment this line to enable Nsys profiling
Expand All @@ -67,16 +79,20 @@ srun -l --container-image=${container_image} \
srun -l --container-name=${container_name} \
--container-mounts=${mounts} \
--mpi=pmix --overlap \
python3 ${workdir}/${gen_yaml_file} --config ${full_logdir}/config.yaml \
python3 ${workdir}/gen_yaml.py --config ${full_logdir}/config.yaml \
--model ${model_dir} \
--num_ctx_servers ${num_ctx_servers} \
--ctx_tp_size ${ctx_tp_size} \
--ctx_batch_size ${ctx_batch_size} \
--ctx_max_num_tokens ${ctx_max_num_tokens} \
--ctx_max_seq_len ${ctx_max_seq_len} \
--ctx_free_gpu_memory_fraction ${ctx_gpu_frac} \
--cache_transceiver_max_num_tokens ${cache_transceiver_max_num_tokens} \
--num_gen_servers ${num_gen_servers} \
--gen_tp_size ${gen_tp_size} \
--gen_batch_size ${gen_batch_size} \
--gen_max_num_tokens ${gen_max_num_tokens} \
--gen_max_seq_len ${gen_max_seq_len} \
--gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \
--eplb_num_slots ${eplb_num_slots} \
$(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \
Expand All @@ -88,17 +104,11 @@ echo "YAML file generated."
hostname_value=$(grep '^hostname:' ${full_logdir}/config.yaml | awk -F': ' '{print $2}')
echo "server host name: $hostname_value"

# try to kill the server and workers
srun -l --container-name=${container_name} \
--container-mounts=${mounts} \
--mpi=pmix --overlap \
pkill -f "trtllm-serve" || true

# start the workers
srun -l --container-name=${container_name} \
--container-mounts=${mounts} \
--mpi=pmix --overlap \
bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${concurrency}" "${enable_pdl}" ${ctx_gpus} ${nsys_on} &> ${full_logdir}/output_workers.log &
bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${nsys_on} &> ${full_logdir}/output_workers.log &

# start the server
srun -l --container-name=${container_name} \
Expand Down
61 changes: 49 additions & 12 deletions examples/disaggregated/slurm/gen_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,21 @@ def gen_config_file(config_path: str,
ctx_tp_size: int,
ctx_batch_size: int,
ctx_max_num_tokens: int,
ctx_max_seq_len: int,
ctx_free_gpu_memory_fraction: float,
ctx_enable_attention_dp: bool,
num_gen_servers: int,
gen_tp_size: int,
gen_batch_size: int,
gen_max_num_tokens: int,
gen_max_seq_len: int,
gen_enable_attention_dp: bool,
gen_gpu_memory_fraction: float,
eplb_num_slots: int,
mtp_size: int = 0,
worker_start_port: int = 8001,
server_port: int = 8000) -> None:
server_port: int = 8000,
cache_transceiver_max_num_tokens: int = 4608) -> None:
"""
Generate configuration YAML file for disaggregated inference.

Expand All @@ -146,6 +150,8 @@ def gen_config_file(config_path: str,
ctx_tp_size: Tensor parallel size for context servers
ctx_batch_size: Batch size for context servers
ctx_max_num_tokens: Max number of tokens for context servers
ctx_max_seq_len: Max sequence length for context servers
ctx_free_gpu_memory_fraction: Free GPU memory fraction for context servers
ctx_enable_attention_dp: Enable attention DP for context servers
num_gen_servers: Number of generation servers
gen_tp_size: Tensor parallel size for generation servers
Expand All @@ -161,7 +167,11 @@ def gen_config_file(config_path: str,
1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024, 2048, gen_batch_size
]

gen_moe_backend = "WIDEEP"
gen_moe_backend = "CUTLASS"
if gen_tp_size >= 16 and gen_enable_attention_dp:
gen_moe_backend = "WIDEEP"
if not gen_enable_attention_dp:
gen_moe_backend = "TRTLLM"

config = {
'model': model_path,
Expand All @@ -172,20 +182,22 @@ def gen_config_file(config_path: str,
'num_instances': num_ctx_servers,
'max_batch_size': ctx_batch_size,
'max_num_tokens': ctx_max_num_tokens,
'max_seq_len': 1152,
'max_seq_len': ctx_max_seq_len,
'free_gpu_memory_fraction': ctx_free_gpu_memory_fraction,
'tensor_parallel_size': ctx_tp_size,
'moe_expert_parallel_size': ctx_tp_size,
'enable_attention_dp': ctx_enable_attention_dp,
'pipeline_parallel_size': 1,
'print_iter_log': True,
'disable_overlap_scheduler': True,
'kv_cache_config': {
'free_gpu_memory_fraction': 0.85,
'enable_block_reuse': False,
'free_gpu_memory_fraction': ctx_free_gpu_memory_fraction,
'dtype': 'fp8',
},
'cache_transceiver_config': {
'max_tokens_in_buffer': cache_transceiver_max_num_tokens,
'backend': 'default',
'max_tokens_in_buffer': 8320,
},
},
'generation_servers': {
Expand All @@ -196,23 +208,26 @@ def gen_config_file(config_path: str,
'pipeline_parallel_size': 1,
'max_batch_size': gen_batch_size,
'max_num_tokens': gen_max_num_tokens,
'max_seq_len': 2176,
'max_seq_len': gen_max_seq_len,
'free_gpu_memory_fraction': gen_gpu_memory_fraction,
'cuda_graph_config': {
'enable_padding': True,
'batch_sizes': gen_cuda_graph_batch_sizes,
},
'print_iter_log': True,
'kv_cache_config': {
'enable_block_reuse': False,
'free_gpu_memory_fraction': gen_gpu_memory_fraction,
'dtype': 'fp8',
},
'moe_config': {
'backend': gen_moe_backend,
},
'cache_transceiver_config': {
'max_tokens_in_buffer': cache_transceiver_max_num_tokens,
'backend': 'default',
'max_tokens_in_buffer': 8320,
},
'stream_interval': 20,
}
}

Expand All @@ -236,6 +251,9 @@ def gen_config_file(config_path: str,
# set the hostname to the first node
config['hostname'] = nodes[0]

if gen_tp_size == 8 and not gen_enable_attention_dp:
config['generation_servers']['allreduce_strategy'] = "MNNVL"

if eplb_num_slots > 0:
moe_load_balancer_file = os.path.join(os.path.dirname(config_path),
"moe_load_balancer.yaml")
Expand Down Expand Up @@ -290,6 +308,14 @@ def gen_config_file(config_path: str,
type=int,
required=True,
help="Max number of tokens for context servers")
parser.add_argument("--ctx_max_seq_len",
type=int,
required=True,
help="Max sequence length for context servers")
parser.add_argument("--ctx_free_gpu_memory_fraction",
type=float,
required=True,
help="Free GPU memory fraction for context servers")
parser.add_argument("--ctx_enable_attention_dp",
dest='ctx_enable_attention_dp',
action='store_true',
Expand All @@ -310,6 +336,10 @@ def gen_config_file(config_path: str,
type=int,
required=True,
help="Max number of tokens for generation servers")
parser.add_argument("--gen_max_seq_len",
type=int,
required=True,
help="Max sequence length for generation servers")
parser.add_argument("--gen_enable_attention_dp",
dest='gen_enable_attention_dp',
action='store_true',
Expand All @@ -334,13 +364,20 @@ def gen_config_file(config_path: str,
type=int,
default=8333,
help="Server port")
parser.add_argument("--cache_transceiver_max_num_tokens",
type=int,
default=4608,
help="Max number of tokens for cache transceiver")

args = parser.parse_args()

gen_config_file(args.config, args.model, args.num_ctx_servers,
args.ctx_tp_size, args.ctx_batch_size,
args.ctx_max_num_tokens, args.ctx_enable_attention_dp,
args.num_gen_servers, args.gen_tp_size, args.gen_batch_size,
args.gen_max_num_tokens, args.gen_enable_attention_dp,
args.gen_gpu_memory_fraction, args.eplb_num_slots,
args.mtp_size, args.worker_start_port, args.server_port)
args.ctx_max_num_tokens, args.ctx_max_seq_len,
args.ctx_free_gpu_memory_fraction,
args.ctx_enable_attention_dp, args.num_gen_servers,
args.gen_tp_size, args.gen_batch_size,
args.gen_max_num_tokens, args.gen_max_seq_len,
args.gen_enable_attention_dp, args.gen_gpu_memory_fraction,
args.eplb_num_slots, args.mtp_size, args.worker_start_port,
args.server_port, args.cache_transceiver_max_num_tokens)
6 changes: 3 additions & 3 deletions examples/disaggregated/slurm/start_server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ config_file=$1
# Check and replace hostname settings in config_file
if [ -f "$config_file" ]; then
# Use sed to find hostname line and check if replacement is needed
if grep -q "^hostname:" "$config_file"; then
if grep -q "hostname:" "$config_file"; then
# Extract current hostname value from config
current_hostname=$(grep "^hostname:" "$config_file" | sed 's/^hostname:[ ]*//' | awk '{print $1}')
current_hostname=$(grep "hostname:" "$config_file" | sed 's/.*hostname:[ ]*//' | awk '{print $1}')

if [ "$current_hostname" != "$short_hostname" ]; then
echo "Replacing hostname '$current_hostname' with '$short_hostname' in $config_file"
Expand All @@ -31,4 +31,4 @@ else
echo "Config file $config_file not found"
fi

trtllm-serve disaggregated -c ${config_file} -t 1800 -r 1800
trtllm-serve disaggregated -c ${config_file} -t 1800 -r 7200
14 changes: 5 additions & 9 deletions examples/disaggregated/slurm/start_worker.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
#! /bin/bash

config_file=$1
concurrency=$2
enable_pdl=$3
ctx_gpus=$4
work_dir=$5
enable_pdl=$2
ctx_gpus=$3
work_dir=$4
unset UCX_TLS
echo "config_file: ${config_file}, concurrency: ${concurrency}, enable_pdl: ${enable_pdl}, ctx_gpus: ${ctx_gpus}, work_dir: ${work_dir}"
echo "config_file: ${config_file}, enable_pdl: ${enable_pdl}, ctx_gpus: ${ctx_gpus}, work_dir: ${work_dir}"

export TLLM_LOG_LEVEL=INFO
export TRTLLM_USE_UCX_KVCACHE=1
export TLLM_BENCHMARK_REQ_QUEUES_SIZE=${concurrency}
export TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1
export TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER=1

if [ "${enable_pdl}" = "true" ]; then
Expand All @@ -27,7 +23,7 @@ else
nsys_file=${work_dir}/nsys_worker_proc_${SLURM_PROCID}
export TLLM_PROFILE_RECORD_GC=1
export TLLM_NVTX_DEBUG=1
if [ "${SLURM_PROCID}" -ge "${ctx_gpus}" ]; then
if [ ${SLURM_PROCID} -ge ${ctx_gpus} ]; then
export TLLM_PROFILE_START_STOP=200-250
nsys_prefix="nsys profile -e \"NSYS_MPI_STORE_TEAMS_PER_RANK=1\" -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none"
echo "nsys_prefix: ${nsys_prefix}"
Expand Down
Loading