@@ -48,6 +48,7 @@ def float8_desugar_op(aten_op, args, kwargs=None):
48
48
return Float8Tensor (
49
49
new_data ,
50
50
args [0 ]._scale ,
51
+ args [0 ]._inv_scale ,
51
52
args [0 ]._orig_dtype ,
52
53
args [0 ]._linear_mm_config ,
53
54
args [0 ]._gemm_input_role ,
@@ -62,6 +63,7 @@ def make_float8(data):
62
63
return Float8Tensor (
63
64
data ,
64
65
args [0 ]._scale ,
66
+ args [0 ]._inv_scale ,
65
67
args [0 ]._orig_dtype ,
66
68
args [0 ]._linear_mm_config ,
67
69
args [0 ]._gemm_input_role ,
@@ -78,6 +80,7 @@ def float8_cat(aten_op, args, kwargs=None):
78
80
79
81
orig_dtype = chunked_tensors [0 ]._orig_dtype
80
82
scale = chunked_tensors [0 ]._scale
83
+ inv_scale = chunked_tensors [0 ]._inv_scale
81
84
mm_config = chunked_tensors [0 ]._linear_mm_config
82
85
fp8_dtype = chunked_tensors [0 ]._data .dtype
83
86
gemm_input_role = chunked_tensors [0 ]._gemm_input_role
@@ -105,7 +108,7 @@ def float8_cat(aten_op, args, kwargs=None):
105
108
106
109
new_data = aten_op (chunk_data , * args [1 :], ** kwargs )
107
110
new_data = new_data .view (fp8_dtype )
108
- return Float8Tensor (new_data , scale , orig_dtype , mm_config , gemm_input_role )
111
+ return Float8Tensor (new_data , scale , inv_scale , orig_dtype , mm_config , gemm_input_role )
109
112
110
113
111
114
@implements ([aten .sum .dim_IntList ])
@@ -130,7 +133,7 @@ def unwrap(x):
130
133
131
134
def preprocess_addmm (a : Float8Tensor , b : Float8Tensor ):
132
135
a_data = a ._data
133
- a_scale = a ._scale
136
+ a_inv_scale = a ._inv_scale
134
137
b_data = b ._data
135
138
136
139
scaled_mm_config = choose_scaled_mm_config (
@@ -151,8 +154,8 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
151
154
a_data = a_data .contiguous ()
152
155
if is_row_major (b_data .stride ()):
153
156
b_data = b_data .t ().contiguous ().t ()
154
- b_scale = b ._scale
155
- return a_data , a_scale , b_data , b_scale
157
+ b_inv_scale = b ._inv_scale
158
+ return a_data , a_inv_scale , b_data , b_inv_scale
156
159
157
160
158
161
@implements ([aten .mm .default , aten .matmul .default ])
@@ -165,7 +168,7 @@ def float8_mm(aten_op, args, kwargs=None):
165
168
), "Expecting both Float8Tensor for mm inputs but found {} and {}" .format (
166
169
type (a ), type (b )
167
170
)
168
- a_data , a_scale , b_data , b_scale = preprocess_addmm (a , b )
171
+ a_data , a_inv_scale , b_data , b_inv_scale = preprocess_addmm (a , b )
169
172
output_dtype = a ._orig_dtype
170
173
scaled_mm_config = choose_scaled_mm_config (
171
174
a ._gemm_input_role ,
@@ -175,13 +178,13 @@ def float8_mm(aten_op, args, kwargs=None):
175
178
)
176
179
if scaled_mm_config .emulate :
177
180
return torch .ops .aten .mm_float8_emulated (
178
- a ._data , a ._scale , b ._data , b ._scale , output_dtype
181
+ a ._data , a ._inv_scale , b ._data , b ._inv_scale , output_dtype
179
182
)
180
183
tensor_out = addmm_float8_unwrapped (
181
184
a_data ,
182
- a_scale ,
185
+ a_inv_scale ,
183
186
b_data ,
184
- b_scale ,
187
+ b_inv_scale ,
185
188
output_dtype ,
186
189
output_scale = None ,
187
190
bias = None ,
@@ -200,7 +203,7 @@ def float8_addmm(aten_op, args, kwargs=None):
200
203
bias = args [0 ]
201
204
a = args [1 ]
202
205
b = args [2 ]
203
- a_data , a_scale , b_data , b_scale = preprocess_addmm (a , b )
206
+ a_data , a_inv_scale , b_data , b_inv_scale = preprocess_addmm (a , b )
204
207
output_dtype = a ._orig_dtype
205
208
assert bias .dtype == output_dtype , "bias dtype must match output dtype"
206
209
scaled_mm_config = choose_scaled_mm_config (
@@ -210,15 +213,16 @@ def float8_addmm(aten_op, args, kwargs=None):
210
213
b ._linear_mm_config ,
211
214
)
212
215
if scaled_mm_config .emulate :
216
+ # TODO inv scale here
213
217
out = torch .ops .aten .mm_float8_emulated (
214
218
a ._data , a ._scale , b ._data , b ._scale , output_dtype
215
219
)
216
220
return out + bias
217
221
tensor_out = addmm_float8_unwrapped (
218
222
a_data ,
219
- a_scale ,
223
+ a_inv_scale ,
220
224
b_data ,
221
- b_scale ,
225
+ b_inv_scale ,
222
226
output_dtype ,
223
227
output_scale = None ,
224
228
bias = bias ,
@@ -249,6 +253,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
249
253
return Float8Tensor (
250
254
args [0 ]._data ,
251
255
args [0 ]._scale ,
256
+ args [0 ]._inv_scale ,
252
257
kwargs ["dtype" ],
253
258
args [0 ]._linear_mm_config ,
254
259
args [0 ]._gemm_input_role ,
@@ -276,6 +281,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
276
281
return Float8Tensor (
277
282
fp8_out ,
278
283
fp8_input ._scale ,
284
+ fp8_input ._inv_scale ,
279
285
fp8_input ._orig_dtype ,
280
286
fp8_input ._linear_mm_config ,
281
287
fp8_input ._gemm_input_role ,
@@ -292,6 +298,7 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
292
298
return Float8Tensor (
293
299
fp8_out ,
294
300
fp8_input ._scale ,
301
+ fp8_input ._inv_scale ,
295
302
fp8_input ._orig_dtype ,
296
303
fp8_input ._linear_mm_config ,
297
304
fp8_input ._gemm_input_role ,
@@ -314,6 +321,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
314
321
return Float8Tensor (
315
322
fp8_out ,
316
323
fp8_self ._scale ,
324
+ fp8_self ._inv_scale ,
317
325
fp8_self ._orig_dtype ,
318
326
fp8_self ._linear_mm_config ,
319
327
fp8_self ._gemm_input_role ,
@@ -355,6 +363,7 @@ def copy_fp8(aten_op, args, kwargs=None):
355
363
return Float8Tensor (
356
364
fp8_out ,
357
365
self ._scale ,
366
+ self ._inv_scale ,
358
367
self ._orig_dtype ,
359
368
self ._linear_mm_config ,
360
369
self ._gemm_input_role ,
0 commit comments