22
22
)
23
23
from torch .distributed .tensor .parallel import ParallelStyle
24
24
25
- from torchtitan .tools .logging import logger
26
25
from torchtitan .tools .utils import _round_up
27
26
28
27
@@ -90,18 +89,19 @@ class ExpertParallel(ParallelStyle):
90
89
a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
91
90
"""
92
91
93
- def __init__ (self , a2a_impl : str = "default" ):
92
+ def __init__ (
93
+ self , a2a_dispatch_impl : str = "default" , a2a_combine_impl : str = "default"
94
+ ):
94
95
super ().__init__ ()
95
96
self .input_splits = None
96
97
self .output_splits = None
97
- self .a2a_func = self ._get_a2a_func (a2a_impl )
98
+ self .a2a_dispatch_func = self ._get_a2a_func (a2a_dispatch_impl )
99
+ self .a2a_combine_func = self ._get_a2a_func (a2a_combine_impl )
98
100
99
101
def _get_a2a_func (self , a2a_impl : str ):
100
102
if a2a_impl == "default" :
101
- logger .info ("Using default all-to-all implementation" )
102
103
return all_to_all_single_autograd
103
104
elif a2a_impl == "mxfp8" :
104
- logger .info ("Using mxfp8 all-to-all implementation" )
105
105
from torchao .prototype .moe_training .kernels .mxfp8 .comms import (
106
106
to_mxfp8_a2a_dequant ,
107
107
)
@@ -143,6 +143,13 @@ def _token_dispatch(self, mod, inputs, device_mesh):
143
143
self .input_splits = input_splits .tolist ()
144
144
self .output_splits = output_splits .tolist ()
145
145
146
+ routed_input = self .a2a_dispatch_func (
147
+ routed_input ,
148
+ self .output_splits ,
149
+ self .input_splits ,
150
+ device_mesh .get_group (),
151
+ )
152
+
146
153
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
147
154
# However, the num_tokens_per_expert_group is not of the final target format
148
155
# [#tokens for local expert 0, #tokens for local expert 1, ...]
@@ -152,12 +159,7 @@ def _token_dispatch(self, mod, inputs, device_mesh):
152
159
# We need to perform another shuffle to get the correct format -- this is done via the function
153
160
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
154
161
# each expert gets locally is a multiple of ALIGN_SIZE_M.
155
- routed_input = self .a2a_func (
156
- routed_input ,
157
- self .output_splits ,
158
- self .input_splits ,
159
- device_mesh .get_group (),
160
- )
162
+
161
163
return routed_input , num_tokens_per_expert_group
162
164
163
165
@staticmethod
@@ -170,7 +172,7 @@ def _partition_fn(name, mod, device_mesh):
170
172
# performing all-to-all combine on the output
171
173
def _token_combine (self , mod , routed_output , device_mesh ):
172
174
# For a2a combine, input splits and output splits are opposite of a2a dispatch.
173
- routed_output = self .a2a_func (
175
+ routed_output = self .a2a_combine_func (
174
176
routed_output ,
175
177
self .input_splits ,
176
178
self .output_splits ,
0 commit comments