@@ -308,13 +308,20 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
308
308
def can_use_alltoall (self , all_rank_num_tokens , all_rank_max_num_tokens ):
309
309
# Disable alltoall when chunking is used
310
310
if self .calculate_num_chunks (all_rank_num_tokens ) > 1 :
311
+ print (
312
+ f"can not use alltoall due to chunking { self .calculate_num_chunks (all_rank_num_tokens )} "
313
+ )
311
314
return False
312
315
313
316
# For DeepEPLowLatency, check if tokens exceed the threshold
314
317
if (self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency
315
318
and all_rank_max_num_tokens > self .deep_ep_max_num_tokens ):
319
+ print (
320
+ f"can not use alltoall due to deep_ep_max_num_tokens { all_rank_max_num_tokens } > { self .deep_ep_max_num_tokens } "
321
+ )
316
322
return False
317
323
324
+ print (f"all to all type { self .alltoall_method_type } " )
318
325
return self .enable_alltoall
319
326
320
327
def _get_quant_method (self ):
@@ -323,9 +330,18 @@ def _get_quant_method(self):
323
330
if self .quant_config .layer_quant_mode .has_fp8_qdq ():
324
331
return FP8QDQFusedMoEMethod ()
325
332
elif self .quant_config .layer_quant_mode .has_fp8_block_scales ():
333
+ print (
334
+ f"wide_ep _get_quant_method: get_sm_version()={ get_sm_version ()} "
335
+ )
326
336
if get_sm_version () == 100 :
337
+ print (
338
+ f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
339
+ )
327
340
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ()
328
341
else :
342
+ print (
343
+ f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
344
+ )
329
345
return DeepSeekFP8BlockScalesFusedMoEMethod ()
330
346
elif self .quant_config .layer_quant_mode .has_nvfp4 ():
331
347
return NVFP4CutlassFusedMoEMethod ()
@@ -399,6 +415,10 @@ def forward_chunk(
399
415
400
416
is_first_call , is_last_call = repeating_info
401
417
418
+ print (
419
+ f"wide_ep forward_chunk: layer_load_balancer={ self .layer_load_balancer } , is_first_call={ is_first_call } , is_last_call={ is_last_call } "
420
+ )
421
+
402
422
if self .layer_load_balancer and is_first_call :
403
423
self .layer_load_balancer .start_wait_gpu_stage ()
404
424
@@ -475,7 +495,7 @@ def forward_chunk(
475
495
self .dummy_allreduce ()
476
496
token_count = x .shape [0 ]
477
497
alltoall_info = None
478
- if is_last_call :
498
+ if self . layer_load_balancer and is_last_call :
479
499
loadbalancer_local_statistic_info = self .layer_load_balancer .get_local_statistic_tensor (
480
500
)
481
501
else :
@@ -650,35 +670,7 @@ def forward_chunk(
650
670
)
651
671
652
672
# Original fused_moe call (preserved as reference)
653
- final_hidden_states = torch .ops .trtllm .fused_moe (
654
- x ,
655
- token_selected_slots ,
656
- token_final_scales ,
657
- w3_w1_weight .view (weight_dtype ),
658
- None , # w3_w1_bias
659
- w2_weight .view (weight_dtype ),
660
- None , # w2_bias
661
- output_dtype ,
662
- quant_scales = quant_scales ,
663
- input_sf = x_sf ,
664
- swizzled_input_sf = False ,
665
- tp_size = self .tp_size ,
666
- tp_rank = self .tp_rank ,
667
- ep_size = ep_size ,
668
- ep_rank = ep_rank ,
669
- cluster_size = cluster_size ,
670
- cluster_rank = cluster_rank ,
671
- enable_alltoall = use_all_to_all ,
672
- use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
673
- use_w4_group_scaling = use_w4_group_scaling ,
674
- min_latency_mode = False ,
675
- tune_max_num_tokens = self .tune_max_num_tokens ,
676
- tuner_num_tokens = tuner_num_tokens ,
677
- tuner_top_k = tuner_top_k ,
678
- )
679
-
680
- # Use the selected backend to compute MoE with the same parameters as fused_moe
681
- # final_hidden_states = self.moe_backend.run_moe(
673
+ # final_hidden_states = torch.ops.trtllm.fused_moe(
682
674
# x,
683
675
# token_selected_slots,
684
676
# token_final_scales,
@@ -703,9 +695,38 @@ def forward_chunk(
703
695
# tune_max_num_tokens=self.tune_max_num_tokens,
704
696
# tuner_num_tokens=tuner_num_tokens,
705
697
# tuner_top_k=tuner_top_k,
706
- # module=self, # Additional parameter for backend to access module properties
707
698
# )
708
699
700
+ # Use the selected backend to compute MoE with the same parameters as fused_moe
701
+ final_hidden_states = self .moe_backend .run_moe (
702
+ x ,
703
+ token_selected_slots ,
704
+ token_final_scales ,
705
+ w3_w1_weight .view (weight_dtype ),
706
+ None , # w3_w1_bias
707
+ w2_weight .view (weight_dtype ),
708
+ None , # w2_bias
709
+ output_dtype ,
710
+ quant_scales = quant_scales ,
711
+ input_sf = x_sf ,
712
+ swizzled_input_sf = False ,
713
+ tp_size = self .tp_size ,
714
+ tp_rank = self .tp_rank ,
715
+ ep_size = ep_size ,
716
+ ep_rank = ep_rank ,
717
+ cluster_size = cluster_size ,
718
+ cluster_rank = cluster_rank ,
719
+ enable_alltoall = use_all_to_all ,
720
+ use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
721
+ use_w4_group_scaling = use_w4_group_scaling ,
722
+ min_latency_mode = False ,
723
+ tune_max_num_tokens = self .tune_max_num_tokens ,
724
+ tuner_num_tokens = tuner_num_tokens ,
725
+ tuner_top_k = tuner_top_k ,
726
+ module =
727
+ self , # Additional parameter for backend to access module properties
728
+ )
729
+
709
730
if self .layer_load_balancer and is_last_call :
710
731
self .layer_load_balancer .start_set_cpu_stage ()
711
732
@@ -784,6 +805,10 @@ def forward(
784
805
all_rank_max_num_tokens = all_rank_max_num_tokens ,
785
806
use_dp_padding = use_dp_padding ,
786
807
repeating_info = (is_first_call , is_last_call ))
808
+ # 一行打印所有信息
809
+ print (
810
+ f"xxi x.shape: { getattr (x , 'shape' , None )} , use_all_to_all: { use_all_to_all } , all_rank_num_tokens: { all_rank_num_tokens } , all_rank_num_tokens_padded: { all_rank_num_tokens_padded } , all_rank_max_num_tokens: { all_rank_max_num_tokens } , use_dp_padding: { use_dp_padding } , outputs.shape: { getattr (outputs , 'shape' , None )} , use_dp_padding(again): { use_dp_padding } "
811
+ )
787
812
outputs = self .reducescatter_or_allreduce (
788
813
outputs ,
789
814
use_all_to_all ,
0 commit comments