6
6
import copy
7
7
import logging
8
8
from enum import auto , Enum
9
- from typing import Callable , List , Optional , Type
9
+ from typing import Any , Callable , Dict , List , Optional , Type
10
10
11
11
import torch
12
12
import torch .distributed as dist
@@ -100,6 +100,7 @@ def swap_linear_with_float8_linear(
100
100
skip_fqn_list : Optional [List [str ]] = None ,
101
101
emulate : bool = False ,
102
102
linear_layer_filter : Optional [Callable [[nn .Linear ], bool ]] = None ,
103
+ from_float_kwargs : Dict [str , Any ] = None ,
103
104
) -> nn .Module :
104
105
"""
105
106
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
@@ -114,6 +115,9 @@ def swap_linear_with_float8_linear(
114
115
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
115
116
that pass the filter function will be swapped.
116
117
"""
118
+ if from_float_kwargs is None :
119
+ from_float_kwargs = {}
120
+
117
121
module_names_to_skip = set (skip_fqn_list or [])
118
122
if isinstance (module , nn .Linear ) and (
119
123
linear_layer_filter is None or linear_layer_filter (module )
@@ -122,7 +126,7 @@ def swap_linear_with_float8_linear(
122
126
raise AssertionError (
123
127
f"Does not support a root nn.Linear with children: { module } "
124
128
)
125
- return module_cls .from_float (module , emulate = emulate )
129
+ return module_cls .from_float (module , emulate = emulate , ** from_float_kwargs )
126
130
127
131
# Mark all modules to skip as visited
128
132
root_module = module
@@ -146,7 +150,9 @@ def post_order_traversal(
146
150
assert (
147
151
parent_module is not None
148
152
), f"Linear root module should return early: { module } "
149
- float8linear_module = module_cls .from_float (module , emulate = emulate )
153
+ float8linear_module = module_cls .from_float (
154
+ module , emulate = emulate , ** from_float_kwargs
155
+ )
150
156
setattr (parent_module , module_name , float8linear_module )
151
157
152
158
post_order_traversal (root_module , "" , None )
0 commit comments