Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 0c01953

Browse files
committed
bc breaking - unify filtering functions
Summary: bc breaking, but we don't have bc yet, so just mentioning this upfront Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a6664ad Pull Request resolved: #322
1 parent c58fb5d commit 0c01953

File tree

4 files changed

+90
-83
lines changed

4 files changed

+90
-83
lines changed

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,19 @@ from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_f
4343
# create model
4444
m = Model(...)
4545

46+
# optional: filter layers from being eligible for float8 conversion
47+
def layer_filter_fn(fqn: str, mod: torch.nn.Module):
48+
# don't convert the output layer
49+
if fqn == "output":
50+
return False
51+
# don't convert linear layers with weight dimensions not divisible by 16
52+
if isinstance(mod, torch.nn.Linear):
53+
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
54+
return False
55+
return True
56+
4657
# convert all `torch.nn.Linear` modules to `Float8Linear`
47-
swap_linear_with_float8_linear(m)
58+
swap_linear_with_float8_linear(m, layer_filter_fn=layer_filter_fn)
4859

4960
# optional: use FSDP
5061
model = FSDP(model, use_orig_params=True)

float8_experimental/float8_linear_utils.py

Lines changed: 30 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -59,26 +59,11 @@ def _update_history_stack(
5959
amax_history_stack.copy_(new_amax_history_stack)
6060

6161

62-
def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear], bool]:
63-
"""
64-
Returns a callable that filters out small (dimensions less than the given `size_limit`)
65-
and unaligned (dimenstions not divisible by 16) layers.
66-
It can be passed as the `linear_layer_filter` argument to `swap_linear_with_float8_linear`.
67-
"""
68-
return (
69-
lambda linear_layer: linear_layer.in_features >= size_limit
70-
and linear_layer.out_features >= size_limit
71-
and linear_layer.in_features % 16 == 0
72-
and linear_layer.out_features % 16 == 0
73-
)
74-
75-
7662
def swap_linear_layers(
7763
module: nn.Module,
7864
from_float_func: Callable[[nn.Linear], nn.Linear],
7965
*,
80-
skip_fqn_list: Optional[List[str]] = None,
81-
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
66+
layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
8267
) -> Optional[nn.Module]:
8368
"""
8469
Generic function to swap linear layers in a module with a new type of linear layer.
@@ -90,18 +75,17 @@ def swap_linear_layers(
9075
Args:
9176
module: Module to modify.
9277
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
93-
skip_fqn_list: If specified, a list of module FQNs to skip.
94-
linear_layer_filter: If specified, only the linear layers
95-
that pass the filter function will be swapped.
96-
from_float_kwargs: Additional keyword arguments for from_float_func.
78+
layer_filter_fn: If specified, only the modules that
79+
that pass the filter function will be swapped. The inputs to the
80+
filter function are the FQN and module instance.
9781
9882
Returns:
9983
nn.Module: The modified module with swapped linear layers.
10084
"""
101-
module_names_to_skip = set(skip_fqn_list or [])
102-
10385
if isinstance(module, nn.Linear) and (
104-
linear_layer_filter is None or linear_layer_filter(module)
86+
# linear_layer_filter is None or linear_layer_filter(module)
87+
layer_filter_fn is None
88+
or layer_filter_fn("", module)
10589
):
10690
if len(list(module.children())) > 0:
10791
raise AssertionError(
@@ -112,43 +96,44 @@ def swap_linear_layers(
11296
)
11397

11498
root_module = module
115-
visited_modules = {root_module}
116-
117-
for module_name, module in root_module.named_modules():
118-
if module_name in module_names_to_skip:
119-
visited_modules.add(module)
12099

121100
def post_order_traversal(
122-
module: nn.Module, module_name: str, parent_module: Optional[nn.Module]
101+
module: nn.Module,
102+
cur_fqn: Optional[str] = None,
103+
parent_module: Optional[nn.Module] = None,
123104
):
124-
nonlocal visited_modules
105+
if cur_fqn is None:
106+
cur_fqn = ""
107+
125108
for child_module_name, child_module in module.named_children():
126-
if child_module not in visited_modules:
127-
visited_modules.add(child_module)
128-
post_order_traversal(child_module, child_module_name, module)
109+
if cur_fqn == "":
110+
new_fqn = child_module_name
111+
else:
112+
new_fqn = f"{cur_fqn}.{child_module_name}"
113+
114+
post_order_traversal(child_module, new_fqn, module)
129115

130116
if isinstance(module, nn.Linear) and (
131-
linear_layer_filter is None or linear_layer_filter(module)
117+
# linear_layer_filter is None or linear_layer_filter(module)
118+
layer_filter_fn is None
119+
or layer_filter_fn(cur_fqn, module)
132120
):
133121
assert (
134122
parent_module is not None
135123
), f"Linear root module should return early: {module}"
136124
new_linear_module = from_float_func(module)
137-
setattr(parent_module, module_name, new_linear_module)
125+
cur_module_name = cur_fqn.split(".")[-1]
126+
setattr(parent_module, cur_module_name, new_linear_module)
138127

139-
post_order_traversal(root_module, "", None)
140-
# Without this explicit `del`, this set only gets deleted upon an explicit
141-
# garbage collection (not from when its refcount hits zero)
142-
del visited_modules
128+
post_order_traversal(root_module)
143129
return root_module
144130

145131

146132
def swap_linear_with_float8_linear(
147133
module: nn.Module,
148134
*,
149-
skip_fqn_list: Optional[List[str]] = None,
150135
emulate: bool = False,
151-
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
136+
layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
152137
scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC,
153138
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
154139
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC,
@@ -158,10 +143,10 @@ def swap_linear_with_float8_linear(
158143
159144
Args:
160145
module: Module to modify.
161-
skip_fqn_list: If specified, a list of module FQNs to skip.
162146
emulate: If True, emulation is used instead of hardware accelerated gemm
163-
linear_layer_filter: If specified, only the linear layers
164-
that pass the filter function will be swapped.
147+
layer_filter_fn: If specified, only the modules that
148+
that pass the filter function will be swapped. The inputs to the
149+
filter function are the FQN and module instance.
165150
scaling_type_x (TensorScalingType): scaling type for `x`
166151
scaling_type_w (TensorScalingType): scaling type for `w`
167152
scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
@@ -179,8 +164,7 @@ def swap_linear_with_float8_linear(
179164
return swap_linear_layers(
180165
module,
181166
from_float,
182-
skip_fqn_list=skip_fqn_list,
183-
linear_layer_filter=linear_layer_filter,
167+
layer_filter_fn=layer_filter_fn,
184168
)
185169

186170

float8_experimental/inference.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dataclasses import dataclass
1111

1212
from enum import auto, Enum
13-
from typing import List, Optional
13+
from typing import Callable, List, Optional
1414

1515
import float8_experimental.config as config
1616

@@ -209,7 +209,7 @@ def quantize_to_float8(
209209
module: nn.Module,
210210
quant_config: QuantConfig,
211211
*,
212-
skip_fqn_list: Optional[List[str]] = None,
212+
layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
213213
use_fast_accum: bool = True,
214214
) -> Optional[nn.Module]:
215215
"""
@@ -222,7 +222,9 @@ def quantize_to_float8(
222222
Args:
223223
module (nn.Module): The module to modify.
224224
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
225-
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
225+
layer_filter_fn: If specified, only the modules that
226+
that pass the filter function will be swapped. The inputs to the
227+
filter function are the FQN and module instance.
226228
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
227229
228230
Returns:
@@ -234,5 +236,5 @@ def quantize_to_float8(
234236
return swap_linear_layers(
235237
module,
236238
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
237-
skip_fqn_list=skip_fqn_list,
239+
layer_filter_fn=layer_filter_fn,
238240
)

test/test_base.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
2020
from float8_experimental.float8_linear_utils import (
21-
filter_out_small_unaligned_layers,
2221
linear_requires_sync,
2322
swap_linear_with_float8_linear,
2423
sync_float8_amax_and_scale_history,
@@ -631,24 +630,34 @@ def __init__(self, dim: int):
631630
self.lin1 = nn.Linear(dim, 4 * dim)
632631
self.lin2 = nn.Linear(4 * dim, 4 * dim)
633632

634-
for emulate in [True, False]:
635-
model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40))
636-
# filter out the linear layers whose shape is smaller than 32 or non-divisible by 16.
637-
model = swap_linear_with_float8_linear(
638-
model,
639-
emulate=emulate,
640-
linear_layer_filter=filter_out_small_unaligned_layers(32),
633+
model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40))
634+
# filter out the linear layers whose shape is smaller than 32 or non-divisible by 16.
635+
636+
size_limit = 32
637+
638+
def layer_filter_fn(fqn, mod):
639+
return (
640+
mod.in_features >= size_limit
641+
and mod.out_features >= size_limit
642+
and mod.in_features % 16 == 0
643+
and mod.out_features % 16 == 0
641644
)
642-
# in_features=8, out_features=32, 8 is less than 32.
643-
self.assertNotIsInstance(model[0].lin1, Float8Linear)
644-
# in_features=32, out_features=32,
645-
self.assertIsInstance(model[0].lin2, Float8Linear)
646-
# in_features=32, out_features=32,
647-
self.assertIsInstance(model[1], Float8Linear)
648-
# in_features=40, out_features=160, 40 is not divisible by 16.
649-
self.assertNotIsInstance(model[2].lin1, Float8Linear)
650-
# in_features=160, out_features=160,
651-
self.assertIsInstance(model[2].lin2, Float8Linear)
645+
646+
model = swap_linear_with_float8_linear(
647+
model,
648+
emulate=True,
649+
layer_filter_fn=layer_filter_fn,
650+
)
651+
# in_features=8, out_features=32, 8 is less than 32.
652+
self.assertNotIsInstance(model[0].lin1, Float8Linear)
653+
# in_features=32, out_features=32,
654+
self.assertIsInstance(model[0].lin2, Float8Linear)
655+
# in_features=32, out_features=32,
656+
self.assertIsInstance(model[1], Float8Linear)
657+
# in_features=40, out_features=160, 40 is not divisible by 16.
658+
self.assertNotIsInstance(model[2].lin1, Float8Linear)
659+
# in_features=160, out_features=160,
660+
self.assertIsInstance(model[2].lin2, Float8Linear)
652661

653662
def test_swap_submodule_linears_with_skip(self):
654663
class MLP(nn.Module):
@@ -657,20 +666,21 @@ def __init__(self, dim: int):
657666
self.lin1 = nn.Linear(dim, 4 * dim)
658667
self.lin2 = nn.Linear(4 * dim, dim)
659668

660-
for emulate in [True, False]:
661-
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
662-
skip_fqn_list = ["2", "0.lin2"]
663-
model = swap_linear_with_float8_linear(
664-
model, emulate=emulate, skip_fqn_list=skip_fqn_list
665-
)
666-
self.assertIsInstance(model[0].lin1, Float8Linear)
667-
self.assertNotIsInstance(model[0].lin2, Float8Linear)
668-
self.assertIsInstance(model[0].lin2, nn.Linear)
669-
self.assertIsInstance(model[1], Float8Linear)
670-
self.assertNotIsInstance(model[2].lin2, Float8Linear)
671-
self.assertNotIsInstance(model[2].lin2, Float8Linear)
672-
self.assertIsInstance(model[2].lin1, nn.Linear)
673-
self.assertIsInstance(model[2].lin2, nn.Linear)
669+
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
670+
layer_filter_fn = lambda fqn, mod: fqn not in [
671+
"0.lin2",
672+
"2.lin1",
673+
]
674+
model = swap_linear_with_float8_linear(
675+
model,
676+
emulate=True,
677+
layer_filter_fn=layer_filter_fn,
678+
)
679+
self.assertTrue(type(model[0].lin1) is Float8Linear)
680+
self.assertTrue(type(model[0].lin2) is nn.Linear)
681+
self.assertTrue(type(model[1]) is Float8Linear)
682+
self.assertTrue(type(model[2].lin1) is nn.Linear)
683+
self.assertTrue(type(model[2].lin2) is Float8Linear)
674684

675685
def test_fp8_tensor_statistics(self):
676686
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)

0 commit comments

Comments
 (0)