57
57
class RequestQueueItem :
58
58
id : int
59
59
request : Optional [ExecutorRequest ] = None
60
+ child_req_ids : Optional [list ] = None
60
61
is_canceled_request : bool = False
61
62
query : Optional [list ] = None # only used in `StarAttention`
62
63
@@ -88,6 +89,13 @@ def _get_from_request_queue(
88
89
return items
89
90
90
91
92
+ def _get_num_child_requests (request : ExecutorRequest ) -> int :
93
+ sampling_config = request .sampling_config
94
+ logger .info (sampling_config )
95
+ return 0 if sampling_config .beam_width > 1 else (
96
+ sampling_config .num_return_sequences or 1 ) - 1
97
+
98
+
91
99
def _get_from_waiting_queue (
92
100
waiting_queue : deque [RequestQueueItem ],
93
101
max_req_count : int ,
@@ -108,8 +116,9 @@ def _get_from_waiting_queue(
108
116
items = []
109
117
req_count = 0
110
118
while req_count < max_req_count and waiting_queue :
111
- items .append (waiting_queue .popleft ())
112
- req_count += 1
119
+ req_item = waiting_queue .popleft ()
120
+ items .append (req_item )
121
+ req_count += 1 + _get_num_child_requests (req_item .request )
113
122
return items
114
123
115
124
@@ -359,9 +368,16 @@ def enqueue_requests(self, requests: List[ExecutorRequest]):
359
368
start_time = time .time ()
360
369
for request in requests :
361
370
self .start_times [self .next_req_id ] = start_time
362
- self .request_queue .put (
363
- RequestQueueItem (self .next_req_id , request ))
371
+ req_id = self .next_req_id
364
372
req_ids .append (self .next_req_id )
373
+
374
+ child_req_ids = []
375
+ num_child_requests = _get_num_child_requests (request )
376
+ for _ in range (num_child_requests ):
377
+ self .next_req_id += 1
378
+ child_req_ids .append (self .next_req_id )
379
+ self .request_queue .put (
380
+ RequestQueueItem (req_id , request , child_req_ids ))
365
381
self .next_req_id += 1
366
382
finally :
367
383
self .enqueue_lock .release ()
@@ -472,14 +488,23 @@ def enqueue_request(self,
472
488
try :
473
489
self .enqueue_lock .acquire ()
474
490
assert self .active , "PyExecutor has already been shutdown."
491
+ logger .info (
492
+ f"Enqueuing new Executor request with id { self .next_req_id } " )
475
493
req_id = self .next_req_id
476
494
if self .enable_iter_perf_stats :
477
495
self .start_times [req_id ] = time .time ()
478
496
479
497
if query is not None :
480
- self .request_queue .put (RequestQueueItem (req_id , request , query ))
498
+ self .request_queue .put (
499
+ RequestQueueItem (req_id , request , [], False , query ))
481
500
else :
482
- self .request_queue .put (RequestQueueItem (req_id , request ))
501
+ child_req_ids = []
502
+ num_child_requests = _get_num_child_requests (request )
503
+ for _ in range (num_child_requests ):
504
+ self .next_req_id += 1
505
+ child_req_ids .append (self .next_req_id )
506
+ self .request_queue .put (
507
+ RequestQueueItem (req_id , request , child_req_ids ))
483
508
self .next_req_id += 1
484
509
finally :
485
510
self .enqueue_lock .release ()
@@ -1506,12 +1531,15 @@ def _merge_requests(self, new_requests: list[RequestQueueItem]):
1506
1531
else :
1507
1532
raise NotImplementedError (f'unsupport cp type { cp_type } ' )
1508
1533
else :
1509
- return [
1510
- executor_request_to_llm_request (
1511
- req_item .id , req_item .request ,
1534
+ req_with_children = []
1535
+ for req_item in new_requests :
1536
+ req = executor_request_to_llm_request (
1537
+ req_item .id , req_item .request , req_item .child_req_ids ,
1512
1538
self ._should_exclude_last_generation_logits ())
1513
- for req_item in new_requests
1514
- ]
1539
+ req_with_children .append (req )
1540
+ for child in req .children :
1541
+ req_with_children .append (child )
1542
+ return req_with_children
1515
1543
1516
1544
@nvtx_range ("_schedule" )
1517
1545
def _schedule (self ):
@@ -1977,7 +2005,7 @@ def _handle_canceled_requests(self):
1977
2005
if req .id not in self .canceled_req_ids )
1978
2006
1979
2007
for request in self .active_requests :
1980
- req_id = request .py_request_id
2008
+ req_id = request .py_request_id if not request . is_child else request . parent_request_id
1981
2009
if req_id in self .canceled_req_ids :
1982
2010
# Mark requests as finished, then, we reuse all existing code
1983
2011
# to clean up the KV cache resources.
@@ -1991,7 +2019,7 @@ def _handle_canceled_requests(self):
1991
2019
self .canceled_req_ids .clear ()
1992
2020
1993
2021
@nvtx_range ("_enqueue_responses" )
1994
- def _enqueue_responses (self , responses : Dict [ int , LlmResponse ]):
2022
+ def _enqueue_responses (self , responses : List [ Tuple [ int , LlmResponse ] ]):
1995
2023
if 0 not in self .dist .mapping .tp_group and not self .gather_all_responses :
1996
2024
return
1997
2025
@@ -2003,18 +2031,18 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
2003
2031
else :
2004
2032
responses_list = self .dist .allgather (responses )
2005
2033
if self .dist .rank == 0 or self .gather_all_responses :
2006
- gather_responses = {}
2034
+ gather_responses = []
2007
2035
if responses_list is not None :
2008
2036
for resp in responses_list :
2009
2037
if resp is not None :
2010
- gather_responses .update (resp )
2038
+ gather_responses .append (resp )
2011
2039
responses = gather_responses
2012
2040
logger .debug (
2013
2041
f'after gather, rank = { self .dist .rank } , responses = { responses } ' )
2014
2042
2015
2043
if self .dist .rank == 0 or self .gather_all_responses :
2016
2044
with self .response_cv :
2017
- for req_id , resp in responses . items () :
2045
+ for req_id , resp in responses :
2018
2046
if req_id in self .responses .keys ():
2019
2047
self .responses [req_id ].append (resp )
2020
2048
else :
@@ -2023,20 +2051,20 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
2023
2051
2024
2052
@nvtx_range ("_handle_first_token_response" )
2025
2053
def _handle_first_token_response (self , scheduled_batch ):
2026
- new_responses = {}
2054
+ new_responses = []
2027
2055
for req in scheduled_batch .generation_requests :
2028
2056
if req .py_decoding_iter == 1 :
2029
2057
logger .debug (
2030
2058
f'Send first token response for request { req .py_request_id } '
2031
2059
)
2032
2060
response = req .create_response (False , self .dist .rank )
2033
- new_responses .update ({ req .py_request_id : response } )
2061
+ new_responses .append (( req .py_request_id , response ) )
2034
2062
2035
2063
self ._enqueue_responses (new_responses )
2036
2064
2037
2065
@nvtx_range ("_handle_responses" )
2038
2066
def _handle_responses (self ):
2039
- new_responses = {}
2067
+ new_responses = []
2040
2068
requests_to_terminate = []
2041
2069
new_active_requests = []
2042
2070
logger .debug (
@@ -2070,14 +2098,17 @@ def _handle_responses(self):
2070
2098
request .py_decoding_iter % self .stream_interval == 0 :
2071
2099
response = request .create_response (False , self .dist .rank )
2072
2100
if response :
2073
- request_done = response . result . is_final
2074
- new_responses .update ({ req_id : response } )
2101
+ request_done = request . is_finished
2102
+ new_responses .append (( req_id , response ) )
2075
2103
2076
2104
if request_done :
2077
2105
if request .is_disagg_context_transmission_state :
2078
2106
self .ctx_in_transmission_requests .append (request )
2079
2107
else :
2080
- requests_to_terminate .append (request )
2108
+ if response .result .is_final :
2109
+ requests_to_terminate .append (request )
2110
+ for child in request .children :
2111
+ requests_to_terminate .append (child )
2081
2112
else :
2082
2113
new_active_requests .append (request )
2083
2114
self .active_requests = new_active_requests
0 commit comments