Skip to content

Commit b133535

Browse files
address comments
1 parent d448443 commit b133535

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

torchao/float8/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
ScalingGranularity,
77
ScalingType,
88
)
9-
from torchao.float8.float8_linear_utils import convert_to_float8_training
9+
from torchao.float8.float8_linear_utils import (
10+
_auto_filter_for_recipe,
11+
convert_to_float8_training,
12+
)
1013
from torchao.float8.float8_tensor import (
1114
Float8Tensor,
1215
GemmInputRole,
@@ -44,6 +47,7 @@
4447
# top level UX
4548
"convert_to_float8_training",
4649
"precompute_float8_dynamic_scale_for_fsdp",
50+
"_auto_filter_for_recipe",
4751
# types
4852
"FP8Granularity",
4953
# note: Float8Tensor and Float8Linear are not public APIs

torchao/float8/float8_linear_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def convert_to_float8_training(
116116
)
117117

118118

119-
def auto_filter_for_recipe(
119+
def _auto_filter_for_recipe(
120120
recipe: Float8LinearRecipeName, filter_fqns: List[str]
121121
) -> Callable[[nn.Module, str], bool]:
122122
"""Automatically filters nn.Linear modules that meet at least one of the following criteria:
@@ -127,7 +127,9 @@ def auto_filter_for_recipe(
127127
NOTE: the thresholds are simple heuristics based on performance testing, and may not be optimal
128128
for your model. For the best performance, we recommend defining your own module_filter_fn customized for
129129
your module, using the performance tables for the given float8 recipe here:
130-
https://github.com/pytorch/ao/tree/main/torchao/float8#performance).
130+
https://github.com/pytorch/ao/tree/main/torchao/float8#performance). Note that the benchmarks referenced
131+
for auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
132+
131133
132134
The design of this function may change in the future.
133135
"""
@@ -156,8 +158,10 @@ def _auto_filter_for_rowwise(mod: nn.Module, fqn: str, filter_fqns: List[str]) -
156158
if not dims_multiples_of_16:
157159
return False
158160

159-
# Dims below these thresholds will result in worse performance
161+
# Dims below these thresholds may result in worse performance
160162
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling)
163+
# Note that these benchmarks referenced for auto filtering layers were run on
164+
# H100 GPUs, and may not be representative of other hardware.
161165
if N <= 2048:
162166
return False
163167
elif K <= 1024:
@@ -184,8 +188,10 @@ def _auto_filter_for_tensorwise(
184188
if not dims_multiples_of_16:
185189
return False
186190

187-
# Dims below these thresholds will result in worse performance
191+
# Dims below these thresholds may result in worse performance
188192
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#tensorwise-scaling)
193+
# Note that these benchmarks referenced for auto filtering layers were run on
194+
# H100 GPUs, and may not be representative of other hardware.
189195
if K <= 4096 and N <= 1024:
190196
return False
191197
return True

0 commit comments

Comments
 (0)