Skip to content

Commit 37d06c1

Browse files
committed
apply
1 parent 04410f2 commit 37d06c1

File tree

1 file changed

+12
-30
lines changed

1 file changed

+12
-30
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -588,43 +588,25 @@ def forward_chunk(
588588
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
589589
self.scaling_vector_size)
590590
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
591+
token_num = x_col
592+
hidden_size = x_row
591593
assert x_sf is not None and self.has_nvfp4
592-
token_num = x_row
593-
hidden_size = x_col
594594
assert hidden_size % 32 == 0
595-
x_sf_dtype = x_sf.dtype
596-
x_dtype = x.dtype
597-
assert x_sf_dtype == torch.uint8 and x_dtype == torch.uint8
598-
x_sf = x_sf.view(torch.bfloat16)
595+
assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8
599596
assert x_sf.shape[0] == token_num and x_sf.shape[
600597
1] == hidden_size // 16 // 2
601-
x = x.view(torch.bfloat16)
602-
assert x.shape[0] == token_num and x.shape[1] == hidden_size // 4
603-
# DeepEP LL dispatch only supports bf16 tensors with a hidden size of 2560, 4096, 5120, or 7168 as input. A hidden size of 2560 is sufficient to accommodate packed FP4 data.
604-
packed_hidden_size = 2560
605-
assert x.shape[1] + x_sf.shape[1] <= packed_hidden_size
606-
fp4_packed_tensor = torch.empty((token_num, packed_hidden_size),
607-
dtype=torch.bfloat16,
608-
device=x.device)
609-
fp4_packed_tensor[:, :x.shape[1]] = x
610-
fp4_packed_tensor[:,
611-
x.shape[1]:x.shape[1] + x_sf.shape[1]] = x_sf
598+
assert x.shape[0] == token_num and x.shape[1] == hidden_size // 2
612599

613600
deep_ep_topk_idx = token_selected_slots
614601
deep_ep_topk_weights = token_final_scales
615602

616603
assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens
617-
fp4_packed_tensor, recv_expert_count, deep_ep_handle = \
618-
self.deep_ep_buffer.low_latency_dispatch(fp4_packed_tensor, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
619-
deep_ep_handle = list(deep_ep_handle)
620-
deep_ep_handle[3] = hidden_size
621-
deep_ep_handle = tuple(deep_ep_handle)
622-
623-
assert fp4_packed_tensor.ndim == 3 and fp4_packed_tensor.shape[
624-
2] == packed_hidden_size
625-
x_sf = fp4_packed_tensor[:, :, x.shape[1]:x.shape[1] +
626-
x_sf.shape[1]].contiguous()
627-
x = fp4_packed_tensor[:, :, :x.shape[1]].contiguous()
604+
x, x_sf, recv_expert_count, deep_ep_handle = \
605+
self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
606+
assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8
607+
assert x.dim() == 3 and x_sf.dim() == 3
608+
assert x.shape[2] == hidden_size // 2 and x_sf.shape[2] == hidden_size // 16 // 2
609+
628610
mask = torch.arange(
629611
x.shape[1], dtype=torch.int32, device=x.device).expand(
630612
x.shape[0], x.shape[1]) < recv_expert_count.unsqueeze(1)
@@ -634,9 +616,9 @@ def forward_chunk(
634616
x.shape[0] * (self.mapping.moe_ep_rank + 1),
635617
dtype=torch.int32,
636618
device=x.device).unsqueeze(1), self.num_slots)
637-
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]).view(x_dtype)
619+
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
638620
x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1],
639-
x_sf.shape[2]).view(x_sf_dtype)
621+
x_sf.shape[2])
640622
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
641623
self.scaling_vector_size)
642624
token_selected_slots = token_selected_slots.view(x.shape[0], 1)

0 commit comments

Comments
 (0)