Skip to content

Commit 9c531e9

Browse files
committed
Rebase and lint fixes
1 parent 82477b9 commit 9c531e9

File tree

4 files changed

+10
-11
lines changed

4 files changed

+10
-11
lines changed

test/hqq/test_hqq_affine.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
import unittest
22
import torch
33
from torchao.dtypes.affine_quantized_tensor import (
4-
to_affine_quantized_intx,
54
ZeroPointDomain,
6-
PlainAQTTensorImpl,
7-
PlainLayout,
8-
TensorCoreTiledAQTTensorImpl,
9-
TensorCoreTiledLayout,
105
MappingType,
116
)
127

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
140140
self.zero_point_domain,
141141
output_dtype=output_dtype,
142142
)
143+
from torchao.dtypes.uintx import TensorCoreTiledLayout
144+
143145
if isinstance(self._layout, TensorCoreTiledLayout):
144146
# need to return to original shape if tensor was padded
145147
# in preprocessing

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,11 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
101101
AffineQuantizedTensor._quantized_linear_op = _quantized_linear_op
102102

103103

104-
# # following are a list of (dispatch_condition, implementation) functions that takes the following args:
105-
# # input_tensor: dimension is (M1, M2, ..., in_features)
106-
# # weight_tensor: dimension is (out_features, in_features)
107-
# # bias: dimension is (out_features,)
108-
# # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
104+
# _register_aqt_quantized_linear_dispatches function has a list of (dispatch_condition, implementation) functions, defined in their dtype layout classes, that takes the following args:
105+
# input_tensor: dimension is (M1, M2, ..., in_features)
106+
# weight_tensor: dimension is (out_features, in_features)
107+
# bias: dimension is (out_features,)
108+
# so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
109109
def _register_aqt_quantized_linear_dispatches():
110110
for dispatch_condition, impl in [
111111
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),

torchao/dtypes/uintx/plain_layout.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
register_layout,
1212
)
1313
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
14+
from torchao.kernel import (
15+
int_scaled_matmul,
16+
)
1417
from torchao.quantization.quant_primitives import (
1518
ZeroPointDomain,
16-
int_scaled_matmul,
1719
)
1820
from torchao.utils import fill_defaults
1921

0 commit comments

Comments
 (0)