Skip to content

Commit ffcdc9a

Browse files
QiJunelancelly
authored andcommitted
chore: add _prepare_and_schedule_batch function in PyExecutor (NVIDIA#6365)
Signed-off-by: junq <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 0f68252 commit ffcdc9a

File tree

1 file changed

+50
-81
lines changed

1 file changed

+50
-81
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 50 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,50 @@ def _executor_loop_pp(self):
790790
self.active_requests,
791791
previous_batch)
792792

793+
def _prepare_and_schedule_batch(self):
794+
new_requests = self._fetch_new_requests()
795+
if self.should_stop_processing:
796+
return None, None
797+
798+
if self.kv_cache_transceiver:
799+
self._check_disagg_gen_transfer_status()
800+
801+
iter_stats = None
802+
if self.enable_iter_perf_stats:
803+
iter_stats = self._get_init_iter_stats(
804+
len(new_requests),
805+
self.executor_request_queue.
806+
get_new_active_requests_queue_latency())
807+
808+
self._pad_attention_dp_dummy_request()
809+
810+
if self.drafter is not None:
811+
self._prepare_draft_requests(self.active_requests)
812+
813+
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
814+
)
815+
816+
if self.kv_cache_transceiver:
817+
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
818+
self._prepare_disagg_gen_init(fitting_disagg_gen_init_requests)
819+
820+
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
821+
logger.warning(
822+
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
823+
)
824+
self.kv_cache_transceiver.check_context_transfer_status(1)
825+
else:
826+
assert scheduled_batch.batch_size > 0, (
827+
"fail to schedule any pending request, "
828+
"probably run out of resource.")
829+
830+
self.num_scheduled_requests = scheduled_batch.batch_size
831+
logger.debug(
832+
f'has {len(self.active_requests)} active_request, '
833+
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
834+
f'{len(scheduled_batch.generation_requests)} generation requests')
835+
return scheduled_batch, iter_stats
836+
793837
def _executor_loop(self):
794838
torch.cuda.set_device(self.device_id)
795839
with self._profiler() as profile_step:
@@ -800,48 +844,10 @@ def _executor_loop(self):
800844
profile_step()
801845
if self.enable_iter_perf_stats:
802846
iter_start_time = time.time()
803-
new_requests = self._fetch_new_requests()
804-
if self.should_stop_processing:
805-
break
806-
807-
if self.kv_cache_transceiver:
808-
self._check_disagg_gen_transfer_status()
809-
810-
if self.enable_iter_perf_stats:
811-
iter_stats = self._get_init_iter_stats(
812-
len(new_requests),
813-
self.executor_request_queue.
814-
get_new_active_requests_queue_latency())
815-
816-
self._pad_attention_dp_dummy_request()
817-
818-
if self.drafter is not None:
819-
self._prepare_draft_requests(self.active_requests)
820-
821-
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
822-
)
823-
824-
if self.kv_cache_transceiver:
825-
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
826-
self._prepare_disagg_gen_init(
827-
fitting_disagg_gen_init_requests)
828-
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
829-
logger.warning(
830-
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
831-
)
832-
self.kv_cache_transceiver.check_context_transfer_status(
833-
1)
834-
else:
835-
assert scheduled_batch.batch_size > 0, (
836-
"fail to schedule any pending request, "
837-
"probably run out of resource.")
838847

839-
self.num_scheduled_requests = scheduled_batch.batch_size
840-
logger.debug(
841-
f'has {len(self.active_requests)} active_request, '
842-
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
843-
f'{len(scheduled_batch.generation_requests)} generation requests'
844-
)
848+
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
849+
if scheduled_batch is None:
850+
break
845851

846852
self._pause_requests(scheduled_batch.paused_requests)
847853

@@ -944,47 +950,10 @@ def _executor_loop_overlap(self):
944950
profile_step()
945951
if self.enable_iter_perf_stats:
946952
iter_start_time = time.time()
947-
new_requests = self._fetch_new_requests()
948-
if self.should_stop_processing:
949-
break
950-
951-
if self.kv_cache_transceiver:
952-
self._check_disagg_gen_transfer_status()
953-
954-
if self.enable_iter_perf_stats:
955-
iter_stats = self._get_init_iter_stats(
956-
len(new_requests),
957-
self.executor_request_queue.
958-
get_new_active_requests_queue_latency())
959953

960-
self._pad_attention_dp_dummy_request()
961-
962-
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
963-
)
964-
965-
if self.kv_cache_transceiver:
966-
967-
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
968-
self._prepare_disagg_gen_init(
969-
fitting_disagg_gen_init_requests)
970-
971-
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
972-
logger.warning(
973-
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
974-
)
975-
self.kv_cache_transceiver.check_context_transfer_status(
976-
1)
977-
else:
978-
assert scheduled_batch.batch_size > 0, (
979-
"fail to schedule any pending request, "
980-
"probably run out of resource.")
981-
982-
self.num_scheduled_requests = scheduled_batch.batch_size
983-
logger.debug(
984-
f'has {len(self.active_requests)} active_request, '
985-
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
986-
f'{len(scheduled_batch.generation_requests)} generation requests'
987-
)
954+
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
955+
if scheduled_batch is None:
956+
break
988957

989958
self._pause_requests(scheduled_batch.paused_requests)
990959

0 commit comments

Comments
 (0)