Skip to content

Commit c5fd979

Browse files
author
Varun Sundar Rabindranath
committed
do reduction in experts
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
1 parent 53fa457 commit c5fd979

File tree

9 files changed

+187
-136
lines changed

9 files changed

+187
-136
lines changed

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def apply(
255255
hidden_states: torch.Tensor,
256256
w1: torch.Tensor,
257257
w2: torch.Tensor,
258+
topk_weights: torch.Tensor,
258259
topk_ids: torch.Tensor,
259260
activation: str,
260261
global_num_experts: int,
@@ -268,6 +269,7 @@ def apply(
268269
workspace13: torch.Tensor,
269270
workspace2: torch.Tensor,
270271
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
272+
apply_router_weight_on_input: bool,
271273
):
272274
assert expert_tokens_meta is not None
273275
expert_num_tokens = expert_tokens_meta.expert_num_tokens

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -129,30 +129,22 @@ def workspace_shapes(
129129
return self.batched_triton_experts.workspace_shapes(
130130
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
131131

132-
def apply(
133-
self,
134-
output: torch.Tensor,
135-
hidden_states: torch.Tensor,
136-
w1: torch.Tensor,
137-
w2: torch.Tensor,
138-
topk_ids: torch.Tensor,
139-
activation: str,
140-
global_num_experts: int,
141-
expert_map: Optional[torch.Tensor],
142-
w1_scale: Optional[torch.Tensor],
143-
w2_scale: Optional[torch.Tensor],
144-
w1_zp: Optional[torch.Tensor],
145-
w2_zp: Optional[torch.Tensor],
146-
a1q_scale: Optional[torch.Tensor],
147-
a2_scale: Optional[torch.Tensor],
148-
workspace13: torch.Tensor,
149-
workspace2: torch.Tensor,
150-
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
151-
):
132+
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
133+
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
134+
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
135+
expert_map: Optional[torch.Tensor],
136+
w1_scale: Optional[torch.Tensor],
137+
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
138+
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
139+
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
140+
workspace2: torch.Tensor,
141+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
142+
apply_router_weight_on_input: bool):
152143
experts = (self.batched_deep_gemm_experts
153144
if self.allow_deep_gemm else self.batched_triton_experts)
154145
assert experts is not None
155-
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
156-
global_num_experts, expert_map, w1_scale, w2_scale,
157-
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
158-
workspace2, expert_tokens_meta)
146+
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
147+
activation, global_num_experts, expert_map, w1_scale,
148+
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
149+
workspace2, expert_tokens_meta,
150+
apply_router_weight_on_input)

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -291,26 +291,17 @@ def workspace_shapes(
291291
return (workspace1, workspace2, output,
292292
self.out_dtype if self.out_dtype is not None else a.dtype)
293293

294-
def apply(
295-
self,
296-
output: torch.Tensor,
297-
hidden_states: torch.Tensor,
298-
w1: torch.Tensor,
299-
w2: torch.Tensor,
300-
topk_ids: torch.Tensor,
301-
activation: str,
302-
global_num_experts: int,
303-
expert_map: Optional[torch.Tensor],
304-
w1_scale: Optional[torch.Tensor],
305-
w2_scale: Optional[torch.Tensor],
306-
w1_zp: Optional[torch.Tensor],
307-
w2_zp: Optional[torch.Tensor],
308-
a1q_scale: Optional[torch.Tensor],
309-
a2_scale: Optional[torch.Tensor],
310-
workspace13: torch.Tensor,
311-
workspace2: torch.Tensor,
312-
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
313-
):
294+
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
295+
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
296+
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
297+
expert_map: Optional[torch.Tensor],
298+
w1_scale: Optional[torch.Tensor],
299+
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
300+
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
301+
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
302+
workspace2: torch.Tensor,
303+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
304+
apply_router_weight_on_input: bool):
314305
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
315306
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
316307

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
1414
MoEPrepareAndFinalizeNoEP)
1515
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
16-
TopKWeightAndReduceDelegate)
16+
TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP)
1717
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1818
from vllm.utils import has_deep_gemm, round_up
1919
from vllm.utils.deep_gemm import (m_grouped_fp8_gemm_nt_contiguous,
@@ -89,8 +89,7 @@ def supports_expert_map(self) -> bool:
8989
return True
9090

9191
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
92-
# Let PrepareAndFinalize::finalize() decide the impl.
93-
return TopKWeightAndReduceDelegate()
92+
return TopKWeightAndReduceNoOP()
9493

9594
def workspace_shapes(
9695
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
@@ -103,9 +102,9 @@ def workspace_shapes(
103102
block_m = self.block_shape[0]
104103
M_sum = (M * topk) + num_experts * (block_m - 1)
105104
M_sum = round_up(M_sum, block_m)
106-
workspace1 = (M_sum, max(N * 2, K))
107-
workspace2 = (M_sum, max(N, K))
108-
output = (M, topk, K)
105+
workspace1 = (M_sum, max(N, K))
106+
workspace2 = (M_sum, max(N * 2, K))
107+
output = (M, K)
109108
return (workspace1, workspace2, output, a.dtype)
110109

111110
def apply(
@@ -114,6 +113,7 @@ def apply(
114113
hidden_states: torch.Tensor,
115114
w1: torch.Tensor,
116115
w2: torch.Tensor,
116+
topk_weights: torch.Tensor,
117117
topk_ids: torch.Tensor,
118118
activation: str,
119119
global_num_experts: int,
@@ -127,11 +127,14 @@ def apply(
127127
workspace13: torch.Tensor,
128128
workspace2: torch.Tensor,
129129
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
130+
apply_router_weight_on_input: bool,
130131
):
131132
assert self.block_shape is not None
132133

133134
a1q = hidden_states
134135
_, N, K = w1.size()
136+
M, _ = output.size()
137+
num_topk = topk_ids.size(1)
135138

136139
if global_num_experts == -1:
137140
global_num_experts = w1.size(0)
@@ -158,11 +161,12 @@ def apply(
158161
# Note: M_sum is different than the pre-permuted shape of a1q.
159162
M_sum = a1q.size(0)
160163

161-
mm1_out = _resize_cache(workspace13, (M_sum, N))
162-
act_out = _resize_cache(workspace2, (M_sum, N // 2))
163-
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
164+
mm1_out = _resize_cache(workspace2, (M_sum, N))
165+
act_out = _resize_cache(workspace13, (M_sum, N // 2))
166+
quant_out = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
164167
(M_sum, N // 2))
165-
mm2_out = _resize_cache(workspace2, (M_sum, K))
168+
mm2_out = _resize_cache(workspace13, (M_sum, K))
169+
perm_out = _resize_cache(workspace2, (M * num_topk, K))
166170

167171
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
168172
mm1_out, expert_ids)
@@ -178,7 +182,14 @@ def apply(
178182
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
179183
mm2_out, expert_ids)
180184

181-
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
185+
torch.index_select(mm2_out, 0, inv_perm, out=perm_out)
186+
187+
TopKWeightAndReduceContiguous().apply(
188+
output=output,
189+
fused_expert_output=perm_out,
190+
topk_weights=topk_weights,
191+
topk_ids=topk_ids,
192+
apply_router_weight_on_input=apply_router_weight_on_input)
182193

183194

184195
def deep_gemm_moe_fp8(

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -697,15 +697,16 @@ def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
697697
return t.to(f32) * group_broadcast(scale, t.shape)
698698

699699
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
700-
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
701-
activation: str, global_num_experts: int,
700+
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
701+
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
702702
expert_map: Optional[torch.Tensor],
703703
w1_scale: Optional[torch.Tensor],
704704
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
705705
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
706706
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
707707
workspace2: torch.Tensor,
708-
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
708+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
709+
apply_router_weight_on_input: bool):
709710
assert hidden_states.dim() == 3
710711
assert expert_tokens_meta is not None
711712
expert_num_tokens = expert_tokens_meta.expert_num_tokens
@@ -900,15 +901,16 @@ def workspace_shapes(
900901
return (workspace13, workspace2, output, a.dtype)
901902

902903
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
903-
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
904-
activation: str, global_num_experts: int,
904+
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
905+
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
905906
expert_map: Optional[torch.Tensor],
906907
w1_scale: Optional[torch.Tensor],
907908
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
908909
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
909910
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
910911
workspace2: torch.Tensor,
911-
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
912+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
913+
apply_router_weight_on_input: bool):
912914
# Check constraints.
913915
if self.use_int4_w4a16:
914916
assert hidden_states.size(-1) // 2 == w1.size(2), (

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
2727
MoEPrepareAndFinalizeNoEP)
2828
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
29-
TopKWeightAndReduceDelegate)
29+
TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP)
3030
from vllm.model_executor.layers.fused_moe.utils import (
3131
_resize_cache, moe_kernel_quantize_input)
3232
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
@@ -1606,8 +1606,7 @@ def supports_expert_map(self) -> bool:
16061606
return True
16071607

16081608
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
1609-
# Let PrepareAndFinalize::finalize() decide the impl.
1610-
return TopKWeightAndReduceDelegate()
1609+
return TopKWeightAndReduceNoOP()
16111610

16121611
def workspace_shapes(
16131612
self,
@@ -1620,9 +1619,9 @@ def workspace_shapes(
16201619
global_num_experts: int,
16211620
local_num_experts: int,
16221621
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
1623-
workspace1 = (M, topk, max(N * 2, K))
1624-
workspace2 = (M, topk, N)
1625-
output = (M, topk, K)
1622+
workspace1 = (M, topk, max(N, K))
1623+
workspace2 = (M, topk, max(N * 2, K))
1624+
output = (M, K)
16261625
return (workspace1, workspace2, output, a.dtype)
16271626

16281627
def apply(
@@ -1631,6 +1630,7 @@ def apply(
16311630
hidden_states: torch.Tensor,
16321631
w1: torch.Tensor,
16331632
w2: torch.Tensor,
1633+
topk_weights: torch.Tensor,
16341634
topk_ids: torch.Tensor,
16351635
activation: str,
16361636
global_num_experts: int,
@@ -1644,6 +1644,7 @@ def apply(
16441644
workspace13: torch.Tensor,
16451645
workspace2: torch.Tensor,
16461646
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
1647+
apply_router_weight_on_input: bool,
16471648
):
16481649
# Check constraints.
16491650
if self.use_int4_w4a16:
@@ -1696,12 +1697,13 @@ def apply(
16961697
raise ValueError(
16971698
f"Unsupported compute_type: {hidden_states.dtype}")
16981699

1699-
# We can reuse the memory between these because by the time we need
1700-
# cache3, we're done with cache1
1701-
intermediate_cache1 = _resize_cache(workspace13,
1700+
# Note that the output tensor might be in workspace1
1701+
intermediate_cache1 = _resize_cache(workspace2,
17021702
(num_tokens, top_k_num, N))
1703-
intermediate_cache2 = _resize_cache(workspace2,
1703+
intermediate_cache2 = _resize_cache(workspace13,
17041704
(num_tokens * top_k_num, N // 2))
1705+
intermediate_cache3 = _resize_cache(workspace2,
1706+
(num_tokens, top_k_num, K))
17051707

17061708
sorted_token_ids, expert_ids, num_tokens_post_padded = (
17071709
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
@@ -1739,7 +1741,7 @@ def apply(
17391741

17401742
invoke_fused_moe_kernel(qintermediate_cache2,
17411743
w2,
1742-
output,
1744+
intermediate_cache3,
17431745
a2q_scale,
17441746
w2_scale,
17451747
w2_zp,
@@ -1758,6 +1760,13 @@ def apply(
17581760
per_channel_quant=self.per_act_token_quant,
17591761
block_shape=self.block_shape)
17601762

1763+
TopKWeightAndReduceContiguous().apply(
1764+
output=output,
1765+
fused_expert_output=intermediate_cache3,
1766+
topk_weights=topk_weights,
1767+
topk_ids=topk_ids,
1768+
apply_router_weight_on_input=apply_router_weight_on_input)
1769+
17611770

17621771
def modular_triton_fused_moe(
17631772
use_fp8_w8a8: bool,

0 commit comments

Comments
 (0)