Skip to content

Commit a15af87

Browse files
authored
[None][refactor] Refactor Torch Compile Backend, MoeLoadBalancer and warmup Logic (#6615)
Signed-off-by: yizhang-nv <[email protected]> Signed-off-by: Yi Zhang <[email protected]>
1 parent 71e28ea commit a15af87

File tree

14 files changed

+271
-185
lines changed

14 files changed

+271
-185
lines changed

tensorrt_llm/_torch/compilation/backend.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
enable_inductor=True,
3838
enable_userbuffers=False,
3939
enable_piecewise_cuda_graph: bool = False,
40-
cuda_graph_batch_sizes: Optional[List[int]] = None,
40+
capture_num_tokens: Optional[List[int]] = None,
4141
max_num_streams: int = 1,
4242
) -> None:
4343
super().__init__()
@@ -48,14 +48,12 @@ def __init__(
4848
self.custom_passes = Backend.get_custom_pass(enable_userbuffers)
4949
self.rank = tensorrt_llm.mpi_rank()
5050
self.enable_inductor = enable_inductor
51-
self.cuda_graph_batch_sizes = (cuda_graph_batch_sizes
52-
if cuda_graph_batch_sizes is not None
53-
else [])
51+
self.capture_num_tokens = capture_num_tokens or []
5452
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
5553
self.no_optimization = False
5654
# We only need to create aux streams.
5755
self.aux_streams = Backend.Streams(
58-
[torch.cuda.Stream() for i in range(max_num_streams - 1)])
56+
[torch.cuda.Stream() for _ in range(max_num_streams - 1)])
5957
self.events = Backend.Events()
6058
inductor_config.enable_auto_functionalized_v2 = False
6159

@@ -125,7 +123,7 @@ def optimize(
125123
example_inputs,
126124
self.enable_inductor,
127125
self.input_num_tokens,
128-
self.cuda_graph_batch_sizes,
126+
self.capture_num_tokens,
129127
self._graph_pool_handle,
130128
len(self.aux_streams) + 1,
131129
)

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag,
1515
make_weak_ref)
1616
from .multi_stream.auto_multi_stream import multi_stream_schedule
17-
from .utils import (get_enable_piecewise_cuda_graph_capture_flag,
18-
is_call_function)
17+
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
1918

2019

2120
class PiecewiseInterpreter(Interpreter):
@@ -25,7 +24,7 @@ def __init__(
2524
module: GraphModule,
2625
enable_inductor: bool,
2726
compile_time_num_tokens: Union[int | torch.SymInt],
28-
cuda_graph_batch_sizes: list[int],
27+
capture_num_tokens: list[int],
2928
exclude_modules_id: list[int],
3029
graph_pool_handle: tuple[int, int],
3130
garbage_collect_values: bool = True,
@@ -37,7 +36,7 @@ def __init__(
3736
self.fake_mode = detect_fake_mode()
3837

3938
self.compile_time_num_tokens = compile_time_num_tokens
40-
self.cuda_graph_batch_sizes = cuda_graph_batch_sizes
39+
self.capture_num_tokens = capture_num_tokens
4140
self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id]
4241
self.graph_pool_handle = graph_pool_handle
4342
self.enable_inductor = enable_inductor
@@ -86,7 +85,7 @@ def call_module(self, target, args, kwargs):
8685
target,
8786
self.compile_time_num_tokens,
8887
runtime_num_tokens_idx,
89-
self.cuda_graph_batch_sizes,
88+
self.capture_num_tokens,
9089
self.graph_pool_handle,
9190
compile_fx(submod, args) if self.enable_inductor else submod,
9291
self.enable_inductor,
@@ -120,7 +119,7 @@ def __init__(
120119
name: str,
121120
compile_time_num_tokens: Union[int | torch.SymInt],
122121
runtime_num_tokens_idx: tuple[int],
123-
cuda_graph_batch_sizes: List[int],
122+
capture_num_tokens: List[int],
124123
graph_pool_handle,
125124
default_callable: Callable,
126125
enable_inductor: bool,
@@ -139,9 +138,9 @@ def __init__(
139138

140139
self.entries: dict[int, Entry] = {}
141140

142-
for bs in cuda_graph_batch_sizes:
143-
self.entries[bs] = Entry(
144-
bs,
141+
for num_tokens in capture_num_tokens:
142+
self.entries[num_tokens] = Entry(
143+
num_tokens,
145144
enable_inductor=self.enable_inductor,
146145
callable=default_callable,
147146
)
@@ -167,7 +166,7 @@ def __call__(self, *args):
167166

168167
if entry.cuda_graph is None:
169168

170-
if not get_enable_piecewise_cuda_graph_capture_flag():
169+
if not get_capture_piecewise_cuda_graph_flag():
171170
return entry.callable(*args)
172171

173172
if entry.warmup_count < 3:
@@ -228,7 +227,7 @@ def piecewise_optimizer(
228227
example_inputs: List[torch.Tensor],
229228
enable_inductor: bool,
230229
input_num_tokens: Union[int | torch.SymInt],
231-
cuda_graph_batch_sizes: Sequence[int],
230+
capture_num_tokens: Sequence[int],
232231
graph_pool_handle: tuple[int, int],
233232
max_num_streams: int = 1,
234233
) -> tuple[GraphModule, int]:
@@ -269,7 +268,7 @@ def piecewise_optimizer(
269268
gm,
270269
enable_inductor,
271270
input_num_tokens,
272-
cuda_graph_batch_sizes,
271+
capture_num_tokens,
273272
exclude_modules_id,
274273
graph_pool_handle,
275274
max_num_streams=max_num_streams,

tensorrt_llm/_torch/compilation/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
from typing import Callable, List, Union
23

34
import torch
@@ -33,16 +34,26 @@ def is_call_function(node: Node, target: Union[List[Callable], Callable]):
3334
_enable_piecewise_cuda_graph_capture = False
3435

3536

36-
def set_enable_piecewise_cuda_graph_capture_flag(enable: bool):
37+
def set_capture_piecewise_cuda_graph_flag(enable: bool):
3738
global _enable_piecewise_cuda_graph_capture
3839
_enable_piecewise_cuda_graph_capture = enable
3940

4041

41-
def get_enable_piecewise_cuda_graph_capture_flag() -> bool:
42+
def get_capture_piecewise_cuda_graph_flag() -> bool:
4243
global _enable_piecewise_cuda_graph_capture
4344
return _enable_piecewise_cuda_graph_capture
4445

4546

47+
@contextlib.contextmanager
48+
def capture_piecewise_cuda_graph(enable: bool):
49+
prev_enable = get_capture_piecewise_cuda_graph_flag()
50+
set_capture_piecewise_cuda_graph_flag(enable)
51+
try:
52+
yield
53+
finally:
54+
set_capture_piecewise_cuda_graph_flag(prev_enable)
55+
56+
4657
def inplace_info():
4758
inplace_map = {
4859
torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default: {

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
1010
OptimizationProfile, TunableRunner, TuningConfig)
11+
from ..modules.multi_stream_utils import do_multi_stream
1112
from ..utils import (fp4_scale_infer_shape,
1213
get_last_power_of_2_num_tokens_buckets,
1314
last_positive_power_of_2)
@@ -925,25 +926,33 @@ def get_stream(stream_id: int):
925926

926927
@torch.library.custom_op("trtllm::set_stream", mutates_args=())
927928
def set_stream(stream_id: int) -> None:
929+
if not do_multi_stream():
930+
return
928931
stream = get_stream(stream_id)
929932
assert stream is not None
930933
torch.cuda.set_stream(stream)
931934

932935

933936
@torch.library.custom_op("trtllm::record_event", mutates_args=())
934937
def record_event(event_idx: int) -> None:
938+
if not do_multi_stream():
939+
return
935940
event = get_event(event_idx)
936941
event.record()
937942

938943

939944
@torch.library.custom_op("trtllm::wait_event", mutates_args=())
940945
def wait_event(event_idx: int) -> None:
946+
if not do_multi_stream():
947+
return
941948
event = get_event(event_idx)
942949
event.wait()
943950

944951

945952
@torch.library.custom_op("trtllm::record_stream", mutates_args=())
946953
def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
954+
if not do_multi_stream():
955+
return
947956
stream = get_stream(stream_id)
948957
assert stream is not None
949958
tensor.record_stream(stream)

tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010
import tensorrt_llm
1111
import tensorrt_llm.bindings.internal.runtime as _tbr
12-
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import is_graph_capturing
1312
from tensorrt_llm.logger import logger
1413
from tensorrt_llm.mapping import Mapping
1514

1615
from ...distributed import AllReduce
1716
from ...utils import EventType
17+
from ..multi_stream_utils import do_multi_stream
1818

1919

2020
def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
@@ -472,7 +472,7 @@ def start_wait_gpu_stage(self):
472472
assert self.func_called_count["start_wait_gpu_stage"] == 0
473473
self.func_called_count["start_wait_gpu_stage"] += 1
474474
if self.updates_enabled:
475-
if is_graph_capturing():
475+
if do_multi_stream():
476476
self.event_dict[EventType.Main].record()
477477
with torch.cuda.stream(self.aux_stream):
478478
self.event_dict[EventType.Main].wait()
@@ -491,7 +491,7 @@ def done_wait_gpu_stage(self):
491491
assert self.func_called_count["done_wait_gpu_stage"] == 0
492492
self.func_called_count["done_wait_gpu_stage"] += 1
493493
if self.updates_enabled:
494-
if is_graph_capturing():
494+
if do_multi_stream():
495495
self.event_dict[EventType.MoeBalancer].wait()
496496

497497
def start_set_cpu_stage(self):
@@ -502,7 +502,7 @@ def start_set_cpu_stage(self):
502502
assert self.func_called_count["start_set_cpu_stage"] == 0
503503
self.func_called_count["start_set_cpu_stage"] += 1
504504
if self.updates_enabled:
505-
if is_graph_capturing():
505+
if do_multi_stream():
506506
self.event_dict[EventType.Main].record()
507507
with torch.cuda.stream(self.aux_stream):
508508
self.event_dict[EventType.Main].wait()
@@ -522,7 +522,7 @@ def done_set_cpu_stage(self):
522522
self.func_called_count[name] = 0
523523
self.statistic_flag_tensor = None
524524
if self.updates_enabled:
525-
if is_graph_capturing():
525+
if do_multi_stream():
526526
self.event_dict[EventType.MoeBalancer].wait()
527527

528528
def update_local_statistic(self, local_raw_expert_ids: torch.Tensor,
@@ -544,7 +544,7 @@ def update_local_statistic(self, local_raw_expert_ids: torch.Tensor,
544544
(self.expert_count, ),
545545
dtype=torch.int32,
546546
device=torch.device('cuda'))
547-
if is_graph_capturing():
547+
if do_multi_stream():
548548
self.event_dict[EventType.Main].record()
549549
with torch.cuda.stream(self.aux_stream):
550550
self.event_dict[EventType.Main].wait()
@@ -569,7 +569,7 @@ def get_local_statistic_tensor(self) -> Optional[torch.Tensor]:
569569
assert self.func_called_count["update_local_statistic"] > 0
570570
self.func_called_count["get_local_statistic_tensor"] += 1
571571
if self.updates_enabled:
572-
if is_graph_capturing():
572+
if do_multi_stream():
573573
with torch.cuda.stream(self.aux_stream):
574574
self.event_dict[EventType.MoeBalancer].record()
575575
self.event_dict[EventType.MoeBalancer].wait()
@@ -598,7 +598,7 @@ def _update_statistic():
598598
self.single_layer_load_balancer_ptr)
599599

600600
if self.updates_enabled:
601-
if is_graph_capturing():
601+
if do_multi_stream():
602602
self.event_dict[EventType.Main].record()
603603
with torch.cuda.stream(self.aux_stream):
604604
self.event_dict[EventType.Main].wait()
@@ -636,7 +636,7 @@ def _update_statistic():
636636
if self.updates_enabled:
637637
self.update_local_statistic(local_raw_expert_ids, is_first_stage,
638638
is_last_stage)
639-
if is_graph_capturing():
639+
if do_multi_stream():
640640
with torch.cuda.stream(self.aux_stream):
641641
_update_statistic()
642642
else:
@@ -660,7 +660,7 @@ def update_statistic_with_global_ids(self,
660660
assert self.func_called_count["update_statistic_with_local_ids"] == 0
661661
self.func_called_count["update_statistic_with_global_ids"] += 1
662662
if self.updates_enabled:
663-
if is_graph_capturing():
663+
if do_multi_stream():
664664
self.event_dict[EventType.Main].record()
665665
with torch.cuda.stream(self.aux_stream):
666666
self.event_dict[EventType.Main].wait()
@@ -851,8 +851,8 @@ def set_warm_up_iter_count(self, iter_count: int):
851851
"""
852852
self.load_balancer_impl.set_warm_up_iter_count(iter_count)
853853

854-
def set_next_iter_info(self, enable_statistic: Optional[bool],
855-
enable_update_weights: Optional[bool]):
854+
def set_iter_info(self, enable_statistic: Optional[bool],
855+
enable_update_weights: Optional[bool]):
856856
if enable_statistic is not None:
857857
self.enable_statistic = enable_statistic
858858
if enable_update_weights is not None:
@@ -998,8 +998,8 @@ def __enter__(self):
998998
"""
999999
if self.moe_load_balancer is not None and not self.moe_load_balancer.is_static_routing(
10001000
):
1001-
self.moe_load_balancer.set_next_iter_info(self.enable_statistic,
1002-
self.enable_updates)
1001+
self.moe_load_balancer.set_iter_info(self.enable_statistic,
1002+
self.enable_updates)
10031003
self.moe_load_balancer.start_iter()
10041004
return self
10051005

tensorrt_llm/_torch/modules/multi_stream_utils.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,35 @@
1+
import threading
2+
from contextlib import contextmanager
13
from typing import Any, Callable, Optional
24

35
import torch
46

5-
from ..pyexecutor.cuda_graph_runner import is_graph_capturing
7+
8+
class do_multi_stream_local(threading.local):
9+
10+
def __init__(self):
11+
self.do_multi_stream = False
12+
13+
14+
_local = do_multi_stream_local()
15+
16+
17+
def set_do_multi_stream(enable: bool):
18+
_local.do_multi_stream = enable
19+
20+
21+
def do_multi_stream() -> bool:
22+
return _local.do_multi_stream
23+
24+
25+
@contextmanager
26+
def with_multi_stream(enable: bool):
27+
prev_do_multi_stream = _local.do_multi_stream
28+
set_do_multi_stream(enable)
29+
try:
30+
yield
31+
finally:
32+
set_do_multi_stream(prev_do_multi_stream)
633

734

835
def maybe_execute_in_parallel(
@@ -30,9 +57,9 @@ def maybe_execute_in_parallel(
3057
tuple[Any, Any]: the return values of fn0() and fn1()
3158
"""
3259

33-
do_multi_stream = is_graph_capturing() and aux_stream is not None
60+
multi_stream = do_multi_stream() and aux_stream is not None
3461

35-
if do_multi_stream:
62+
if multi_stream:
3663
event0.record()
3764
result0 = fn0()
3865

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None:
242242
torch_used_bytes = torch.cuda.memory_stats(
243243
)["allocated_bytes.all.current"]
244244
finally:
245-
py_executor.shutdown()
246245
py_executor.is_warmup = False
246+
py_executor.shutdown()
247247
py_executor.enable_iter_perf_stats = origin_iter_stats
248248
py_executor.set_gather_responses(False)
249249

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class PyTorchConfig:
7979
torch_compile_fullgraph: bool = True
8080
torch_compile_inductor_enabled: bool = False
8181
torch_compile_piecewise_cuda_graph: bool = False
82+
torch_compile_piecewise_cuda_graph_num_tokens: Optional[List[int]] = None
8283
# When torch compile is enabled, userbuffers is enabled by default
8384
torch_compile_enable_userbuffers: bool = True
8485
torch_compile_max_num_streams: int = 1

0 commit comments

Comments
 (0)