Skip to content

Commit ad49221

Browse files
committed
[None][chore] Part 1: Create PyExecutor from TorchLlmArgs
Signed-off-by: leslie-fang25 <[email protected]>
1 parent e270884 commit ad49221

File tree

5 files changed

+157
-137
lines changed

5 files changed

+157
-137
lines changed

tensorrt_llm/executor/executor.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..bindings import executor as tllm
2222
from ..builder import Engine
2323
from ..disaggregated_params import DisaggregatedParams
24+
from ..llmapi.llm_args import TorchLlmArgs
2425
from ..llmapi.llm_utils import KvCacheRetentionConfig
2526
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
2627
need_spawn_mpi_workers)
@@ -354,7 +355,8 @@ def create(
354355
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
355356
is_llm_executor: Optional[bool] = None,
356357
lora_config: Optional[LoraConfig] = None,
357-
garbage_collection_gen0_threshold: Optional[int] = None,
358+
hf_model_dir: Optional[Path] = None,
359+
llm_args: Optional[TorchLlmArgs] = None,
358360
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
359361
# local imports to avoid cyclic importing
360362
from .proxy import GenerationExecutorProxy
@@ -381,6 +383,8 @@ def create(
381383
"engine": engine,
382384
"executor_config": executor_config,
383385
"batched_logits_processor": batched_logits_processor,
386+
"hf_model_dir": hf_model_dir,
387+
"llm_args": llm_args,
384388
}
385389

386390
if lora_config:
@@ -398,9 +402,7 @@ def create(
398402
model_world_size=model_world_size,
399403
mpi_session=mpi_session,
400404
postproc_worker_config=postproc_worker_config,
401-
is_llm_executor=is_llm_executor,
402-
garbage_collection_gen0_threshold=
403-
garbage_collection_gen0_threshold)
405+
is_llm_executor=is_llm_executor)
404406

405407
# WAR: For the performance of gathering logits, we use single process worker
406408
# for TP1 to avoid the large overhead of IPC.
@@ -411,9 +413,7 @@ def create(
411413
"Using single process worker for TP1, this may hurt streaming generation performance."
412414
)
413415
return GenerationExecutorWorker(**worker_kwargs,
414-
is_llm_executor=is_llm_executor,
415-
garbage_collection_gen0_threshold=
416-
garbage_collection_gen0_threshold)
416+
is_llm_executor=is_llm_executor)
417417

418418
# For single-gpu case:
419419
# Partition the workload to multiple process for streaming performance.
@@ -425,9 +425,7 @@ def create(
425425
model_world_size=model_world_size,
426426
mpi_session=None, # use mpi4py
427427
postproc_worker_config=postproc_worker_config,
428-
is_llm_executor=is_llm_executor,
429-
garbage_collection_gen0_threshold=
430-
garbage_collection_gen0_threshold)
428+
is_llm_executor=is_llm_executor)
431429
else:
432430
ctx = multiprocessing.get_context("spawn")
433431
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
@@ -438,9 +436,7 @@ def create(
438436
model_world_size=model_world_size,
439437
mpi_session=mpi_session,
440438
postproc_worker_config=postproc_worker_config,
441-
is_llm_executor=is_llm_executor,
442-
garbage_collection_gen0_threshold=
443-
garbage_collection_gen0_threshold)
439+
is_llm_executor=is_llm_executor)
444440

445441
def wait_first_completed(
446442
self, futures: List[GenerationResult]

tensorrt_llm/executor/proxy.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
worker_cls: type = GenerationExecutorWorker,
4646
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
4747
is_llm_executor: Optional[bool] = None,
48-
garbage_collection_gen0_threshold: Optional[int] = None,
4948
) -> None:
5049
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
5150
)
@@ -87,14 +86,14 @@ def __init__(
8786

8887
self.model_world_size = model_world_size
8988

90-
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
89+
self.garbage_collection_gen0_threshold = worker_kwargs[
90+
"llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get(
91+
"llm_args", None) is not None else None
9192

9293
worker_kwargs = dict(**worker_kwargs,
9394
worker_queues=self._setup_queues(),
9495
postproc_worker_config=postproc_worker_config,
95-
is_llm_executor=False,
96-
garbage_collection_gen0_threshold=self.
97-
garbage_collection_gen0_threshold)
96+
is_llm_executor=False)
9897

9998
if "log_level" not in worker_kwargs:
10099
worker_kwargs["log_level"] = logger.level

tensorrt_llm/executor/worker.py

Lines changed: 64 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
mpi_comm, mpi_rank, nvtx_range_debug)
1919
from ..bindings import executor as tllm
2020
from ..builder import ConfigEncoder, Engine, EngineConfig
21-
from ..llmapi.llm_args import PybindMirror
21+
from ..llmapi.llm_args import PybindMirror, TorchLlmArgs
2222
from ..llmapi.mpi_session import set_mpi_session_cpp
2323
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
2424
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
@@ -60,7 +60,8 @@ def __init__(
6060
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
6161
is_llm_executor: Optional[bool] = None,
6262
lora_config: Optional[LoraConfig] = None,
63-
garbage_collection_gen0_threshold: Optional[int] = None,
63+
hf_model_dir: Optional[Path] = None,
64+
llm_args: Optional[TorchLlmArgs] = None,
6465
) -> None:
6566
postproc_config = postproc_worker_config or PostprocWorkerConfig()
6667
super().__init__(
@@ -81,29 +82,51 @@ def __init__(
8182
self._await_response_helper = AwaitResponseHelper(
8283
self) # TODO: make it weakref
8384
self._executor_config = executor_config
84-
self._is_pytorch_backend = getattr(self._executor_config, "backend",
85-
None) == "pytorch"
85+
self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch"
86+
self.llm_args = llm_args
8687

8788
if global_mpi_size() > 1:
8889
logger.set_rank(self.global_rank)
8990

9091
if isinstance(engine, list):
9192
engine = engine[self.rank]
9293

93-
if executor_config is None:
94-
executor_config = tllm.ExecutorConfig(1)
94+
def _create_py_executor(comm_ranks, device_ids):
9595

96-
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
97-
processor_batched=batched_logits_processor, replicate=False)
96+
executor_config = llm_args.get_executor_config(hf_model_dir)
97+
# Persist so downstream code (e.g., default max_tokens deduction) has access
98+
self._executor_config = executor_config
9899

99-
def _create_engine():
100-
device_id = self.global_rank % torch.cuda.device_count()
101-
torch.cuda.set_device(device_id)
100+
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
101+
processor_batched=batched_logits_processor, replicate=False)
102+
executor_config.parallel_config = tllm.ParallelConfig(
103+
participant_ids=comm_ranks, device_ids=device_ids)
104+
args = {
105+
"executor_config": executor_config,
106+
"checkpoint_dir": executor_config.hf_model_dir,
107+
}
108+
assert hasattr(
109+
executor_config, "backend"
110+
), "executor_config should be with backend in _create_py_executor"
111+
if executor_config.backend == "pytorch":
112+
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
113+
create_py_executor
114+
create_executor = create_py_executor
115+
args["lora_config"] = lora_config
116+
args[
117+
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
118+
else:
119+
raise ValueError(
120+
f"Unsupported backend config: {executor_config.backend}")
121+
return create_executor(**args)
122+
123+
def _create_engine(comm_ranks, device_ids):
124+
if executor_config is None:
125+
executor_config = tllm.ExecutorConfig(1)
126+
127+
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
128+
processor_batched=batched_logits_processor, replicate=False)
102129

103-
# Make sure C++ executor would use same devices/ranks as py_executor
104-
global_rank = global_mpi_rank()
105-
comm_ranks = mpi_comm().allgather(global_rank)
106-
device_ids = mpi_comm().allgather(device_id)
107130
executor_config.parallel_config = tllm.ParallelConfig(
108131
participant_ids=comm_ranks, device_ids=device_ids)
109132

@@ -122,14 +145,7 @@ def _create_engine():
122145
"executor_config": executor_config,
123146
"checkpoint_dir": executor_config.hf_model_dir,
124147
}
125-
if executor_config.backend == "pytorch":
126-
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
127-
create_py_executor
128-
create_executor = create_py_executor
129-
args["lora_config"] = lora_config
130-
args[
131-
"garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold
132-
elif executor_config.backend == "_autodeploy":
148+
if executor_config.backend == "_autodeploy":
133149
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
134150
create_autodeploy_executor
135151
create_executor = create_autodeploy_executor
@@ -138,7 +154,17 @@ def _create_engine():
138154
f"Unsupported backend config: {executor_config.backend}")
139155
return create_executor(**args)
140156

141-
self.engine = _create_engine()
157+
device_id = self.global_rank % torch.cuda.device_count()
158+
torch.cuda.set_device(device_id)
159+
160+
# Make sure C++ executor would use same devices/ranks as py_executor
161+
global_rank = global_mpi_rank()
162+
comm_ranks = mpi_comm().allgather(global_rank)
163+
device_ids = mpi_comm().allgather(device_id)
164+
165+
self.engine = _create_py_executor(
166+
comm_ranks, device_ids) if llm_args is not None else _create_engine(
167+
comm_ranks, device_ids)
142168

143169
self._lora_manager: Optional[LoraManager] = None
144170
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
@@ -430,14 +456,16 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
430456
context_phase_params = request.disaggregated_params.get_context_phase_params(
431457
)
432458

433-
is_overlap_enabled = self._is_pytorch_backend and not self._executor_config.pytorch_backend_config.disable_overlap_scheduler
434-
if is_overlap_enabled:
435-
is_disaggregated = self.engine.kv_cache_transceiver is not None
436-
if is_disaggregated and (
437-
request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
438-
raise ValueError(
439-
"Context only requests are not supported in pytorch backend when overlap is enabled."
440-
)
459+
if self._is_pytorch_backend:
460+
assert isinstance(self.llm_args, TorchLlmArgs)
461+
if not self.llm_args.disable_overlap_scheduler:
462+
is_disaggregated = self.engine.kv_cache_transceiver is not None
463+
if is_disaggregated and (
464+
request_type
465+
== tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
466+
raise ValueError(
467+
"Context only requests are not supported in pytorch backend when overlap is enabled."
468+
)
441469

442470
assert request.id is not None
443471

@@ -641,7 +669,8 @@ def worker_main(
641669
is_llm_executor: Optional[
642670
bool] = True, # whether it's the main executor instance
643671
lora_config: Optional[LoraConfig] = None,
644-
garbage_collection_gen0_threshold: Optional[int] = None,
672+
hf_model_dir: Optional[Path] = None,
673+
llm_args: Optional[TorchLlmArgs] = None,
645674
) -> None:
646675
mpi_comm().barrier()
647676
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
@@ -768,7 +797,8 @@ def notify_proxy_threads_to_quit():
768797
postproc_worker_config=postproc_worker_config,
769798
is_llm_executor=is_llm_executor,
770799
lora_config=lora_config,
771-
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
800+
hf_model_dir=hf_model_dir,
801+
llm_args=llm_args)
772802
except Exception as e:
773803
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
774804
logger.error(traceback.format_exc())

tensorrt_llm/llmapi/llm.py

Lines changed: 5 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@
3737
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
3838
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
3939
from .mpi_session import MpiPoolSession, external_mpi_comm_available
40-
from .tokenizer import (TokenizerBase, _llguidance_tokenizer_info,
41-
_xgrammar_tokenizer_info)
40+
from .tokenizer import TokenizerBase, _xgrammar_tokenizer_info
4241
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
4342
from .utils import (append_docstring, exception_handler, get_device_count,
4443
print_colored_debug, set_api_status)
@@ -958,90 +957,13 @@ def _build_model(self):
958957
self.tokenizer)
959958
self._tokenizer = self.input_processor.tokenizer
960959

961-
max_batch_size = self.args.max_batch_size
962-
max_num_tokens = self.args.max_num_tokens
963-
max_seq_len = self.args.max_seq_len
964-
965-
kwargs = {}
966-
if self._on_trt_backend:
967-
kwargs[
968-
"batching_type"] = self.args.batching_type or tllm.BatchingType.INFLIGHT
969-
970-
self._executor_config = tllm.ExecutorConfig(
971-
max_beam_width=self.args.max_beam_width,
972-
scheduler_config=PybindMirror.maybe_to_pybind(
973-
self.args.scheduler_config),
974-
max_batch_size=max_batch_size,
975-
max_num_tokens=max_num_tokens,
976-
gather_generation_logits=self.args.gather_generation_logits,
977-
fail_fast_on_attention_window_too_large=getattr(
978-
self.args, 'fail_fast_on_attention_window_too_large', False),
979-
**kwargs)
980-
981-
if self.args.kv_cache_config is not None:
982-
self._executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
983-
self.args.kv_cache_config)
984-
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
985-
# Disable KV cache reuse for deterministic mode
986-
self._executor_config.kv_cache_config.enable_block_reuse = False
987-
self._executor_config.kv_cache_config.enable_partial_reuse = False
988-
if self.args.peft_cache_config is not None:
989-
self._executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
990-
self.args.peft_cache_config)
991-
if self.args.decoding_config is not None:
992-
self._executor_config.decoding_config = self.args.decoding_config
993-
if self.args.guided_decoding_backend == 'xgrammar':
994-
self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig(
995-
backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend.
996-
XGRAMMAR,
997-
**_xgrammar_tokenizer_info(self.tokenizer))
998-
elif self.args.guided_decoding_backend == 'llguidance':
999-
self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig(
1000-
backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend.
1001-
LLGUIDANCE,
1002-
**_llguidance_tokenizer_info(self.tokenizer))
1003-
elif self.args.guided_decoding_backend is not None:
1004-
raise ValueError(
1005-
f"Unsupported guided decoding backend {self.args.guided_decoding_backend}"
1006-
)
1007-
1008-
if self._on_trt_backend:
1009-
self._executor_config.normalize_log_probs = self.args.normalize_log_probs
1010-
self._executor_config.enable_chunked_context = self.args.enable_chunked_prefill
1011-
self._executor_config.max_beam_width = self.args.max_beam_width
1012-
if self.args.cache_transceiver_config is not None:
1013-
self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
1014-
self.args.cache_transceiver_config)
1015-
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
1016-
1017-
spec_config = self.args.speculative_config
1018-
max_batch_size = self._executor_config.max_batch_size
1019-
1020-
if spec_config is not None and spec_config.decoding_type == "AUTO":
1021-
from tensorrt_llm._torch.speculative import suggest_spec_config
1022-
spec_config = suggest_spec_config(max_batch_size)
1023-
1024-
update_executor_config(
1025-
self._executor_config,
1026-
backend=self.args.backend,
1027-
pytorch_backend_config=self.args.get_pytorch_backend_config()
1028-
if self.args.backend in ["pytorch", "_autodeploy"] else None,
1029-
mapping=self.args.parallel_config.to_mapping(),
1030-
speculative_config=spec_config,
1031-
hf_model_dir=self._hf_model_dir,
1032-
max_input_len=self.args.max_input_len,
1033-
max_seq_len=max_seq_len,
1034-
checkpoint_format=None if self.args.backend == "_autodeploy" else
1035-
self.args.checkpoint_format,
1036-
checkpoint_loader=None if self.args.backend == "_autodeploy" else
1037-
self.args.checkpoint_loader)
960+
assert isinstance(self.args, TorchLlmArgs)
1038961

1039962
# TODO: revisit gather_context_logits
1040963
return_logits = self.args.gather_generation_logits
1041-
1042964
self._executor = self._executor_cls.create(
1043965
self._engine_dir,
1044-
executor_config=self._executor_config,
966+
executor_config=None,
1045967
batched_logits_processor=self.args.batched_logits_processor,
1046968
model_world_size=self.args.parallel_config.world_size,
1047969
mpi_session=self.mpi_session,
@@ -1054,8 +976,8 @@ def _build_model(self):
1054976
),
1055977
is_llm_executor=True,
1056978
lora_config=self.args.lora_config,
1057-
garbage_collection_gen0_threshold=self.args.
1058-
garbage_collection_gen0_threshold)
979+
hf_model_dir=self._hf_model_dir,
980+
llm_args=self.args)
1059981

1060982
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
1061983
"""Validate that users don't pass TrtLlmArgs-specific arguments when using PyTorch backend.

0 commit comments

Comments
 (0)