File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed
torchtitan/experiments/llama4/infra Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -94,10 +94,12 @@ def parallelize_llama(
94
94
)
95
95
maybe_enable_async_tp (job_config , world_mesh ["tp" ])
96
96
97
- # Assume 2x tokens per EP rank in the worst case.
98
- # TODO: explore other options
97
+ # Worst case = single expert receives all tokens
98
+ # TODO: explore using token dropping to avoid this huge overallocation
99
99
max_tokens_per_ep_rank = (
100
- job_config .training .seq_len * job_config .training .local_batch_size * 2
100
+ job_config .training .seq_len
101
+ * job_config .training .local_batch_size
102
+ * model .model_args .moe_args .num_experts
101
103
)
102
104
if parallel_dims .tp_enabled or parallel_dims .ep_enabled :
103
105
apply_moe_ep_tp (
@@ -504,7 +506,9 @@ def apply_moe_ep_tp(
504
506
experts_plan = ExpertTensorParallel (tp_mesh = tp_mesh , ep_mesh = ep_mesh )
505
507
else :
506
508
experts_mesh = ep_mesh
507
- experts_plan = ExpertParallel (a2a_impl = a2a_impl )
509
+ experts_plan = ExpertParallel (
510
+ a2a_impl = a2a_impl , max_tokens_per_ep_rank = max_tokens_per_ep_rank
511
+ )
508
512
509
513
parallelize_module (
510
514
module = transformer_block .moe .experts ,
You can’t perform that action at this time.
0 commit comments