26
26
from vllm .model_executor .layers .fused_moe .prepare_finalize import (
27
27
MoEPrepareAndFinalizeNoEP )
28
28
from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
29
- TopKWeightAndReduceContiguous , TopKWeightAndReduceNoOP )
29
+ TopKWeightAndReduceNoOP )
30
30
from vllm .model_executor .layers .fused_moe .utils import (
31
31
_resize_cache , moe_kernel_quantize_input )
32
32
from vllm .model_executor .layers .quantization .utils .mxfp4_utils import (
@@ -1709,26 +1709,27 @@ def apply(
1709
1709
moe_align_block_size (topk_ids , config ['BLOCK_SIZE_M' ],
1710
1710
global_num_experts , expert_map ))
1711
1711
1712
- invoke_fused_moe_kernel (hidden_states ,
1713
- w1 ,
1714
- intermediate_cache1 ,
1715
- a1q_scale ,
1716
- w1_scale ,
1717
- w1_zp ,
1718
- None ,
1719
- sorted_token_ids ,
1720
- expert_ids ,
1721
- num_tokens_post_padded ,
1722
- False ,
1723
- top_k_num ,
1724
- config ,
1725
- compute_type = compute_type ,
1726
- use_fp8_w8a8 = self .use_fp8_w8a8 ,
1727
- use_int8_w8a8 = self .use_int8_w8a8 ,
1728
- use_int8_w8a16 = self .use_int8_w8a16 ,
1729
- use_int4_w4a16 = self .use_int4_w4a16 ,
1730
- per_channel_quant = self .per_act_token_quant ,
1731
- block_shape = self .block_shape )
1712
+ invoke_fused_moe_kernel (
1713
+ hidden_states ,
1714
+ w1 ,
1715
+ intermediate_cache1 ,
1716
+ a1q_scale ,
1717
+ w1_scale ,
1718
+ w1_zp ,
1719
+ None , # topk_weights
1720
+ sorted_token_ids ,
1721
+ expert_ids ,
1722
+ num_tokens_post_padded ,
1723
+ False , # mul_routed_weights
1724
+ top_k_num ,
1725
+ config ,
1726
+ compute_type = compute_type ,
1727
+ use_fp8_w8a8 = self .use_fp8_w8a8 ,
1728
+ use_int8_w8a8 = self .use_int8_w8a8 ,
1729
+ use_int8_w8a16 = self .use_int8_w8a16 ,
1730
+ use_int4_w4a16 = self .use_int4_w4a16 ,
1731
+ per_channel_quant = self .per_act_token_quant ,
1732
+ block_shape = self .block_shape )
1732
1733
1733
1734
self .activation (activation , intermediate_cache2 ,
1734
1735
intermediate_cache1 .view (- 1 , N ))
@@ -1745,11 +1746,11 @@ def apply(
1745
1746
a2q_scale ,
1746
1747
w2_scale ,
1747
1748
w2_zp ,
1748
- None ,
1749
+ topk_weights ,
1749
1750
sorted_token_ids ,
1750
1751
expert_ids ,
1751
1752
num_tokens_post_padded ,
1752
- False ,
1753
+ not apply_router_weight_on_input ,
1753
1754
1 ,
1754
1755
config ,
1755
1756
compute_type = compute_type ,
@@ -1760,12 +1761,7 @@ def apply(
1760
1761
per_channel_quant = self .per_act_token_quant ,
1761
1762
block_shape = self .block_shape )
1762
1763
1763
- TopKWeightAndReduceContiguous ().apply (
1764
- output = output ,
1765
- fused_expert_output = intermediate_cache3 ,
1766
- topk_weights = topk_weights ,
1767
- topk_ids = topk_ids ,
1768
- apply_router_weight_on_input = apply_router_weight_on_input )
1764
+ ops .moe_sum (intermediate_cache3 , output )
1769
1765
1770
1766
1771
1767
def modular_triton_fused_moe (
0 commit comments