Skip to content

Commit 36f4493

Browse files
committed
debug
Signed-off-by: xxi <[email protected]> modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py Signed-off-by: xxi <[email protected]> modified: tensorrt_llm/_torch/distributed/ops.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py
1 parent 373e11f commit 36f4493

File tree

4 files changed

+70
-36
lines changed

4 files changed

+70
-36
lines changed

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ def reducescatter(
240240
if isinstance(input, torch.Tensor):
241241
assert input.shape[dim] == sum_split_size
242242
else:
243+
for val in input:
244+
if val is not None and val.shape[dim] != sum_split_size:
245+
print(
246+
f"[reducescatter] val.shape={val.shape}, dim={dim}, val.shape[dim]={val.shape[dim]}, sum_split_size={sum_split_size}, sizes={sizes}"
247+
)
243248
assert all([
244249
val.shape[dim] == sum_split_size for val in input
245250
if val is not None

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,20 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
308308
def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens):
309309
# Disable alltoall when chunking is used
310310
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
311+
print(
312+
f"can not use alltoall due to chunking {self.calculate_num_chunks(all_rank_num_tokens)}"
313+
)
311314
return False
312315

313316
# For DeepEPLowLatency, check if tokens exceed the threshold
314317
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency
315318
and all_rank_max_num_tokens > self.deep_ep_max_num_tokens):
319+
print(
320+
f"can not use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}"
321+
)
316322
return False
317323

324+
print(f"all to all type {self.alltoall_method_type}")
318325
return self.enable_alltoall
319326

320327
def _get_quant_method(self):
@@ -323,9 +330,18 @@ def _get_quant_method(self):
323330
if self.quant_config.layer_quant_mode.has_fp8_qdq():
324331
return FP8QDQFusedMoEMethod()
325332
elif self.quant_config.layer_quant_mode.has_fp8_block_scales():
333+
print(
334+
f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}"
335+
)
326336
if get_sm_version() == 100:
337+
print(
338+
f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
339+
)
327340
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
328341
else:
342+
print(
343+
f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
344+
)
329345
return DeepSeekFP8BlockScalesFusedMoEMethod()
330346
elif self.quant_config.layer_quant_mode.has_nvfp4():
331347
return NVFP4CutlassFusedMoEMethod()
@@ -399,6 +415,10 @@ def forward_chunk(
399415

400416
is_first_call, is_last_call = repeating_info
401417

418+
print(
419+
f"wide_ep forward_chunk: layer_load_balancer={self.layer_load_balancer}, is_first_call={is_first_call}, is_last_call={is_last_call}"
420+
)
421+
402422
if self.layer_load_balancer and is_first_call:
403423
self.layer_load_balancer.start_wait_gpu_stage()
404424

@@ -475,7 +495,7 @@ def forward_chunk(
475495
self.dummy_allreduce()
476496
token_count = x.shape[0]
477497
alltoall_info = None
478-
if is_last_call:
498+
if self.layer_load_balancer and is_last_call:
479499
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
480500
)
481501
else:
@@ -650,35 +670,7 @@ def forward_chunk(
650670
)
651671

652672
# Original fused_moe call (preserved as reference)
653-
final_hidden_states = torch.ops.trtllm.fused_moe(
654-
x,
655-
token_selected_slots,
656-
token_final_scales,
657-
w3_w1_weight.view(weight_dtype),
658-
None, # w3_w1_bias
659-
w2_weight.view(weight_dtype),
660-
None, # w2_bias
661-
output_dtype,
662-
quant_scales=quant_scales,
663-
input_sf=x_sf,
664-
swizzled_input_sf=False,
665-
tp_size=self.tp_size,
666-
tp_rank=self.tp_rank,
667-
ep_size=ep_size,
668-
ep_rank=ep_rank,
669-
cluster_size=cluster_size,
670-
cluster_rank=cluster_rank,
671-
enable_alltoall=use_all_to_all,
672-
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
673-
use_w4_group_scaling=use_w4_group_scaling,
674-
min_latency_mode=False,
675-
tune_max_num_tokens=self.tune_max_num_tokens,
676-
tuner_num_tokens=tuner_num_tokens,
677-
tuner_top_k=tuner_top_k,
678-
)
679-
680-
# Use the selected backend to compute MoE with the same parameters as fused_moe
681-
# final_hidden_states = self.moe_backend.run_moe(
673+
# final_hidden_states = torch.ops.trtllm.fused_moe(
682674
# x,
683675
# token_selected_slots,
684676
# token_final_scales,
@@ -703,9 +695,38 @@ def forward_chunk(
703695
# tune_max_num_tokens=self.tune_max_num_tokens,
704696
# tuner_num_tokens=tuner_num_tokens,
705697
# tuner_top_k=tuner_top_k,
706-
# module=self, # Additional parameter for backend to access module properties
707698
# )
708699

700+
# Use the selected backend to compute MoE with the same parameters as fused_moe
701+
final_hidden_states = self.moe_backend.run_moe(
702+
x,
703+
token_selected_slots,
704+
token_final_scales,
705+
w3_w1_weight.view(weight_dtype),
706+
None, # w3_w1_bias
707+
w2_weight.view(weight_dtype),
708+
None, # w2_bias
709+
output_dtype,
710+
quant_scales=quant_scales,
711+
input_sf=x_sf,
712+
swizzled_input_sf=False,
713+
tp_size=self.tp_size,
714+
tp_rank=self.tp_rank,
715+
ep_size=ep_size,
716+
ep_rank=ep_rank,
717+
cluster_size=cluster_size,
718+
cluster_rank=cluster_rank,
719+
enable_alltoall=use_all_to_all,
720+
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
721+
use_w4_group_scaling=use_w4_group_scaling,
722+
min_latency_mode=False,
723+
tune_max_num_tokens=self.tune_max_num_tokens,
724+
tuner_num_tokens=tuner_num_tokens,
725+
tuner_top_k=tuner_top_k,
726+
module=
727+
self, # Additional parameter for backend to access module properties
728+
)
729+
709730
if self.layer_load_balancer and is_last_call:
710731
self.layer_load_balancer.start_set_cpu_stage()
711732

@@ -784,6 +805,10 @@ def forward(
784805
all_rank_max_num_tokens=all_rank_max_num_tokens,
785806
use_dp_padding=use_dp_padding,
786807
repeating_info=(is_first_call, is_last_call))
808+
# 一行打印所有信息
809+
print(
810+
f"xxi x.shape: {getattr(x, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_num_tokens_padded: {all_rank_num_tokens_padded}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, outputs.shape: {getattr(outputs, 'shape', None)}, use_dp_padding(again): {use_dp_padding}"
811+
)
787812
outputs = self.reducescatter_or_allreduce(
788813
outputs,
789814
use_all_to_all,

tensorrt_llm/_torch/modules/fused_moe/moe_backend.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def compute_moe(
9898
Computed MoE output tensor
9999
"""
100100

101-
@abstractmethod
102101
def run_moe(
103102
self,
104103
# Positional arguments (same order as torch.ops.trtllm.fused_moe)
@@ -542,10 +541,11 @@ def __init__(self):
542541
super().__init__()
543542
# Import DeepGemm specific functions
544543
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
545-
from tensorrt_llm import deep_gemm
546-
self.deep_gemm = deep_gemm
547544
self.fp8_utils = fp8_utils
548545

546+
from .fused_moe_deepgemm import deepgemm_fp8_group_blockwise_gemm
547+
self.deepgemm_fp8_group_blockwise_gemm = deepgemm_fp8_group_blockwise_gemm
548+
549549
def finalize_tactic(
550550
self,
551551
module: Any,
@@ -664,6 +664,7 @@ def compute_moe(
664664
Note: This assumes the data has already been gathered/alltoall'd
665665
by the WideEP forward_chunk method.
666666
"""
667+
667668
# Import necessary functions for DeepGemm
668669
from .fused_moe_deepgemm import (masked_index_copy_group_quant_fp8,
669670
preprocess_after_permute, set_strides,
@@ -750,7 +751,7 @@ def compute_moe(
750751
h1 = set_strides(workspace["workspace_1"], expert_size_per_partition,
751752
m_max, intermediate_size * 2)
752753

753-
self.deep_gemm.deepgemm_fp8_group_blockwise_gemm(
754+
self.deepgemm_fp8_group_blockwise_gemm(
754755
d=h1,
755756
a=act_input_fp8,
756757
b=w3_w1_weight,
@@ -783,7 +784,7 @@ def compute_moe(
783784
h3 = set_strides(workspace["workspace_1"], expert_size_per_partition,
784785
m_max, hidden_size)
785786

786-
self.deep_gemm.deepgemm_fp8_group_blockwise_gemm(
787+
self.deepgemm_fp8_group_blockwise_gemm(
787788
d=h3,
788789
a=act_input_fp8,
789790
b=w2_weight,

tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,9 @@ def maybe_create_moe_load_balancer(
960960
in_supported_model_arch = model_arch in moe_model_arch_list
961961
using_smart_router = mapping and mapping.moe_cluster_size > 1
962962
moe_load_balancer = nullcontext()
963+
print(
964+
f"maybe_create_moe_load_balancer: in_supported_model_arch={in_supported_model_arch}, using_ep={using_ep}, using_smart_router={using_smart_router}, model_config.moe_load_balancer={model_config.moe_load_balancer}"
965+
)
963966
if in_supported_model_arch and using_ep and not using_smart_router and model_config.moe_load_balancer is not None:
964967
model_config.moe_load_balancer.setup(ep_rank=ep_rank, ep_size=ep_size)
965968
if model_config.moe_load_balancer.layer_updates_per_iter > 0:

0 commit comments

Comments
 (0)