Skip to content

Commit af79e9a

Browse files
a2a dispatch and combine configurable separately
1 parent 4d51ae2 commit af79e9a

File tree

3 files changed

+35
-16
lines changed

3 files changed

+35
-16
lines changed

torchtitan/config/job_config.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,16 @@ class Parallelism:
399399
Note that this is still an experimental feature.
400400
"""
401401

402-
expert_parallel_a2a_impl: Literal["default", "mxfp8"] = "default"
402+
expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default"
403403
"""
404-
MXFP8 all-to-all removes the need for device-to-host sync and optimizes network bandwidth usage
405-
by using dynamic MXFP8 quantization on the all-to-all inputs, then dequantizes the outputs.
404+
MXFP8 all-to-all optimizes network bandwidth usage by using dynamic MXFP8 quantization on the all-to-all
405+
inputs, then dequantizing the outputs.
406+
"""
407+
408+
expert_parallel_a2a_combine_impl: Literal["default", "mxfp8"] = "default"
409+
"""
410+
MXFP8 all-to-all optimizes network bandwidth usage by using dynamic MXFP8 quantization on the all-to-all
411+
inputs, then dequantizing the outputs.
406412
"""
407413

408414

torchtitan/distributed/expert_parallel.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,14 @@ class ExpertParallel(ParallelStyle):
9090
a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
9191
"""
9292

93-
def __init__(self, a2a_impl: str = "default"):
93+
def __init__(
94+
self, a2a_dispatch_impl: str = "default", a2a_combine_impl: str = "default"
95+
):
9496
super().__init__()
9597
self.input_splits = None
9698
self.output_splits = None
97-
self.a2a_func = self._get_a2a_func(a2a_impl)
99+
self.a2a_dispatch_func = self._get_a2a_func(a2a_dispatch_impl)
100+
self.a2a_combine_func = self._get_a2a_func(a2a_combine_impl)
98101

99102
def _get_a2a_func(self, a2a_impl: str):
100103
if a2a_impl == "default":
@@ -143,6 +146,13 @@ def _token_dispatch(self, mod, inputs, device_mesh):
143146
self.input_splits = input_splits.tolist()
144147
self.output_splits = output_splits.tolist()
145148

149+
routed_input = self.a2a_dispatch_func(
150+
routed_input,
151+
self.output_splits,
152+
self.input_splits,
153+
device_mesh.get_group(),
154+
)
155+
146156
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
147157
# However, the num_tokens_per_expert_group is not of the final target format
148158
# [#tokens for local expert 0, #tokens for local expert 1, ...]
@@ -152,12 +162,7 @@ def _token_dispatch(self, mod, inputs, device_mesh):
152162
# We need to perform another shuffle to get the correct format -- this is done via the function
153163
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
154164
# each expert gets locally is a multiple of ALIGN_SIZE_M.
155-
routed_input = self.a2a_func(
156-
routed_input,
157-
self.output_splits,
158-
self.input_splits,
159-
device_mesh.get_group(),
160-
)
165+
161166
return routed_input, num_tokens_per_expert_group
162167

163168
@staticmethod
@@ -170,7 +175,7 @@ def _partition_fn(name, mod, device_mesh):
170175
# performing all-to-all combine on the output
171176
def _token_combine(self, mod, routed_output, device_mesh):
172177
# For a2a combine, input splits and output splits are opposite of a2a dispatch.
173-
routed_output = self.a2a_func(
178+
routed_output = self.a2a_combine_func(
174179
routed_output,
175180
self.input_splits,
176181
self.output_splits,

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def parallelize_llama(
107107
else None
108108
),
109109
etp_enabled=parallel_dims.etp_enabled,
110-
a2a_impl=job_config.parallelism.expert_parallel_a2a_impl,
110+
a2a_dispatch_impl=job_config.parallelism.expert_parallel_a2a_dispatch_impl,
111+
a2a_combine_impl=job_config.parallelism.expert_parallel_a2a_combine_impl,
111112
)
112113

113114
model_compile_enabled = (
@@ -439,7 +440,8 @@ def apply_moe_ep_tp(
439440
ep_mesh: DeviceMesh | None,
440441
ep_tp_mesh: DeviceMesh | None,
441442
etp_enabled: bool,
442-
a2a_impl: str = "default",
443+
a2a_dispatch_impl: str = "default",
444+
a2a_combine_impl: str = "default",
443445
):
444446
for transformer_block in model.layers.values():
445447
if not transformer_block.moe_enabled:
@@ -489,13 +491,19 @@ def apply_moe_ep_tp(
489491
elif tp_mesh is None:
490492
experts_mesh = ep_mesh
491493
# input / output sharding on the batch / tokens dim
492-
experts_plan = ExpertParallel(a2a_impl=a2a_impl)
494+
experts_plan = ExpertParallel(
495+
a2a_dispatch_impl=a2a_dispatch_impl,
496+
a2a_combine_impl=a2a_combine_impl,
497+
)
493498
elif etp_enabled:
494499
experts_mesh = ep_tp_mesh
495500
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
496501
else:
497502
experts_mesh = ep_mesh
498-
experts_plan = ExpertParallel(a2a_impl=a2a_impl)
503+
experts_plan = ExpertParallel(
504+
a2a_dispatch_impl=a2a_dispatch_impl,
505+
a2a_combine_impl=a2a_combine_impl,
506+
)
499507

500508
parallelize_module(
501509
module=transformer_block.moe.experts,

0 commit comments

Comments
 (0)