@@ -57,14 +57,15 @@ def __init__(
57
57
self .logger = llm_logger
58
58
self .cfg = config .cache_config
59
59
self .max_num_seqs = max_num_seqs
60
- self .stop_flags = [True ] * max_num_seqs
60
+ self .stop_flags = [True ] * max_num_seqs # flag set to true if the slot has not been taken
61
61
self .enable_prefix_cache = config .cache_config .enable_prefix_caching
62
62
self .cache_manager = PrefixCacheManager (config , tensor_parallel_size , splitwise_role , local_data_parallel_id )
63
- self .tasks_list = [None ] * max_num_seqs
63
+ self .tasks_list = [None ] * max_num_seqs # task slots
64
64
self .req_dict = dict ()
65
65
# current batch status of the engine
66
66
self .real_bsz = 0
67
67
self .logger .info (f"{ self .info ()} " )
68
+ main_process_metrics .max_batch_size .set (max_num_seqs )
68
69
69
70
def reset_cache_config (self , cfg ):
70
71
"""
@@ -228,72 +229,76 @@ def allocate_resources_for_new_tasks(self, tasks):
228
229
Returns:
229
230
list: processed task list
230
231
"""
231
-
232
- allocated_position = 0
233
- processing_task_index = 0
232
+ llm_logger . debug ( f"Allocating resources for a batch of new tasks: { tasks } " )
233
+ allocated_position = 0 # number of tasks that have been allocated, also the position in request slots
234
+ processing_task_index = 0 # current task
234
235
processed_tasks = list ()
235
- while allocated_position < self .max_num_seqs :
236
- if processing_task_index >= len (tasks ):
236
+ while allocated_position < self .max_num_seqs : # loop until all tasks are allocated resources for
237
+ if processing_task_index >= len (tasks ): # if all taskes have been tried, don't give a second chance
237
238
break
238
239
239
240
can_insert = False
240
241
while allocated_position + 1 <= self .max_num_seqs :
241
242
if sum (self .stop_flags [allocated_position : allocated_position + 1 ]) == 1 :
242
- can_insert = True
243
+ can_insert = True # if there is a empty slot, try to allocate resources for current task
243
244
break
244
245
allocated_position += 1
245
246
if can_insert :
246
247
if self .stop_flags [allocated_position ]:
247
248
248
- task = tasks [processing_task_index ]
249
+ task = tasks [processing_task_index ] # retrieve current task
249
250
250
251
if task .get ("seed" ) is None :
251
252
task .set ("seed" , random .randint (0 , 9223372036854775807 ))
252
253
task .idx = allocated_position
253
254
254
- if self .enable_prefix_cache :
255
+ if self .enable_prefix_cache : # if prefix caching is enabled
256
+ # 1. request for enough blocks for current task
255
257
cache_prepare_time = time .time ()
256
258
common_block_ids , unique_block_ids , hit_info = self .cache_manager .request_block_ids (
257
259
task , self .cfg .block_size , self .cfg .dec_token_num
258
260
)
259
261
if unique_block_ids is None :
260
262
self .logger .warning ("req_id: {0} not enough blocks available" .format (task ["req_id" ]))
261
263
return
262
-
264
+ # 2. record cache hit information, and return the number of tokens already in cache
263
265
cached_len = self ._record_request_cache_info (
264
266
task , common_block_ids , unique_block_ids , hit_info
265
267
)
266
268
task .cache_prepare_time = time .time () - cache_prepare_time
267
-
269
+ # 3. if prefill/decode disaggregation is enabled
268
270
if task .disaggregate_info is not None :
269
271
if task .disaggregate_info ["role" ] == "prefill" :
272
+ # record the slot position for current task, indexed by request id
270
273
self .req_dict [task .request_id ] = allocated_position
271
274
task .disaggregate_info ["block_tables" ] = task .block_tables
272
275
self ._delete_cached_data (task , cached_len )
273
276
elif task .disaggregate_info ["role" ] == "decode" :
274
277
self .req_dict [task .request_id ] = allocated_position
275
278
task .disaggregate_info ["block_tables" ] = task .need_block_tables
276
279
else :
280
+ # remove cached tokens from prompt token ids to avoid kv recomputation
277
281
self ._delete_cached_data (task , cached_len )
278
282
279
- else :
283
+ else : # if prefix caching is disabled
284
+ # 1. directly allocate empty block from the cache, if there is any
280
285
block_tables = self ._get_block_tables (task .prompt_token_ids_len )
281
286
if not block_tables :
282
287
llm_logger .error (f"req_id: { task .request_id } block_tables is empty" )
283
- continue
288
+ continue # retry
284
289
else :
285
290
task .block_tables = block_tables
286
291
task .need_block_tables = task .block_tables
287
-
292
+ # 2. if prefill/decode disaggregation is enabled
288
293
if task .disaggregate_info is not None :
289
294
task .disaggregate_info ["block_tables" ] = block_tables
290
295
if task .disaggregate_info ["role" ] == "prefill" :
291
296
self .req_dict [task .request_id ] = allocated_position
292
297
elif task .disaggregate_info ["role" ] == "decode" :
293
298
self .req_dict [task .request_id ] = allocated_position
294
299
295
- processed_tasks .append (task )
296
- self .stop_flags [allocated_position ] = False
300
+ processed_tasks .append (task ) # add current task
301
+ self .stop_flags [allocated_position ] = False # mark the slot as occupied
297
302
task .inference_start_time = time .time ()
298
303
task .inference_time_cost = - 1.0
299
304
task .tokens_all_num = 0
@@ -307,11 +312,18 @@ def allocate_resources_for_new_tasks(self, tasks):
307
312
processing_task_index += 1
308
313
309
314
# batch size when the statistical engine is inferring
315
+ # determine batch size by index of the first slot that is not occupied
310
316
for i in range (self .max_num_seqs - 1 , - 1 , - 1 ):
311
317
if not self .stop_flags [i ]:
312
318
self .real_bsz = i + 1
313
319
break
314
320
321
+ # record batch size here
322
+ task_used_block_num = sum ([len (task .block_tables ) if task else 0 for task in self .tasks_list ])
323
+ main_process_metrics .available_gpu_block_num .set (self .total_block_number () - task_used_block_num )
324
+ main_process_metrics .batch_size .set (self .max_num_seqs - self .available_batch ())
325
+ main_process_metrics .gpu_cache_usage_perc .set (self .get_gpu_cache_usage_perc ())
326
+
315
327
self .logger .info (
316
328
f"Number of allocated requests: { len (tasks )} , number of " f"running requests in worker: { self .real_bsz } "
317
329
)
@@ -343,6 +355,11 @@ def _record_request_cache_info(self, task, common_block_ids, unique_block_ids, h
343
355
task .cpu_cache_token_num = hit_info ["cpu_cache_blocks" ] * self .cfg .block_size
344
356
task .cache_info = (cache_block_num , no_cache_block_num )
345
357
358
+ # Report the number of cached tokens to Prometheus metrics
359
+ main_process_metrics .prefix_cache_token_num .inc (task .num_cached_tokens )
360
+ main_process_metrics .prefix_gpu_cache_token_num .inc (task .gpu_cache_token_num )
361
+ main_process_metrics .prefix_cpu_cache_token_num .inc (task .cpu_cache_token_num )
362
+
346
363
cached_len = len (common_block_ids ) * self .cfg .block_size
347
364
task .block_tables = common_block_ids + unique_block_ids
348
365
task .need_block_tables = unique_block_ids
0 commit comments