You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The execution method of disaggregated serving relies on the `trtllm-serve` command. Specifically, compared to the standard usage of `trtllm-serve`, serving requires running this command multiple times to separately start the router and workers (including context and generation) serving components. This document focuses on this approach and provides a detailed guide on how to use it.
3
+
To run TRT-LLM in disaggregated mode, you must first launch context (prefill) and generation (decode) servers using `trtllm-serve`.
4
+
Depending on your deployment environment, this can be done in different ways.
4
5
5
-
Please note that disaggregated serving is currently an experimental feature, so the usage described in this document may change in the future.
6
+
## Launching context and generation servers using multiple independent `trtllm-serve` commands
6
7
7
-
## Startup Procedure
8
+
You can use multiple `trtllm-serve` commands to launch the context and generation servers that will be used
9
+
for disaggregated serving. For example, you could launch two context servers and one generation servers as follows:
8
10
9
-
### Configuration File
10
-
11
-
The `trtllm-serve` command supports the `extra-llm-config.yaml` parameter. In the extra LLM configuration file, the `cache_transceiver_config` field is specifically used for disaggregated service. It is mainly used to specify additional parameters required for the KV cache transmission process.
For non-SLURM clusters - particularly in single-node, multi-GPU setups, it is recommended to use standard mode. In such cases, the system does not enforce limits on process creation or termination.
45
-
46
-
Suppose we have three CUDA devices on the same machine. The first two devices are used to launch one context model each, and the third device is used to launch one generation model. In this case, the following commands need to be executed.
Once the context and generation servers are launched, you can launch the disaggregated
65
23
server, which will accept requests from clients and do the orchestration between context
66
24
and generation servers. The disaggregated server can be launched with:
67
25
68
-
```bash
69
-
# Start proxy
26
+
```
70
27
trtllm-serve disaggregated -c disagg_config.yaml
71
28
```
72
-
73
29
where `disagg_config.yaml` contains information about the context and generation servers. For the current example,
74
30
it would look like:
75
-
76
-
```yaml
77
-
# disagg_config.yaml
78
-
31
+
```
79
32
hostname: localhost
80
33
port: 8000
81
34
backend: pytorch
@@ -90,215 +43,62 @@ generation_servers:
90
43
- "localhost:8003"
91
44
```
92
45
93
-
Clients can then send requests to the disaggregated server at `localhost:8000`, which is an OpenAI API compatible endpoint.
94
-
95
-
96
-
#### Sending requests to the disaggregated server
97
-
98
-
Once the context, generation and disaggregated servers are launched, you can send requests to the disaggregated server using curl:
99
-
100
-
```bash
101
-
curl http://localhost:8000/v1/completions \
102
-
-H "Content-Type: application/json" \
103
-
-d '{
104
-
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
105
-
"prompt": "NVIDIA is a great company because",
106
-
"max_tokens": 16,
107
-
"temperature": 0
108
-
}' -w "\n"
109
-
```
110
-
111
-
Or using the provided client parsing the prompts from a file and sending request to the disaggregated server specified in the `disagg_config.yaml` file at the `chat` endpoint:
Clients can then send requests to the disaggregated server at `localhost:8000`, which is an OpenAI compatible endpoint.
116
47
117
-
### Launching disaggregated servers on SLURM clusters
48
+
## Launching context and generation servers using MPI
118
49
119
-
To simplify usage, TensorRT-LLM internally relies on MPI spawning processes. However, some clusters do not offer such process flexibility. In these cases, we provide the `trtllm-llmapi-launch` tool to launch all processes at once. Therefore, when using TensorRT-LLM on a Slurm cluster, please refer to the following method.
120
-
121
-
#### Single-Node Execution
122
-
123
-
After starting the node and entering interactive mode, you can run the following command to prevent process spawning.
Additionally, we offer a fully executable script—please refer to [Disaggregated SLURM Scripts](./slurm/simple_example/).
202
-
203
-
204
-
## Dynamic scaling (Prototype)
205
-
206
-
Currently, trtllm supports dynamic addition and removal of servers by leveraging ETCD. To enable this feature, you should start the context and generation servers with an additional flag ```--metadata_server_config_file``` and ```--server_role```.
207
-
Before launching the context and generation servers, you should first start the ETCD server. By default, the ETCD server listens for client requests at ```localhost:2379```.
208
-
209
-
```bash
210
-
etcd
211
-
```
212
-
213
-
After this, you can enable the dynamic scaling feature for the use case above as follows:
The ```hostname``` and ```port``` must match those used when starting the ETCD server. The ```health_check_timeout``` parameter specifies how long a server will be considered dead if no healthy response is received. By default, trtllm will perform two checks before marking a server as dead. The ```refresh_interval``` parameter determines how often the latest server list is fetched from the ETCD server.
239
-
240
-
### Dynamically adding servers
241
-
242
-
Users can add servers by directly launching them with trtllm-serve. For example, you can start an additional generation server as follows:
TensorRT-LLM will automatically register any newly launched server with the ETCD server, allowing the router to send new requests to the added server.
253
-
254
-
### Dynamically removing servers
255
-
256
-
When removing servers, special attention is required in the current version. You need to first remove the corresponding key from the ETCD server. After you see the log message "Server xxxx is removed," you can then safely shut down the server. This part will be improved soon.
257
-
258
-
## Startup Procedure with MPI Worker (Deprecated)
259
-
260
-
In the past, we used `disaggregated_mpi_worker` to allow context nodes and generation nodes to operate within the same MPI world. However, this approach conflicts with the dynamic node addition and removal functionality. As a result, disaggregated_mpi_worker has been marked as deprecated, and the corresponding examples will be gradually removed.
where `total_num_ranks` is the sum of `TP*PP` for all context and generation servers. For the example above, `total_num_ranks` is 3
55
+
where `<total_num_ranks>` is the sum of `TP*PP` for all context and generation servers. For the example above, `total_num_ranks` is 3
266
56
since `TP` and `PP` is 1 for the two context and one generation server.
267
57
268
58
The `disagg_config.yaml` file must now contain the configuration parameters of the context and generation servers. For example,
269
59
it could look like:
270
60
271
-
```yaml
61
+
```
272
62
hostname: localhost
273
63
port: 8000
274
64
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
275
65
backend: "pytorch"
66
+
use_cuda_graph: False
276
67
disable_overlap_scheduler: True
277
68
context_servers:
278
69
num_instances: 2
279
70
tensor_parallel_size: 1
280
71
pipeline_parallel_size: 1
281
72
kv_cache_config:
282
73
free_gpu_memory_fraction: 0.9
283
-
cache_transceiver_config:
284
-
backend: UCX
285
74
urls:
286
75
- "localhost:8001"
287
76
- "localhost:8002"
288
77
generation_servers:
289
78
num_instances: 1
290
79
tensor_parallel_size: 1
291
80
pipeline_parallel_size: 1
292
-
cache_transceiver_config:
293
-
backend: UCX
294
81
urls:
295
82
- "localhost:8003"
296
83
```
297
84
298
85
Once the context and generation servers are launched, you can again launch the disaggregated server with
299
-
300
-
```bash
86
+
```
301
87
trtllm-serve disaggregated -c disagg_config.yaml
302
88
```
303
89
304
-
The MPI communication backend for KV cache transfer has been deprecated and may not be supported in the future. When using the MPI backend, the environment variable `TRTLLM_USE_MPI_KVCACHE=1` should be set to avoid conflicts between mpi4py and KV cache transfer.
90
+
## Sending requests to the disaggregated server
91
+
92
+
Once the context, generation and disaggregated servers are launched, you can send requests to the disaggregated server using curl:
Or using the provided client parsing the prompts from a file and sending request to the disaggregated server specified in the `disagg_config.yaml` file at the `chat` endpoint:
0 commit comments