@@ -276,10 +276,19 @@ def create_response(
276
276
return None
277
277
else :
278
278
return LlmResponse (request_id = request .py_request_id ,
279
- result = LlmResult (result , request .py_result ),
279
+ result = LlmResult (result , request .py_result ,
280
+ result .is_final ),
280
281
client_id = request .py_client_id )
281
282
282
283
284
+ def finish_by (request : Union [
285
+ 'LlmRequest' , tensorrt_llm .bindings .internal .batch_manager .LlmRequest ],
286
+ reason : FinishReason , beam : int ) -> None :
287
+ """CPP finish by reason does not support beam_width > 1"""
288
+ request .state = LlmRequestState .GENERATION_COMPLETE
289
+ request .set_finished_reason (reason , beam )
290
+
291
+
283
292
class LlmRequest (tensorrt_llm .bindings .internal .batch_manager .LlmRequest ):
284
293
"""LlmRequest wraps `bindings.internal.batch_manager.LlmRequest`
285
294
but detour some features to Python implementation"""
@@ -298,6 +307,7 @@ def __init__(
298
307
stop_words_list : list [list [int ]] | None = None ,
299
308
is_draft : bool = False ,
300
309
** kwargs ):
310
+
301
311
self .py_logits_post_processors = kwargs .pop ("py_logits_post_processors" ,
302
312
None )
303
313
# Multimodal data
@@ -377,8 +387,9 @@ def create_child_request(self, request_id: int):
377
387
child_request .is_cuda_graph_dummy = self .is_cuda_graph_dummy
378
388
child_request .is_dummy = self .is_dummy
379
389
380
- # Override create_response to return the child request
390
+ # Mimic the behavior of the original LlmRequest.
381
391
child_request .create_response = partial (create_response , child_request )
392
+ child_request .finish_by = partial (finish_by , child_request )
382
393
383
394
return child_request
384
395
@@ -394,8 +405,7 @@ def is_dummy(self):
394
405
395
406
def finish_by (self , reason : FinishReason , beam : int ) -> None :
396
407
"""CPP finish by reason does not support beam_width > 1"""
397
- self .state = LlmRequestState .GENERATION_COMPLETE
398
- self .set_finished_reason (reason , beam )
408
+ finish_by (self , reason , beam )
399
409
400
410
401
411
def convert_wordlist (word_list ) -> List [List [int ]]:
0 commit comments