|
36 | 36 | pack_tinygemm_scales_and_zeros,
|
37 | 37 | per_token_dynamic_quant,
|
38 | 38 | )
|
| 39 | +from torchao.dtypes.utils import is_device |
39 | 40 |
|
40 | 41 | aten = torch.ops.aten
|
41 | 42 |
|
@@ -765,9 +766,14 @@ def _create_quantized_state_dict(
|
765 | 766 | self.precision, # dtype for scales_and_zeros
|
766 | 767 | )
|
767 | 768 | # TODO: just get the device from mod.weight.device?
|
768 |
| - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( |
769 |
| - w_int4x8.to(self.device), self.inner_k_tiles |
770 |
| - ) |
| 769 | + if is_device(w_int4x8.device.type, "cpu"): |
| 770 | + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( |
| 771 | + w_int4x8.to(self.device), self.inner_k_tiles |
| 772 | + ) |
| 773 | + else: |
| 774 | + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( |
| 775 | + w_int4x8.to(self.device), self.inner_k_tiles |
| 776 | + ) |
771 | 777 | cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device)
|
772 | 778 | cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(
|
773 | 779 | self.device
|
@@ -851,9 +857,14 @@ def make_names_and_values_dict_func(q, qparams):
|
851 | 857 | # how much we need to pad the weight
|
852 | 858 | delta_k = int((new_k - k) / 2)
|
853 | 859 | q = q.to(self.device)
|
854 |
| - final_q = torch.ops.aten._convert_weight_to_int4pack( |
855 |
| - F.pad(q, pad=(0, delta_k)), inner_k_tiles |
856 |
| - ) |
| 860 | + if is_device(self.device.type, "cpu"): |
| 861 | + final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu( |
| 862 | + F.pad(q, pad=(0, delta_k)), inner_k_tiles |
| 863 | + ) |
| 864 | + else: |
| 865 | + final_q = torch.ops.aten._convert_weight_to_int4pack( |
| 866 | + F.pad(q, pad=(0, delta_k)), inner_k_tiles |
| 867 | + ) |
857 | 868 | scales = qparams[0].to(torch.bfloat16).to(self.device)
|
858 | 869 | zeros = qparams[1].to(torch.bfloat16).to(self.device)
|
859 | 870 | scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
|
|
0 commit comments