Skip to content

Commit a48e631

Browse files
overallocate max output tokens per ep rank
1 parent 701f8fc commit a48e631

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,12 @@ def parallelize_llama(
9494
)
9595
maybe_enable_async_tp(job_config, world_mesh["tp"])
9696

97-
# Assume 2x tokens per EP rank in the worst case.
98-
# TODO: explore other options
97+
# Worst case = single expert receives all tokens
98+
# TODO: explore using token dropping to avoid this huge overallocation
9999
max_tokens_per_ep_rank = (
100-
job_config.training.seq_len * job_config.training.local_batch_size * 2
100+
job_config.training.seq_len
101+
* job_config.training.local_batch_size
102+
* model.model_args.moe_args.num_experts
101103
)
102104
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
103105
apply_moe_ep_tp(
@@ -504,7 +506,9 @@ def apply_moe_ep_tp(
504506
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
505507
else:
506508
experts_mesh = ep_mesh
507-
experts_plan = ExpertParallel(a2a_impl=a2a_impl)
509+
experts_plan = ExpertParallel(
510+
a2a_impl=a2a_impl, max_tokens_per_ep_rank=max_tokens_per_ep_rank
511+
)
508512

509513
parallelize_module(
510514
module=transformer_block.moe.experts,

0 commit comments

Comments
 (0)