Skip to content

Commit 51d3ca0

Browse files
committed
Fix wrong scale eps applied
1 parent 79e3366 commit 51d3ca0

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,66 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
957957
torch.testing.assert_close(expected_quantized, quantized)
958958
torch.testing.assert_close(expected_dequantized, dequantized)
959959

960+
@parameterized.expand(
961+
[
962+
torch.float64,
963+
torch.float32,
964+
torch.bfloat16,
965+
torch.float16,
966+
]
967+
)
968+
def test_choose_qparams_affine_for_inf_scale_reciprocal(self, hp_dtype):
969+
# Fixed by #1770, the test will fail for all the variants
970+
# before that fix, and will pass afterwards.
971+
#
972+
# The scale value must be forcefully clamped, within
973+
# _choose_qparams_affine() function, (that
974+
# choose_qparams_affine() and others call into) to a large
975+
# enough number so that its reciprocal does not become Inf.
976+
# Otherwise during the quantization, by multiplying with scale
977+
# reciprocal, all the values will be quantized to Inf value,
978+
# except from zero value that would produce NaN (0*Inf) as
979+
# quantized value.
980+
#
981+
# The minimal normalized value for given floating point data
982+
# type is given by torch.finfo(hp_dtype).tiny - let's call
983+
# this value "tiny". It could be seen by checking, that for
984+
# all of torch.float64, torch.float32, torch.float16 and
985+
# torch.floatb16, denormalized number that is equal to tiny/4
986+
# will produce Inf as its reciprocal.
987+
#
988+
# Thus, to reproduce the problem, one would create a tensor
989+
# with such values that their absolute maximum, after being
990+
# divided with the range of quantized data (that is 57344 for
991+
# torch.float8_e5m2), would produce scale smaller than tiny/4.
992+
# Also, eps parameter should be set to value no greater than
993+
# tiny/4, as scale is clamped from below to that value. With
994+
# such inpujts, choose_qparams_affine() will produce Inf as
995+
# scale value.
996+
#
997+
# Note that this may seem as contrieved reproduces. However,
998+
# there are cases with existing code that would pass
999+
# torch.finfo(torch.float32).eps as eps value, no matters of
1000+
# scale_dtype. The float16 has rather small range, so this
1001+
# value is well bellow torch.finfo(torch.float32).eps, and for
1002+
# such eps value, the code bellow would produce Inf scale even
1003+
# for float16 tensor that has 0.5 as its maximum value.
1004+
float8_dtype = torch.float8_e5m2
1005+
tiny = torch.finfo(hp_dtype).tiny
1006+
x = torch.tensor([[0, 100 * tiny]], dtype=hp_dtype)
1007+
scale, _ = choose_qparams_affine(
1008+
input=x,
1009+
mapping_type=MappingType.SYMMETRIC,
1010+
block_size=[1, 2],
1011+
target_dtype=float8_dtype,
1012+
eps=tiny / 4,
1013+
scale_dtype=hp_dtype,
1014+
preserve_zero=True,
1015+
zero_point_domain=ZeroPointDomain.NONE,
1016+
)
1017+
scale_reciprocal = scale.reciprocal()
1018+
assert not torch.any(torch.isinf(scale_reciprocal)).item()
1019+
9601020

9611021
if __name__ == "__main__":
9621022
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,7 @@ def _choose_qparams_affine(
856856
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
857857
and `zero_point_domain`
858858
"""
859+
859860
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
860861
assert mapping_type in [
861862
MappingType.SYMMETRIC.name,
@@ -944,10 +945,16 @@ def _choose_qparams_affine(
944945
else:
945946
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
946947
scale = torch.clamp(scale, min=eps)
948+
if torch.is_floating_point(scale):
949+
# Prevent 1.0 / scale to become Inf.
950+
scale = torch.clamp(scale, min=torch.finfo(scale.dtype).tiny)
947951
else:
948952
assert mapping_type == MappingType.ASYMMETRIC.name
949953
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
950954
scale = torch.clamp(scale, min=eps)
955+
if torch.is_floating_point(scale):
956+
# Prevent 1.0 / scale to become Inf.
957+
scale = torch.clamp(scale, min=torch.finfo(scale.dtype).tiny)
951958
if zero_point_domain == ZeroPointDomain.NONE.name:
952959
zero_point = None
953960
else:

0 commit comments

Comments
 (0)