Skip to content

Commit 512eb75

Browse files
committed
Update
1 parent fbb2cae commit 512eb75

File tree

5 files changed

+43
-17
lines changed

5 files changed

+43
-17
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,9 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
526526
groupsize = 128
527527

528528
if TORCH_VERSION_AT_LEAST_2_5:
529-
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
530-
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
529+
if not is_device(input.device.type, "cpu"):
530+
input = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
531+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
531532
else:
532533
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
533534
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

torchao/prototype/hqq/hqq_tinygemm_linear.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch.nn.functional as F
1515
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
16+
from torchao.dtypes.utils import is_device
1617

1718

1819
class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
@@ -162,9 +163,14 @@ def process_hqq_quants(self, W_q, meta):
162163
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(
163164
W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits
164165
)
165-
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
166-
W_q_torch, self.inner_k_tiles
167-
)
166+
if is_device(W_q.device.type, "cpu"):
167+
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
168+
W_q_torch, self.inner_k_tiles
169+
)
170+
else:
171+
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
172+
W_q_torch, self.inner_k_tiles
173+
)
168174
self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch)
169175

170176
del W_q_torch, scales_torch, zeros_torch
@@ -200,7 +206,8 @@ def hqq_quants_to_torch_quants(
200206
.contiguous()
201207
)
202208
if TORCH_VERSION_AT_LEAST_2_5:
203-
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
209+
if not is_device(W_q.device.type, "cpu"):
210+
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
204211

205212
# group_dequantize_tensor_from_qparams
206213
# W_r = W_q*scales + min_val

torchao/quantization/GPTQ.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
pack_tinygemm_scales_and_zeros,
3737
per_token_dynamic_quant,
3838
)
39+
from torchao.dtypes.utils import is_device
3940

4041
aten = torch.ops.aten
4142

@@ -765,9 +766,14 @@ def _create_quantized_state_dict(
765766
self.precision, # dtype for scales_and_zeros
766767
)
767768
# TODO: just get the device from mod.weight.device?
768-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
769-
w_int4x8.to(self.device), self.inner_k_tiles
770-
)
769+
if is_device(w_int4x8.device.type, "cpu"):
770+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
771+
w_int4x8.to(self.device), self.inner_k_tiles
772+
)
773+
else:
774+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
775+
w_int4x8.to(self.device), self.inner_k_tiles
776+
)
771777
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device)
772778
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(
773779
self.device
@@ -851,9 +857,14 @@ def make_names_and_values_dict_func(q, qparams):
851857
# how much we need to pad the weight
852858
delta_k = int((new_k - k) / 2)
853859
q = q.to(self.device)
854-
final_q = torch.ops.aten._convert_weight_to_int4pack(
855-
F.pad(q, pad=(0, delta_k)), inner_k_tiles
856-
)
860+
if is_device(self.device.type, "cpu"):
861+
final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
862+
F.pad(q, pad=(0, delta_k)), inner_k_tiles
863+
)
864+
else:
865+
final_q = torch.ops.aten._convert_weight_to_int4pack(
866+
F.pad(q, pad=(0, delta_k)), inner_k_tiles
867+
)
857868
scales = qparams[0].to(torch.bfloat16).to(self.device)
858869
zeros = qparams[1].to(torch.bfloat16).to(self.device)
859870
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

torchao/quantization/qat/linear.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .utils import (
3030
_get_qmin_qmax,
3131
)
32+
from torchao.dtypes.utils import is_device
3233

3334

3435
class FakeQuantizedLinear(torch.nn.Linear):
@@ -373,10 +374,16 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module):
373374
n_bit,
374375
config.group_size,
375376
)
376-
q_weight = torch.ops.aten._convert_weight_to_int4pack(
377-
q_weight.to(child.weight.device),
378-
child.inner_k_tiles,
379-
)
377+
if is_device(q_weight.device.type, "cpu"):
378+
q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
379+
q_weight.to(child.weight.device),
380+
child.inner_k_tiles,
381+
)
382+
else:
383+
q_weight = torch.ops.aten._convert_weight_to_int4pack(
384+
q_weight.to(child.weight.device),
385+
child.inner_k_tiles,
386+
)
380387
quantized_linear.weight = q_weight
381388
quantized_linear.scales_and_zeros = scales_and_zeros
382389
else:

torchao/quantization/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
418418
# need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path
419419
if TORCH_VERSION_AT_LEAST_2_5 and (
420420
w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1
421-
):
421+
) and not is_device(w_int4x8.device.type, "cpu"):
422422
data = w_int4x8.to(torch.int32)
423423
high_bits = data >> 4
424424
low_bits = data & 0x0F

0 commit comments

Comments
 (0)