9
9
10
10
import tensorrt_llm
11
11
import tensorrt_llm .bindings .internal .runtime as _tbr
12
- from tensorrt_llm ._torch .pyexecutor .cuda_graph_runner import is_graph_capturing
13
12
from tensorrt_llm .logger import logger
14
13
from tensorrt_llm .mapping import Mapping
15
14
16
15
from ...distributed import AllReduce
17
16
from ...utils import EventType
17
+ from ..multi_stream_utils import do_multi_stream
18
18
19
19
20
20
def _tensor_to_weight (t : torch .Tensor ) -> _tbr .MoeWeight :
@@ -472,7 +472,7 @@ def start_wait_gpu_stage(self):
472
472
assert self .func_called_count ["start_wait_gpu_stage" ] == 0
473
473
self .func_called_count ["start_wait_gpu_stage" ] += 1
474
474
if self .updates_enabled :
475
- if is_graph_capturing ():
475
+ if do_multi_stream ():
476
476
self .event_dict [EventType .Main ].record ()
477
477
with torch .cuda .stream (self .aux_stream ):
478
478
self .event_dict [EventType .Main ].wait ()
@@ -491,7 +491,7 @@ def done_wait_gpu_stage(self):
491
491
assert self .func_called_count ["done_wait_gpu_stage" ] == 0
492
492
self .func_called_count ["done_wait_gpu_stage" ] += 1
493
493
if self .updates_enabled :
494
- if is_graph_capturing ():
494
+ if do_multi_stream ():
495
495
self .event_dict [EventType .MoeBalancer ].wait ()
496
496
497
497
def start_set_cpu_stage (self ):
@@ -502,7 +502,7 @@ def start_set_cpu_stage(self):
502
502
assert self .func_called_count ["start_set_cpu_stage" ] == 0
503
503
self .func_called_count ["start_set_cpu_stage" ] += 1
504
504
if self .updates_enabled :
505
- if is_graph_capturing ():
505
+ if do_multi_stream ():
506
506
self .event_dict [EventType .Main ].record ()
507
507
with torch .cuda .stream (self .aux_stream ):
508
508
self .event_dict [EventType .Main ].wait ()
@@ -522,7 +522,7 @@ def done_set_cpu_stage(self):
522
522
self .func_called_count [name ] = 0
523
523
self .statistic_flag_tensor = None
524
524
if self .updates_enabled :
525
- if is_graph_capturing ():
525
+ if do_multi_stream ():
526
526
self .event_dict [EventType .MoeBalancer ].wait ()
527
527
528
528
def update_local_statistic (self , local_raw_expert_ids : torch .Tensor ,
@@ -544,7 +544,7 @@ def update_local_statistic(self, local_raw_expert_ids: torch.Tensor,
544
544
(self .expert_count , ),
545
545
dtype = torch .int32 ,
546
546
device = torch .device ('cuda' ))
547
- if is_graph_capturing ():
547
+ if do_multi_stream ():
548
548
self .event_dict [EventType .Main ].record ()
549
549
with torch .cuda .stream (self .aux_stream ):
550
550
self .event_dict [EventType .Main ].wait ()
@@ -569,7 +569,7 @@ def get_local_statistic_tensor(self) -> Optional[torch.Tensor]:
569
569
assert self .func_called_count ["update_local_statistic" ] > 0
570
570
self .func_called_count ["get_local_statistic_tensor" ] += 1
571
571
if self .updates_enabled :
572
- if is_graph_capturing ():
572
+ if do_multi_stream ():
573
573
with torch .cuda .stream (self .aux_stream ):
574
574
self .event_dict [EventType .MoeBalancer ].record ()
575
575
self .event_dict [EventType .MoeBalancer ].wait ()
@@ -598,7 +598,7 @@ def _update_statistic():
598
598
self .single_layer_load_balancer_ptr )
599
599
600
600
if self .updates_enabled :
601
- if is_graph_capturing ():
601
+ if do_multi_stream ():
602
602
self .event_dict [EventType .Main ].record ()
603
603
with torch .cuda .stream (self .aux_stream ):
604
604
self .event_dict [EventType .Main ].wait ()
@@ -636,7 +636,7 @@ def _update_statistic():
636
636
if self .updates_enabled :
637
637
self .update_local_statistic (local_raw_expert_ids , is_first_stage ,
638
638
is_last_stage )
639
- if is_graph_capturing ():
639
+ if do_multi_stream ():
640
640
with torch .cuda .stream (self .aux_stream ):
641
641
_update_statistic ()
642
642
else :
@@ -660,7 +660,7 @@ def update_statistic_with_global_ids(self,
660
660
assert self .func_called_count ["update_statistic_with_local_ids" ] == 0
661
661
self .func_called_count ["update_statistic_with_global_ids" ] += 1
662
662
if self .updates_enabled :
663
- if is_graph_capturing ():
663
+ if do_multi_stream ():
664
664
self .event_dict [EventType .Main ].record ()
665
665
with torch .cuda .stream (self .aux_stream ):
666
666
self .event_dict [EventType .Main ].wait ()
@@ -851,8 +851,8 @@ def set_warm_up_iter_count(self, iter_count: int):
851
851
"""
852
852
self .load_balancer_impl .set_warm_up_iter_count (iter_count )
853
853
854
- def set_next_iter_info (self , enable_statistic : Optional [bool ],
855
- enable_update_weights : Optional [bool ]):
854
+ def set_iter_info (self , enable_statistic : Optional [bool ],
855
+ enable_update_weights : Optional [bool ]):
856
856
if enable_statistic is not None :
857
857
self .enable_statistic = enable_statistic
858
858
if enable_update_weights is not None :
@@ -998,8 +998,8 @@ def __enter__(self):
998
998
"""
999
999
if self .moe_load_balancer is not None and not self .moe_load_balancer .is_static_routing (
1000
1000
):
1001
- self .moe_load_balancer .set_next_iter_info (self .enable_statistic ,
1002
- self .enable_updates )
1001
+ self .moe_load_balancer .set_iter_info (self .enable_statistic ,
1002
+ self .enable_updates )
1003
1003
self .moe_load_balancer .start_iter ()
1004
1004
return self
1005
1005
0 commit comments