Skip to content

Commit bdbdd45

Browse files
a2a dispatch and combine configurable separately
1 parent 4d51ae2 commit bdbdd45

File tree

3 files changed

+37
-19
lines changed

3 files changed

+37
-19
lines changed

torchtitan/config/job_config.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,16 @@ class Parallelism:
399399
Note that this is still an experimental feature.
400400
"""
401401

402-
expert_parallel_a2a_impl: Literal["default", "mxfp8"] = "default"
402+
expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default"
403403
"""
404-
MXFP8 all-to-all removes the need for device-to-host sync and optimizes network bandwidth usage
405-
by using dynamic MXFP8 quantization on the all-to-all inputs, then dequantizes the outputs.
404+
MXFP8 all-to-all optimizes network bandwidth usage by using dynamic MXFP8 quantization on the all-to-all
405+
inputs, then dequantizing the outputs.
406+
"""
407+
408+
expert_parallel_a2a_combine_impl: Literal["default", "mxfp8"] = "default"
409+
"""
410+
MXFP8 all-to-all optimizes network bandwidth usage by using dynamic MXFP8 quantization on the all-to-all
411+
inputs, then dequantizing the outputs.
406412
"""
407413

408414

torchtitan/distributed/expert_parallel.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from torch.distributed.tensor.parallel import ParallelStyle
2424

25-
from torchtitan.tools.logging import logger
2625
from torchtitan.tools.utils import _round_up
2726

2827

@@ -90,18 +89,19 @@ class ExpertParallel(ParallelStyle):
9089
a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
9190
"""
9291

93-
def __init__(self, a2a_impl: str = "default"):
92+
def __init__(
93+
self, a2a_dispatch_impl: str = "default", a2a_combine_impl: str = "default"
94+
):
9495
super().__init__()
9596
self.input_splits = None
9697
self.output_splits = None
97-
self.a2a_func = self._get_a2a_func(a2a_impl)
98+
self.a2a_dispatch_func = self._get_a2a_func(a2a_dispatch_impl)
99+
self.a2a_combine_func = self._get_a2a_func(a2a_combine_impl)
98100

99101
def _get_a2a_func(self, a2a_impl: str):
100102
if a2a_impl == "default":
101-
logger.info("Using default all-to-all implementation")
102103
return all_to_all_single_autograd
103104
elif a2a_impl == "mxfp8":
104-
logger.info("Using mxfp8 all-to-all implementation")
105105
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
106106
to_mxfp8_a2a_dequant,
107107
)
@@ -143,6 +143,13 @@ def _token_dispatch(self, mod, inputs, device_mesh):
143143
self.input_splits = input_splits.tolist()
144144
self.output_splits = output_splits.tolist()
145145

146+
routed_input = self.a2a_dispatch_func(
147+
routed_input,
148+
self.output_splits,
149+
self.input_splits,
150+
device_mesh.get_group(),
151+
)
152+
146153
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
147154
# However, the num_tokens_per_expert_group is not of the final target format
148155
# [#tokens for local expert 0, #tokens for local expert 1, ...]
@@ -152,12 +159,7 @@ def _token_dispatch(self, mod, inputs, device_mesh):
152159
# We need to perform another shuffle to get the correct format -- this is done via the function
153160
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
154161
# each expert gets locally is a multiple of ALIGN_SIZE_M.
155-
routed_input = self.a2a_func(
156-
routed_input,
157-
self.output_splits,
158-
self.input_splits,
159-
device_mesh.get_group(),
160-
)
162+
161163
return routed_input, num_tokens_per_expert_group
162164

163165
@staticmethod
@@ -170,7 +172,7 @@ def _partition_fn(name, mod, device_mesh):
170172
# performing all-to-all combine on the output
171173
def _token_combine(self, mod, routed_output, device_mesh):
172174
# For a2a combine, input splits and output splits are opposite of a2a dispatch.
173-
routed_output = self.a2a_func(
175+
routed_output = self.a2a_combine_func(
174176
routed_output,
175177
self.input_splits,
176178
self.output_splits,

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def parallelize_llama(
107107
else None
108108
),
109109
etp_enabled=parallel_dims.etp_enabled,
110-
a2a_impl=job_config.parallelism.expert_parallel_a2a_impl,
110+
a2a_dispatch_impl=job_config.parallelism.expert_parallel_a2a_dispatch_impl,
111+
a2a_combine_impl=job_config.parallelism.expert_parallel_a2a_combine_impl,
111112
)
112113

113114
model_compile_enabled = (
@@ -439,8 +440,11 @@ def apply_moe_ep_tp(
439440
ep_mesh: DeviceMesh | None,
440441
ep_tp_mesh: DeviceMesh | None,
441442
etp_enabled: bool,
442-
a2a_impl: str = "default",
443+
a2a_dispatch_impl: str = "default",
444+
a2a_combine_impl: str = "default",
443445
):
446+
logger.info(f"Using all-to-all dispatch: {a2a_dispatch_impl}")
447+
logger.info(f"Using all-to-all combine: {a2a_combine_impl}")
444448
for transformer_block in model.layers.values():
445449
if not transformer_block.moe_enabled:
446450
continue
@@ -489,13 +493,19 @@ def apply_moe_ep_tp(
489493
elif tp_mesh is None:
490494
experts_mesh = ep_mesh
491495
# input / output sharding on the batch / tokens dim
492-
experts_plan = ExpertParallel(a2a_impl=a2a_impl)
496+
experts_plan = ExpertParallel(
497+
a2a_dispatch_impl=a2a_dispatch_impl,
498+
a2a_combine_impl=a2a_combine_impl,
499+
)
493500
elif etp_enabled:
494501
experts_mesh = ep_tp_mesh
495502
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
496503
else:
497504
experts_mesh = ep_mesh
498-
experts_plan = ExpertParallel(a2a_impl=a2a_impl)
505+
experts_plan = ExpertParallel(
506+
a2a_dispatch_impl=a2a_dispatch_impl,
507+
a2a_combine_impl=a2a_combine_impl,
508+
)
499509

500510
parallelize_module(
501511
module=transformer_block.moe.experts,

0 commit comments

Comments
 (0)