17
17
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
18
# See the License for the specific language governing permissions and
19
19
# limitations under the License.
20
-
20
+ """MX quantization utils."""
21
21
22
22
from enum import Enum , IntEnum
23
23
28
28
29
29
30
30
class ElemFormat (Enum ):
31
+ """Element format."""
32
+
31
33
int8 = 1
32
34
int4 = 2
33
35
int2 = 3
@@ -44,6 +46,7 @@ class ElemFormat(Enum):
44
46
45
47
@staticmethod
46
48
def from_str (s ):
49
+ """Get element format with str."""
47
50
assert s is not None , "String elem_format == None"
48
51
s = s .lower ()
49
52
if hasattr (ElemFormat , s ):
@@ -53,6 +56,7 @@ def from_str(s):
53
56
54
57
@staticmethod
55
58
def is_bf (s ):
59
+ """Whether the format is brain floating-point format."""
56
60
if isinstance (s , str ):
57
61
assert s is not None , "String elem_format == None"
58
62
s = s .lower ()
@@ -65,6 +69,7 @@ def is_bf(s):
65
69
66
70
@staticmethod
67
71
def is_fp (s ):
72
+ """Whether the format is floating-point format."""
68
73
if isinstance (s , str ):
69
74
assert s is not None , "String elem_format == None"
70
75
s = s .lower ()
@@ -77,6 +82,7 @@ def is_fp(s):
77
82
78
83
@staticmethod
79
84
def is_int (s ):
85
+ """Whether the format is integer format."""
80
86
if isinstance (s , str ):
81
87
assert s is not None , "String elem_format == None"
82
88
s = s .lower ()
@@ -89,12 +95,15 @@ def is_int(s):
89
95
90
96
91
97
class RoundingMode (IntEnum ):
98
+ """Rounding mode."""
99
+
92
100
nearest = 0
93
101
floor = 1
94
102
even = 2
95
103
96
104
@staticmethod
97
105
def string_enums ():
106
+ """Rounding mode names."""
98
107
return [s .name for s in list (RoundingMode )]
99
108
100
109
@@ -115,14 +124,19 @@ def _get_max_norm(ebits, mbits):
115
124
116
125
117
126
def _get_format_params (fmt ):
118
- """Allowed formats:
127
+ """Get parameters of the format.
128
+
129
+ Allowed formats:
119
130
- intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation
120
131
- floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf
121
132
- bfloatX/bfX: 9 <= X <= 32
122
133
- fp4, no NaN/Inf
123
134
- fp6_e3m2/e2m3, no NaN/Inf
124
135
- fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior
125
136
137
+ Args:
138
+ fmt (str od ElemFormat): format
139
+
126
140
Returns:
127
141
ebits: exponent bits
128
142
mbits: mantissa bits: includes sign and implicit bits
@@ -198,17 +212,19 @@ def _safe_rshift(x, bits, exp):
198
212
199
213
200
214
def _round_mantissa (A , bits , round , clamp = False ):
201
- """
202
- Rounds mantissa to nearest bits depending on the rounding method 'round'
215
+ """Rounds mantissa to nearest bits depending on the rounding method 'round'.
216
+
203
217
Args:
204
- A {PyTorch tensor} -- Input tensor
205
- round {str} -- Rounding method
206
- "floor" rounds to the floor
207
- "nearest" rounds to ceil or floor, whichever is nearest
218
+ A (torch.Tensor): input tensor
219
+ bits (int): bit number of mantissa
220
+ round (str): rounding method
221
+ "floor" rounds to the floor
222
+ "nearest" rounds to ceil or floor, whichever is nearest
223
+ clamp (bool, optional): Whether do clip. Defaults to False.
224
+
208
225
Returns:
209
- A {PyTorch tensor} -- Tensor with mantissas rounded
226
+ torch. Tensor: tensor with mantissas rounded
210
227
"""
211
-
212
228
if round == "dither" :
213
229
rand_A = torch .rand_like (A , requires_grad = False )
214
230
A = torch .sign (A ) * torch .floor (torch .abs (A ) + rand_A )
@@ -235,16 +251,18 @@ def _shared_exponents(A, method="max", axes=None, ebits=0):
235
251
"""Get shared exponents for the passed matrix A.
236
252
237
253
Args:
238
- A {PyTorch tensor} -- Input tensor
239
- method {str} -- Exponent selection method.
240
- "max" uses the max absolute value
241
- "none" uses an exponent for each value (i.e., no sharing)
242
- axes {list(int)} -- List of integers which specifies the axes across which
243
- shared exponents are calculated.
254
+ A (torch.Tensor): Input tensor
255
+ method (str, optional): Exponent selection method.
256
+ "max" uses the max absolute value.
257
+ "none" uses an exponent for each value (i.e., no sharing)
258
+ Defaults to "max".
259
+ axes (list(int), optional): list of integers which specifies the axes across which
260
+ shared exponents are calculated. Defaults to None.
261
+ ebits (int, optional): bit number of the shared exponents. Defaults to 0.
262
+
244
263
Returns:
245
- shared_exp {PyTorch tensor} -- Tensor of shared exponents
264
+ shared_exp (torch.Tensor): Tensor of shared exponents
246
265
"""
247
-
248
266
if method == "max" :
249
267
if axes is None :
250
268
shared_exp = torch .max (torch .abs (A ))
@@ -346,21 +364,20 @@ def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes):
346
364
347
365
348
366
def _quantize_elemwise_core (A , bits , exp_bits , max_norm , round = "nearest" , saturate_normals = False , allow_denorm = True ):
349
- """Core function used for element-wise quantization
350
- Arguments:
351
- A {PyTorch tensor} -- A tensor to be quantized
352
- bits {int} -- Number of mantissa bits. Includes
353
- sign bit and implicit one for floats
354
- exp_bits {int} -- Number of exponent bits, 0 for ints
355
- max_norm {float} -- Largest representable normal number
356
- round {str} -- Rounding mode: (floor, nearest, even)
357
- saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf)
358
- that exceed max norm are clamped.
359
- Must be True for correct MX conversion.
360
- allow_denorm {bool} -- If False, flush denorm numbers in the
361
- elem_format to zero.
367
+ """Core function used for element-wise quantization.
368
+
369
+ Args:
370
+ A (torch.Tensor): tensor to be quantized
371
+ bits (int): number of mantissa bits. Includes sign bit and implicit one for floats
372
+ exp_bits (int): number of exponent bits, 0 for ints
373
+ max_norm (float): largest representable normal number
374
+ round (str, optional): rounding mode: (floor, nearest, even). Defaults to "nearest".
375
+ saturate_normals (bool, optional): whether clip normal numbers that exceed max norm.
376
+ Must be True for correct MX conversion. Defaults to False.
377
+ allow_denorm (bool, optional): if False, flush denorm numbers in the elem_format to zero. Defaults to True.
378
+
362
379
Returns:
363
- quantized tensor {PyTorch tensor} -- A tensor that has been quantized
380
+ torch.Tensor: tensor that has been quantized
364
381
"""
365
382
# Flush values < min_norm to zero if denorms are not allowed
366
383
if not allow_denorm and exp_bits > 0 :
@@ -401,15 +418,20 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", satura
401
418
402
419
403
420
def _quantize_fp (A , exp_bits = None , mantissa_bits = None , round = "nearest" , allow_denorm = True ):
404
- """Quantize values to IEEE fpX format.
405
-
406
- The format defines NaN/Inf
407
- and subnorm numbers in the same way as FP32 and FP16.
408
- Arguments:
409
- exp_bits {int} -- number of bits used to store exponent
410
- mantissa_bits {int} -- number of bits used to store mantissa, not
411
- including sign or implicit 1
412
- round {str} -- Rounding mode, (floor, nearest, even)
421
+ """Quantize values to IEEE fpX format..
422
+
423
+ The format defines NaN/Inf and subnorm numbers in the same way as FP32 and FP16.
424
+
425
+ Args:
426
+ A (torch.Tensor): a tensor that needs to be quantized
427
+ exp_bits (int, optional): number of bits used to store exponent. Defaults to None.
428
+ mantissa_bits (int, optional): number of bits used to store mantissa.
429
+ Not including sign or implicit 1. Defaults to None.
430
+ round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest".
431
+ allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True.
432
+
433
+ Returns:
434
+ torch.Tensor: tensor that has been quantized
413
435
"""
414
436
# Shortcut for no quantization
415
437
if exp_bits is None or mantissa_bits is None :
@@ -425,11 +447,17 @@ def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_de
425
447
426
448
427
449
def _quantize_bfloat (A , bfloat , round = "nearest" , allow_denorm = True ):
428
- """Quantize values to bfloatX format
429
- Arguments:
430
- bfloat {int} -- Total number of bits for bfloatX format,
431
- Includes 1 sign, 8 exp bits, and variable
432
- mantissa bits. Must be >= 9.
450
+ """Quantize values to bfloatX format.
451
+
452
+ Args:
453
+ A (torch.Tensor): a tensor that needs to be quantized
454
+ bfloat (int): total number of bits for bfloatX format.
455
+ Includes 1 sign, 8 exp bits, and variable mantissa bits. Must be >= 9.
456
+ round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest".
457
+ allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True.
458
+
459
+ Returns:
460
+ torch.Tensor: tensor that has been quantized
433
461
"""
434
462
# Shortcut for no quantization
435
463
if bfloat == 0 or bfloat == 32 :
@@ -443,12 +471,14 @@ def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True):
443
471
444
472
445
473
def quantize_elemwise_op (A , mx_specs ):
446
- """A function used for element-wise quantization with mx_specs
447
- Arguments:
448
- A {PyTorch tensor} -- a tensor that needs to be quantized
449
- mx_specs {dictionary} -- dictionary to specify mx_specs
474
+ """A function used for element-wise quantization with mx_specs.
475
+
476
+ Args:
477
+ A (torch.Tensor): a tensor that needs to be quantized
478
+ mx_specs (dict): dictionary to specify mx_specs
479
+
450
480
Returns:
451
- quantized value {PyTorch tensor} -- a tensor that has been quantized
481
+ torch.Tensor: tensor that has been quantized
452
482
"""
453
483
if mx_specs is None :
454
484
return A
@@ -530,14 +560,15 @@ def _quantize_mx(
530
560
531
561
532
562
def quantize_mx_op (
533
- A ,
563
+ A : torch . Tensor ,
534
564
elem_format : str ,
535
565
round : str ,
536
566
block_size : int ,
537
567
scale_bits = 8 ,
538
568
axes = None ,
539
569
expand_and_reshape = False ,
540
570
):
571
+ """Quantize tensor to MX data type."""
541
572
if elem_format is None :
542
573
return A
543
574
elif type (elem_format ) is str :
0 commit comments