@@ -51,14 +51,15 @@ def __init__(
51
51
"""
52
52
self .cfg = config .cache_config
53
53
self .max_num_seqs = max_num_seqs
54
- self .stop_flags = [True ] * max_num_seqs
54
+ self .stop_flags = [True ] * max_num_seqs # flag set to true if the slot has not been taken
55
55
self .enable_prefix_cache = config .cache_config .enable_prefix_caching
56
56
self .cache_manager = PrefixCacheManager (config , tensor_parallel_size , splitwise_role , local_data_parallel_id )
57
- self .tasks_list = [None ] * max_num_seqs
57
+ self .tasks_list = [None ] * max_num_seqs # task slots
58
58
self .req_dict = dict ()
59
59
# current batch status of the engine
60
60
self .real_bsz = 0
61
61
llm_logger .info (f"{ self .info ()} " )
62
+ main_process_metrics .max_batch_size .set (max_num_seqs )
62
63
63
64
def reset_cache_config (self , cfg ):
64
65
"""
@@ -222,18 +223,18 @@ def allocate_resources_for_new_tasks(self, tasks):
222
223
Returns:
223
224
list: processed task list
224
225
"""
225
-
226
- allocated_position = 0
227
- processing_task_index = 0
226
+ llm_logger . debug ( f"Allocating resources for a batch of new tasks: { tasks } " )
227
+ allocated_position = 0 # number of tasks that have been allocated, also the position in request slots
228
+ processing_task_index = 0 # current task
228
229
processed_tasks = list ()
229
- while allocated_position < self .max_num_seqs :
230
- if processing_task_index >= len (tasks ):
230
+ while allocated_position < self .max_num_seqs : # loop until all tasks are allocated resources for
231
+ if processing_task_index >= len (tasks ): # if all taskes have been tried, don't give a second chance
231
232
break
232
233
233
234
can_insert = False
234
235
while allocated_position < self .max_num_seqs :
235
236
if sum (self .stop_flags [allocated_position : allocated_position + 1 ]) == 1 :
236
- can_insert = True
237
+ can_insert = True # if there is a empty slot, try to allocate resources for current task
237
238
break
238
239
allocated_position += 1
239
240
if can_insert :
@@ -243,7 +244,8 @@ def allocate_resources_for_new_tasks(self, tasks):
243
244
task .set ("seed" , random .randint (0 , 9223372036854775807 ))
244
245
task .idx = allocated_position
245
246
246
- if self .enable_prefix_cache :
247
+ if self .enable_prefix_cache : # if prefix caching is enabled
248
+ # 1. request for enough blocks for current task
247
249
cache_prepare_time = time .time ()
248
250
common_block_ids , unique_block_ids , hit_info = self .cache_manager .request_block_ids (
249
251
task ,
@@ -253,39 +255,42 @@ def allocate_resources_for_new_tasks(self, tasks):
253
255
if unique_block_ids is None :
254
256
llm_logger .warning ("req_id: {0} not enough blocks available" .format (task ["req_id" ]))
255
257
return
256
-
258
+ # 2. record cache hit information, and return the number of tokens already in cache
257
259
cached_len = self ._record_request_cache_info (task , common_block_ids , unique_block_ids , hit_info )
258
260
task .cache_prepare_time = time .time () - cache_prepare_time
259
-
261
+ # 3. if prefill/decode disaggregation is enabled
260
262
if task .disaggregate_info is not None :
261
263
if task .disaggregate_info ["role" ] == "prefill" :
264
+ # record the slot position for current task, indexed by request id
262
265
self .req_dict [task .request_id ] = allocated_position
263
266
task .disaggregate_info ["block_tables" ] = task .block_tables
264
267
self ._delete_cached_data (task , cached_len )
265
268
elif task .disaggregate_info ["role" ] == "decode" :
266
269
self .req_dict [task .request_id ] = allocated_position
267
270
task .disaggregate_info ["block_tables" ] = task .need_block_tables
268
271
else :
272
+ # remove cached tokens from prompt token ids to avoid kv recomputation
269
273
self ._delete_cached_data (task , cached_len )
270
274
271
- else :
275
+ else : # if prefix caching is disabled
276
+ # 1. directly allocate empty block from the cache, if there is any
272
277
block_tables = self ._get_block_tables (task .prompt_token_ids_len )
273
278
if not block_tables :
274
279
llm_logger .error (f"req_id: { task .request_id } block_tables is empty" )
275
- continue
280
+ continue # retry
276
281
else :
277
282
task .block_tables = block_tables
278
283
task .need_block_tables = task .block_tables
279
-
284
+ # 2. if prefill/decode disaggregation is enabled
280
285
if task .disaggregate_info is not None :
281
286
task .disaggregate_info ["block_tables" ] = block_tables
282
287
if task .disaggregate_info ["role" ] == "prefill" :
283
288
self .req_dict [task .request_id ] = allocated_position
284
289
elif task .disaggregate_info ["role" ] == "decode" :
285
290
self .req_dict [task .request_id ] = allocated_position
286
291
287
- processed_tasks .append (task )
288
- self .stop_flags [allocated_position ] = False
292
+ processed_tasks .append (task ) # add current task
293
+ self .stop_flags [allocated_position ] = False # mark the slot as occupied
289
294
task .inference_start_time = time .time ()
290
295
task .inference_time_cost = - 1.0
291
296
task .tokens_all_num = 0
@@ -299,11 +304,18 @@ def allocate_resources_for_new_tasks(self, tasks):
299
304
processing_task_index += 1
300
305
301
306
# batch size when the statistical engine is inferring
307
+ # determine batch size by index of the first slot that is not occupied
302
308
for i in range (self .max_num_seqs - 1 , - 1 , - 1 ):
303
309
if not self .stop_flags [i ]:
304
310
self .real_bsz = i + 1
305
311
break
306
312
313
+ # record batch size here
314
+ task_used_block_num = sum ([len (task .block_tables ) if task else 0 for task in self .tasks_list ])
315
+ main_process_metrics .available_gpu_block_num .set (self .total_block_number () - task_used_block_num )
316
+ main_process_metrics .batch_size .set (self .max_num_seqs - self .available_batch ())
317
+ main_process_metrics .gpu_cache_usage_perc .set (self .get_gpu_cache_usage_perc ())
318
+
307
319
llm_logger .info (
308
320
f"Number of allocated requests: { len (tasks )} , number of " f"running requests in worker: { self .real_bsz } "
309
321
)
@@ -335,6 +347,11 @@ def _record_request_cache_info(self, task, common_block_ids, unique_block_ids, h
335
347
task .cpu_cache_token_num = hit_info ["cpu_cache_blocks" ] * self .cfg .block_size
336
348
task .cache_info = (cache_block_num , no_cache_block_num )
337
349
350
+ # Report the number of cached tokens to Prometheus metrics
351
+ main_process_metrics .prefix_cache_token_num .inc (task .num_cached_tokens )
352
+ main_process_metrics .prefix_gpu_cache_token_num .inc (task .gpu_cache_token_num )
353
+ main_process_metrics .prefix_cpu_cache_token_num .inc (task .cpu_cache_token_num )
354
+
338
355
cached_len = len (common_block_ids ) * self .cfg .block_size
339
356
task .block_tables = common_block_ids + unique_block_ids
340
357
task .need_block_tables = unique_block_ids
0 commit comments