@@ -81,76 +81,75 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
81
81
82
82
83
83
class ExpertParallel (ParallelStyle ):
84
- AllToAllImpl = Literal ["default" , "mxfp8" ]
84
+ """
85
+ ExpertParallel is a parallel style for MoE, where each experts
86
+ are distributed across ranks along a given axis of the device mesh.
85
87
86
- def __init__ (self , a2a_impl : AllToAllImpl = "default" , max_output_len_per_rank : int = - 1 ):
87
- """
88
- ExpertParallel is a parallel style for MoE, where each experts
89
- are distributed across ranks along a given axis of the device mesh.
88
+ Args:
89
+ a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
90
+ """
90
91
91
- Args:
92
- a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
93
- max_output_len_per_rank (int): The maximum length of the output tensor per rank. Default is -1. Required for mxfp8 all-to-all, otherwise not used.
94
- """
92
+ def __init__ (self , a2a_impl : str = "default" ):
95
93
super ().__init__ ()
96
94
self .input_splits = None
97
95
self .output_splits = None
98
96
self .a2a_impl = a2a_impl
99
- self .max_output_len_per_rank = max_output_len_per_rank
100
97
101
98
# performing all-to-all dispatch on the input
102
99
def _token_dispatch (self , mod , inputs , device_mesh ):
103
100
# annotate module input placements/sharding with input_layouts
104
101
routed_input , num_tokens_per_expert = inputs
105
102
ep_size = device_mesh .shape [0 ]
106
103
107
- def default_a2a (routed_input : torch .Tensor , num_tokens_per_expert : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
108
- # generate the input splits and output splits for all-to-all
109
- with torch .no_grad ():
110
- num_tokens_per_expert_group = all_to_all_single (
111
- num_tokens_per_expert ,
112
- None ,
113
- None ,
114
- group = device_mesh .get_group (),
115
- )
116
- # Need to wait explicitly because it is used by a triton kernel later
117
- # which doesn't realize that AsyncCollectiveTensor needs unwrapping
118
- num_tokens_per_expert_group = torch .ops ._c10d_functional .wait_tensor (
119
- num_tokens_per_expert_group
120
- )
121
- input_splits = (
122
- num_tokens_per_expert .view (ep_size , - 1 )
123
- .sum (dim = 1 )
124
- .to (torch .device ("cpu" ), non_blocking = True )
125
- )
126
- # NOTE: this would incur a device-to-host sync
127
- output_splits = (
128
- num_tokens_per_expert_group .view (ep_size , - 1 )
129
- .sum (dim = 1 )
130
- .to (torch .device ("cpu" ), non_blocking = False )
131
- )
132
- self .input_splits = input_splits .tolist ()
133
- self .output_splits = output_splits .tolist ()
104
+ # generate the input splits and output splits for all-to-all
105
+ with torch .no_grad ():
106
+ num_tokens_per_expert_group = all_to_all_single (
107
+ num_tokens_per_expert ,
108
+ None ,
109
+ None ,
110
+ group = device_mesh .get_group (),
111
+ )
112
+ # Need to wait explicitly because it is used by a triton kernel later
113
+ # which doesn't realize that AsyncCollectiveTensor needs unwrapping
114
+ num_tokens_per_expert_group = torch .ops ._c10d_functional .wait_tensor (
115
+ num_tokens_per_expert_group
116
+ )
117
+ input_splits = (
118
+ num_tokens_per_expert .view (ep_size , - 1 )
119
+ .sum (dim = 1 )
120
+ .to (torch .device ("cpu" ), non_blocking = True )
121
+ )
122
+ # NOTE: this would incur a device-to-host sync
123
+ output_splits = (
124
+ num_tokens_per_expert_group .view (ep_size , - 1 )
125
+ .sum (dim = 1 )
126
+ .to (torch .device ("cpu" ), non_blocking = False )
127
+ )
128
+ self .input_splits = input_splits .tolist ()
129
+ self .output_splits = output_splits .tolist ()
130
+
131
+ # EP degree is small, doing max() on the host is fine since the data is already on the host after *.tolist() d2h sync above.
132
+ self .max_output_tokens_per_ep_rank = max (self .output_splits )
133
+ self .input_splits_on_device = input_splits .to (routed_input .device )
134
134
135
+ if self .a2a_impl == "mxfp8" :
136
+ from torchao .prototype .moe_training .kernels .mxfp8 .comms import (
137
+ mxfp8_on_device_all_to_all_v ,
138
+ )
139
+ routed_input , _ = mxfp8_on_device_all_to_all_v (
140
+ routed_input ,
141
+ self .input_splits_on_device ,
142
+ self .max_output_tokens_per_ep_rank ,
143
+ device_mesh .get_group ().group_name ,
144
+ )
145
+ else :
135
146
# perform all-to-all
136
147
routed_input = all_to_all_single_autograd (
137
148
routed_input ,
138
149
self .output_splits ,
139
150
self .input_splits ,
140
151
device_mesh .get_group (),
141
152
)
142
- return routed_input , num_tokens_per_expert_group
143
-
144
- def mxfp8_a2a (routed_input : torch .Tensor , num_tokens_per_expert : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
145
- from torchao .prototype .moe_training .kernels .mxfp8 .comms import mxfp8_on_device_all_to_all_v
146
- assert self .max_output_len_per_rank > 0 , "max_output_len_per_rank must be positive for mxfp8 all-to-all"
147
- routed_input , num_tokens_per_expert_group = mxfp8_on_device_all_to_all_v (
148
- routed_input ,
149
- num_tokens_per_expert ,
150
- self .max_output_len_per_rank ,
151
- device_mesh .get_group ().group_name ,
152
- )
153
- return routed_input , num_tokens_per_expert_group
154
153
155
154
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
156
155
# However, the num_tokens_per_expert_group is not of the final target format
@@ -173,12 +172,24 @@ def _partition_fn(name, mod, device_mesh):
173
172
174
173
# performing all-to-all combine on the output
175
174
def _token_combine (self , mod , routed_output , device_mesh ):
176
- routed_output = all_to_all_single_autograd (
177
- routed_output ,
178
- self .input_splits ,
179
- self .output_splits ,
180
- device_mesh .get_group (),
181
- )
175
+ if self .a2a_impl == "mxfp8" :
176
+ from torchao .prototype .moe_training .kernels .mxfp8 .comms import (
177
+ mxfp8_on_device_all_to_all_v ,
178
+ )
179
+
180
+ routed_output , _ = mxfp8_on_device_all_to_all_v (
181
+ routed_output ,
182
+ self .input_splits_on_device ,
183
+ self .max_output_tokens_per_ep_rank ,
184
+ device_mesh .get_group ().group_name ,
185
+ )
186
+ else :
187
+ routed_output = all_to_all_single_autograd (
188
+ routed_output ,
189
+ self .input_splits ,
190
+ self .output_splits ,
191
+ device_mesh .get_group (),
192
+ )
182
193
return routed_output
183
194
184
195
def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
@@ -190,6 +201,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
190
201
output_fn = self ._token_combine ,
191
202
)
192
203
204
+
193
205
# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
194
206
class ExpertTensorParallel (ExpertParallel ):
195
207
def __init__ (
0 commit comments