[float8] add _auto_filter_for_recipe for float8 training #1319
+45
−23
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #1207
Problem
filter_fqns
for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe.Solution
This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria:
filter_fqns
It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns.
Results
Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline).
Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: