Skip to content

Commit 99b56a7

Browse files
committed
do not delete fused_moe_cutlass hopper test
Signed-off-by: Mindy Li <[email protected]>
1 parent c1d9877 commit 99b56a7

File tree

1 file changed

+136
-1
lines changed

1 file changed

+136
-1
lines changed

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from mpi4py import MPI
1515
from mpi4py.futures import MPIPoolExecutor
1616
from utils.util import (skip_neither_ada_nor_hopper_unittest,
17-
skip_pre_blackwell, skip_pre_hopper)
17+
skip_non_hopper_unittest, skip_pre_blackwell,
18+
skip_pre_hopper)
1819

1920
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
2021
from tensorrt_llm._torch.model_config import ModelConfig
@@ -693,6 +694,140 @@ def test_fused_moe_fp8_blockwise_cute_dsl(dtype,
693694
return True
694695

695696

697+
@skip_non_hopper_unittest
698+
@pytest.mark.parametrize(
699+
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode",
700+
product(
701+
[torch.bfloat16],
702+
[72],
703+
[128, 256, 384, 512, 1024, 2048, 4096, 8192],
704+
[2560],
705+
[DefaultMoeRoutingMethod],
706+
[MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ],
707+
),
708+
)
709+
def test_fused_moe_fp8_blockwise_cutlass(dtype,
710+
num_experts,
711+
seq_len,
712+
hidden_size,
713+
RoutingMethodCls,
714+
WeightLoadingMode,
715+
mapping=None):
716+
SEQ_LEN = seq_len
717+
HIDDEN_SIZE = hidden_size
718+
INTERMEDIATE_SIZE = 1536
719+
NUM_EXPERTS = num_experts
720+
TOP_K = 6
721+
722+
routing_method = RoutingMethodCls(top_k=TOP_K)
723+
724+
mapping = mapping or Mapping()
725+
mapping.rank = mpi_rank()
726+
torch.cuda.set_device(mapping.rank)
727+
torch.manual_seed(0)
728+
torch.cuda.manual_seed(0)
729+
730+
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda")
731+
# Note: we use some special values init x and weight, otherwise the test will false positive failed.
732+
set_tensor_value_2(x, SEQ_LEN, HIDDEN_SIZE)
733+
734+
x = x.cuda()
735+
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS),
736+
dtype=dtype,
737+
device="cuda")
738+
739+
weights = {}
740+
741+
if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
742+
weights['gate_up_proj'] = {}
743+
weights['down_proj'] = {}
744+
weights['gate_up_proj_weight_scale'] = {}
745+
weights['down_proj_weight_scale'] = {}
746+
747+
for expert_id in range(NUM_EXPERTS):
748+
w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
749+
dtype=dtype,
750+
device="cuda")
751+
w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE),
752+
dtype=dtype,
753+
device="cuda")
754+
w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
755+
dtype=dtype,
756+
device="cuda")
757+
set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE)
758+
set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE)
759+
set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE)
760+
761+
w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8(w1_weight)
762+
w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda()
763+
764+
w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8(w2_weight)
765+
w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda()
766+
767+
w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8(w3_weight)
768+
w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda()
769+
770+
weights[f"{expert_id}.w1.weight"] = w1_weight_fp8
771+
weights[f"{expert_id}.w2.weight"] = w2_weight_fp8
772+
weights[f"{expert_id}.w3.weight"] = w3_weight_fp8
773+
weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale
774+
weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale
775+
weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale
776+
777+
if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
778+
weights['gate_up_proj'][expert_id] = torch.cat(
779+
[w3_weight_fp8, w1_weight_fp8],
780+
dim=-2).transpose(0, 1).contiguous()
781+
weights['down_proj'][expert_id] = w2_weight_fp8.transpose(
782+
0, 1).contiguous()
783+
weights['gate_up_proj_weight_scale'][expert_id] = torch.cat(
784+
[w3_weight_scale, w1_weight_scale],
785+
dim=-2).transpose(0, 1).contiguous()
786+
weights['down_proj_weight_scale'][
787+
expert_id] = w2_weight_scale.transpose(0, 1).contiguous()
788+
elif WeightLoadingMode == MoEWeightLoadingMode.VANILLA:
789+
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale
790+
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale
791+
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale
792+
793+
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES)
794+
795+
fused_moe = CutlassFusedMoE(
796+
num_experts=NUM_EXPERTS,
797+
routing_method=routing_method,
798+
hidden_size=HIDDEN_SIZE,
799+
intermediate_size=INTERMEDIATE_SIZE,
800+
dtype=dtype,
801+
reduce_results=True,
802+
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
803+
weight_loading_mode=WeightLoadingMode,
804+
)
805+
fused_moe.cuda()
806+
fused_moe.load_weights([weights])
807+
808+
ref_fused_moe = RefGatedMLPFusedMoE(
809+
num_experts=NUM_EXPERTS,
810+
routing_method=routing_method,
811+
hidden_size=HIDDEN_SIZE,
812+
intermediate_size=INTERMEDIATE_SIZE,
813+
dtype=dtype,
814+
model_config=ModelConfig(quant_config=quant_config),
815+
# Note: use deepgemm mm will cause accuracy error, so we use trtllmgen mm here
816+
use_cute_dsl_blockscaling_mm=True,
817+
)
818+
ref_fused_moe.load_weights([weights])
819+
ref_fused_moe.cuda()
820+
821+
with torch.inference_mode():
822+
output = fused_moe.forward(x, router_logits)
823+
ref_output = ref_fused_moe.forward(x, router_logits)
824+
825+
# compare
826+
torch.cuda.synchronize()
827+
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
828+
return True
829+
830+
696831
@skip_pre_blackwell
697832
@pytest.mark.skipif(torch.cuda.device_count() < 4,
698833
reason="needs 4 GPUs to run this test")

0 commit comments

Comments
 (0)