Skip to content

float8 rowwise vanilla TP low throughput #1207

@danielvegamyhre

Description

@danielvegamyhre
Contributor

Bug description

Llama3 8b on 4xH100s with per op SAC, using FSDP=2, TP=2

  • bf16: 5378 TPS, 45.68 GiB peak memory
  • float8 rowwise: 5189 TPS, 45.67 GiB peak memory

Versions

  • torch 2.8.0a0+gite21ad6e
  • torchtitan @ HEAD
  • torchao 0.11.0

Activity

danielvegamyhre

danielvegamyhre commented on Jun 18, 2025

@danielvegamyhre
ContributorAuthor

I started analyzing traces, comparing bf16 vs fp8 tensorwise vs fp8 rowwise for vanilla TP.

Looking at the TP region specifically (FFN) for the forward pass, I see (1) the expected FDSP all-gather occurring after the attention fwd for FFN weights, and (2) the expected exposed all-gather and reduce-scatter comms seen in TP+SP.

For bf16 trace, these are the only comms in the FFN forward. However, in fp8 tensorwise and rowwise, there are all-reduces for syncing the scales (expected).

Forward pass

durations from start of all-gather input to end of reduce-scatter output:

  • bf16: 1757us
  • fp8 tensorwise: 1543us
  • fp8 rowwise: 1777us

Looking at the 3 GEMMs in the FFN, I see the bf16 GEMMs are about 380us each vs fp8 rowwise GEMMS are about 180us each, so we are seeing a solid speedup and those aren't the reason perf is flat.

This difference is pretty small, fp8 tensorwise fwd pass step is only 1.11x faster than rowwise, but the e2e TPS is 1.5x+ faster. So this is not the whole story.

Backward pass

In the backward pass I see major differences.

  • The backward for attention -> FFN together takes 7320us for tensorwise, and 10128us for rowwise. So tensorwise bwd step is 1.38x faster than rowwise.
  • For rowwise there are 2 blocking all-gathers scheduled back to back, and subsequent computation seems to be blocked / depend on both. For tensorwise, after the first all gather runs, some compute is run, then the 2nd all gather runs, etc.

Tensorwise:

Image

Rowwise:

Image
danielvegamyhre

danielvegamyhre commented on Jun 18, 2025

@danielvegamyhre
ContributorAuthor

Looking at attention bwd traces more deeply, I see the following:

Tensorwise

Image

Rowwise

Image

The backward for attention (and Q/K/V projections) is 2917us in fp8 tensorwise, and 4375us for fp8 rowwise. This is a 1.49x speedup of tensorwise over rowwise aligns witht the roughly 1.5x difference we see in overall TPS throughput:

FP8 Tensorwise backward for attention:

Image

FP8 Rowwise backward for attention:

Image

Analyzing components of backward in more detail, to understand the source of 4375us - 2917us = 1485us gap.

  • flash bwd is roughly ~1400us in both tensorwise and rowwise, so that's not contributing.
  • the triton kernels and 2 GEMMs for computing dWO and its input grad use 321us vs 789us, a difference of 468us, which is 468us/1458us = ~32% of our gap.
  • the triton kernels and 6 GEMMs after the flash attn bwd for computing dWQ, dWK, dWV, (dW and input gradient for each of 3 linears) are 855us vs 1539us, a difference of 684us, which is 684us/1458us = ~47% of our gap.

In total, 20% + 32% + 47% = 99% of our gap. So basically the entire perf difference for attention backward can be attributed to overhead in the Float8Linear backward passes (dynamic quant kernels, etc).

danielvegamyhre

danielvegamyhre commented on Jun 18, 2025

@danielvegamyhre
ContributorAuthor

I can see both attention bwd takes longer for rowwise, so I'm checking is if the linears are large enough after sharding weights w/ TP=8 to benefit from fp8 rowwise.

Rerunning the Llama3 70b job on MAST and dumping the model (link):

[trainer0|0]:    (0): TransformerBlock(
[trainer0|0]:      (attention): Attention(
[trainer0|0]:        (wq): Float8Linear(in_features=8192, out_features=8192, bias=False, cast_configs=i:dyn_axs_e4m3,w:dyn_axs_e4m3,go:dyn_axs_e4m3")
[trainer0|0]:        (wk): Float8Linear(in_features=8192, out_features=1024, bias=False, cast_configs=i:dyn_axs_e4m3,w:dyn_axs_e4m3,go:dyn_axs_e4m3")
[trainer0|0]:        (wv): Float8Linear(in_features=8192, out_features=1024, bias=False, cast_configs=i:dyn_axs_e4m3,w:dyn_axs_e4m3,go:dyn_axs_e4m3")
[trainer0|0]:        (wo): Float8Linear(in_features=8192, out_features=8192, bias=False, cast_configs=i:dyn_axs_e4m3,w:dyn_axs_e4m3,go:dyn_axs_e4m3")
[trainer0|0]:        (sdpa): ScaledDotProductAttention()
[trainer0|0]:      )
[trainer0|0]:      (feed_forward): FeedForward(
[trainer0|0]:        (w1): Float8Linear(in_features=8192, out_features=28672, bias=False, cast_configs=i:dyn_axs_e4m3,w:dyn_axs_e4m3,go:dyn_axs_e4m3")
[trainer0|0]:        (w2): Float8Linear(in_features=28672, out_features=8192, bias=False, cast_configs=i:dyn_axs_e4m3,w:dyn_axs_e4m3,go:dyn_axs_e4m3")
[trainer0|0]:        (w3): Float8Linear(in_features=8192, out_features=28672, bias=False, cast_configs=i:dyn_axs_e4m3,w:dyn_axs_e4m3,go:dyn_axs_e4m3")
[trainer0|0]:      )
[trainer0|0]:      (attention_norm): RMSNorm((8192,), eps=1e-05, elementwise_affine=True)
[trainer0|0]:      (ffn_norm): RMSNorm((8192,), eps=1e-05, elementwise_affine=True)
[trainer0|0]:    )

Examining the float8 linear shapes and referencing the fp8 rowwise perf table we can see the attention.wk, attention.wv have far too small of a N dim to see any perf benefit from fp8 rowwise, and in fact will see a significant slowdown.

Next I'll try rerunning the job without converting these to Float8Linears.

danielvegamyhre

danielvegamyhre commented on Jun 18, 2025

@danielvegamyhre
ContributorAuthor

Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16 with --float8.filter_fqns="output,attention.wk,attention.wv":

danielvegamyhre

danielvegamyhre commented on Jun 18, 2025

@danielvegamyhre
ContributorAuthor

@tianyu-l @vkuzo I looked into this (see RCA above) and TL;DR is the default filter_fqns for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe.

For Llama3 70b with TP=8, the attention.wk and attention.wv linear layers have a small enough N dimension that here is actually a significant slowdown (estimating ~40% slowdown based on this fp8 rowwise perf reference table I generated recently).

Rerunning fp8 rowwise with --float8.filter_fqns="output,attention.wk,attention.wv" shows vanilla TP gives ~10%+ TPS increase over the bf16 baseline and Async TP yields ~15%+ increase over bf16, which is more aligned with my expectations.

Solution

This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria:

  1. dims not divisible by 16 (hardware requirement for float8)
  2. dim sizes below thresholds that will result in worse perf for that given recipe, using simple heuristics based on the linked recipe perf tables above.
  3. fqn matches one of the user defined filter_fqns

I integrated a PoC into torchtitan and the auto filter improved fp8 rowwise perf both local Llama3 8b run and Llama3 70b MAST run, compared to the default filter_fn we have now.

It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns.

Results

See #1207 (comment) showing Llama3 70b fp8 rowwise w/ TP=8 improves TPS ~10% over bf16 baseline (previously 0%). Confirmed async TP + fp8 rowwise is still getting an additional +5% TPS over the new higher baseline as well.

What do you think about including this in torchao + torchtitan? To be clear, this doesn't change the torchtitan API, the current filter_fqns config is still consumed and applied in the same way - there are just additional checks which automatically filter layers which will hurt perf if we convert.

added a commit that references this issue on Jul 1, 2025
b0902b2
added a commit that references this issue on Jul 1, 2025
7104125
added a commit that references this issue on Jul 8, 2025
84bd872
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

    Development

    Participants

    @danielvegamyhre

    Issue actions

      float8 rowwise vanilla TP low throughput · Issue #1207 · pytorch/torchtitan