-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Description
Motivation
Currently, we have a pure Python based PyExecutor class, which handles the main event loop. It provides good flexibility to support features like overlap scheduler and attention data parallelism quickly.
Inside it, we still use lots of pybind classes, including LlmRequest, KVCacheManager and Scheduler.
To improve flexibility further, we want to migrate more components from C++ to Python.
Analysis
LlmRequest
, KVCacheManager
and Scheduler
are coupled tightly. There are many state tensors maintained in LlmRequest
, including output_tokens/chunk_size/state/etc. Both KVCacheManager
and Scheduler
read and write members of LlmRequest
internally.
We tried to implement a pure Python CapacityScheduler before, but it introduces too much pybind calls of LlmRequest. We observed pybind calls are about 2X-3X slower than pure Python calls. So, we don’t enable this pure Python CapacitySchedule due to its big host overhead.
Considering the complexity of KVCacheManager
, we decide to re-implement LlmRequest
and Scheduler
in pure Python as the first step. At the same time, we will remove LlmRequest
from the KVCacheManager
interface.
Proposed Solution
- Introduce a new flag
enable_pure_python_scheduler
inPyTorchConfig
to enable pure Python based scheduler
Considering it takes some time to migrate all the components and do performance tuning, pure Python based scheduler will be hidden from users at the begining.
- Refactor
LlmRequest
to support maintaining all state tensors in Python side
All state tensors will be "duplicated" in Python side first. The member functions of LlmRequest
will be dispatched to different paths depending on enable_pure_python_scheduler
flag.
class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
def __init__(self, *args, enable_pure_python_scheduler, **kwargs):
super().__init__(*args, **kwargs)
self.enable_pure_python_scheduler
self.py_request_id = self.request_id
self.py_state = self.state
self.py_tokens = [[] for i in range(self.sampling_config.beam_width)]
def get_tokens(self, beam_idx: int):
if self.enable_pure_python_scheduler:
# dispatch to pure Python path
return self.py_tokens[beam_idx]
else:
# dispatch to pybind path
return self.get_tokens(beam_idx)
-
Implement pure Python based
Scheduler
-
Decouple
LlmRequest
fromKVCacheManager
interface
Future Works
We need some time to do performance tuning. After that, let's evaluate the possibility to enable pure Python based Scheduler
by default.