Skip to content

Commit eb1511e

Browse files
authored
Refactor custom FPx cast (#363)
* refactor custom fp cast * add dequant * small formating * compile with fullgraph=True * add fullgraph=true * undo * add another version * fast path for mbits=1 * add back docstring
1 parent 664f073 commit eb1511e

File tree

4 files changed

+239
-348
lines changed

4 files changed

+239
-348
lines changed

test/prototype/test_fp6_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_to_tc_float6_e3m2_compile(self, device):
3434
x = torch.randn(256, 64, device=device)
3535

3636
expected = to_tc_float6_e3m2(x)
37-
actual = torch.compile(to_tc_float6_e3m2)(x)
37+
actual = torch.compile(to_tc_float6_e3m2, fullgraph=True)(x)
3838
torch.testing.assert_close(actual, expected)
3939

4040
@parametrize("device", _DEVICES)
@@ -53,7 +53,7 @@ def test_from_tc_float6_e3m2_compile(self, device):
5353
x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device)
5454

5555
expected = from_tc_float6_e3m2(x)
56-
actual = torch.compile(from_tc_float6_e3m2)(x)
56+
actual = torch.compile(from_tc_float6_e3m2, fullgraph=True)(x)
5757
torch.testing.assert_close(actual, expected)
5858

5959
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -81,7 +81,7 @@ def test_fp6_llm_linear_compile(self, bias):
8181

8282
x = torch.randn(N, IC, device=device, dtype=torch.half)
8383
expected = fp6_linear(x)
84-
actual = torch.compile(fp6_linear)(x)
84+
actual = torch.compile(fp6_linear, fullgraph=True)(x)
8585
torch.testing.assert_close(actual, expected)
8686

8787
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3).
8+
# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
9+
# 1. No encodings are reserved for special values (+/-inf, NaN).
10+
# 2. When downcasting from FP32 to FPx,
11+
# - Rounding mode is round to nearest, ties to even.
12+
# - Values outside the representable range of FPx after rounding are clamped to the maximum FPx
13+
# magnitude (sign is preserved).
14+
15+
import torch
16+
from torch import Tensor
17+
18+
19+
def _n_ones(n: int) -> int:
20+
return (1 << n) - 1
21+
22+
23+
EBITS_F32, MBITS_F32 = 8, 23
24+
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
25+
26+
27+
def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
28+
"""Convert FP32 numbers to sub-byte floating point numbers with the given
29+
number of exponent and mantissa bits.
30+
31+
Input: torch.Tensor of dtype torch.float
32+
Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
33+
in the least significant bits. e.g.
34+
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
35+
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
36+
37+
Note: there are no special values (NaN, inf) support in this code. Values
38+
outside the representable range of FPx after rounding are clamped to the
39+
maximum FPx magnitude (sign is preserved).
40+
41+
Code below is an adaptation of https://fburl.com/code/ciwofcg4
42+
43+
Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
44+
Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
45+
"""
46+
assert x.dtype == torch.float
47+
assert 1 + ebits + mbits <= 8
48+
49+
# calculate constants
50+
exp_bias = _n_ones(ebits - 1)
51+
max_int = _n_ones(ebits + mbits)
52+
sign_mask = 1 << (ebits + mbits)
53+
54+
# TODO document this better
55+
magic_adder = _n_ones(MBITS_F32 - mbits - 1)
56+
57+
# all E bits and M bits are 1s
58+
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
59+
60+
# E bits = 1, M bits = 0
61+
min_normal = 2 ** (1 - exp_bias)
62+
63+
denorm_exp = (
64+
# exp bias conversion between formats
65+
(F32_EXP_BIAS - exp_bias)
66+
# mantissa length difference between formats
67+
+ (MBITS_F32 - mbits)
68+
# add one to encoded exponent for denormalized numbers
69+
+ 1
70+
)
71+
denorm_mask_int = denorm_exp << MBITS_F32
72+
73+
# reinterpret int32 as float32
74+
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)
75+
76+
# save the sign
77+
# Note that we have torch.uint32, but some ops like cpu bit shifts
78+
# do not work on it. So, we stay in int32.
79+
x = x.view(torch.int32)
80+
sign = x & 0x80000000
81+
82+
# set everything to positive, will add sign back at the end
83+
x = x ^ sign
84+
85+
# TODO: can the branch floating point comparisons below be done without
86+
# converting to float? probably but need to verify
87+
x = x.view(torch.float)
88+
89+
# rewrite saturate/denorm/norm branches without explicit data dependent
90+
# control flow, to be more compiler friendly
91+
saturate_mask = x >= max_normal
92+
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
93+
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
94+
95+
#
96+
# branch 1: saturate to max val - handled later in the code which combines
97+
# the branches
98+
#
99+
100+
#
101+
# branch 2: to conversion to denormal as well as rounding up to normal
102+
#
103+
denormal_x = x + denorm_mask_float
104+
denormal_x = denormal_x.view(torch.int32)
105+
denormal_x -= denorm_mask_int
106+
denormal_x = denormal_x.to(torch.uint8)
107+
108+
#
109+
# branch 3: stay in normal range, adjust the exponent and round
110+
#
111+
normal_x = x.view(torch.int32)
112+
# resulting mantissa is odd
113+
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
114+
# update exponent, rounding bias part 1
115+
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
116+
normal_x += val_to_add
117+
# rounding bias part 2
118+
normal_x += mant_odd
119+
# take the bits!
120+
normal_x = normal_x >> (MBITS_F32 - mbits)
121+
normal_x = normal_x.to(torch.uint8)
122+
123+
#
124+
# combine the branches
125+
#
126+
x = torch.full_like(x, max_int, dtype=torch.uint8)
127+
x = torch.where(denormal_mask, denormal_x, x)
128+
x = torch.where(normal_mask, normal_x, x)
129+
130+
# add sign back
131+
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
132+
sign_lp = sign_lp.to(torch.uint8)
133+
# Right shift of a negative signed integer can fill the least significant
134+
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
135+
# doesn't have an uint32 dtype, we mask out these bits to get just the
136+
# f4 sign bit
137+
sign_lp = sign_lp & sign_mask
138+
x = x | sign_lp
139+
140+
return x.to(torch.uint8)
141+
142+
143+
# TODO(future): check if LUT for everything is faster than bit shifting,
144+
# especially for fp4 (only 2^4=16 unique values).
145+
def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
146+
"""Convert sub-byte floating point numbers with the given number of exponent
147+
and mantissa bits to FP32.
148+
149+
Input: torch.Tensor of dtype uint8, where the bit encoding is stored
150+
in the least significant bits. e.g.
151+
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
152+
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
153+
Output: torch.Tensor of dtype fp32 with the dequantized value
154+
"""
155+
assert x.dtype == torch.uint8
156+
assert 1 + ebits + mbits <= 8
157+
158+
sign_mask = 1 << (ebits + mbits)
159+
exp_bias = _n_ones(ebits - 1)
160+
mantissa_mask = _n_ones(mbits)
161+
162+
# save the sign
163+
sign_lp = x & sign_mask
164+
165+
# set everything to positive, will add sign back at the end
166+
x_pos = x ^ sign_lp
167+
168+
#
169+
# 1. Calculate zero mask
170+
#
171+
zero_mask = x_pos == 0
172+
173+
#
174+
# 2. Calculate the denormal path mask
175+
#
176+
denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
177+
178+
#
179+
# 3. Calculate the normal path
180+
#
181+
182+
# calculate the new exponent and shift it to bits 2:9 of the result
183+
exp_biased_lp = x_pos >> mbits
184+
exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
185+
exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
186+
187+
# shift the mantissa to bits 10:32 of the result
188+
mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
189+
mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
190+
result = exp_biased_f32 | mantissa_f32
191+
192+
#
193+
# 4. Add the zero and denormal casts to the already casted normal path
194+
#
195+
result[zero_mask] = 0
196+
197+
denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
198+
199+
# fast path.
200+
# without this, performance for FP4_E2M1 is slower by 2x
201+
if mbits == 1:
202+
result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
203+
204+
else:
205+
# iterate over all possible values of mantissa
206+
# i=0, j=1
207+
# i=1, j=10,11
208+
# i=2, j=100,101,110,111
209+
# and so on
210+
for i in range(mbits):
211+
for mantissa_cmp in range(1 << i, 1 << (i+1)):
212+
# left shift mantissa until it overflows (create an implicit 1)
213+
# subtract exponent by the same amount
214+
left_shift = mbits - i
215+
mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
216+
exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
217+
218+
# we can update this in-place since the values won't overlap
219+
mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 | mantissa_f32
220+
221+
result = torch.where(denormal_mask, mantissa_lp_int32, result)
222+
223+
# add sign back
224+
sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
225+
result = result | sign_f32
226+
227+
return result.view(torch.float)

torchao/prototype/mx_formats/benchmarks/bench_qdq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def run(profile_folder: Optional[str] = None):
6161
data_lp = MXTensor.to_mx(data_hp, elem_dtype, block_size=32)
6262

6363
if not use_fp4_custom_triton_dequant_kernel:
64-
quant = torch.compile(MXTensor.to_mx)
65-
dequant = torch.compile(data_lp.to_dtype)
64+
quant = torch.compile(MXTensor.to_mx, fullgraph=True)
65+
dequant = torch.compile(data_lp.to_dtype, fullgraph=True)
6666
else:
6767
# As of 2024-04, torch.compile didn't work with the
6868
# handwritten triton kernel,

0 commit comments

Comments
 (0)