Skip to content

Commit 118f9e7

Browse files
committed
DeepEP LL combine FP4
Signed-off-by: Yilin Zhang <[email protected]>
1 parent 2923eb8 commit 118f9e7

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

cpp/tensorrt_llm/deep_ep/CMakeLists.txt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
set(DEEP_EP_COMMIT edf3ea2b086a393d3163bf2773eab69d9191cc01)
1+
set(DEEP_EP_COMMIT 515a311f290eb6d9592fcccfcc80c40f5123ca72)
22
set(NVSHMEM_URL_HASH
33
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)
44

@@ -19,8 +19,15 @@ foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
1919
set(CUDA_ARCH_MINOR ${CMAKE_MATCH_2})
2020
set(CUDA_ARCH_POSTFIX ${CMAKE_MATCH_3})
2121
if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 9)
22-
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
23-
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}")
22+
# The FP4-related conversion instructions in DeepEP require SM100a, SM110a,
23+
# or SM120a.
24+
if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 10 AND ${CUDA_ARCH_MINOR} EQUAL 0)
25+
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
26+
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}a${CUDA_ARCH_POSTFIX}")
27+
else()
28+
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
29+
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}")
30+
endif()
2431
endif()
2532
endforeach()
2633

tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,34 +154,40 @@ def low_latency_dispatch(self, hidden_states: torch.Tensor,
154154
# Later, you can use our GEMM library to do the computation with this specific format
155155
return recv_hidden_states, recv_expert_count, handle
156156

157+
def low_latency_combine(self, hidden_states: torch.Tensor,
158+
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
159+
handle: Tuple):
160+
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
161+
combined_hidden_states, event, hook = \
162+
self.buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle)
163+
assert event.event is None
164+
assert hook is None
165+
166+
# NOTES: the same behavior as described in the dispatch kernel
167+
return combined_hidden_states
168+
157169
def low_latency_dispatch_fp4(self, hidden_states: torch.Tensor,
158170
scales: torch.Tensor, topk_idx: torch.Tensor,
159171
num_max_dispatch_tokens_per_rank: int,
160172
num_experts: int):
161173
assert num_experts == self.num_experts
162174

163-
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
164175
recv_hidden_states, recv_scales, recv_expert_count, handle, event, hook = \
165176
self.buffer.low_latency_dispatch_fp4(hidden_states, scales, topk_idx, num_max_dispatch_tokens_per_rank, num_experts)
166177
assert event.event is None
167178
assert hook is None
168179

169-
# NOTES: the actual tensor will not be received only if you call `hook()`,
170-
# it is useful for double-batch overlapping, but **without any SM occupation**
171-
# If you don't want to overlap, please set `return_recv_hook=False`
172-
# Later, you can use our GEMM library to do the computation with this specific format
173180
return recv_hidden_states, recv_scales, recv_expert_count, handle
174181

175-
def low_latency_combine(self, hidden_states: torch.Tensor,
176-
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
177-
handle: Tuple):
178-
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
182+
def low_latency_combine_fp4(self, hidden_states: torch.Tensor,
183+
global_scales: torch.Tensor,
184+
topk_idx: torch.Tensor,
185+
topk_weights: torch.Tensor, handle: Tuple):
179186
combined_hidden_states, event, hook = \
180-
self.buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle)
187+
self.buffer.low_latency_combine_fp4(hidden_states, global_scales, topk_idx, topk_weights, handle)
181188
assert event.event is None
182189
assert hook is None
183190

184-
# NOTES: the same behavior as described in the dispatch kernel
185191
return combined_hidden_states
186192

187193
def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,15 @@ def __init__(
184184
f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}",
185185
key="alltoall_method_type")
186186
self.use_postquant_alltoall = False
187+
self.use_low_precision_combine = False
187188
if self.enable_alltoall:
188189
qm = self.quant_config.quant_mode
189190
self.use_postquant_alltoall = (os.environ.get(
190191
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1")
191192
== "1") and qm.has_nvfp4()
193+
self.use_low_precision_combine = (os.environ.get(
194+
"TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0")
195+
== "1") and qm.has_nvfp4()
192196
# TODO: support alltoall without allgather for top_k % 4 != 0
193197
self.enable_alltoall_without_allgather = (
194198
os.environ.get("TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER",
@@ -685,9 +689,16 @@ def forward_chunk(
685689
final_hidden_states = final_hidden_states.view(
686690
self.expert_size_per_partition,
687691
num_tokens_per_expert_for_fused_moe, self.hidden_size)
688-
final_hidden_states = self.deep_ep_buffer.low_latency_combine(
689-
final_hidden_states, deep_ep_topk_idx, deep_ep_topk_weights,
690-
deep_ep_handle)
692+
if self.use_low_precision_combine:
693+
global_scales = (448 * 6) / final_hidden_states.abs().max(
694+
dim=-1, keepdim=True).values.to(torch.float32)
695+
final_hidden_states = self.deep_ep_buffer.low_latency_combine_fp4(
696+
final_hidden_states, global_scales, deep_ep_topk_idx,
697+
deep_ep_topk_weights, deep_ep_handle)
698+
else:
699+
final_hidden_states = self.deep_ep_buffer.low_latency_combine(
700+
final_hidden_states, deep_ep_topk_idx,
701+
deep_ep_topk_weights, deep_ep_handle)
691702
else:
692703
raise NotImplementedError(
693704
f"Not available alltoall method type: {self.alltoall_method_type!r}"

0 commit comments

Comments
 (0)