@@ -154,34 +154,40 @@ def low_latency_dispatch(self, hidden_states: torch.Tensor,
154
154
# Later, you can use our GEMM library to do the computation with this specific format
155
155
return recv_hidden_states , recv_expert_count , handle
156
156
157
+ def low_latency_combine (self , hidden_states : torch .Tensor ,
158
+ topk_idx : torch .Tensor , topk_weights : torch .Tensor ,
159
+ handle : Tuple ):
160
+ # Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
161
+ combined_hidden_states , event , hook = \
162
+ self .buffer .low_latency_combine (hidden_states , topk_idx , topk_weights , handle )
163
+ assert event .event is None
164
+ assert hook is None
165
+
166
+ # NOTES: the same behavior as described in the dispatch kernel
167
+ return combined_hidden_states
168
+
157
169
def low_latency_dispatch_fp4 (self , hidden_states : torch .Tensor ,
158
170
scales : torch .Tensor , topk_idx : torch .Tensor ,
159
171
num_max_dispatch_tokens_per_rank : int ,
160
172
num_experts : int ):
161
173
assert num_experts == self .num_experts
162
174
163
- # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
164
175
recv_hidden_states , recv_scales , recv_expert_count , handle , event , hook = \
165
176
self .buffer .low_latency_dispatch_fp4 (hidden_states , scales , topk_idx , num_max_dispatch_tokens_per_rank , num_experts )
166
177
assert event .event is None
167
178
assert hook is None
168
179
169
- # NOTES: the actual tensor will not be received only if you call `hook()`,
170
- # it is useful for double-batch overlapping, but **without any SM occupation**
171
- # If you don't want to overlap, please set `return_recv_hook=False`
172
- # Later, you can use our GEMM library to do the computation with this specific format
173
180
return recv_hidden_states , recv_scales , recv_expert_count , handle
174
181
175
- def low_latency_combine (self , hidden_states : torch .Tensor ,
176
- topk_idx : torch . Tensor , topk_weights : torch .Tensor ,
177
- handle : Tuple ):
178
- # Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
182
+ def low_latency_combine_fp4 (self , hidden_states : torch .Tensor ,
183
+ global_scales : torch .Tensor ,
184
+ topk_idx : torch . Tensor ,
185
+ topk_weights : torch . Tensor , handle : Tuple ):
179
186
combined_hidden_states , event , hook = \
180
- self .buffer .low_latency_combine (hidden_states , topk_idx , topk_weights , handle )
187
+ self .buffer .low_latency_combine_fp4 (hidden_states , global_scales , topk_idx , topk_weights , handle )
181
188
assert event .event is None
182
189
assert hook is None
183
190
184
- # NOTES: the same behavior as described in the dispatch kernel
185
191
return combined_hidden_states
186
192
187
193
def clean_low_latency_buffer (self , num_max_dispatch_tokens_per_rank : int ,
0 commit comments