Skip to content

Commit fbb2cae

Browse files
committed
Update based on comments
1 parent 104d1f3 commit fbb2cae

File tree

5 files changed

+14
-17
lines changed

5 files changed

+14
-17
lines changed

test/integration/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
is_fbcode,
9292
benchmark_model
9393
)
94+
from torchao.dtypes.utils import is_device
9495

9596
logger = logging.getLogger("INFO")
9697

@@ -132,8 +133,7 @@ def _int8da_int8w_api(mod):
132133

133134
def _int4wo_api(mod):
134135
if TORCH_VERSION_AT_LEAST_2_4:
135-
device_type = next(mod.parameters()).device
136-
if device_type == torch.device("cpu"):
136+
if is_device(next(mod.parameters()).device.type, "cpu"):
137137
quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False)
138138
else:
139139
quantize_(mod, int4_weight_only(), set_inductor_config=False)

test/quantization/test_quant_primitives.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
TORCH_VERSION_AT_LEAST_2_6,
3434
is_fbcode,
3535
)
36+
from torchao.dtypes.utils import is_device
3637

3738
_SEED = 1234
3839
torch.manual_seed(_SEED)
@@ -102,7 +103,7 @@ def _groupwise_affine_quantize_tensor_from_qparams(
102103
.reshape_as(w)
103104
)
104105
if TORCH_VERSION_AT_LEAST_2_5:
105-
if w.device.type != "cpu":
106+
if not is_device(w.device.type, "cpu"):
106107
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
107108

108109
return w_int4x8

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -632,8 +632,7 @@ def extra_repr(self):
632632

633633
@dataclass(frozen=True)
634634
class Int4CPULayout(Layout):
635-
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
636-
return input
635+
pass
637636

638637
@dataclass(frozen=True)
639638
class Float8Layout(Layout):
@@ -1714,6 +1713,10 @@ def from_plain(
17141713
return cls(packed_weight, scale_and_zero, False, _layout)
17151714

17161715
def to(self, *args, **kwargs):
1716+
if not is_device(torch.device(self.device).type, device):
1717+
raise ValueError(
1718+
f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}"
1719+
)
17171720
kwargs = self._get_to_kwargs(*args, **kwargs)
17181721
device = kwargs["device"]
17191722
return self.__class__(
@@ -1724,9 +1727,6 @@ def to(self, *args, **kwargs):
17241727
)
17251728

17261729
def _apply_fn_to_data(self, fn):
1727-
# self.packed_weight = fn(self.packed_weight)
1728-
# self.scale_and_zero = fn(self.scale_and_zero)
1729-
# return self
17301730
return self.__class__(
17311731
fn(self.packed_weight),
17321732
fn(self.scale_and_zero),

torchao/quantization/subclass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
unpack_tinygemm_scales_and_zeros,
1717
)
1818
from torchao.utils import find_multiple
19+
from torchao.dtypes.utils import is_device
1920

2021
__all__ = [
2122
"Int8DynamicallyQuantizedLinearWeight",
@@ -458,7 +459,7 @@ def _quantized_op(act_mat, w_qtensor, bias):
458459
act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1]))
459460

460461
# matmul
461-
if act_mat.device == torch.device("cpu"):
462+
if is_device(act_mat.device.type, "cpu"):
462463
y = aten._weight_int4pack_mm_for_cpu(
463464
act_mat.contiguous(),
464465
w_qtensor.int_data,
@@ -617,7 +618,7 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
617618
input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor(
618619
input_float, 4, groupsize, dtype=input_float.dtype
619620
)
620-
if input_float.device == torch.device("cpu"):
621+
if is_device(input_float.device.type, "cpu"):
621622
int_data = aten._convert_weight_to_int4pack_for_cpu(input_int4x8, inner_k_tiles)
622623
else:
623624
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)

torchao/quantization/utils.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
quantize_affine,
1919
)
2020
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
21+
from torchao.dtypes.utils import is_device
2122

2223
__all__ = [
2324
"compute_error",
@@ -400,14 +401,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
400401
zero_point_domain=ZeroPointDomain.FLOAT,
401402
)
402403
if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1:
403-
int_data_device_type = int_data.device.type
404-
# Move to cpu, until issue with MPS memory management of temporary tensors is resolved
405-
# if int_data_device_type == "mps":
406-
# int_data = int_data.cpu()
407-
if int_data_device_type != "cpu":
404+
if not is_device(int_data.device.type, "cpu"):
408405
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
409-
# if int_data_device_type == "mps":
410-
# int_data = int_data.to(device="mps")
411406
return int_data
412407

413408

0 commit comments

Comments
 (0)