Skip to content

Commit b787940

Browse files
mengniwang95pre-commit-ci[bot]xin3he
authored
add docstring for mx quant (#1932)
Signed-off-by: Mengni Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: xinhe <[email protected]>
1 parent 0c52e12 commit b787940

File tree

4 files changed

+93
-53
lines changed

4 files changed

+93
-53
lines changed

.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
/neural-compressor/neural_compressor/strategy
1616
/neural-compressor/neural_compressor/training.py
1717
/neural-compressor/neural_compressor/utils
18+
/neural_compressor/torch/algorithms/mx_quant
1819
/neural-compressor/neural_compressor/torch/algorithms/static_quant
1920
/neural-compressor/neural_compressor/torch/algorithms/smooth_quant
2021
/neural_compressor/torch/algorithms/pt2e_quant

neural_compressor/torch/algorithms/mx_quant/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
# limitations under the License.
1414

1515
# pylint:disable=import-error
16+
"""MX quantization."""

neural_compressor/torch/algorithms/mx_quant/mx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
20-
20+
"""MX quantization."""
2121

2222
from collections import OrderedDict
2323

@@ -31,6 +31,8 @@
3131

3232

3333
class MXLinear(torch.nn.Linear):
34+
"""Linear for MX data type."""
35+
3436
def __init__(
3537
self,
3638
in_features,
@@ -39,13 +41,15 @@ def __init__(
3941
mx_specs=None,
4042
name=None,
4143
):
44+
"""Initialization function."""
4245
self.mx_none = mx_specs is None
4346

4447
self.name = name
4548
self.mx_specs = mx_specs
4649
super().__init__(in_features, out_features, bias)
4750

4851
def apply_mx_specs(self):
52+
"""Apply MX data type to weight."""
4953
if self.mx_specs is not None:
5054
if self.mx_specs.out_dtype != "float32":
5155
self.weight.data = quantize_elemwise_op(self.weight.data, mx_specs=self.mx_specs)
@@ -63,6 +67,7 @@ def apply_mx_specs(self):
6367
)
6468

6569
def forward(self, input):
70+
"""Forward function."""
6671
if self.mx_none:
6772
return super().forward(input)
6873

@@ -93,6 +98,8 @@ def forward(self, input):
9398

9499

95100
class MXQuantizer(Quantizer):
101+
"""Quantizer of MX data type."""
102+
96103
def __init__(self, quant_config: OrderedDict = {}):
97104
"""Init a MXQuantizer object.
98105

neural_compressor/torch/algorithms/mx_quant/utils.py

Lines changed: 83 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
20-
20+
"""MX quantization utils."""
2121

2222
from enum import Enum, IntEnum
2323

@@ -28,6 +28,8 @@
2828

2929

3030
class ElemFormat(Enum):
31+
"""Element format."""
32+
3133
int8 = 1
3234
int4 = 2
3335
int2 = 3
@@ -44,6 +46,7 @@ class ElemFormat(Enum):
4446

4547
@staticmethod
4648
def from_str(s):
49+
"""Get element format with str."""
4750
assert s is not None, "String elem_format == None"
4851
s = s.lower()
4952
if hasattr(ElemFormat, s):
@@ -53,6 +56,7 @@ def from_str(s):
5356

5457
@staticmethod
5558
def is_bf(s):
59+
"""Whether the format is brain floating-point format."""
5660
if isinstance(s, str):
5761
assert s is not None, "String elem_format == None"
5862
s = s.lower()
@@ -65,6 +69,7 @@ def is_bf(s):
6569

6670
@staticmethod
6771
def is_fp(s):
72+
"""Whether the format is floating-point format."""
6873
if isinstance(s, str):
6974
assert s is not None, "String elem_format == None"
7075
s = s.lower()
@@ -77,6 +82,7 @@ def is_fp(s):
7782

7883
@staticmethod
7984
def is_int(s):
85+
"""Whether the format is integer format."""
8086
if isinstance(s, str):
8187
assert s is not None, "String elem_format == None"
8288
s = s.lower()
@@ -89,12 +95,15 @@ def is_int(s):
8995

9096

9197
class RoundingMode(IntEnum):
98+
"""Rounding mode."""
99+
92100
nearest = 0
93101
floor = 1
94102
even = 2
95103

96104
@staticmethod
97105
def string_enums():
106+
"""Rounding mode names."""
98107
return [s.name for s in list(RoundingMode)]
99108

100109

@@ -115,14 +124,19 @@ def _get_max_norm(ebits, mbits):
115124

116125

117126
def _get_format_params(fmt):
118-
"""Allowed formats:
127+
"""Get parameters of the format.
128+
129+
Allowed formats:
119130
- intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation
120131
- floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf
121132
- bfloatX/bfX: 9 <= X <= 32
122133
- fp4, no NaN/Inf
123134
- fp6_e3m2/e2m3, no NaN/Inf
124135
- fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior
125136
137+
Args:
138+
fmt (str od ElemFormat): format
139+
126140
Returns:
127141
ebits: exponent bits
128142
mbits: mantissa bits: includes sign and implicit bits
@@ -198,17 +212,19 @@ def _safe_rshift(x, bits, exp):
198212

199213

200214
def _round_mantissa(A, bits, round, clamp=False):
201-
"""
202-
Rounds mantissa to nearest bits depending on the rounding method 'round'
215+
"""Rounds mantissa to nearest bits depending on the rounding method 'round'.
216+
203217
Args:
204-
A {PyTorch tensor} -- Input tensor
205-
round {str} -- Rounding method
206-
"floor" rounds to the floor
207-
"nearest" rounds to ceil or floor, whichever is nearest
218+
A (torch.Tensor): input tensor
219+
bits (int): bit number of mantissa
220+
round (str): rounding method
221+
"floor" rounds to the floor
222+
"nearest" rounds to ceil or floor, whichever is nearest
223+
clamp (bool, optional): Whether do clip. Defaults to False.
224+
208225
Returns:
209-
A {PyTorch tensor} -- Tensor with mantissas rounded
226+
torch.Tensor: tensor with mantissas rounded
210227
"""
211-
212228
if round == "dither":
213229
rand_A = torch.rand_like(A, requires_grad=False)
214230
A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A)
@@ -235,16 +251,18 @@ def _shared_exponents(A, method="max", axes=None, ebits=0):
235251
"""Get shared exponents for the passed matrix A.
236252
237253
Args:
238-
A {PyTorch tensor} -- Input tensor
239-
method {str} -- Exponent selection method.
240-
"max" uses the max absolute value
241-
"none" uses an exponent for each value (i.e., no sharing)
242-
axes {list(int)} -- List of integers which specifies the axes across which
243-
shared exponents are calculated.
254+
A (torch.Tensor): Input tensor
255+
method (str, optional): Exponent selection method.
256+
"max" uses the max absolute value.
257+
"none" uses an exponent for each value (i.e., no sharing)
258+
Defaults to "max".
259+
axes (list(int), optional): list of integers which specifies the axes across which
260+
shared exponents are calculated. Defaults to None.
261+
ebits (int, optional): bit number of the shared exponents. Defaults to 0.
262+
244263
Returns:
245-
shared_exp {PyTorch tensor} -- Tensor of shared exponents
264+
shared_exp (torch.Tensor): Tensor of shared exponents
246265
"""
247-
248266
if method == "max":
249267
if axes is None:
250268
shared_exp = torch.max(torch.abs(A))
@@ -346,21 +364,20 @@ def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes):
346364

347365

348366
def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", saturate_normals=False, allow_denorm=True):
349-
"""Core function used for element-wise quantization
350-
Arguments:
351-
A {PyTorch tensor} -- A tensor to be quantized
352-
bits {int} -- Number of mantissa bits. Includes
353-
sign bit and implicit one for floats
354-
exp_bits {int} -- Number of exponent bits, 0 for ints
355-
max_norm {float} -- Largest representable normal number
356-
round {str} -- Rounding mode: (floor, nearest, even)
357-
saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf)
358-
that exceed max norm are clamped.
359-
Must be True for correct MX conversion.
360-
allow_denorm {bool} -- If False, flush denorm numbers in the
361-
elem_format to zero.
367+
"""Core function used for element-wise quantization.
368+
369+
Args:
370+
A (torch.Tensor): tensor to be quantized
371+
bits (int): number of mantissa bits. Includes sign bit and implicit one for floats
372+
exp_bits (int): number of exponent bits, 0 for ints
373+
max_norm (float): largest representable normal number
374+
round (str, optional): rounding mode: (floor, nearest, even). Defaults to "nearest".
375+
saturate_normals (bool, optional): whether clip normal numbers that exceed max norm.
376+
Must be True for correct MX conversion. Defaults to False.
377+
allow_denorm (bool, optional): if False, flush denorm numbers in the elem_format to zero. Defaults to True.
378+
362379
Returns:
363-
quantized tensor {PyTorch tensor} -- A tensor that has been quantized
380+
torch.Tensor: tensor that has been quantized
364381
"""
365382
# Flush values < min_norm to zero if denorms are not allowed
366383
if not allow_denorm and exp_bits > 0:
@@ -401,15 +418,20 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", satura
401418

402419

403420
def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_denorm=True):
404-
"""Quantize values to IEEE fpX format.
405-
406-
The format defines NaN/Inf
407-
and subnorm numbers in the same way as FP32 and FP16.
408-
Arguments:
409-
exp_bits {int} -- number of bits used to store exponent
410-
mantissa_bits {int} -- number of bits used to store mantissa, not
411-
including sign or implicit 1
412-
round {str} -- Rounding mode, (floor, nearest, even)
421+
"""Quantize values to IEEE fpX format..
422+
423+
The format defines NaN/Inf and subnorm numbers in the same way as FP32 and FP16.
424+
425+
Args:
426+
A (torch.Tensor): a tensor that needs to be quantized
427+
exp_bits (int, optional): number of bits used to store exponent. Defaults to None.
428+
mantissa_bits (int, optional): number of bits used to store mantissa.
429+
Not including sign or implicit 1. Defaults to None.
430+
round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest".
431+
allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True.
432+
433+
Returns:
434+
torch.Tensor: tensor that has been quantized
413435
"""
414436
# Shortcut for no quantization
415437
if exp_bits is None or mantissa_bits is None:
@@ -425,11 +447,17 @@ def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_de
425447

426448

427449
def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True):
428-
"""Quantize values to bfloatX format
429-
Arguments:
430-
bfloat {int} -- Total number of bits for bfloatX format,
431-
Includes 1 sign, 8 exp bits, and variable
432-
mantissa bits. Must be >= 9.
450+
"""Quantize values to bfloatX format.
451+
452+
Args:
453+
A (torch.Tensor): a tensor that needs to be quantized
454+
bfloat (int): total number of bits for bfloatX format.
455+
Includes 1 sign, 8 exp bits, and variable mantissa bits. Must be >= 9.
456+
round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest".
457+
allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True.
458+
459+
Returns:
460+
torch.Tensor: tensor that has been quantized
433461
"""
434462
# Shortcut for no quantization
435463
if bfloat == 0 or bfloat == 32:
@@ -443,12 +471,14 @@ def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True):
443471

444472

445473
def quantize_elemwise_op(A, mx_specs):
446-
"""A function used for element-wise quantization with mx_specs
447-
Arguments:
448-
A {PyTorch tensor} -- a tensor that needs to be quantized
449-
mx_specs {dictionary} -- dictionary to specify mx_specs
474+
"""A function used for element-wise quantization with mx_specs.
475+
476+
Args:
477+
A (torch.Tensor): a tensor that needs to be quantized
478+
mx_specs (dict): dictionary to specify mx_specs
479+
450480
Returns:
451-
quantized value {PyTorch tensor} -- a tensor that has been quantized
481+
torch.Tensor: tensor that has been quantized
452482
"""
453483
if mx_specs is None:
454484
return A
@@ -530,14 +560,15 @@ def _quantize_mx(
530560

531561

532562
def quantize_mx_op(
533-
A,
563+
A: torch.Tensor,
534564
elem_format: str,
535565
round: str,
536566
block_size: int,
537567
scale_bits=8,
538568
axes=None,
539569
expand_and_reshape=False,
540570
):
571+
"""Quantize tensor to MX data type."""
541572
if elem_format is None:
542573
return A
543574
elif type(elem_format) is str:

0 commit comments

Comments
 (0)