Skip to content

Commit 92533a6

Browse files
authored
Fix in-place processing error in quant_weight function (#1703)
Signed-off-by: xin3he <[email protected]> Signed-off-by: Cheng, Penghui <[email protected]>
1 parent 3b150d6 commit 92533a6

File tree

3 files changed

+40
-16
lines changed

3 files changed

+40
-16
lines changed

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_rang
176176
weight.round_()
177177
weight.clamp_(minq, maxq)
178178
if return_int:
179-
return weight, scale.type(torch.float), None
179+
return weight, scale, None
180180
return weight.mul_(scale)
181181

182182

@@ -238,6 +238,7 @@ def quant_weight(
238238

239239
orig_shape = weight.shape
240240
if weight.shape[1] % group_size == 0:
241+
orig_weight = weight
241242
weight = weight.reshape(-1, group_size)
242243
if return_int:
243244
weight, scale, zp = qdq_weight_actor(
@@ -250,17 +251,21 @@ def quant_weight(
250251
data_type=data_type,
251252
)
252253
weight = weight.reshape(orig_shape)
254+
orig_weight.copy_(weight)
253255
scale = scale.reshape(orig_shape[0], -1)
254256
if zp is not None:
255257
zp = zp.reshape(orig_shape[0], -1)
256-
return weight, scale, zp
258+
return orig_weight, scale, zp
257259
else:
258260
qdq_weight_actor(
259261
weight, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range
260262
)
261-
return weight.reshape(orig_shape)
263+
weight = weight.reshape(orig_shape)
264+
orig_weight.copy_(weight)
265+
return orig_weight
262266
else:
263267
split_index = weight.shape[1] // group_size * group_size
268+
orig_weight = weight
264269
weight1 = weight[:, :split_index]
265270
weight1 = weight1.reshape(-1, group_size)
266271
if return_int:
@@ -277,7 +282,7 @@ def quant_weight(
277282
if zp1 is not None:
278283
zp1 = zp1.reshape(orig_shape[0], -1)
279284
else:
280-
weight1 = qdq_weight_actor(
285+
qdq_weight_actor(
281286
weight1, num_bits, scheme=scheme, quantile=quantile, data_type=data_type, full_range=full_range
282287
)
283288
weight1 = weight1.reshape(orig_shape[0], split_index)
@@ -292,19 +297,19 @@ def quant_weight(
292297
return_int=True,
293298
full_range=full_range,
294299
)
295-
weight.copy_(torch.cat([weight1, weight2], dim=1))
300+
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
296301
scale = torch.cat([scale1, scale2], dim=1)
297302
if zp2 is not None:
298303
zp = torch.cat([zp1, zp2], dim=1)
299304
else:
300305
zp = None
301-
return weight, scale, zp
306+
return orig_weight, scale, zp
302307
else:
303308
weight2 = qdq_weight_actor(
304309
weight2, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range
305310
)
306-
weight.copy_(torch.cat([weight1, weight2], dim=1))
307-
return weight
311+
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
312+
return orig_weight
308313

309314

310315
def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", enable_full_range=False):

neural_compressor/torch/algorithms/weight_only/utility.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,10 @@ def quant_tensor(
266266
group_size = weight.shape[1]
267267
# case 2, reshape based on group size
268268
orig_shape = weight.shape
269+
orig_weight = weight
269270
if weight.shape[1] % group_size == 0:
270271
weight = weight.reshape(-1, group_size)
272+
# return weight for unpacking scale and zp
271273
weight = qdq_weight_actor(
272274
weight,
273275
bits,
@@ -281,12 +283,15 @@ def quant_tensor(
281283
if return_int or quant_scale:
282284
weight, scale, zp = weight
283285
weight = weight.reshape(orig_shape)
286+
orig_weight.copy_(weight)
284287
scale = scale.reshape(orig_shape[0], -1)
285288
if zp is not None:
286289
zp = zp.reshape(orig_shape[0], -1)
287-
q_state = weight, scale, zp
290+
q_state = orig_weight, scale, zp
288291
else:
289-
return weight.reshape(orig_shape)
292+
weight = weight.reshape(orig_shape)
293+
orig_weight.copy_(weight)
294+
return orig_weight
290295
else:
291296
# case 3, process left part split by group size
292297
split_index = weight.shape[1] // group_size * group_size
@@ -321,13 +326,13 @@ def quant_tensor(
321326
)
322327
if return_int or quant_scale:
323328
weight2, scale2, zp2 = weight2
324-
weight.copy_(torch.cat([weight1, weight2], dim=1))
329+
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
325330
scale = torch.cat([scale1, scale2], dim=1)
326331
zp = None if zp2 is None else torch.cat([zp1, zp2], dim=1)
327332
q_state = (weight, scale, zp)
328333
else:
329-
weight.copy_(torch.cat([weight1, weight2], dim=1))
330-
return weight
334+
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
335+
return orig_weight
331336
if quant_scale:
332337
weight, scale, zp = q_state
333338
scale_dtype = kwargs.get("double_quant_dtype", "int")
@@ -343,7 +348,7 @@ def quant_tensor(
343348
scale.sub_(scale_mean)
344349
scale_scheme = "sym"
345350
# process: scale
346-
scale = quant_tensor(
351+
quant_tensor(
347352
scale,
348353
dtype=scale_dtype,
349354
bits=scale_bits,
@@ -375,15 +380,16 @@ def quant_tensor(
375380
weight1 = weight1.mul_(scale[:, :-1].reshape(-1, 1))
376381
weight1 = weight1.reshape(orig_shape[0], -1)
377382
weight2 = weight2.mul_(scale[:, -1].reshape(-1, 1))
378-
weight.copy_(torch.cat([weight1, weight2], dim=1))
383+
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
379384
else:
380385
if zp is not None:
381386
weight = weight.reshape(-1, group_size) - zp.reshape(-1, 1)
382387
else:
383388
weight = weight.reshape(-1, group_size)
384389
weight = weight.mul_(scale.reshape(-1, 1))
385390
weight = weight.reshape(orig_shape[0], -1)
386-
return weight
391+
orig_weight.copy_(weight)
392+
return orig_weight
387393
else:
388394
return q_state
389395

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
import torch
3+
4+
5+
@pytest.mark.parametrize("shape", [1024, 512, 300])
6+
def test_quant_tensor_id(shape):
7+
from neural_compressor.torch.algorithms.weight_only.utility import quant_tensor
8+
9+
input = torch.randn(shape, shape)
10+
id1 = id(input)
11+
output = quant_tensor(input)
12+
id2 = id(output)
13+
assert id1 == id2, "quant_tensor function is an in-place operator"

0 commit comments

Comments
 (0)