@@ -90,11 +90,14 @@ class ExpertParallel(ParallelStyle):
90
90
a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
91
91
"""
92
92
93
- def __init__ (self , a2a_impl : str = "default" ):
93
+ def __init__ (
94
+ self , a2a_dispatch_impl : str = "default" , a2a_combine_impl : str = "default"
95
+ ):
94
96
super ().__init__ ()
95
97
self .input_splits = None
96
98
self .output_splits = None
97
- self .a2a_func = self ._get_a2a_func (a2a_impl )
99
+ self .a2a_dispatch_func = self ._get_a2a_func (a2a_dispatch_impl )
100
+ self .a2a_combine_func = self ._get_a2a_func (a2a_combine_impl )
98
101
99
102
def _get_a2a_func (self , a2a_impl : str ):
100
103
if a2a_impl == "default" :
@@ -143,6 +146,13 @@ def _token_dispatch(self, mod, inputs, device_mesh):
143
146
self .input_splits = input_splits .tolist ()
144
147
self .output_splits = output_splits .tolist ()
145
148
149
+ routed_input = self .a2a_dispatch_func (
150
+ routed_input ,
151
+ self .output_splits ,
152
+ self .input_splits ,
153
+ device_mesh .get_group (),
154
+ )
155
+
146
156
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
147
157
# However, the num_tokens_per_expert_group is not of the final target format
148
158
# [#tokens for local expert 0, #tokens for local expert 1, ...]
@@ -152,12 +162,7 @@ def _token_dispatch(self, mod, inputs, device_mesh):
152
162
# We need to perform another shuffle to get the correct format -- this is done via the function
153
163
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
154
164
# each expert gets locally is a multiple of ALIGN_SIZE_M.
155
- routed_input = self .a2a_func (
156
- routed_input ,
157
- self .output_splits ,
158
- self .input_splits ,
159
- device_mesh .get_group (),
160
- )
165
+
161
166
return routed_input , num_tokens_per_expert_group
162
167
163
168
@staticmethod
@@ -170,7 +175,7 @@ def _partition_fn(name, mod, device_mesh):
170
175
# performing all-to-all combine on the output
171
176
def _token_combine (self , mod , routed_output , device_mesh ):
172
177
# For a2a combine, input splits and output splits are opposite of a2a dispatch.
173
- routed_output = self .a2a_func (
178
+ routed_output = self .a2a_combine_func (
174
179
routed_output ,
175
180
self .input_splits ,
176
181
self .output_splits ,
0 commit comments