4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Any , Optional , Tuple
7
+ from typing import Any , Tuple
8
8
9
9
import torch
10
10
from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib
11
11
from torch .library import impl
12
12
13
- from torchao .quantization .utils import TORCH_VERSION_AFTER_2_3
13
+ from torchao .quantization .utils import TORCH_VERSION_AFTER_2_4
14
14
from torchao .quantization .quant_primitives import get_group_qparams_symmetric
15
15
from torchao .quantization .unified import TwoStepQuantizer
16
16
17
17
18
- if TORCH_VERSION_AFTER_2_3 :
18
+ if TORCH_VERSION_AFTER_2_4 :
19
19
from torchao .quantization .GPTQ import (
20
20
_replace_linear_8da4w ,
21
21
Int8DynActInt4WeightLinear ,
@@ -54,7 +54,7 @@ def prepare(
54
54
self .precision ,
55
55
self .scales_precision ,
56
56
Int8DynActInt4WeightQATLinear ,
57
- copy_weights = True ,
57
+ copy_weights = True ,
58
58
)
59
59
return model
60
60
@@ -95,7 +95,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):
95
95
quantized_linear .zeros = zp
96
96
else :
97
97
_convert_qat_linear_8da4w (child )
98
-
98
+
99
99
class Int8DynActInt4WeightQATLinear (torch .nn .Linear ):
100
100
"""
101
101
This module implements a linear layer with int8 dynamic per token fake
@@ -131,6 +131,8 @@ def __init__(
131
131
self .groupsize = groupsize
132
132
self .precision = precision
133
133
self .scales_precision = scales_precision
134
+ # TODO: make this configurable?
135
+ self .zero_points_precision = torch .int32
134
136
self ._fake_quant_enabled = True
135
137
136
138
def enable_fake_quant (self , enabled : bool = True ):
@@ -142,8 +144,8 @@ def disable_fake_quant(self):
142
144
def forward (self , x : torch .Tensor ) -> torch .Tensor :
143
145
# activations: int8 dynamic asymmetric quant
144
146
if self ._fake_quant_enabled :
145
- (act_scales , act_zp ) = _choose_qparams_per_token_asymmetric (
146
- x , torch . int8 , # dtype not used
147
+ (act_scales , act_zp ) = _choose_qparams_per_token_asymmetric (
148
+ x , self . scales_precision , self . zero_points_precision ,
147
149
)
148
150
(act_qmin , act_qmax ) = self ._get_qmin_qmax (8 )
149
151
x_fq = fake_quantize_per_token (
@@ -157,6 +159,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
157
159
(weight_scales , weight_zp ) = get_group_qparams_symmetric (
158
160
self .weight , 4 , self .groupsize , self .scales_precision ,
159
161
)
162
+ # TODO: pass zp dtype to `get_group_qparams_symmetric` instead
163
+ weight_zp = weight_zp .to (self .zero_points_precision )
160
164
(weight_qmin , weight_qmax ) = self ._get_qmin_qmax (4 )
161
165
w_fq = fake_quantize_per_channel_group (
162
166
self .weight ,
@@ -190,6 +194,20 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module):
190
194
if isinstance (mod , Int8DynActInt4WeightQATLinear ):
191
195
mod .disable_fake_quant ()
192
196
197
+ else : # not TORCH_VERSION_AFTER_2_4
198
+
199
+ class Int8DynActInt4WeightQATQuantizer :
200
+ def __init__ (* args , ** kwargs ):
201
+ raise ValueError (
202
+ "Int8DynActInt4WeightQATQuantizer is only supported after PyTorch 2.4+"
203
+ )
204
+
205
+ class Int8DynActInt4WeightQATLinear :
206
+ def __init__ (* args , ** kwargs ):
207
+ raise ValueError (
208
+ "Int8DynActInt4WeightQATQuantizer is only supported after PyTorch 2.4+"
209
+ )
210
+
193
211
194
212
# ========================
195
213
# | QUANT PRIMITIVES |
@@ -205,13 +223,15 @@ class _GenericFakeQuantize(torch.autograd.Function):
205
223
206
224
@staticmethod
207
225
def forward (ctx , input , scales , zero_points , quant_min , quant_max ):
226
+ assert input .dtype == torch .float32
227
+ assert scales .dtype == torch .float32
228
+ assert zero_points .dtype == torch .int32
208
229
# Note: this diverges from `torch.fake_quantize_per_channel_affine`,
209
- # which rounds first before adding the zero points. However, this
210
- # is what `quantize_per_channel_group` and `quantize_per_token`
211
- # do and here we try to match that behavior as closely as possible .
230
+ # which rounds first before adding the zero points. However, since
231
+ # zero points are integers here, the ordering of these two ops
232
+ # shouldn't matter in practice .
212
233
q = input .mul (1.0 / scales ).add (zero_points ).round ()
213
234
dq = q .clamp (quant_min , quant_max ).sub (zero_points ).mul (scales )
214
- # TODO: do we need this mask?
215
235
mask = torch .logical_and ((q >= quant_min ), (q <= quant_max ))
216
236
ctx .save_for_backward (mask )
217
237
return dq
@@ -239,14 +259,13 @@ def fake_quantize_per_channel_group(
239
259
assert group_size > 1
240
260
assert input .shape [- 1 ] % group_size == 0
241
261
assert input .dim () == 2
242
- assert torch .isnan (input ).sum () == 0
243
- grouped_input = input .reshape (- 1 , group_size )
262
+ grouped_input = input .reshape (- 1 , group_size ).to (torch .float32 )
244
263
scales = scales .reshape (- 1 , 1 )
245
264
zero_points = zero_points .reshape (- 1 , 1 )
246
265
fq = _GenericFakeQuantize .apply (
247
266
grouped_input , scales , zero_points , quant_min , quant_max ,
248
267
)
249
- return fq .reshape_as (input )
268
+ return fq .reshape_as (input ). to ( input . dtype )
250
269
251
270
# TODO: move this to core
252
271
quantized_decomposed_lib .define (
@@ -266,17 +285,20 @@ def fake_quantize_per_token(
266
285
from torch .ao .quantization .fx ._decomposed import _per_token_quant_qparam_dim_check
267
286
268
287
_per_token_quant_qparam_dim_check (input , scales , zero_points )
269
- return _GenericFakeQuantize .apply (
270
- input , scales , zero_points , quant_min , quant_max ,
288
+ fq_input = input .to (torch .float32 )
289
+ fq = _GenericFakeQuantize .apply (
290
+ fq_input , scales , zero_points , quant_min , quant_max ,
271
291
)
292
+ return fq .reshape_as (input ).to (input .dtype )
272
293
273
294
# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
274
295
# The version in pytorch does not have backward support yet so we add
275
296
# it here for now until https://github.com/pytorch/pytorch/pull/123452
276
297
# is landed.
277
298
def _choose_qparams_per_token_asymmetric (
278
299
input : torch .Tensor ,
279
- dtype : torch .dtype ,
300
+ scales_precision : torch .dtype = torch .float32 ,
301
+ zero_points_precision : torch .dtype = torch .float32 ,
280
302
) -> Tuple [torch .Tensor , torch .Tensor ]:
281
303
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
282
304
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
@@ -285,7 +307,8 @@ def _choose_qparams_per_token_asymmetric(
285
307
286
308
Args:
287
309
input (torch.Tensor): original float32/float16 Tensor
288
- dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
310
+ scales_precision (torch.dtype): precision of returned scales
311
+ zero_points_precision (torch.dtype): precision of returned zero points
289
312
290
313
Returns:
291
314
scales and zero_points, both float32 Tensors
@@ -314,4 +337,4 @@ def _choose_qparams_per_token_asymmetric(
314
337
)
315
338
zero_point = torch .clamp (zero_point , qmin , qmax ).round ()
316
339
317
- return scale .to (torch . float32 ), zero_point .to (torch . float32 )
340
+ return scale .to (scales_precision ), zero_point .to (zero_points_precision )
0 commit comments