|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | + |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3). |
| 8 | +# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain: |
| 9 | +# 1. No encodings are reserved for special values (+/-inf, NaN). |
| 10 | +# 2. When downcasting from FP32 to FPx, |
| 11 | +# - Rounding mode is round to nearest, ties to even. |
| 12 | +# - Values outside the representable range of FPx after rounding are clamped to the maximum FPx |
| 13 | +# magnitude (sign is preserved). |
| 14 | + |
| 15 | +import torch |
| 16 | +from torch import Tensor |
| 17 | + |
| 18 | + |
| 19 | +def _n_ones(n: int) -> int: |
| 20 | + return (1 << n) - 1 |
| 21 | + |
| 22 | + |
| 23 | +EBITS_F32, MBITS_F32 = 8, 23 |
| 24 | +F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) |
| 25 | + |
| 26 | + |
| 27 | +def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: |
| 28 | + """Convert FP32 numbers to sub-byte floating point numbers with the given |
| 29 | + number of exponent and mantissa bits. |
| 30 | +
|
| 31 | + Input: torch.Tensor of dtype torch.float |
| 32 | + Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored |
| 33 | + in the least significant bits. e.g. |
| 34 | + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding |
| 35 | + fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding |
| 36 | +
|
| 37 | + Note: there are no special values (NaN, inf) support in this code. Values |
| 38 | + outside the representable range of FPx after rounding are clamped to the |
| 39 | + maximum FPx magnitude (sign is preserved). |
| 40 | +
|
| 41 | + Code below is an adaptation of https://fburl.com/code/ciwofcg4 |
| 42 | +
|
| 43 | + Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501 |
| 44 | + Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5 |
| 45 | + """ |
| 46 | + assert x.dtype == torch.float |
| 47 | + assert 1 + ebits + mbits <= 8 |
| 48 | + |
| 49 | + # calculate constants |
| 50 | + exp_bias = _n_ones(ebits - 1) |
| 51 | + max_int = _n_ones(ebits + mbits) |
| 52 | + sign_mask = 1 << (ebits + mbits) |
| 53 | + |
| 54 | + # TODO document this better |
| 55 | + magic_adder = _n_ones(MBITS_F32 - mbits - 1) |
| 56 | + |
| 57 | + # all E bits and M bits are 1s |
| 58 | + max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) |
| 59 | + |
| 60 | + # E bits = 1, M bits = 0 |
| 61 | + min_normal = 2 ** (1 - exp_bias) |
| 62 | + |
| 63 | + denorm_exp = ( |
| 64 | + # exp bias conversion between formats |
| 65 | + (F32_EXP_BIAS - exp_bias) |
| 66 | + # mantissa length difference between formats |
| 67 | + + (MBITS_F32 - mbits) |
| 68 | + # add one to encoded exponent for denormalized numbers |
| 69 | + + 1 |
| 70 | + ) |
| 71 | + denorm_mask_int = denorm_exp << MBITS_F32 |
| 72 | + |
| 73 | + # reinterpret int32 as float32 |
| 74 | + denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32) |
| 75 | + |
| 76 | + # save the sign |
| 77 | + # Note that we have torch.uint32, but some ops like cpu bit shifts |
| 78 | + # do not work on it. So, we stay in int32. |
| 79 | + x = x.view(torch.int32) |
| 80 | + sign = x & 0x80000000 |
| 81 | + |
| 82 | + # set everything to positive, will add sign back at the end |
| 83 | + x = x ^ sign |
| 84 | + |
| 85 | + # TODO: can the branch floating point comparisons below be done without |
| 86 | + # converting to float? probably but need to verify |
| 87 | + x = x.view(torch.float) |
| 88 | + |
| 89 | + # rewrite saturate/denorm/norm branches without explicit data dependent |
| 90 | + # control flow, to be more compiler friendly |
| 91 | + saturate_mask = x >= max_normal |
| 92 | + denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal) |
| 93 | + normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask)) |
| 94 | + |
| 95 | + # |
| 96 | + # branch 1: saturate to max val - handled later in the code which combines |
| 97 | + # the branches |
| 98 | + # |
| 99 | + |
| 100 | + # |
| 101 | + # branch 2: to conversion to denormal as well as rounding up to normal |
| 102 | + # |
| 103 | + denormal_x = x + denorm_mask_float |
| 104 | + denormal_x = denormal_x.view(torch.int32) |
| 105 | + denormal_x -= denorm_mask_int |
| 106 | + denormal_x = denormal_x.to(torch.uint8) |
| 107 | + |
| 108 | + # |
| 109 | + # branch 3: stay in normal range, adjust the exponent and round |
| 110 | + # |
| 111 | + normal_x = x.view(torch.int32) |
| 112 | + # resulting mantissa is odd |
| 113 | + mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 |
| 114 | + # update exponent, rounding bias part 1 |
| 115 | + val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder |
| 116 | + normal_x += val_to_add |
| 117 | + # rounding bias part 2 |
| 118 | + normal_x += mant_odd |
| 119 | + # take the bits! |
| 120 | + normal_x = normal_x >> (MBITS_F32 - mbits) |
| 121 | + normal_x = normal_x.to(torch.uint8) |
| 122 | + |
| 123 | + # |
| 124 | + # combine the branches |
| 125 | + # |
| 126 | + x = torch.full_like(x, max_int, dtype=torch.uint8) |
| 127 | + x = torch.where(denormal_mask, denormal_x, x) |
| 128 | + x = torch.where(normal_mask, normal_x, x) |
| 129 | + |
| 130 | + # add sign back |
| 131 | + sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) |
| 132 | + sign_lp = sign_lp.to(torch.uint8) |
| 133 | + # Right shift of a negative signed integer can fill the least significant |
| 134 | + # bits with either 1s or 0s, depending on the implementation. Since PyTorch |
| 135 | + # doesn't have an uint32 dtype, we mask out these bits to get just the |
| 136 | + # f4 sign bit |
| 137 | + sign_lp = sign_lp & sign_mask |
| 138 | + x = x | sign_lp |
| 139 | + |
| 140 | + return x.to(torch.uint8) |
| 141 | + |
| 142 | + |
| 143 | +# TODO(future): check if LUT for everything is faster than bit shifting, |
| 144 | +# especially for fp4 (only 2^4=16 unique values). |
| 145 | +def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: |
| 146 | + """Convert sub-byte floating point numbers with the given number of exponent |
| 147 | + and mantissa bits to FP32. |
| 148 | +
|
| 149 | + Input: torch.Tensor of dtype uint8, where the bit encoding is stored |
| 150 | + in the least significant bits. e.g. |
| 151 | + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding |
| 152 | + fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding |
| 153 | + Output: torch.Tensor of dtype fp32 with the dequantized value |
| 154 | + """ |
| 155 | + assert x.dtype == torch.uint8 |
| 156 | + assert 1 + ebits + mbits <= 8 |
| 157 | + |
| 158 | + sign_mask = 1 << (ebits + mbits) |
| 159 | + exp_bias = _n_ones(ebits - 1) |
| 160 | + mantissa_mask = _n_ones(mbits) |
| 161 | + |
| 162 | + # save the sign |
| 163 | + sign_lp = x & sign_mask |
| 164 | + |
| 165 | + # set everything to positive, will add sign back at the end |
| 166 | + x_pos = x ^ sign_lp |
| 167 | + |
| 168 | + # |
| 169 | + # 1. Calculate zero mask |
| 170 | + # |
| 171 | + zero_mask = x_pos == 0 |
| 172 | + |
| 173 | + # |
| 174 | + # 2. Calculate the denormal path mask |
| 175 | + # |
| 176 | + denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0)) |
| 177 | + |
| 178 | + # |
| 179 | + # 3. Calculate the normal path |
| 180 | + # |
| 181 | + |
| 182 | + # calculate the new exponent and shift it to bits 2:9 of the result |
| 183 | + exp_biased_lp = x_pos >> mbits |
| 184 | + exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS |
| 185 | + exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32 |
| 186 | + |
| 187 | + # shift the mantissa to bits 10:32 of the result |
| 188 | + mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32) |
| 189 | + mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits) |
| 190 | + result = exp_biased_f32 | mantissa_f32 |
| 191 | + |
| 192 | + # |
| 193 | + # 4. Add the zero and denormal casts to the already casted normal path |
| 194 | + # |
| 195 | + result[zero_mask] = 0 |
| 196 | + |
| 197 | + denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS |
| 198 | + |
| 199 | + # fast path. |
| 200 | + # without this, performance for FP4_E2M1 is slower by 2x |
| 201 | + if mbits == 1: |
| 202 | + result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32 |
| 203 | + |
| 204 | + else: |
| 205 | + # iterate over all possible values of mantissa |
| 206 | + # i=0, j=1 |
| 207 | + # i=1, j=10,11 |
| 208 | + # i=2, j=100,101,110,111 |
| 209 | + # and so on |
| 210 | + for i in range(mbits): |
| 211 | + for mantissa_cmp in range(1 << i, 1 << (i+1)): |
| 212 | + # left shift mantissa until it overflows (create an implicit 1) |
| 213 | + # subtract exponent by the same amount |
| 214 | + left_shift = mbits - i |
| 215 | + mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits) |
| 216 | + exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 |
| 217 | + |
| 218 | + # we can update this in-place since the values won't overlap |
| 219 | + mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 | mantissa_f32 |
| 220 | + |
| 221 | + result = torch.where(denormal_mask, mantissa_lp_int32, result) |
| 222 | + |
| 223 | + # add sign back |
| 224 | + sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits) |
| 225 | + result = result | sign_f32 |
| 226 | + |
| 227 | + return result.view(torch.float) |
0 commit comments