Skip to content

Commit e369637

Browse files
author
Varun Sundar Rabindranath
committed
TritonExperts opt
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
1 parent c4145c6 commit e369637

File tree

1 file changed

+25
-29
lines changed

1 file changed

+25
-29
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
2727
MoEPrepareAndFinalizeNoEP)
2828
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
29-
TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP)
29+
TopKWeightAndReduceNoOP)
3030
from vllm.model_executor.layers.fused_moe.utils import (
3131
_resize_cache, moe_kernel_quantize_input)
3232
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
@@ -1709,26 +1709,27 @@ def apply(
17091709
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
17101710
global_num_experts, expert_map))
17111711

1712-
invoke_fused_moe_kernel(hidden_states,
1713-
w1,
1714-
intermediate_cache1,
1715-
a1q_scale,
1716-
w1_scale,
1717-
w1_zp,
1718-
None,
1719-
sorted_token_ids,
1720-
expert_ids,
1721-
num_tokens_post_padded,
1722-
False,
1723-
top_k_num,
1724-
config,
1725-
compute_type=compute_type,
1726-
use_fp8_w8a8=self.use_fp8_w8a8,
1727-
use_int8_w8a8=self.use_int8_w8a8,
1728-
use_int8_w8a16=self.use_int8_w8a16,
1729-
use_int4_w4a16=self.use_int4_w4a16,
1730-
per_channel_quant=self.per_act_token_quant,
1731-
block_shape=self.block_shape)
1712+
invoke_fused_moe_kernel(
1713+
hidden_states,
1714+
w1,
1715+
intermediate_cache1,
1716+
a1q_scale,
1717+
w1_scale,
1718+
w1_zp,
1719+
None, # topk_weights
1720+
sorted_token_ids,
1721+
expert_ids,
1722+
num_tokens_post_padded,
1723+
False, # mul_routed_weights
1724+
top_k_num,
1725+
config,
1726+
compute_type=compute_type,
1727+
use_fp8_w8a8=self.use_fp8_w8a8,
1728+
use_int8_w8a8=self.use_int8_w8a8,
1729+
use_int8_w8a16=self.use_int8_w8a16,
1730+
use_int4_w4a16=self.use_int4_w4a16,
1731+
per_channel_quant=self.per_act_token_quant,
1732+
block_shape=self.block_shape)
17321733

17331734
self.activation(activation, intermediate_cache2,
17341735
intermediate_cache1.view(-1, N))
@@ -1745,11 +1746,11 @@ def apply(
17451746
a2q_scale,
17461747
w2_scale,
17471748
w2_zp,
1748-
None,
1749+
topk_weights,
17491750
sorted_token_ids,
17501751
expert_ids,
17511752
num_tokens_post_padded,
1752-
False,
1753+
not apply_router_weight_on_input,
17531754
1,
17541755
config,
17551756
compute_type=compute_type,
@@ -1760,12 +1761,7 @@ def apply(
17601761
per_channel_quant=self.per_act_token_quant,
17611762
block_shape=self.block_shape)
17621763

1763-
TopKWeightAndReduceContiguous().apply(
1764-
output=output,
1765-
fused_expert_output=intermediate_cache3,
1766-
topk_weights=topk_weights,
1767-
topk_ids=topk_ids,
1768-
apply_router_weight_on_input=apply_router_weight_on_input)
1764+
ops.moe_sum(intermediate_cache3, output)
17691765

17701766

17711767
def modular_triton_fused_moe(

0 commit comments

Comments
 (0)