Skip to content

Commit ab7f401

Browse files
committed
Add cachemask variant for fake_quantize_affine
Summary: In QAT, we often wish to filter out the gradients corresponding to values outside the expected quantization range, for example: ``` q = _quantize_affine_no_dtype_cast(...) dq = _dequantize_affine_no_dtype_check(...) mask = torch.logical_and((q >= quant_min), (q <= quant_max)) grad = grad * mask ``` The existing `fake_quantize_affine` returns the dequantized values only, so callers do not have access to this mask. This commit adds the variant to this op that returns both the dequantized values and the mask, similar to `fake_quantize_per_tensor_affine_cachemask` in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask
1 parent 1029df3 commit ab7f401

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torchao.quantization.quant_primitives import (
1212
fake_quantize_affine,
13+
fake_quantize_affine_cachemask,
1314
quantize_affine,
1415
dequantize_affine,
1516
choose_qparams_affine,
@@ -523,5 +524,28 @@ def test_fake_quantize_affine(self):
523524
fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
524525
torch.testing.assert_close(dequantized, fake_quantized)
525526

527+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
528+
def test_fake_quantize_affine_cachemask(self):
529+
input = torch.randn(10, 10)
530+
531+
mapping_type = MappingType.SYMMETRIC
532+
block_size = list(input.shape)
533+
for i in range(len(block_size) - 1):
534+
block_size[i] = 1
535+
dtype = torch.int8
536+
eps = 1e-5
537+
quant_min = -127
538+
quant_max = 127
539+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float)
540+
541+
quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
542+
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max)
543+
(fake_quantized, mask) = fake_quantize_affine_cachemask(
544+
input, block_size, scale, zero_point, dtype, quant_min, quant_max,
545+
)
546+
expected_mask = torch.full(input.shape, True)
547+
torch.testing.assert_close(dequantized, fake_quantized)
548+
torch.testing.assert_close(expected_mask, mask)
549+
526550
if __name__ == "__main__":
527551
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"quantize_affine",
2525
"dequantize_affine",
2626
"fake_quantize_affine",
27+
"fake_quantize_affine_cachemask",
2728
]
2829

2930
class MappingType(Enum):
@@ -411,6 +412,74 @@ def fake_quantize_affine(
411412
value during quantization
412413
default is ZeroPointDomain.INT
413414
"""
415+
(_, fq) = _do_fake_quantize_affine(
416+
input,
417+
block_size,
418+
scale,
419+
zero_point,
420+
quant_dtype,
421+
quant_min,
422+
quant_max,
423+
zero_point_domain,
424+
)
425+
return fq
426+
427+
428+
def fake_quantize_affine_cachemask(
429+
input: torch.Tensor,
430+
block_size: Tuple[int, ...],
431+
scale: torch.Tensor,
432+
zero_point: Optional[torch.Tensor],
433+
quant_dtype: torch.dtype,
434+
quant_min: Optional[int] = None,
435+
quant_max: Optional[int] = None,
436+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
437+
) -> Tuple[torch.Tensor, torch.Tensor]:
438+
"""
439+
General fake quantize op for quantization-aware training (QAT).
440+
This is equivalent to calling `quantize_affine` + `dequantize_affine`
441+
but without the dtype casts.
442+
443+
Note: Compared to `fake_quantize_affine`, this consumes more memory and
444+
returns an additional outlier mask for intermediate quantized values.
445+
446+
Returns:
447+
A 2-tuple of (
448+
final fake quantized values,
449+
outlier mask for intermediate quantized values
450+
)
451+
452+
Please refer to :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`
453+
for details about the arguments.
454+
"""
455+
(q, dq) = _do_fake_quantize_affine(
456+
input,
457+
block_size,
458+
scale,
459+
zero_point,
460+
quant_dtype,
461+
quant_min,
462+
quant_max,
463+
zero_point_domain,
464+
)
465+
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
466+
return (dq, mask)
467+
468+
469+
def _do_fake_quantize_affine(
470+
input: torch.Tensor,
471+
block_size: Tuple[int, ...],
472+
scale: torch.Tensor,
473+
zero_point: Optional[torch.Tensor],
474+
quant_dtype: torch.dtype,
475+
quant_min: Optional[int] = None,
476+
quant_max: Optional[int] = None,
477+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
478+
) -> Tuple[torch.Tensor, torch.Tensor]:
479+
"""
480+
Helper function for `fake_quantize_affine` that returns both the
481+
intermediate quantized values and the final dequantized values.
482+
"""
414483
input_dtype = input.dtype
415484
quant_min, quant_max = _get_and_check_qmin_qmax(quant_dtype, quant_min, quant_max)
416485
q = _quantize_affine_no_dtype_cast(
@@ -432,7 +501,7 @@ def fake_quantize_affine(
432501
zero_point_domain.name,
433502
output_dtype=input_dtype,
434503
)
435-
return dq
504+
return (q, dq)
436505

437506

438507
def choose_qparams_affine(

0 commit comments

Comments
 (0)