|
14 | 14 | from mpi4py import MPI
|
15 | 15 | from mpi4py.futures import MPIPoolExecutor
|
16 | 16 | from utils.util import (skip_neither_ada_nor_hopper_unittest,
|
17 |
| - skip_pre_blackwell, skip_pre_hopper) |
| 17 | + skip_non_hopper_unittest, skip_pre_blackwell, |
| 18 | + skip_pre_hopper) |
18 | 19 |
|
19 | 20 | from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
20 | 21 | from tensorrt_llm._torch.model_config import ModelConfig
|
@@ -693,6 +694,140 @@ def test_fused_moe_fp8_blockwise_cute_dsl(dtype,
|
693 | 694 | return True
|
694 | 695 |
|
695 | 696 |
|
| 697 | +@skip_non_hopper_unittest |
| 698 | +@pytest.mark.parametrize( |
| 699 | + "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode", |
| 700 | + product( |
| 701 | + [torch.bfloat16], |
| 702 | + [72], |
| 703 | + [128, 256, 384, 512, 1024, 2048, 4096, 8192], |
| 704 | + [2560], |
| 705 | + [DefaultMoeRoutingMethod], |
| 706 | + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ], |
| 707 | + ), |
| 708 | +) |
| 709 | +def test_fused_moe_fp8_blockwise_cutlass(dtype, |
| 710 | + num_experts, |
| 711 | + seq_len, |
| 712 | + hidden_size, |
| 713 | + RoutingMethodCls, |
| 714 | + WeightLoadingMode, |
| 715 | + mapping=None): |
| 716 | + SEQ_LEN = seq_len |
| 717 | + HIDDEN_SIZE = hidden_size |
| 718 | + INTERMEDIATE_SIZE = 1536 |
| 719 | + NUM_EXPERTS = num_experts |
| 720 | + TOP_K = 6 |
| 721 | + |
| 722 | + routing_method = RoutingMethodCls(top_k=TOP_K) |
| 723 | + |
| 724 | + mapping = mapping or Mapping() |
| 725 | + mapping.rank = mpi_rank() |
| 726 | + torch.cuda.set_device(mapping.rank) |
| 727 | + torch.manual_seed(0) |
| 728 | + torch.cuda.manual_seed(0) |
| 729 | + |
| 730 | + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") |
| 731 | + # Note: we use some special values init x and weight, otherwise the test will false positive failed. |
| 732 | + set_tensor_value_2(x, SEQ_LEN, HIDDEN_SIZE) |
| 733 | + |
| 734 | + x = x.cuda() |
| 735 | + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), |
| 736 | + dtype=dtype, |
| 737 | + device="cuda") |
| 738 | + |
| 739 | + weights = {} |
| 740 | + |
| 741 | + if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: |
| 742 | + weights['gate_up_proj'] = {} |
| 743 | + weights['down_proj'] = {} |
| 744 | + weights['gate_up_proj_weight_scale'] = {} |
| 745 | + weights['down_proj_weight_scale'] = {} |
| 746 | + |
| 747 | + for expert_id in range(NUM_EXPERTS): |
| 748 | + w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), |
| 749 | + dtype=dtype, |
| 750 | + device="cuda") |
| 751 | + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), |
| 752 | + dtype=dtype, |
| 753 | + device="cuda") |
| 754 | + w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), |
| 755 | + dtype=dtype, |
| 756 | + device="cuda") |
| 757 | + set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) |
| 758 | + set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE) |
| 759 | + set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) |
| 760 | + |
| 761 | + w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8(w1_weight) |
| 762 | + w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda() |
| 763 | + |
| 764 | + w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8(w2_weight) |
| 765 | + w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda() |
| 766 | + |
| 767 | + w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8(w3_weight) |
| 768 | + w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda() |
| 769 | + |
| 770 | + weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 |
| 771 | + weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 |
| 772 | + weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 |
| 773 | + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale |
| 774 | + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale |
| 775 | + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale |
| 776 | + |
| 777 | + if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: |
| 778 | + weights['gate_up_proj'][expert_id] = torch.cat( |
| 779 | + [w3_weight_fp8, w1_weight_fp8], |
| 780 | + dim=-2).transpose(0, 1).contiguous() |
| 781 | + weights['down_proj'][expert_id] = w2_weight_fp8.transpose( |
| 782 | + 0, 1).contiguous() |
| 783 | + weights['gate_up_proj_weight_scale'][expert_id] = torch.cat( |
| 784 | + [w3_weight_scale, w1_weight_scale], |
| 785 | + dim=-2).transpose(0, 1).contiguous() |
| 786 | + weights['down_proj_weight_scale'][ |
| 787 | + expert_id] = w2_weight_scale.transpose(0, 1).contiguous() |
| 788 | + elif WeightLoadingMode == MoEWeightLoadingMode.VANILLA: |
| 789 | + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale |
| 790 | + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale |
| 791 | + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale |
| 792 | + |
| 793 | + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) |
| 794 | + |
| 795 | + fused_moe = CutlassFusedMoE( |
| 796 | + num_experts=NUM_EXPERTS, |
| 797 | + routing_method=routing_method, |
| 798 | + hidden_size=HIDDEN_SIZE, |
| 799 | + intermediate_size=INTERMEDIATE_SIZE, |
| 800 | + dtype=dtype, |
| 801 | + reduce_results=True, |
| 802 | + model_config=ModelConfig(quant_config=quant_config, mapping=mapping), |
| 803 | + weight_loading_mode=WeightLoadingMode, |
| 804 | + ) |
| 805 | + fused_moe.cuda() |
| 806 | + fused_moe.load_weights([weights]) |
| 807 | + |
| 808 | + ref_fused_moe = RefGatedMLPFusedMoE( |
| 809 | + num_experts=NUM_EXPERTS, |
| 810 | + routing_method=routing_method, |
| 811 | + hidden_size=HIDDEN_SIZE, |
| 812 | + intermediate_size=INTERMEDIATE_SIZE, |
| 813 | + dtype=dtype, |
| 814 | + model_config=ModelConfig(quant_config=quant_config), |
| 815 | + # Note: use deepgemm mm will cause accuracy error, so we use trtllmgen mm here |
| 816 | + use_cute_dsl_blockscaling_mm=True, |
| 817 | + ) |
| 818 | + ref_fused_moe.load_weights([weights]) |
| 819 | + ref_fused_moe.cuda() |
| 820 | + |
| 821 | + with torch.inference_mode(): |
| 822 | + output = fused_moe.forward(x, router_logits) |
| 823 | + ref_output = ref_fused_moe.forward(x, router_logits) |
| 824 | + |
| 825 | + # compare |
| 826 | + torch.cuda.synchronize() |
| 827 | + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) |
| 828 | + return True |
| 829 | + |
| 830 | + |
696 | 831 | @skip_pre_blackwell
|
697 | 832 | @pytest.mark.skipif(torch.cuda.device_count() < 4,
|
698 | 833 | reason="needs 4 GPUs to run this test")
|
|
0 commit comments