@@ -59,26 +59,11 @@ def _update_history_stack(
59
59
amax_history_stack .copy_ (new_amax_history_stack )
60
60
61
61
62
- def filter_out_small_unaligned_layers (size_limit : int ) -> Callable [[nn .Linear ], bool ]:
63
- """
64
- Returns a callable that filters out small (dimensions less than the given `size_limit`)
65
- and unaligned (dimenstions not divisible by 16) layers.
66
- It can be passed as the `linear_layer_filter` argument to `swap_linear_with_float8_linear`.
67
- """
68
- return (
69
- lambda linear_layer : linear_layer .in_features >= size_limit
70
- and linear_layer .out_features >= size_limit
71
- and linear_layer .in_features % 16 == 0
72
- and linear_layer .out_features % 16 == 0
73
- )
74
-
75
-
76
62
def swap_linear_layers (
77
63
module : nn .Module ,
78
64
from_float_func : Callable [[nn .Linear ], nn .Linear ],
79
65
* ,
80
- skip_fqn_list : Optional [List [str ]] = None ,
81
- linear_layer_filter : Optional [Callable [[nn .Linear ], bool ]] = None ,
66
+ layer_filter_fn : Optional [Callable [[str , nn .Module ], bool ]] = None ,
82
67
) -> Optional [nn .Module ]:
83
68
"""
84
69
Generic function to swap linear layers in a module with a new type of linear layer.
@@ -90,18 +75,17 @@ def swap_linear_layers(
90
75
Args:
91
76
module: Module to modify.
92
77
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
93
- skip_fqn_list: If specified, a list of module FQNs to skip.
94
- linear_layer_filter: If specified, only the linear layers
95
- that pass the filter function will be swapped.
96
- from_float_kwargs: Additional keyword arguments for from_float_func.
78
+ layer_filter_fn: If specified, only the modules that
79
+ that pass the filter function will be swapped. The inputs to the
80
+ filter function are the FQN and module instance.
97
81
98
82
Returns:
99
83
nn.Module: The modified module with swapped linear layers.
100
84
"""
101
- module_names_to_skip = set (skip_fqn_list or [])
102
-
103
85
if isinstance (module , nn .Linear ) and (
104
- linear_layer_filter is None or linear_layer_filter (module )
86
+ # linear_layer_filter is None or linear_layer_filter(module)
87
+ layer_filter_fn is None
88
+ or layer_filter_fn ("" , module )
105
89
):
106
90
if len (list (module .children ())) > 0 :
107
91
raise AssertionError (
@@ -112,43 +96,44 @@ def swap_linear_layers(
112
96
)
113
97
114
98
root_module = module
115
- visited_modules = {root_module }
116
-
117
- for module_name , module in root_module .named_modules ():
118
- if module_name in module_names_to_skip :
119
- visited_modules .add (module )
120
99
121
100
def post_order_traversal (
122
- module : nn .Module , module_name : str , parent_module : Optional [nn .Module ]
101
+ module : nn .Module ,
102
+ cur_fqn : Optional [str ] = None ,
103
+ parent_module : Optional [nn .Module ] = None ,
123
104
):
124
- nonlocal visited_modules
105
+ if cur_fqn is None :
106
+ cur_fqn = ""
107
+
125
108
for child_module_name , child_module in module .named_children ():
126
- if child_module not in visited_modules :
127
- visited_modules .add (child_module )
128
- post_order_traversal (child_module , child_module_name , module )
109
+ if cur_fqn == "" :
110
+ new_fqn = child_module_name
111
+ else :
112
+ new_fqn = f"{ cur_fqn } .{ child_module_name } "
113
+
114
+ post_order_traversal (child_module , new_fqn , module )
129
115
130
116
if isinstance (module , nn .Linear ) and (
131
- linear_layer_filter is None or linear_layer_filter (module )
117
+ # linear_layer_filter is None or linear_layer_filter(module)
118
+ layer_filter_fn is None
119
+ or layer_filter_fn (cur_fqn , module )
132
120
):
133
121
assert (
134
122
parent_module is not None
135
123
), f"Linear root module should return early: { module } "
136
124
new_linear_module = from_float_func (module )
137
- setattr (parent_module , module_name , new_linear_module )
125
+ cur_module_name = cur_fqn .split ("." )[- 1 ]
126
+ setattr (parent_module , cur_module_name , new_linear_module )
138
127
139
- post_order_traversal (root_module , "" , None )
140
- # Without this explicit `del`, this set only gets deleted upon an explicit
141
- # garbage collection (not from when its refcount hits zero)
142
- del visited_modules
128
+ post_order_traversal (root_module )
143
129
return root_module
144
130
145
131
146
132
def swap_linear_with_float8_linear (
147
133
module : nn .Module ,
148
134
* ,
149
- skip_fqn_list : Optional [List [str ]] = None ,
150
135
emulate : bool = False ,
151
- linear_layer_filter : Optional [Callable [[nn .Linear ], bool ]] = None ,
136
+ layer_filter_fn : Optional [Callable [[str , nn .Module ], bool ]] = None ,
152
137
scaling_type_x : TensorScalingType = TensorScalingType .DYNAMIC ,
153
138
scaling_type_w : TensorScalingType = TensorScalingType .DYNAMIC ,
154
139
scaling_type_dL_dY : TensorScalingType = TensorScalingType .DYNAMIC ,
@@ -158,10 +143,10 @@ def swap_linear_with_float8_linear(
158
143
159
144
Args:
160
145
module: Module to modify.
161
- skip_fqn_list: If specified, a list of module FQNs to skip.
162
146
emulate: If True, emulation is used instead of hardware accelerated gemm
163
- linear_layer_filter: If specified, only the linear layers
164
- that pass the filter function will be swapped.
147
+ layer_filter_fn: If specified, only the modules that
148
+ that pass the filter function will be swapped. The inputs to the
149
+ filter function are the FQN and module instance.
165
150
scaling_type_x (TensorScalingType): scaling type for `x`
166
151
scaling_type_w (TensorScalingType): scaling type for `w`
167
152
scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
@@ -179,8 +164,7 @@ def swap_linear_with_float8_linear(
179
164
return swap_linear_layers (
180
165
module ,
181
166
from_float ,
182
- skip_fqn_list = skip_fqn_list ,
183
- linear_layer_filter = linear_layer_filter ,
167
+ layer_filter_fn = layer_filter_fn ,
184
168
)
185
169
186
170
0 commit comments