@@ -325,6 +325,27 @@ def _get_request_id(self):
325
325
self ._next_req_id = (self ._next_req_id + 1 ) & ((1 << 64 ) - 1 )
326
326
return self ._next_req_id
327
327
328
+ def _generate_child_request_ids (
329
+ self , request : ExecutorRequest ) -> List [int ] | None :
330
+ """ Generate child request IDs if needed. """
331
+ child_req_ids = None
332
+ sampling_config = request .sampling_config
333
+ beam_width = (sampling_config .beam_width
334
+ if sampling_config .beam_width else 1 )
335
+ num_return_sequences = (sampling_config .num_return_sequences
336
+ if sampling_config .num_return_sequences else 1 )
337
+
338
+ # Create child requests if beam_width == 1 and num_return_sequences > 1.
339
+ if beam_width == 1 and num_return_sequences > 1 :
340
+ child_req_ids = []
341
+ for _ in range (num_return_sequences - 1 ):
342
+ child_req_id = self ._get_request_id ()
343
+ if self .enable_iter_perf_stats :
344
+ self .start_times [child_req_id ] = time .time ()
345
+ child_req_ids .append (child_req_id )
346
+
347
+ return child_req_ids
348
+
328
349
def enqueue_requests (self , requests : List [ExecutorRequest ]):
329
350
"""
330
351
Enqueue new requests
@@ -339,21 +360,8 @@ def enqueue_requests(self, requests: List[ExecutorRequest]):
339
360
if self .enable_iter_perf_stats :
340
361
self .start_times [req_id ] = time .time ()
341
362
342
- # Generate child request IDs if needed
343
- child_req_ids = None
344
- sampling_config = request .sampling_config
345
- beam_width = sampling_config .beam_width
346
- num_return_sequences = sampling_config .num_return_sequences or beam_width
347
-
348
- if beam_width == 1 and num_return_sequences > 1 :
349
- # Reserve request ids for child requests.
350
- child_req_ids = []
351
- for _ in range (num_return_sequences - 1 ):
352
- child_req_id = self ._get_request_id ()
353
- if self .enable_iter_perf_stats :
354
- self .start_times [child_req_id ] = time .time ()
355
- child_req_ids .append (child_req_id )
356
-
363
+ # Reserve child request ids if needed.
364
+ child_req_ids = self ._generate_child_request_ids (request )
357
365
self .request_queue .put (
358
366
RequestQueueItem (req_id ,
359
367
request ,
@@ -476,23 +484,8 @@ def enqueue_request(self,
476
484
if self .enable_iter_perf_stats :
477
485
self .start_times [req_id ] = time .time ()
478
486
479
- # Generate child request IDs if needed
480
- child_req_ids = None
481
- sampling_config = request .sampling_config
482
- beam_width = (sampling_config .beam_width
483
- if sampling_config .beam_width else 1 )
484
- num_return_sequences = (sampling_config .num_return_sequences if
485
- sampling_config .num_return_sequences else 1 )
486
-
487
- # Only create child requests if beam_width == 1 and num_return_sequences > 1
488
- if beam_width == 1 and num_return_sequences > 1 :
489
- child_req_ids = []
490
- for i in range (num_return_sequences - 1 ):
491
- child_req_id = self ._get_request_id ()
492
- if self .enable_iter_perf_stats :
493
- self .start_times [child_req_id ] = time .time ()
494
- child_req_ids .append (child_req_id )
495
-
487
+ # Reserve child request ids if needed.
488
+ child_req_ids = self ._generate_child_request_ids (request )
496
489
self .request_queue .put (
497
490
RequestQueueItem (req_id ,
498
491
request ,
0 commit comments