Skip to content

[roadmap/tracker] Low precision MoE training #2147

@danielvegamyhre

Description

@danielvegamyhre

Creating this issue as a roadmap/tracker for enabling float8 training for MoEs with token-choice routing. Both core requirements as well as ideas for additional performance optimizations are included.

UPDATE 07/22/2025: revised priorities to reflect shifting focus to prioritize mxfp8

This is not an exhaustive list, but highlights some primary milestones / requirements

Compute

Communication

I looked at traces and validated "all to all dispatch and shuffle -> grouped gemm -> all to all combine and unshuffle" are all sequentially dependent, so in theory faster/low precision comms should improve performance. There is some overlap with the shared expert computation, but it is not 100% overlap, so there is room for optimization. This will be especially important if/when "all to all" spans multiple nodes, where inter-node network bandwidth is lower than the intra-node NVLink bandwidth.

Torchao UX

Compile support

  • Compile support for torch._grouped_mm
  • Differentiable _scaled_grouped_mm can compile with fullgraph=True
  • E2E compilation of each TranformerBlock in torchtitan after MoE conversion via tensor subclass approach (fullgraph=False)
  • E2E compilation of each TranformerBlock in torchtitan after MoE conversion via tensor subclass approach (fullgraph=True)

Distributed support

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions