@@ -95,20 +95,24 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
95
95
TEST_CONFIGS_DEQUANT = list (itertools .product (SHAPES , INNERKTILES , QGROUP_SIZES ))
96
96
97
97
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
98
- @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
98
+ # @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
99
99
@pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
100
100
def test_unpack_tensor_core_tiled_layout_correctness (shape , inner_k_tiles ):
101
101
N , K = shape
102
102
assert K % (inner_k_tiles * kTileSizeK ) == 0 and N % kTileSizeN == 0
103
103
104
104
t = torch .randint (0 , 16 , dtype = torch .int , size = shape , device = "cuda" )
105
+ if TORCH_VERSION_AFTER_2_5 :
106
+ t = (t [::, ::2 ] << 4 | t [::, 1 ::2 ]).to (torch .uint8 )
105
107
packed_w = torch .ops .aten ._convert_weight_to_int4pack (t , inner_k_tiles )
106
108
unpacked = torchao .ops .unpack_tensor_core_tiled_layout (packed_w , inner_k_tiles )
109
+ if TORCH_VERSION_AFTER_2_5 :
110
+ unpacked = (unpacked [::, ::2 ] << 4 | unpacked [::, 1 ::2 ]).to (torch .uint8 )
107
111
assert torch .equal (t , unpacked )
108
112
109
113
# TODO: Fix "test_aot_dispatch_dynamic" test failure
110
114
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
111
- @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
115
+ # @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
112
116
@pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
113
117
def test_unpack_tensor_core_tiled_layout_op (shape , inner_k_tiles ):
114
118
test_utils = [
@@ -122,6 +126,8 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
122
126
test_utils .append ("test_aot_dispatch_dynamic" )
123
127
124
128
t = torch .randint (0 , 16 , dtype = torch .int , size = shape , device = "cuda" )
129
+ if TORCH_VERSION_AFTER_2_5 :
130
+ t = (t [::, ::2 ] << 4 | t [::, 1 ::2 ]).to (torch .uint8 )
125
131
packed_w = torch .ops .aten ._convert_weight_to_int4pack (t , inner_k_tiles )
126
132
127
133
opcheck (
@@ -151,7 +157,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
151
157
152
158
153
159
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
154
- @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
160
+ # @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
155
161
@pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
156
162
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant (shape , inner_k_tiles , group_size ):
157
163
n , k = shape
@@ -210,7 +216,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
210
216
211
217
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
212
218
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
213
- @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
219
+ # @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
214
220
@pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
215
221
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant (shape , inner_k_tiles , group_size ):
216
222
n , k = shape
@@ -229,6 +235,9 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
229
235
230
236
# Unpack and dequantize
231
237
unpacked = torchao .ops .unpack_tensor_core_tiled_layout (packed , inner_k_tiles )
238
+ if TORCH_VERSION_AFTER_2_5 :
239
+ unpacked = (unpacked [::, ::2 ] << 4 | unpacked [::, 1 ::2 ]).to (torch .uint8 )
240
+
232
241
dq_ao = groupwise_affine_dequantize_tensor_from_qparams (
233
242
unpacked , scales , zeros , n_bit = 4 , groupsize = group_size
234
243
)
@@ -264,13 +273,15 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
264
273
assert diff_op_ao < 1e-1
265
274
266
275
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
267
- @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
276
+ # @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
268
277
@pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
269
278
def test_dequantize_tensor_core_tiled_layout_op (shape , inner_k_tiles , group_size ):
270
279
n , k = shape
271
280
device = "cuda"
272
281
273
282
q = torch .randint (0 , 16 , shape , dtype = torch .int , device = device )
283
+ if TORCH_VERSION_AFTER_2_5 :
284
+ q = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
274
285
packed_w = torch ._convert_weight_to_int4pack (q , inner_k_tiles )
275
286
q_groups = k // group_size
276
287
scales = torch .randn (n , q_groups , dtype = torch .bfloat16 , device = device )
0 commit comments