Skip to content

Commit 4527f8b

Browse files
preallocated buffer not big enough
1 parent 24ebc9b commit 4527f8b

File tree

2 files changed

+71
-60
lines changed

2 files changed

+71
-60
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 67 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -81,76 +81,75 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
8181

8282

8383
class ExpertParallel(ParallelStyle):
84-
AllToAllImpl = Literal["default", "mxfp8"]
84+
"""
85+
ExpertParallel is a parallel style for MoE, where each experts
86+
are distributed across ranks along a given axis of the device mesh.
8587
86-
def __init__(self, a2a_impl: AllToAllImpl = "default", max_output_len_per_rank: int = -1):
87-
"""
88-
ExpertParallel is a parallel style for MoE, where each experts
89-
are distributed across ranks along a given axis of the device mesh.
88+
Args:
89+
a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
90+
"""
9091

91-
Args:
92-
a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
93-
max_output_len_per_rank (int): The maximum length of the output tensor per rank. Default is -1. Required for mxfp8 all-to-all, otherwise not used.
94-
"""
92+
def __init__(self, a2a_impl: str = "default"):
9593
super().__init__()
9694
self.input_splits = None
9795
self.output_splits = None
9896
self.a2a_impl = a2a_impl
99-
self.max_output_len_per_rank = max_output_len_per_rank
10097

10198
# performing all-to-all dispatch on the input
10299
def _token_dispatch(self, mod, inputs, device_mesh):
103100
# annotate module input placements/sharding with input_layouts
104101
routed_input, num_tokens_per_expert = inputs
105102
ep_size = device_mesh.shape[0]
106103

107-
def default_a2a(routed_input: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
108-
# generate the input splits and output splits for all-to-all
109-
with torch.no_grad():
110-
num_tokens_per_expert_group = all_to_all_single(
111-
num_tokens_per_expert,
112-
None,
113-
None,
114-
group=device_mesh.get_group(),
115-
)
116-
# Need to wait explicitly because it is used by a triton kernel later
117-
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
118-
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
119-
num_tokens_per_expert_group
120-
)
121-
input_splits = (
122-
num_tokens_per_expert.view(ep_size, -1)
123-
.sum(dim=1)
124-
.to(torch.device("cpu"), non_blocking=True)
125-
)
126-
# NOTE: this would incur a device-to-host sync
127-
output_splits = (
128-
num_tokens_per_expert_group.view(ep_size, -1)
129-
.sum(dim=1)
130-
.to(torch.device("cpu"), non_blocking=False)
131-
)
132-
self.input_splits = input_splits.tolist()
133-
self.output_splits = output_splits.tolist()
104+
# generate the input splits and output splits for all-to-all
105+
with torch.no_grad():
106+
num_tokens_per_expert_group = all_to_all_single(
107+
num_tokens_per_expert,
108+
None,
109+
None,
110+
group=device_mesh.get_group(),
111+
)
112+
# Need to wait explicitly because it is used by a triton kernel later
113+
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
114+
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
115+
num_tokens_per_expert_group
116+
)
117+
input_splits = (
118+
num_tokens_per_expert.view(ep_size, -1)
119+
.sum(dim=1)
120+
.to(torch.device("cpu"), non_blocking=True)
121+
)
122+
# NOTE: this would incur a device-to-host sync
123+
output_splits = (
124+
num_tokens_per_expert_group.view(ep_size, -1)
125+
.sum(dim=1)
126+
.to(torch.device("cpu"), non_blocking=False)
127+
)
128+
self.input_splits = input_splits.tolist()
129+
self.output_splits = output_splits.tolist()
130+
131+
# EP degree is small, doing max() on the host is fine since the data is already on the host after *.tolist() d2h sync above.
132+
self.max_output_tokens_per_ep_rank = max(self.output_splits)
133+
self.input_splits_on_device = input_splits.to(routed_input.device)
134134

135+
if self.a2a_impl == "mxfp8":
136+
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
137+
mxfp8_on_device_all_to_all_v,
138+
)
139+
routed_input, _ = mxfp8_on_device_all_to_all_v(
140+
routed_input,
141+
self.input_splits_on_device,
142+
self.max_output_tokens_per_ep_rank,
143+
device_mesh.get_group().group_name,
144+
)
145+
else:
135146
# perform all-to-all
136147
routed_input = all_to_all_single_autograd(
137148
routed_input,
138149
self.output_splits,
139150
self.input_splits,
140151
device_mesh.get_group(),
141152
)
142-
return routed_input, num_tokens_per_expert_group
143-
144-
def mxfp8_a2a(routed_input: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
145-
from torchao.prototype.moe_training.kernels.mxfp8.comms import mxfp8_on_device_all_to_all_v
146-
assert self.max_output_len_per_rank > 0, "max_output_len_per_rank must be positive for mxfp8 all-to-all"
147-
routed_input, num_tokens_per_expert_group = mxfp8_on_device_all_to_all_v(
148-
routed_input,
149-
num_tokens_per_expert,
150-
self.max_output_len_per_rank,
151-
device_mesh.get_group().group_name,
152-
)
153-
return routed_input, num_tokens_per_expert_group
154153

155154
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
156155
# However, the num_tokens_per_expert_group is not of the final target format
@@ -173,12 +172,24 @@ def _partition_fn(name, mod, device_mesh):
173172

174173
# performing all-to-all combine on the output
175174
def _token_combine(self, mod, routed_output, device_mesh):
176-
routed_output = all_to_all_single_autograd(
177-
routed_output,
178-
self.input_splits,
179-
self.output_splits,
180-
device_mesh.get_group(),
181-
)
175+
if self.a2a_impl == "mxfp8":
176+
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
177+
mxfp8_on_device_all_to_all_v,
178+
)
179+
180+
routed_output, _ = mxfp8_on_device_all_to_all_v(
181+
routed_output,
182+
self.input_splits_on_device,
183+
self.max_output_tokens_per_ep_rank,
184+
device_mesh.get_group().group_name,
185+
)
186+
else:
187+
routed_output = all_to_all_single_autograd(
188+
routed_output,
189+
self.input_splits,
190+
self.output_splits,
191+
device_mesh.get_group(),
192+
)
182193
return routed_output
183194

184195
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
@@ -190,6 +201,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
190201
output_fn=self._token_combine,
191202
)
192203

204+
193205
# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
194206
class ExpertTensorParallel(ExpertParallel):
195207
def __init__(

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def parallelize_llama(
9494
)
9595
maybe_enable_async_tp(job_config, world_mesh["tp"])
9696

97-
max_output_len_per_rank = 0 # TODO
9897
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
9998
apply_moe_ep_tp(
10099
model,
@@ -108,7 +107,7 @@ def parallelize_llama(
108107
else None
109108
),
110109
etp_enabled=parallel_dims.etp_enabled,
111-
max_output_len_per_rank=max_output_len_per_rank,
110+
a2a_impl=job_config.parallelism.expert_parallel_a2a_impl,
112111
)
113112

114113
model_compile_enabled = (
@@ -440,7 +439,7 @@ def apply_moe_ep_tp(
440439
ep_mesh: DeviceMesh | None,
441440
ep_tp_mesh: DeviceMesh | None,
442441
etp_enabled: bool,
443-
max_output_len_per_rank: int,
442+
a2a_impl: str,
444443
):
445444
for transformer_block in model.layers.values():
446445
if not transformer_block.moe_enabled:
@@ -490,13 +489,13 @@ def apply_moe_ep_tp(
490489
elif tp_mesh is None:
491490
experts_mesh = ep_mesh
492491
# input / output sharding on the batch / tokens dim
493-
experts_plan = ExpertParallel()
492+
experts_plan = ExpertParallel(a2a_impl=a2a_impl)
494493
elif etp_enabled:
495494
experts_mesh = ep_tp_mesh
496495
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
497496
else:
498497
experts_mesh = ep_mesh
499-
experts_plan = ExpertParallel()
498+
experts_plan = ExpertParallel(a2a_impl=a2a_impl)
500499

501500
parallelize_module(
502501
module=transformer_block.moe.experts,

0 commit comments

Comments
 (0)