Skip to content

Commit a9dd5a9

Browse files
committed
[wip] store inv_scale on Float8Tensor
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6ab4322 ghstack-comment-id: 2273795212 Pull Request resolved: #628
1 parent d582f9a commit a9dd5a9

File tree

7 files changed

+62
-31
lines changed

7 files changed

+62
-31
lines changed

benchmarks/float8/profile_linear_float8.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def main(
209209
scaling_type_grad_output: str = "dynamic",
210210
model_type: str = "linear",
211211
dtype_filter: str = "both",
212+
skip_amax_sync: bool = False,
212213
):
213214
assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported"
214215
assert dtype_filter in ("both", "float8", "bfloat16")
@@ -220,6 +221,9 @@ def main(
220221
cast_config_input=CastConfig(scaling_type=scaling_type_input),
221222
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
222223
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
224+
# for now we don't care about amax init for performance profiling
225+
enable_amax_init=False,
226+
enable_pre_and_post_forward = not skip_amax_sync,
223227
)
224228
scaling_repr = "_".join(
225229
[
@@ -290,7 +294,7 @@ def float8_forw_backward_wrapper(x):
290294
# inspection of the fw+bw torch.compile without the scale
291295
# syncing code
292296
# TODO(future): make this better
293-
if linear_requires_sync(config):
297+
if linear_requires_sync(config) and not skip_amax_sync:
294298
with record_function("scale_amax_and_scales"):
295299
sync_amax_history(m_float8)
296300
out = float8_forw(x)

test/float8/test_base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def test_copy_(self):
128128
fp8_b = Float8Tensor(
129129
torch.empty(16, dtype=torch.float8_e4m3fn),
130130
scale_a,
131+
scale_a.reciprocal(),
131132
torch.bfloat16,
132133
fp8_a._linear_mm_config,
133134
)
@@ -417,14 +418,14 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
417418

418419
out_scaled_mm = addmm_float8_unwrapped(
419420
a_fp8._data,
420-
a_fp8._scale,
421+
a_fp8._inv_scale,
421422
b_fp8._data,
422-
b_fp8._scale,
423+
b_fp8._inv_scale,
423424
output_dtype=output_dtype,
424425
use_fast_accum=use_fast_accum,
425426
)
426427
out_emulated = torch.ops.aten.mm_float8_emulated(
427-
a_fp8._data, a_fp8._scale, b_fp8._data, b_fp8._scale, output_dtype
428+
a_fp8._data, a_fp8._inv_scale, b_fp8._data, b_fp8._inv_scale, output_dtype
428429
)
429430

430431
if output_dtype != base_dtype:

torchao/float8/float8_aten_api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
def mm_float8_emulated(
1717
m1, # input 1 data
18-
s1, # input 1 scale
18+
inv_s1, # input 1 inverse scale
1919
m2, # input 2 data
20-
s2, # input 2 scale
20+
inv_s2, # input 2 inverse scale
2121
dtype3, # output dtype
2222
):
2323
# naive implementation: dq -> op -> q
24-
m1_fp32 = m1.float() / s1
25-
m2_fp32 = m2.float() / s2
24+
m1_fp32 = m1.float() * inv_s1
25+
m2_fp32 = m2.float() * inv_s2
2626
m3_fp32 = torch.mm(m1_fp32, m2_fp32)
2727

2828
return m3_fp32.to(dtype3)
@@ -37,13 +37,13 @@ def mm_float8_emulated(
3737
lib = Library("aten", "FRAGMENT")
3838

3939
lib.define(
40-
"mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> Tensor"
40+
"mm_float8_emulated(Tensor m1, Tensor inv_s1, Tensor m2, Tensor inv_s2, ScalarType dtype3) -> Tensor"
4141
)
4242
lib.impl("mm_float8_emulated", mm_float8_emulated, "CPU")
4343
lib.impl("mm_float8_emulated", mm_float8_emulated, "CUDA")
4444

4545

4646
@torch.library.impl(lib, "mm_float8_emulated", "Meta")
47-
def _mm_float8_emulated_meta(m1, s1, m2, s2, dtype3):
47+
def _mm_float8_emulated_meta(m1, inv_s1, m2, inv_s2, dtype3):
4848
out = torch.mm(m1.float(), m2.float()).to(dtype3)
4949
return out

torchao/float8/float8_ops.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def float8_desugar_op(aten_op, args, kwargs=None):
4848
return Float8Tensor(
4949
new_data,
5050
args[0]._scale,
51+
args[0]._inv_scale,
5152
args[0]._orig_dtype,
5253
args[0]._linear_mm_config,
5354
args[0]._gemm_input_role,
@@ -62,6 +63,7 @@ def make_float8(data):
6263
return Float8Tensor(
6364
data,
6465
args[0]._scale,
66+
args[0]._inv_scale,
6567
args[0]._orig_dtype,
6668
args[0]._linear_mm_config,
6769
args[0]._gemm_input_role,
@@ -78,6 +80,7 @@ def float8_cat(aten_op, args, kwargs=None):
7880

7981
orig_dtype = chunked_tensors[0]._orig_dtype
8082
scale = chunked_tensors[0]._scale
83+
inv_scale = chunked_tensors[0]._inv_scale
8184
mm_config = chunked_tensors[0]._linear_mm_config
8285
fp8_dtype = chunked_tensors[0]._data.dtype
8386
gemm_input_role = chunked_tensors[0]._gemm_input_role
@@ -105,7 +108,7 @@ def float8_cat(aten_op, args, kwargs=None):
105108

106109
new_data = aten_op(chunk_data, *args[1:], **kwargs)
107110
new_data = new_data.view(fp8_dtype)
108-
return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role)
111+
return Float8Tensor(new_data, scale, inv_scale, orig_dtype, mm_config, gemm_input_role)
109112

110113

111114
@implements([aten.sum.dim_IntList])
@@ -130,7 +133,7 @@ def unwrap(x):
130133

131134
def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
132135
a_data = a._data
133-
a_scale = a._scale
136+
a_inv_scale = a._inv_scale
134137
b_data = b._data
135138

136139
scaled_mm_config = choose_scaled_mm_config(
@@ -151,8 +154,8 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
151154
a_data = a_data.contiguous()
152155
if is_row_major(b_data.stride()):
153156
b_data = b_data.t().contiguous().t()
154-
b_scale = b._scale
155-
return a_data, a_scale, b_data, b_scale
157+
b_inv_scale = b._inv_scale
158+
return a_data, a_inv_scale, b_data, b_inv_scale
156159

157160

158161
@implements([aten.mm.default, aten.matmul.default])
@@ -165,7 +168,7 @@ def float8_mm(aten_op, args, kwargs=None):
165168
), "Expecting both Float8Tensor for mm inputs but found {} and {}".format(
166169
type(a), type(b)
167170
)
168-
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
171+
a_data, a_inv_scale, b_data, b_inv_scale = preprocess_addmm(a, b)
169172
output_dtype = a._orig_dtype
170173
scaled_mm_config = choose_scaled_mm_config(
171174
a._gemm_input_role,
@@ -175,13 +178,13 @@ def float8_mm(aten_op, args, kwargs=None):
175178
)
176179
if scaled_mm_config.emulate:
177180
return torch.ops.aten.mm_float8_emulated(
178-
a._data, a._scale, b._data, b._scale, output_dtype
181+
a._data, a._inv_scale, b._data, b._inv_scale, output_dtype
179182
)
180183
tensor_out = addmm_float8_unwrapped(
181184
a_data,
182-
a_scale,
185+
a_inv_scale,
183186
b_data,
184-
b_scale,
187+
b_inv_scale,
185188
output_dtype,
186189
output_scale=None,
187190
bias=None,
@@ -200,7 +203,7 @@ def float8_addmm(aten_op, args, kwargs=None):
200203
bias = args[0]
201204
a = args[1]
202205
b = args[2]
203-
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
206+
a_data, a_inv_scale, b_data, b_inv_scale = preprocess_addmm(a, b)
204207
output_dtype = a._orig_dtype
205208
assert bias.dtype == output_dtype, "bias dtype must match output dtype"
206209
scaled_mm_config = choose_scaled_mm_config(
@@ -210,15 +213,16 @@ def float8_addmm(aten_op, args, kwargs=None):
210213
b._linear_mm_config,
211214
)
212215
if scaled_mm_config.emulate:
216+
# TODO inv scale here
213217
out = torch.ops.aten.mm_float8_emulated(
214218
a._data, a._scale, b._data, b._scale, output_dtype
215219
)
216220
return out + bias
217221
tensor_out = addmm_float8_unwrapped(
218222
a_data,
219-
a_scale,
223+
a_inv_scale,
220224
b_data,
221-
b_scale,
225+
b_inv_scale,
222226
output_dtype,
223227
output_scale=None,
224228
bias=bias,
@@ -249,6 +253,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
249253
return Float8Tensor(
250254
args[0]._data,
251255
args[0]._scale,
256+
args[0]._inv_scale,
252257
kwargs["dtype"],
253258
args[0]._linear_mm_config,
254259
args[0]._gemm_input_role,
@@ -276,6 +281,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
276281
return Float8Tensor(
277282
fp8_out,
278283
fp8_input._scale,
284+
fp8_input._inv_scale,
279285
fp8_input._orig_dtype,
280286
fp8_input._linear_mm_config,
281287
fp8_input._gemm_input_role,
@@ -292,6 +298,7 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
292298
return Float8Tensor(
293299
fp8_out,
294300
fp8_input._scale,
301+
fp8_input._inv_scale,
295302
fp8_input._orig_dtype,
296303
fp8_input._linear_mm_config,
297304
fp8_input._gemm_input_role,
@@ -314,6 +321,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
314321
return Float8Tensor(
315322
fp8_out,
316323
fp8_self._scale,
324+
fp8_self._inv_scale,
317325
fp8_self._orig_dtype,
318326
fp8_self._linear_mm_config,
319327
fp8_self._gemm_input_role,
@@ -355,6 +363,7 @@ def copy_fp8(aten_op, args, kwargs=None):
355363
return Float8Tensor(
356364
fp8_out,
357365
self._scale,
366+
self._inv_scale,
358367
self._orig_dtype,
359368
self._linear_mm_config,
360369
self._gemm_input_role,

torchao/float8/float8_python_api.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
# For output going from fp32 -> fp8 we multiply by the scale
2424
def addmm_float8_unwrapped(
2525
a_data: torch.Tensor,
26-
a_scale: torch.Tensor,
26+
a_inv_scale: torch.Tensor,
2727
b_data: torch.Tensor,
28-
b_scale: torch.tensor,
28+
b_inv_scale: torch.tensor,
2929
output_dtype: torch.dtype,
3030
output_scale: Optional[torch.Tensor] = None,
3131
bias: Optional[torch.Tensor] = None,
@@ -36,8 +36,10 @@ def addmm_float8_unwrapped(
3636
as inputs. This is used to standardize the logic between subclassed and non subclassed
3737
versions of the linear module.
3838
"""
39-
a_inverse_scale = a_scale.reciprocal()
40-
b_inverse_scale = b_scale.reciprocal()
39+
# a_inverse_scale = a_scale.reciprocal()
40+
# b_inverse_scale = b_scale.reciprocal()
41+
a_inverse_scale = a_inv_scale
42+
b_inverse_scale = b_inv_scale
4143
if output_dtype == torch.float32 and bias is not None:
4244
# Bias is not supported by _scaled_mm when output is fp32
4345
output = torch._scaled_mm(

torchao/float8/float8_tensor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def forward(
157157
"""
158158
tensor_scaled = tensor * scale
159159
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
160+
inv_scale = scale.reciprocal()
160161

161162
if isinstance(bits_fp8, DTensor):
162163
assert isinstance(
@@ -166,9 +167,11 @@ def forward(
166167
bits_placements = bits_fp8.placements
167168
local_bits = bits_fp8.to_local()
168169
local_scale = scale.to_local()
170+
local_inv_scale = inv_scale.to_local()
169171
inner_float8_tensor = Float8Tensor(
170172
local_bits,
171173
local_scale,
174+
local_inv_scale,
172175
tensor.dtype,
173176
linear_mm_config=linear_mm_config,
174177
gemm_input_role=gemm_input_role,
@@ -185,6 +188,7 @@ def forward(
185188
return Float8Tensor(
186189
bits_fp8,
187190
scale,
191+
inv_scale,
188192
tensor.dtype,
189193
linear_mm_config=linear_mm_config,
190194
gemm_input_role=gemm_input_role,
@@ -251,6 +255,11 @@ class Float8Tensor(torch.Tensor):
251255
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
252256
by scale to go from fp32 range to fp8 range, and divide by scale to go
253257
from fp8 range to fp32 range.
258+
* `_inv_scale`: the inverse of `_scale`. We need this because the
259+
`torch._scaled_mm` function requires inverse scales, and torch.compile
260+
does not reliably fuse this into preceding ops, which can lead to extra
261+
GPU kernel launches. If we calculate the inverse scale colocated with
262+
creating the `Float8Tensor` instance, we don't see the extra GPU kernels.
254263
* `_orig_dtype`: the original dtype of the tensor used to create this
255264
tensor.
256265
* `_emulate`: if true using fp32 emulation for the matmuls, helpful
@@ -275,6 +284,7 @@ def __new__(
275284
cls,
276285
data: torch.Tensor,
277286
scale: torch.Tensor,
287+
inv_scale: torch.Tensor,
278288
orig_dtype: torch.dtype,
279289
linear_mm_config: Optional[LinearMMConfig],
280290
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
@@ -297,6 +307,7 @@ def __new__(
297307
)
298308
self._data = data
299309
self._scale = scale
310+
self._inv_scale = inv_scale
300311
self._orig_dtype = orig_dtype
301312
self._linear_mm_config = (
302313
linear_mm_config if linear_mm_config is not None else LinearMMConfig()
@@ -314,14 +325,15 @@ def __tensor_flatten__(self):
314325
"_linear_mm_config": self._linear_mm_config,
315326
"_gemm_input_role": self._gemm_input_role,
316327
}
317-
return ["_data", "_scale"], ctx
328+
return ["_data", "_scale", "_inv_scale"], ctx
318329

319330
@staticmethod
320331
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
321-
assert len(inner_tensors) == 2
332+
assert len(inner_tensors) == 3
322333
return Float8Tensor(
323334
inner_tensors["_data"],
324335
inner_tensors["_scale"],
336+
inner_tensors["_inv_scale"],
325337
metadata["_orig_dtype"],
326338
metadata["_linear_mm_config"],
327339
metadata["_gemm_input_role"],

torchao/float8/fsdp_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def fsdp_pre_all_gather(self, mesh):
217217
reduce_amax=True,
218218
gemm_input_role=GemmInputRole.WEIGHT,
219219
)
220-
return (float8_tensor._data,), (float8_tensor._scale,)
220+
return (float8_tensor._data,), (float8_tensor._scale, float8_tensor._inv_scale)
221221

222222
def fsdp_post_all_gather(
223223
self,
@@ -228,7 +228,7 @@ def fsdp_post_all_gather(
228228
out: Optional[torch.Tensor] = None,
229229
):
230230
(data,) = all_gather_outputs
231-
(scale,) = metadata
231+
(scale, inv_scale) = metadata
232232
if out is not None:
233233
from torch.distributed._tensor import DTensor
234234
if isinstance(out, Float8Tensor):
@@ -245,6 +245,7 @@ def fsdp_post_all_gather(
245245
return Float8Tensor(
246246
data,
247247
scale,
248+
inv_scale,
248249
param_dtype,
249250
self._linear_mm_config,
250251
gemm_input_role=GemmInputRole.WEIGHT,
@@ -407,7 +408,7 @@ def fsdp_pre_all_gather(self, mesh):
407408
self._linear_mm_config,
408409
GemmInputRole.WEIGHT,
409410
)
410-
return (float8_tensor._data,), (float8_tensor._scale,)
411+
return (float8_tensor._data,), (float8_tensor._scale, float8_tensor._inv_scale)
411412

412413
def fsdp_post_all_gather(
413414
self,
@@ -418,14 +419,16 @@ def fsdp_post_all_gather(
418419
out: Optional[torch.Tensor] = None,
419420
):
420421
(data,) = all_gather_outputs
421-
(scale,) = metadata
422+
(scale, inv_scale) = metadata
422423
if out is not None:
423424
assert isinstance(out, Float8Tensor), f"{type(out)}"
424425
out._scale = scale
426+
out._inv_scale = inv_scale
425427
return
426428
return Float8Tensor(
427429
data,
428430
scale,
431+
inv_scale,
429432
param_dtype,
430433
self._linear_mm_config,
431434
gemm_input_role=GemmInputRole.WEIGHT,

0 commit comments

Comments
 (0)