18
18
mpi_comm , mpi_rank , nvtx_range_debug )
19
19
from ..bindings import executor as tllm
20
20
from ..builder import ConfigEncoder , Engine , EngineConfig
21
- from ..llmapi .llm_args import PybindMirror
21
+ from ..llmapi .llm_args import PybindMirror , TorchLlmArgs
22
22
from ..llmapi .mpi_session import set_mpi_session_cpp
23
23
from ..llmapi .tracer import VizTracer , global_tracer , set_global_tracer
24
24
from ..llmapi .utils import (AsyncQueue , ManagedThread , _SyncQueue ,
@@ -60,7 +60,8 @@ def __init__(
60
60
postproc_worker_config : Optional [PostprocWorkerConfig ] = None ,
61
61
is_llm_executor : Optional [bool ] = None ,
62
62
lora_config : Optional [LoraConfig ] = None ,
63
- garbage_collection_gen0_threshold : Optional [int ] = None ,
63
+ hf_model_dir : Optional [Path ] = None ,
64
+ llm_args : Optional [TorchLlmArgs ] = None ,
64
65
) -> None :
65
66
postproc_config = postproc_worker_config or PostprocWorkerConfig ()
66
67
super ().__init__ (
@@ -81,29 +82,51 @@ def __init__(
81
82
self ._await_response_helper = AwaitResponseHelper (
82
83
self ) # TODO: make it weakref
83
84
self ._executor_config = executor_config
84
- self ._is_pytorch_backend = getattr ( self . _executor_config , "backend" ,
85
- None ) == "pytorch"
85
+ self ._is_pytorch_backend = llm_args is not None and llm_args . backend == "pytorch"
86
+ self . llm_args = llm_args
86
87
87
88
if global_mpi_size () > 1 :
88
89
logger .set_rank (self .global_rank )
89
90
90
91
if isinstance (engine , list ):
91
92
engine = engine [self .rank ]
92
93
93
- if executor_config is None :
94
- executor_config = tllm .ExecutorConfig (1 )
94
+ def _create_py_executor (comm_ranks , device_ids ):
95
95
96
- executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
97
- processor_batched = batched_logits_processor , replicate = False )
96
+ executor_config = llm_args .get_executor_config (hf_model_dir )
97
+ # Persist so downstream code (e.g., default max_tokens deduction) has access
98
+ self ._executor_config = executor_config
98
99
99
- def _create_engine ():
100
- device_id = self .global_rank % torch .cuda .device_count ()
101
- torch .cuda .set_device (device_id )
100
+ executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
101
+ processor_batched = batched_logits_processor , replicate = False )
102
+ executor_config .parallel_config = tllm .ParallelConfig (
103
+ participant_ids = comm_ranks , device_ids = device_ids )
104
+ args = {
105
+ "executor_config" : executor_config ,
106
+ "checkpoint_dir" : executor_config .hf_model_dir ,
107
+ }
108
+ assert hasattr (
109
+ executor_config , "backend"
110
+ ), "executor_config should be with backend in _create_py_executor"
111
+ if executor_config .backend == "pytorch" :
112
+ from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
113
+ create_py_executor
114
+ create_executor = create_py_executor
115
+ args ["lora_config" ] = lora_config
116
+ args [
117
+ "garbage_collection_gen0_threshold" ] = llm_args .garbage_collection_gen0_threshold
118
+ else :
119
+ raise ValueError (
120
+ f"Unsupported backend config: { executor_config .backend } " )
121
+ return create_executor (** args )
122
+
123
+ def _create_engine (comm_ranks , device_ids ):
124
+ if executor_config is None :
125
+ executor_config = tllm .ExecutorConfig (1 )
126
+
127
+ executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
128
+ processor_batched = batched_logits_processor , replicate = False )
102
129
103
- # Make sure C++ executor would use same devices/ranks as py_executor
104
- global_rank = global_mpi_rank ()
105
- comm_ranks = mpi_comm ().allgather (global_rank )
106
- device_ids = mpi_comm ().allgather (device_id )
107
130
executor_config .parallel_config = tllm .ParallelConfig (
108
131
participant_ids = comm_ranks , device_ids = device_ids )
109
132
@@ -122,14 +145,7 @@ def _create_engine():
122
145
"executor_config" : executor_config ,
123
146
"checkpoint_dir" : executor_config .hf_model_dir ,
124
147
}
125
- if executor_config .backend == "pytorch" :
126
- from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
127
- create_py_executor
128
- create_executor = create_py_executor
129
- args ["lora_config" ] = lora_config
130
- args [
131
- "garbage_collection_gen0_threshold" ] = garbage_collection_gen0_threshold
132
- elif executor_config .backend == "_autodeploy" :
148
+ if executor_config .backend == "_autodeploy" :
133
149
from tensorrt_llm ._torch .auto_deploy .shim .ad_executor import \
134
150
create_autodeploy_executor
135
151
create_executor = create_autodeploy_executor
@@ -138,7 +154,17 @@ def _create_engine():
138
154
f"Unsupported backend config: { executor_config .backend } " )
139
155
return create_executor (** args )
140
156
141
- self .engine = _create_engine ()
157
+ device_id = self .global_rank % torch .cuda .device_count ()
158
+ torch .cuda .set_device (device_id )
159
+
160
+ # Make sure C++ executor would use same devices/ranks as py_executor
161
+ global_rank = global_mpi_rank ()
162
+ comm_ranks = mpi_comm ().allgather (global_rank )
163
+ device_ids = mpi_comm ().allgather (device_id )
164
+
165
+ self .engine = _create_py_executor (
166
+ comm_ranks , device_ids ) if llm_args is not None else _create_engine (
167
+ comm_ranks , device_ids )
142
168
143
169
self ._lora_manager : Optional [LoraManager ] = None
144
170
self ._prompt_adapter_manager : Optional [PromptAdapterManager ] = None
@@ -430,14 +456,16 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
430
456
context_phase_params = request .disaggregated_params .get_context_phase_params (
431
457
)
432
458
433
- is_overlap_enabled = self ._is_pytorch_backend and not self ._executor_config .pytorch_backend_config .disable_overlap_scheduler
434
- if is_overlap_enabled :
435
- is_disaggregated = self .engine .kv_cache_transceiver is not None
436
- if is_disaggregated and (
437
- request_type == tllm .RequestType .REQUEST_TYPE_CONTEXT_ONLY ):
438
- raise ValueError (
439
- "Context only requests are not supported in pytorch backend when overlap is enabled."
440
- )
459
+ if self ._is_pytorch_backend :
460
+ assert isinstance (self .llm_args , TorchLlmArgs )
461
+ if not self .llm_args .disable_overlap_scheduler :
462
+ is_disaggregated = self .engine .kv_cache_transceiver is not None
463
+ if is_disaggregated and (
464
+ request_type
465
+ == tllm .RequestType .REQUEST_TYPE_CONTEXT_ONLY ):
466
+ raise ValueError (
467
+ "Context only requests are not supported in pytorch backend when overlap is enabled."
468
+ )
441
469
442
470
assert request .id is not None
443
471
@@ -641,7 +669,8 @@ def worker_main(
641
669
is_llm_executor : Optional [
642
670
bool ] = True , # whether it's the main executor instance
643
671
lora_config : Optional [LoraConfig ] = None ,
644
- garbage_collection_gen0_threshold : Optional [int ] = None ,
672
+ hf_model_dir : Optional [Path ] = None ,
673
+ llm_args : Optional [TorchLlmArgs ] = None ,
645
674
) -> None :
646
675
mpi_comm ().barrier ()
647
676
print_colored_debug (f"Worker { mpi_rank ()} entering worker_main...\n " ,
@@ -768,7 +797,8 @@ def notify_proxy_threads_to_quit():
768
797
postproc_worker_config = postproc_worker_config ,
769
798
is_llm_executor = is_llm_executor ,
770
799
lora_config = lora_config ,
771
- garbage_collection_gen0_threshold = garbage_collection_gen0_threshold )
800
+ hf_model_dir = hf_model_dir ,
801
+ llm_args = llm_args )
772
802
except Exception as e :
773
803
logger .error (f"Failed to initialize executor on rank { mpi_rank ()} : { e } " )
774
804
logger .error (traceback .format_exc ())
0 commit comments