|
6 | 6 |
|
7 | 7 | from typing import Callable, Optional
|
8 | 8 |
|
| 9 | +from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ |
| 10 | +from torchao.quantization.prototype.qat import ( |
| 11 | + disable_8da4w_fake_quant, |
| 12 | + enable_8da4w_fake_quant, |
| 13 | + Int8DynActInt4WeightQATQuantizer, |
| 14 | +) |
| 15 | +from torchao.quantization.prototype.qat._module_swap_api import ( |
| 16 | + disable_8da4w_fake_quant_module_swap, |
| 17 | + enable_8da4w_fake_quant_module_swap, |
| 18 | + Int8DynActInt4WeightQATQuantizerModuleSwap, |
| 19 | +) |
| 20 | + |
| 21 | + |
9 | 22 | __all__ = [
|
10 | 23 | "get_quantizer_mode",
|
| 24 | + "Int8DynActInt4WeightQuantizer", |
| 25 | + "Int8DynActInt4WeightQATQuantizer", |
11 | 26 | ]
|
12 | 27 |
|
13 | 28 |
|
|
16 | 31 | _quantizer_mode_to_enable_fake_quant = {}
|
17 | 32 |
|
18 | 33 |
|
19 |
| -from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer |
| 34 | +# ======================================================== |
| 35 | +# int8 dynamic activations + int4 weight tensor subclass | |
| 36 | +# ======================================================== |
20 | 37 |
|
21 |
| -__all__.append("Int8DynActInt4WeightQuantizer") |
22 |
| -_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w" |
23 | 38 |
|
| 39 | +class Int8DynActInt4WeightQuantizer: |
| 40 | + """ |
| 41 | + Quantizer for applying int8 per token dynamic activation + int4 |
| 42 | + per group weight quantization to linear layers in the model. |
| 43 | + """ |
| 44 | + |
| 45 | + def __init__(self, groupsize: int = 256): |
| 46 | + self.groupsize = groupsize |
| 47 | + |
| 48 | + def quantize(self, model): |
| 49 | + quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize) |
| 50 | + quantize_(model, quantize_fn) |
| 51 | + return model |
24 | 52 |
|
25 |
| -from torchao.quantization.prototype.qat import ( |
26 |
| - disable_8da4w_fake_quant, |
27 |
| - enable_8da4w_fake_quant, |
28 |
| - Int8DynActInt4WeightQATQuantizer, |
29 |
| -) |
30 | 53 |
|
31 |
| -__all__.append("Int8DynActInt4WeightQATQuantizer") |
| 54 | +_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w" |
32 | 55 | _quantizer_to_mode[Int8DynActInt4WeightQATQuantizer] = "8da4w-qat"
|
33 | 56 | _quantizer_mode_to_disable_fake_quant["8da4w-qat"] = disable_8da4w_fake_quant
|
34 | 57 | _quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant
|
35 | 58 |
|
36 |
| -try: |
37 |
| - # Note: QAT tensor subclass implementation in torchao only works |
38 |
| - # with FSDP2 today. For other distribution strategies like DDP and |
39 |
| - # FSDP1, users will need to fall back to the old module swap flow. |
40 |
| - # TODO: remove this try catch once we upgrade to torchao 0.5.0 |
41 |
| - |
42 |
| - from torchao.quantization.prototype.qat._module_swap_api import ( |
43 |
| - disable_8da4w_fake_quant_module_swap, |
44 |
| - enable_8da4w_fake_quant_module_swap, |
45 |
| - Int8DynActInt4WeightQATQuantizerModuleSwap, |
46 |
| - ) |
47 |
| - |
48 |
| - __all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap") |
49 |
| - _quantizer_to_mode[ |
50 |
| - Int8DynActInt4WeightQATQuantizerModuleSwap |
51 |
| - ] = "8da4w-qat-module-swap" |
52 |
| - _quantizer_mode_to_disable_fake_quant[ |
53 |
| - "8da4w-qat-module-swap" |
54 |
| - ] = disable_8da4w_fake_quant_module_swap |
55 |
| - _quantizer_mode_to_enable_fake_quant[ |
56 |
| - "8da4w-qat-module-swap" |
57 |
| - ] = enable_8da4w_fake_quant_module_swap |
58 |
| -except ImportError: |
59 |
| - pass |
| 59 | + |
| 60 | +# ==================================================== |
| 61 | +# int8 dynamic activations + int4 weight module swap | |
| 62 | +# ==================================================== |
| 63 | + |
| 64 | +# Note: QAT tensor subclass implementation in torchao only works |
| 65 | +# with FSDP2 today. For other distribution strategies like DDP and |
| 66 | +# FSDP1, users will need to fall back to the old module swap flow. |
| 67 | +__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap") |
| 68 | +_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap" |
| 69 | +_quantizer_mode_to_disable_fake_quant[ |
| 70 | + "8da4w-qat-module-swap" |
| 71 | +] = disable_8da4w_fake_quant_module_swap |
| 72 | +_quantizer_mode_to_enable_fake_quant[ |
| 73 | + "8da4w-qat-module-swap" |
| 74 | +] = enable_8da4w_fake_quant_module_swap |
60 | 75 |
|
61 | 76 |
|
62 | 77 | def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:
|
|
0 commit comments