Skip to content

Commit 7c1d650

Browse files
factor out child req id generation logic
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 8a6e429 commit 7c1d650

File tree

1 file changed

+25
-32
lines changed

1 file changed

+25
-32
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,27 @@ def _get_request_id(self):
325325
self._next_req_id = (self._next_req_id + 1) & ((1 << 64) - 1)
326326
return self._next_req_id
327327

328+
def _generate_child_request_ids(
329+
self, request: ExecutorRequest) -> List[int] | None:
330+
""" Generate child request IDs if needed. """
331+
child_req_ids = None
332+
sampling_config = request.sampling_config
333+
beam_width = (sampling_config.beam_width
334+
if sampling_config.beam_width else 1)
335+
num_return_sequences = (sampling_config.num_return_sequences
336+
if sampling_config.num_return_sequences else 1)
337+
338+
# Create child requests if beam_width == 1 and num_return_sequences > 1.
339+
if beam_width == 1 and num_return_sequences > 1:
340+
child_req_ids = []
341+
for _ in range(num_return_sequences - 1):
342+
child_req_id = self._get_request_id()
343+
if self.enable_iter_perf_stats:
344+
self.start_times[child_req_id] = time.time()
345+
child_req_ids.append(child_req_id)
346+
347+
return child_req_ids
348+
328349
def enqueue_requests(self, requests: List[ExecutorRequest]):
329350
"""
330351
Enqueue new requests
@@ -339,21 +360,8 @@ def enqueue_requests(self, requests: List[ExecutorRequest]):
339360
if self.enable_iter_perf_stats:
340361
self.start_times[req_id] = time.time()
341362

342-
# Generate child request IDs if needed
343-
child_req_ids = None
344-
sampling_config = request.sampling_config
345-
beam_width = sampling_config.beam_width
346-
num_return_sequences = sampling_config.num_return_sequences or beam_width
347-
348-
if beam_width == 1 and num_return_sequences > 1:
349-
# Reserve request ids for child requests.
350-
child_req_ids = []
351-
for _ in range(num_return_sequences - 1):
352-
child_req_id = self._get_request_id()
353-
if self.enable_iter_perf_stats:
354-
self.start_times[child_req_id] = time.time()
355-
child_req_ids.append(child_req_id)
356-
363+
# Reserve child request ids if needed.
364+
child_req_ids = self._generate_child_request_ids(request)
357365
self.request_queue.put(
358366
RequestQueueItem(req_id,
359367
request,
@@ -476,23 +484,8 @@ def enqueue_request(self,
476484
if self.enable_iter_perf_stats:
477485
self.start_times[req_id] = time.time()
478486

479-
# Generate child request IDs if needed
480-
child_req_ids = None
481-
sampling_config = request.sampling_config
482-
beam_width = (sampling_config.beam_width
483-
if sampling_config.beam_width else 1)
484-
num_return_sequences = (sampling_config.num_return_sequences if
485-
sampling_config.num_return_sequences else 1)
486-
487-
# Only create child requests if beam_width == 1 and num_return_sequences > 1
488-
if beam_width == 1 and num_return_sequences > 1:
489-
child_req_ids = []
490-
for i in range(num_return_sequences - 1):
491-
child_req_id = self._get_request_id()
492-
if self.enable_iter_perf_stats:
493-
self.start_times[child_req_id] = time.time()
494-
child_req_ids.append(child_req_id)
495-
487+
# Reserve child request ids if needed.
488+
child_req_ids = self._generate_child_request_ids(request)
496489
self.request_queue.put(
497490
RequestQueueItem(req_id,
498491
request,

0 commit comments

Comments
 (0)