Skip to content

Commit d3ecc01

Browse files
Apply automatic Ruff fixes
1 parent a870ed0 commit d3ecc01

File tree

6 files changed

+33
-18
lines changed

6 files changed

+33
-18
lines changed

torchao/dtypes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
AffineQuantizedTensor,
33
Float8AQTTensorImpl,
44
Float8Layout,
5+
Int4CPULayout,
56
Layout,
67
MarlinQQQLayout,
78
MarlinSparseLayout,
89
PlainLayout,
910
SemiSparseLayout,
1011
TensorCoreTiledLayout,
11-
Int4CPULayout,
1212
to_affine_quantized_floatx,
1313
to_affine_quantized_floatx_static,
1414
# experimental, will be merged into floatx in the future

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,9 +688,11 @@ def extra_repr(self):
688688

689689
@dataclass(frozen=True)
690690
class Int4CPULayout(Layout):
691-
""" Only for PyTorch version at least 2.6 """
691+
"""Only for PyTorch version at least 2.6"""
692+
692693
pass
693694

695+
694696
@dataclass(frozen=True)
695697
class Float8Layout(Layout):
696698
mm_config: Optional[Float8MMConfig] = None
@@ -1965,7 +1967,8 @@ def from_plain(
19651967
int_data.dtype == torch.int32
19661968
), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
19671969
packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
1968-
int_data, 1 # TODO:remove
1970+
int_data,
1971+
1, # TODO:remove
19691972
)
19701973
elif TORCH_VERSION_AT_LEAST_2_5:
19711974
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
@@ -2124,6 +2127,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
21242127
def get_layout(self) -> Layout:
21252128
return self._layout
21262129

2130+
21272131
#####################################################
21282132
# torch functional and aten operator implementation #
21292133
#####################################################

torchao/quantization/GPTQ.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.nn.functional as F
1818
from torch.utils._pytree import tree_flatten, tree_unflatten
1919

20+
from torchao.dtypes.utils import is_device
2021
from torchao.utils import (
2122
TORCH_VERSION_AT_LEAST_2_3,
2223
TORCH_VERSION_AT_LEAST_2_6,
@@ -37,7 +38,6 @@
3738
pack_tinygemm_scales_and_zeros,
3839
per_token_dynamic_quant,
3940
)
40-
from torchao.dtypes.utils import is_device
4141

4242
aten = torch.ops.aten
4343

@@ -788,9 +788,14 @@ def _create_quantized_state_dict(
788788
self.precision, # dtype for scales_and_zeros
789789
)
790790
# TODO: just get the device from mod.weight.device?
791-
if is_device(w_int4x8.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
792-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
793-
w_int4x8.to(self.device), self.inner_k_tiles
791+
if (
792+
is_device(w_int4x8.device.type, "cpu")
793+
and TORCH_VERSION_AT_LEAST_2_6
794+
):
795+
weight_int4pack = (
796+
torch.ops.aten._convert_weight_to_int4pack_for_cpu(
797+
w_int4x8.to(self.device), self.inner_k_tiles
798+
)
794799
)
795800
else:
796801
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(

torchao/quantization/qat/linear.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn.functional as F
1111

12+
from torchao.dtypes.utils import is_device
1213
from torchao.quantization.GPTQ import (
1314
Int8DynActInt4WeightLinear,
1415
WeightOnlyInt4Linear,
@@ -23,14 +24,13 @@
2324
)
2425
from torchao.quantization.unified import TwoStepQuantizer
2526
from torchao.quantization.utils import get_group_qparams_symmetric
27+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
2628

2729
from .api import FakeQuantizeConfig
2830
from .fake_quantizer import FakeQuantizer
2931
from .utils import (
3032
_get_qmin_qmax,
3133
)
32-
from torchao.dtypes.utils import is_device
33-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
3434

3535

3636
class FakeQuantizedLinear(torch.nn.Linear):
@@ -375,7 +375,10 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module):
375375
n_bit,
376376
config.group_size,
377377
)
378-
if is_device(q_weight.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
378+
if (
379+
is_device(q_weight.device.type, "cpu")
380+
and TORCH_VERSION_AT_LEAST_2_6
381+
):
379382
q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
380383
q_weight.to(child.weight.device),
381384
child.inner_k_tiles,

torchao/quantization/subclass.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@
88
import torch
99
from torch.utils._python_dispatch import return_and_correct_aliasing
1010

11+
from torchao.dtypes.utils import is_device
1112
from torchao.quantization.utils import (
1213
dequantize_per_channel,
1314
dynamically_quantize_per_channel,
1415
groupwise_affine_quantize_tensor,
1516
quant_int8_dynamic_per_token_linear,
1617
unpack_tinygemm_scales_and_zeros,
1718
)
18-
from torchao.utils import find_multiple
19-
from torchao.dtypes.utils import is_device
20-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
19+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, find_multiple
2120

2221
__all__ = [
2322
"Int8DynamicallyQuantizedLinearWeight",
@@ -620,7 +619,9 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
620619
input_float, 4, groupsize, dtype=input_float.dtype
621620
)
622621
if is_device(input_float.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
623-
int_data = aten._convert_weight_to_int4pack_for_cpu(input_int4x8, inner_k_tiles)
622+
int_data = aten._convert_weight_to_int4pack_for_cpu(
623+
input_int4x8, inner_k_tiles
624+
)
624625
else:
625626
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
626627
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles

torchao/quantization/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from torch.utils._python_dispatch import TorchDispatchMode
1111

12+
from torchao.dtypes.utils import is_device
1213
from torchao.kernel import (
1314
int_scaled_matmul,
1415
)
@@ -20,7 +21,6 @@
2021
quantize_affine,
2122
)
2223
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6
23-
from torchao.dtypes.utils import is_device
2424

2525
__all__ = [
2626
"compute_error",
@@ -418,9 +418,11 @@ def groupwise_affine_dequantize_tensor_from_qparams(
418418
assert groupsize > 1
419419
assert w_int4x8.dim() == 2
420420
# need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path
421-
if TORCH_VERSION_AT_LEAST_2_5 and (
422-
w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1
423-
) and not (is_device(w_int4x8.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
421+
if (
422+
TORCH_VERSION_AT_LEAST_2_5
423+
and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1)
424+
and not (is_device(w_int4x8.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6)
425+
):
424426
data = w_int4x8.to(torch.int32)
425427
high_bits = data >> 4
426428
low_bits = data & 0x0F

0 commit comments

Comments
 (0)