19
19
Float8Tensor ,
20
20
merge_mm_configs ,
21
21
ScaledMMConfig ,
22
+ ScalingStrategy ,
22
23
tensor_already_casted_to_fp8 ,
23
24
to_fp8_no_autograd ,
24
25
)
@@ -36,21 +37,27 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
36
37
@staticmethod
37
38
def forward (
38
39
ctx ,
39
- tensor ,
40
+ tensor : torch . Tensor ,
40
41
mm_config : ScaledMMConfig ,
42
+ scaling_strategy : ScalingStrategy ,
41
43
):
42
44
ctx .mm_config = mm_config
45
+ ctx .scaling_strategy = scaling_strategy
43
46
return tensor
44
47
45
48
@staticmethod
46
- def backward (ctx , gradY ):
49
+ def backward (ctx , gradY : torch . Tensor ):
47
50
if tensor_already_casted_to_fp8 (gradY ):
48
- return gradY , None
51
+ return gradY , None , None
49
52
gradY_scale = tensor_to_scale (gradY , e5m2_dtype )
50
53
fp8_tensor = to_fp8_no_autograd (
51
- gradY , gradY_scale , e5m2_dtype , mm_config = ctx .mm_config
54
+ gradY ,
55
+ gradY_scale ,
56
+ e5m2_dtype ,
57
+ mm_config = ctx .mm_config ,
58
+ scaling_strategy = ctx .scaling_strategy ,
52
59
)
53
- return fp8_tensor , None
60
+ return fp8_tensor , None , None
54
61
55
62
56
63
class Float8DynamicLinear (torch .nn .Linear ):
@@ -63,13 +70,15 @@ def __init__(self, **super_kwargs):
63
70
super ().__init__ (** super_kwargs )
64
71
65
72
def forward (self , input : torch .Tensor ) -> torch .Tensor :
66
- x_fp8 = cast_to_float8_e4m3fn (input , self .forward_config )
73
+ x_fp8 = cast_to_float8_e4m3fn (input , self .forward_config , self . scaling_strategy )
67
74
if isinstance (self .weight , Float8Tensor ): # cast by FSDP
68
75
w_fp8 = self .weight
69
76
else :
70
- w_fp8 = cast_to_float8_e4m3fn (self .weight , self .forward_config )
77
+ w_fp8 = cast_to_float8_e4m3fn (
78
+ self .weight , self .forward_config , self .scaling_strategy
79
+ )
71
80
y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
72
- y = cast_to_float8_e5m2_bw (y , self .backward_config )
81
+ y = cast_to_float8_e5m2_bw (y , self .backward_config , self . scaling_strategy )
73
82
return y
74
83
75
84
@classmethod
@@ -101,9 +110,14 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
101
110
fp8_output = False ,
102
111
pad_inner_dim = config .pad_inner_dim ,
103
112
)
113
+ # TODO: For now hardcode TensorWise scaling
114
+ new_mod .scaling_strategy = ScalingStrategy .TensorWise
115
+
104
116
if config .enable_fsdp_fp8_all_gather :
105
117
new_mod .weight = nn .Parameter (
106
- WeightWithDynamicFloat8CastTensor (mod .weight , new_mod .forward_config )
118
+ WeightWithDynamicFloat8CastTensor (
119
+ mod .weight , new_mod .forward_config , new_mod .scaling_strategy
120
+ )
107
121
)
108
122
else :
109
123
new_mod .weight = mod .weight
@@ -112,18 +126,27 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
112
126
113
127
114
128
def cast_to_float8_e4m3fn (
115
- inpt_tensor : torch .Tensor , mm_config : ScaledMMConfig , reduce_amax : bool = False
129
+ inpt_tensor : torch .Tensor ,
130
+ mm_config : ScaledMMConfig ,
131
+ scaling_strategy : ScalingStrategy ,
132
+ reduce_amax : bool = False ,
116
133
) -> Float8Tensor :
117
134
if tensor_already_casted_to_fp8 (inpt_tensor ):
118
135
return inpt_tensor
119
136
scale = tensor_to_scale (inpt_tensor , e4m3_dtype , reduce_amax )
120
- return Float8Tensor .to_float8 (inpt_tensor , scale , e4m3_dtype , mm_config = mm_config )
137
+ return Float8Tensor .to_float8 (
138
+ inpt_tensor ,
139
+ scale ,
140
+ e4m3_dtype ,
141
+ mm_config = mm_config ,
142
+ scaling_strategy = scaling_strategy ,
143
+ )
121
144
122
145
123
146
def cast_to_float8_e5m2_bw (
124
- gradY : torch .Tensor , mm_config : ScaledMMConfig
147
+ gradY : torch .Tensor , mm_config : ScaledMMConfig , scaling_strategy : ScalingStrategy
125
148
) -> torch .Tensor :
126
- return NoopFwToFloat8E5M2Bw .apply (gradY , mm_config )
149
+ return NoopFwToFloat8E5M2Bw .apply (gradY , mm_config , scaling_strategy )
127
150
128
151
129
152
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
@@ -143,7 +166,12 @@ def cast_to_float8_e5m2_bw(
143
166
144
167
class WeightWithDynamicFloat8CastTensor (torch .Tensor ):
145
168
@staticmethod
146
- def __new__ (cls , tensor : torch .Tensor , mm_config : ScaledMMConfig ):
169
+ def __new__ (
170
+ cls ,
171
+ tensor : torch .Tensor ,
172
+ mm_config : ScaledMMConfig ,
173
+ scaling_strategy : ScalingStrategy ,
174
+ ):
147
175
return torch .Tensor ._make_wrapper_subclass (
148
176
cls ,
149
177
tensor .size (),
@@ -157,24 +185,38 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
157
185
requires_grad = tensor .requires_grad ,
158
186
)
159
187
160
- def __init__ (self , tensor : torch .Tensor , mm_config : ScaledMMConfig ):
188
+ def __init__ (
189
+ self ,
190
+ tensor : torch .Tensor ,
191
+ mm_config : ScaledMMConfig ,
192
+ scaling_strategy : ScalingStrategy ,
193
+ ):
161
194
self ._tensor = tensor
162
195
self ._mm_config = mm_config
196
+ self ._scaling_strategy = scaling_strategy
163
197
164
198
@classmethod
165
199
def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
166
200
if func == torch .ops .aten .detach .default :
167
201
return WeightWithDynamicFloat8CastTensor (
168
- args [0 ]._tensor , args [0 ]._mm_config
202
+ args [0 ]._tensor , args [0 ]._mm_config , args [ 0 ]. _scaling_strategy
169
203
)
170
204
mm_config : Optional [ScaledMMConfig ] = None
205
+ scaling_strategy : Optional [ScalingStrategy ] = None
171
206
172
207
def unwrap (t ):
173
208
nonlocal mm_config
209
+ nonlocal scaling_strategy
174
210
if mm_config is None :
175
211
mm_config = t ._mm_config
176
212
else :
177
213
mm_config = merge_mm_configs (mm_config , t ._mm_config )
214
+
215
+ if scaling_strategy is None :
216
+ scaling_strategy = t ._scaling_strategy
217
+ else :
218
+ # TODO For now we assume that the scaling strategy is same across all tensors
219
+ assert scaling_strategy == t ._scaling_strategy
178
220
return t ._tensor
179
221
180
222
args , kwargs = pytree .tree_map_only (
@@ -184,23 +226,31 @@ def unwrap(t):
184
226
if func not in _ops_to_preserve_subclass :
185
227
return out
186
228
return pytree .tree_map_only (
187
- torch .Tensor , lambda x : WeightWithDynamicFloat8CastTensor (x , mm_config ), out
229
+ torch .Tensor ,
230
+ lambda x : WeightWithDynamicFloat8CastTensor (x , mm_config , scaling_strategy ),
231
+ out ,
188
232
)
189
233
190
234
def __tensor_flatten__ (self ):
191
- return ["_tensor" ], self ._mm_config
235
+ return ["_tensor" ], {
236
+ "_mm_config" : self ._mm_config ,
237
+ "_scaling_strategy" : self ._scaling_strategy ,
238
+ }
192
239
193
240
@staticmethod
194
241
def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
195
- mm_config = flatten_spec
196
- return WeightWithDynamicFloat8CastTensor (inner_tensors ["_tensor" ], mm_config )
242
+ mm_config = flatten_spec ["_mm_config" ]
243
+ scaling_strategy = flatten_spec ["_scaling_strategy" ]
244
+ return WeightWithDynamicFloat8CastTensor (
245
+ inner_tensors ["_tensor" ], mm_config , scaling_strategy
246
+ )
197
247
198
248
def __repr__ (self ):
199
- return f"WeightWithDynamicFloat8CastTensor(tensor={ self ._tensor } , mm_config={ self ._mm_config } )"
249
+ return f"WeightWithDynamicFloat8CastTensor(tensor={ self ._tensor } , mm_config={ self ._mm_config } , scaling_strategy= { self . _scaling_strategy } )"
200
250
201
251
def fsdp_pre_all_gather (self , mesh ):
202
252
float8_tensor = cast_to_float8_e4m3fn (
203
- self ._tensor , self ._mm_config , reduce_amax = True
253
+ self ._tensor , self ._mm_config , self . _scaling_strategy , reduce_amax = True
204
254
)
205
255
return (float8_tensor ._data ,), (float8_tensor ._scale ,)
206
256
@@ -218,4 +268,6 @@ def fsdp_post_all_gather(
218
268
assert isinstance (out , Float8Tensor ), f"{ type (out )} "
219
269
out ._scale = scale
220
270
return
221
- return Float8Tensor (data , scale , param_dtype , self ._mm_config ), (data ,)
271
+ return Float8Tensor (
272
+ data , scale , param_dtype , self ._mm_config , self ._scaling_strategy
273
+ ), (data ,)
0 commit comments