|
16 | 16 |
|
17 | 17 |
|
18 | 18 | if TORCH_VERSION_AFTER_2_3:
|
19 |
| - from torchao.quantization.GPTQ import _replace_linear_8da4w |
| 19 | + from torchao.quantization.GPTQ import ( |
| 20 | + _replace_linear_8da4w, |
| 21 | + Int8DynActInt4WeightLinear, |
| 22 | + ) |
20 | 23 |
|
21 | 24 | class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
|
22 | 25 | """
|
@@ -60,10 +63,38 @@ def convert(
|
60 | 63 | *args: Any,
|
61 | 64 | **kwargs: Any
|
62 | 65 | ) -> torch.nn.Module:
|
63 |
| - # TODO: replace Int8DynActInt4WeightQATLinear -> Int8DynActInt4WeightLinear |
64 |
| - pass |
65 |
| - |
| 66 | + _convert_qat_linear_8da4w(model) |
| 67 | + return model |
66 | 68 |
|
| 69 | + def _convert_qat_linear_8da4w(module: torch.nn.Module): |
| 70 | + """ |
| 71 | + Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. |
| 72 | + """ |
| 73 | + for name, child in module.named_children(): |
| 74 | + if isinstance(child, Int8DynActInt4WeightQATLinear): |
| 75 | + quantized_linear = Int8DynActInt4WeightLinear( |
| 76 | + child.in_features, |
| 77 | + child.out_features, |
| 78 | + bias=False, |
| 79 | + groupsize=child.groupsize, |
| 80 | + precision=child.precision, |
| 81 | + scales_precision=child.scales_precision, |
| 82 | + ) |
| 83 | + setattr(module, name, quantized_linear) |
| 84 | + |
| 85 | + # Load weights and qparams into quantized linear |
| 86 | + n_bit = 4 |
| 87 | + (qmin, qmax) = child._get_qmin_qmax(n_bit) |
| 88 | + (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) |
| 89 | + q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( |
| 90 | + child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, |
| 91 | + ) |
| 92 | + quantized_linear.weight = q_weight |
| 93 | + quantized_linear.scales = s |
| 94 | + quantized_linear.zeros = zp |
| 95 | + else: |
| 96 | + _convert_qat_linear_8da4w(child) |
| 97 | + |
67 | 98 | class Int8DynActInt4WeightQATLinear(torch.nn.Linear):
|
68 | 99 | """
|
69 | 100 | This module implements a linear layer with int8 dynamic per token fake
|
@@ -96,6 +127,7 @@ def __init__(
|
96 | 127 | ), f"require in_features:{in_features} % groupsize:{groupsize} == 0"
|
97 | 128 | assert not bias, "require bias=False"
|
98 | 129 | self.groupsize = groupsize
|
| 130 | + self.precision = precision |
99 | 131 | self.scales_precision = scales_precision
|
100 | 132 |
|
101 | 133 | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -123,6 +155,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
123 | 155 | )
|
124 | 156 | return torch.nn.functional.linear(x_fq, w_fq)
|
125 | 157 |
|
| 158 | + # TODO: move this to common util |
126 | 159 | def _get_qmin_qmax(self, n_bit: int):
|
127 | 160 | qmin = -(2 ** (n_bit - 1))
|
128 | 161 | qmax = 2 ** (n_bit - 1) - 1
|
|
0 commit comments