Open
Description
This is a running list of planned features for low precision training. As features are completed we plan to delete them from this list, to keep things simple.
float8
performance
- [in progress] optimize torch.compile performance for float8 tensorwise scaling/casting kernels
- [in progress] ensure that float8 rowwise scaling is performant with TP and async TP [Async TP] Fuse all-gather-matmuls for float8 rowwise training pytorch#149990
distributed
- [planned] verify integration with PP
new features
- [2025-Q2] float8 grouped gemm support
- [2025-Q2] better story for float8 training -> float8 inference
- productionize no-compile version of float8 training (https://github.com/pytorch/ao/tree/main/torchao/prototype/float8nocompile, priority TBD)
- [2025-Q2] weight gradient accumulation in float32
- float8 SDPA (priority TBD)
ecosystem
- [in progress] add torchtune integration ((WIP/RFC) FP8 full finetune distributed torchtune#2404)
other
- [2025-Q2] expose float8 training via the quantize_ API
- [2025-Q2] migrate
torchao.float8
code totorchao.quantization
for better unification with the rest of torchao, in a BC-preserving way
MX
pytorch/pytorch
- [in progress] fp4_x2 dtype
- [in progress] torch._scaled_mm for nvfp4, wrapping cuBLAS
- [in progress] inductor performance work for mx block scaling fusion into surrounding ops: request for faster inductor kernels for blockwise reduction across dim1 -> write pytorch#149982
- [2025-Q1] PT2 integration for e8m0 and fp4_x2
pytorch/torchao
- [in progress] performance: MX single node performance tracker #1768
- [in progress] expose in quantize_ API
pytorch/torchtitan
- [in progress] integrate mx training: [not for land] enable torchao's mxfp8 training recipe torchtitan#1015