Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 26, 2025

Summary

  • Make EP a2a dispatch and a2a combine each be separately configurable to use either "default" or "mxfp8" impl
  • "mxfp8" impl uses torchao's new to_mxfp8_a2a_dequant, which has the exact same API as functional collective all_to_all_single_autograd and is differentiable, so it can be used as a drop-in replacement for the default a2a impl.
  • torchao to_mxfp8_a2a_dequant works as follows:
    • quantizes the inputs to mxfp8
    • all_to_all_single on e4m3 data
    • all_to_all_single on e8m0 scales
    • dequantize outputs back to original precision

Performance

  • Single node benchmarks with 4xB200

  • Llama4 16e default configs; FSDP=4, EP=4; AC=none; compile=True; seq_len=8192; local_bs=8

  • Reduced num layers from 48 -> 2 to avoid OOM in single node setting

  • Debug model config:

llama4_configs = {
    "debugmodel": TransformerModelArgs(
        dim=5120,
        n_layers=2,
        n_heads=40,
        n_kv_heads=8,
        ffn_dim_multiplier=1.2,
        multiple_of=2048,
        rope_theta=500000,
        max_seq_len=10485760,
        moe_args=MoEArgs(num_experts=16),
        interleave_moe_layer_step=1,
    ),
Configuration Throughput (Median Tokens/s) Max Memory (GiB)
bf16 baseline 49381.0 145.55
MXFP8 for Linears only 52038.0 146.62
MXFP8 for Grouped GEMMs only 69350.0 144.71
MXFP8 for Linears + Grouped GEMMs 70747.0 145.32
MXFP8 for Linears + Grouped GEMMs + A2A Dispatch 72602.5 145.45
MXFP8 for Linears + Grouped GEMMs + A2A Dispatch + A2A Combine 73152.0 146.08

Additional context on design/implementation choices

  • Note: both default and mxfp8 impls require the d2h sync to get input_splits/output_splits on the host for the a2a call.
    • I also explored a no-sync/on-device implementation using Triton + Symmetric memory, and got it working e2e in a torchtitan PoC: [mxfp8 moe training] mxfp8 a2a working e2e in torchtitan llama4 training; improve tests + bench scripts ao#3088
    • I found that this design of preallocating over-allocated symmetric memory buffers for exchange of variable token numbers (to avoid syncs required for exact allocation, while risking either crash or token dropping if overflow factor heuristic is wrong), is fundamentally in conflict with the torchtitan MoE design of doing a d2h sync to safely do exact allocation. Extracting out the variable size outputs from the padded buffers causes d2h sync (causing perf to regress below baseline), and we can't avoid this since otherwise downstream ops will break due to shape mismatches - the whole model basically would need to be designed assuming the static padded shapes.
    • Therefore, we choose to integrate this more straight-forward impl that is natively compatible with non-experimental titan MoE design

Additional background on motivation

  • MoE performance literature has shown ~47% average runtime for flagship OSS MoE models (Qwen2, Phi3.5, Mixtra8x7b) is due to exposed MoE comms.
  • Torchtitan Llama4 debug model with EP=4, ~30% of MoE training with EP is a2a comms, most of that exposed (see trace screenshot), which directionally corroborates this.
  • We can optimize this via (1) quantizing the comms to minimize data sent over NVLink/IB, (2) avoid d2h sync that can occur in implementations which move a2a output splits from device->host to compute exact preallocation necessary for incoming tokens, and (3) finer grained overlapping techniques.

30% of llama4 model profiled runtime is all2all comms

  • FSDP=4, EP=4, dim=5120, num_experts=16, seq_len=8192, local_batch_size=8
Screenshot 2025-09-29 at 3 08 47 PM

47% avg runtime devoted to MoE comms in profiled OSS models

Screenshot 2025-09-29 at 3 11 00 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 26, 2025
@danielvegamyhre danielvegamyhre force-pushed the mx-a2a branch 2 times, most recently from 4527f8b to bba9c6a Compare September 27, 2025 00:13
@danielvegamyhre danielvegamyhre force-pushed the mx-a2a branch 2 times, most recently from fde6de2 to a48e631 Compare September 29, 2025 22:27
@danielvegamyhre danielvegamyhre changed the title [WIP] Support mxfp8 on device all_to_all_v in expert parallel Support mxfp8 on device all_to_all_v in expert parallel Sep 29, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft September 29, 2025 23:04
@danielvegamyhre danielvegamyhre changed the title Support mxfp8 on device all_to_all_v in expert parallel Support mxfp8 all to all in expert parallel Sep 30, 2025
@danielvegamyhre danielvegamyhre changed the title Support mxfp8 all to all in expert parallel [mxfp8 MoE training] Support mxfp8 all to all in expert parallel Sep 30, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review September 30, 2025 21:17
@danielvegamyhre danielvegamyhre marked this pull request as draft September 30, 2025 22:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant