@@ -116,7 +116,7 @@ def convert_to_float8_training(
116
116
)
117
117
118
118
119
- def auto_filter_for_recipe (
119
+ def _auto_filter_for_recipe (
120
120
recipe : Float8LinearRecipeName , filter_fqns : List [str ]
121
121
) -> Callable [[nn .Module , str ], bool ]:
122
122
"""Automatically filters nn.Linear modules that meet at least one of the following criteria:
@@ -127,7 +127,9 @@ def auto_filter_for_recipe(
127
127
NOTE: the thresholds are simple heuristics based on performance testing, and may not be optimal
128
128
for your model. For the best performance, we recommend defining your own module_filter_fn customized for
129
129
your module, using the performance tables for the given float8 recipe here:
130
- https://github.com/pytorch/ao/tree/main/torchao/float8#performance).
130
+ https://github.com/pytorch/ao/tree/main/torchao/float8#performance). Note that the benchmarks referenced
131
+ for auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
132
+
131
133
132
134
The design of this function may change in the future.
133
135
"""
@@ -156,8 +158,10 @@ def _auto_filter_for_rowwise(mod: nn.Module, fqn: str, filter_fqns: List[str]) -
156
158
if not dims_multiples_of_16 :
157
159
return False
158
160
159
- # Dims below these thresholds will result in worse performance
161
+ # Dims below these thresholds may result in worse performance
160
162
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling)
163
+ # Note that these benchmarks referenced for auto filtering layers were run on
164
+ # H100 GPUs, and may not be representative of other hardware.
161
165
if N <= 2048 :
162
166
return False
163
167
elif K <= 1024 :
@@ -184,8 +188,10 @@ def _auto_filter_for_tensorwise(
184
188
if not dims_multiples_of_16 :
185
189
return False
186
190
187
- # Dims below these thresholds will result in worse performance
191
+ # Dims below these thresholds may result in worse performance
188
192
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#tensorwise-scaling)
193
+ # Note that these benchmarks referenced for auto filtering layers were run on
194
+ # H100 GPUs, and may not be representative of other hardware.
189
195
if K <= 4096 and N <= 1024 :
190
196
return False
191
197
return True
0 commit comments