You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I started analyzing traces, comparing bf16 vs fp8 tensorwise vs fp8 rowwise for vanilla TP.
Looking at the TP region specifically (FFN) for the forward pass, I see (1) the expected FDSP all-gather occurring after the attention fwd for FFN weights, and (2) the expected exposed all-gather and reduce-scatter comms seen in TP+SP.
For bf16 trace, these are the only comms in the FFN forward. However, in fp8 tensorwise and rowwise, there are all-reduces for syncing the scales (expected).
Forward pass
durations from start of all-gather input to end of reduce-scatter output:
bf16: 1757us
fp8 tensorwise: 1543us
fp8 rowwise: 1777us
Looking at the 3 GEMMs in the FFN, I see the bf16 GEMMs are about 380us each vs fp8 rowwise GEMMS are about 180us each, so we are seeing a solid speedup and those aren't the reason perf is flat.
This difference is pretty small, fp8 tensorwise fwd pass step is only 1.11x faster than rowwise, but the e2e TPS is 1.5x+ faster. So this is not the whole story.
Backward pass
In the backward pass I see major differences.
The backward for attention -> FFN together takes 7320us for tensorwise, and 10128us for rowwise. So tensorwise bwd step is 1.38x faster than rowwise.
For rowwise there are 2 blocking all-gathers scheduled back to back, and subsequent computation seems to be blocked / depend on both. For tensorwise, after the first all gather runs, some compute is run, then the 2nd all gather runs, etc.
Looking at attention bwd traces more deeply, I see the following:
Tensorwise
Rowwise
The backward for attention (and Q/K/V projections) is 2917us in fp8 tensorwise, and 4375us for fp8 rowwise. This is a 1.49x speedup of tensorwise over rowwise aligns witht the roughly 1.5x difference we see in overall TPS throughput:
FP8 Tensorwise backward for attention:
FP8 Rowwise backward for attention:
Analyzing components of backward in more detail, to understand the source of 4375us - 2917us = 1485us gap.
flash bwd is roughly ~1400us in both tensorwise and rowwise, so that's not contributing.
the triton kernels and 2 GEMMs for computing dWO and its input grad use 321us vs 789us, a difference of 468us, which is 468us/1458us = ~32% of our gap.
the triton kernels and 6 GEMMs after the flash attn bwd for computing dWQ, dWK, dWV, (dW and input gradient for each of 3 linears) are 855us vs 1539us, a difference of 684us, which is 684us/1458us = ~47% of our gap.
In total, 20% + 32% + 47% = 99% of our gap. So basically the entire perf difference for attention backward can be attributed to overhead in the Float8Linear backward passes (dynamic quant kernels, etc).
I can see both attention bwd takes longer for rowwise, so I'm checking is if the linears are large enough after sharding weights w/ TP=8 to benefit from fp8 rowwise.
Rerunning the Llama3 70b job on MAST and dumping the model (link):
Examining the float8 linear shapes and referencing the fp8 rowwise perf table we can see the attention.wk, attention.wv have far too small of a N dim to see any perf benefit from fp8 rowwise, and in fact will see a significant slowdown.
Next I'll try rerunning the job without converting these to Float8Linears.
@tianyu-l@vkuzo I looked into this (see RCA above) and TL;DR is the default filter_fqns for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe.
For Llama3 70b with TP=8, the attention.wk and attention.wv linear layers have a small enough N dimension that here is actually a significant slowdown (estimating ~40% slowdown based on this fp8 rowwise perf reference table I generated recently).
Rerunning fp8 rowwise with --float8.filter_fqns="output,attention.wk,attention.wv" shows vanilla TP gives ~10%+ TPS increase over the bf16 baseline and Async TP yields ~15%+ increase over bf16, which is more aligned with my expectations.
Solution
This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria:
dims not divisible by 16 (hardware requirement for float8)
dim sizes below thresholds that will result in worse perf for that given recipe, using simple heuristics based on the linked recipe perf tables above.
fqn matches one of the user defined filter_fqns
I integrated a PoC into torchtitan and the auto filter improved fp8 rowwise perf both local Llama3 8b run and Llama3 70b MAST run, compared to the default filter_fn we have now.
It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns.
Results
See #1207 (comment) showing Llama3 70b fp8 rowwise w/ TP=8 improves TPS ~10% over bf16 baseline (previously 0%). Confirmed async TP + fp8 rowwise is still getting an additional +5% TPS over the new higher baseline as well.
What do you think about including this in torchao + torchtitan? To be clear, this doesn't change the torchtitan API, the current filter_fqns config is still consumed and applied in the same way - there are just additional checks which automatically filter layers which will hurt perf if we convert.
Activity
danielvegamyhre commentedon Jun 18, 2025
I started analyzing traces, comparing bf16 vs fp8 tensorwise vs fp8 rowwise for vanilla TP.
Looking at the TP region specifically (FFN) for the forward pass, I see (1) the expected FDSP all-gather occurring after the attention fwd for FFN weights, and (2) the expected exposed all-gather and reduce-scatter comms seen in TP+SP.
For bf16 trace, these are the only comms in the FFN forward. However, in fp8 tensorwise and rowwise, there are all-reduces for syncing the scales (expected).
Forward pass
durations from start of all-gather input to end of reduce-scatter output:
Looking at the 3 GEMMs in the FFN, I see the bf16 GEMMs are about 380us each vs fp8 rowwise GEMMS are about 180us each, so we are seeing a solid speedup and those aren't the reason perf is flat.
This difference is pretty small, fp8 tensorwise fwd pass step is only 1.11x faster than rowwise, but the e2e TPS is 1.5x+ faster. So this is not the whole story.
Backward pass
In the backward pass I see major differences.
Tensorwise:
Rowwise:
danielvegamyhre commentedon Jun 18, 2025
Looking at attention bwd traces more deeply, I see the following:
Tensorwise
Rowwise
The backward for attention (and Q/K/V projections) is 2917us in fp8 tensorwise, and 4375us for fp8 rowwise. This is a 1.49x speedup of tensorwise over rowwise aligns witht the roughly 1.5x difference we see in overall TPS throughput:
FP8 Tensorwise backward for attention:
FP8 Rowwise backward for attention:
Analyzing components of backward in more detail, to understand the source of 4375us - 2917us = 1485us gap.
In total, 20% + 32% + 47% = 99% of our gap. So basically the entire perf difference for attention backward can be attributed to overhead in the Float8Linear backward passes (dynamic quant kernels, etc).
danielvegamyhre commentedon Jun 18, 2025
I can see both attention bwd takes longer for rowwise, so I'm checking is if the linears are large enough after sharding weights w/ TP=8 to benefit from fp8 rowwise.
Rerunning the Llama3 70b job on MAST and dumping the model (link):
Examining the float8 linear shapes and referencing the fp8 rowwise perf table we can see the attention.wk, attention.wv have far too small of a N dim to see any perf benefit from fp8 rowwise, and in fact will see a significant slowdown.
Next I'll try rerunning the job without converting these to Float8Linears.
danielvegamyhre commentedon Jun 18, 2025
Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16 with
--float8.filter_fqns="output,attention.wk,attention.wv"
:danielvegamyhre commentedon Jun 18, 2025
@tianyu-l @vkuzo I looked into this (see RCA above) and TL;DR is the default
filter_fqns
for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe.For Llama3 70b with TP=8, the attention.wk and attention.wv linear layers have a small enough N dimension that here is actually a significant slowdown (estimating ~40% slowdown based on this fp8 rowwise perf reference table I generated recently).
Rerunning fp8 rowwise with
--float8.filter_fqns="output,attention.wk,attention.wv"
shows vanilla TP gives ~10%+ TPS increase over the bf16 baseline and Async TP yields ~15%+ increase over bf16, which is more aligned with my expectations.Solution
This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria:
filter_fqns
I integrated a PoC into torchtitan and the auto filter improved fp8 rowwise perf both local Llama3 8b run and Llama3 70b MAST run, compared to the default filter_fn we have now.
It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns.
Results
See #1207 (comment) showing Llama3 70b fp8 rowwise w/ TP=8 improves TPS ~10% over bf16 baseline (previously 0%). Confirmed async TP + fp8 rowwise is still getting an additional +5% TPS over the new higher baseline as well.
What do you think about including this in torchao + torchtitan? To be clear, this doesn't change the torchtitan API, the current
filter_fqns
config is still consumed and applied in the same way - there are just additional checks which automatically filter layers which will hurt perf if we convert.[float8] add _auto_filter_for_recipe for float8 training (pytorch#1319)
[float8] add _auto_filter_for_recipe for float8 training (#1319)
[float8] add _auto_filter_for_recipe for float8 training (pytorch#1319)