|
27 | 27 | get_groupwise_affine_qparams,
|
28 | 28 | groupwise_affine_quantize_tensor,
|
29 | 29 | )
|
30 |
| -from torchao.utils import TORCH_VERSION_AFTER_2_4 |
| 30 | +from torchao.utils import ( |
| 31 | + TORCH_VERSION_AFTER_2_4, |
| 32 | + TORCH_VERSION_AFTER_2_5, |
| 33 | +) |
31 | 34 |
|
32 | 35 |
|
33 | 36 | # TODO: put this in a common test utils file
|
@@ -366,6 +369,8 @@ def _assert_close_4w(self, val, ref):
|
366 | 369 |
|
367 | 370 | @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
|
368 | 371 | @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
|
| 372 | + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 |
| 373 | + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") |
369 | 374 | def test_qat_4w_primitives(self):
|
370 | 375 | n_bit = 4
|
371 | 376 | group_size = 32
|
@@ -411,6 +416,8 @@ def test_qat_4w_primitives(self):
|
411 | 416 |
|
412 | 417 | @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
|
413 | 418 | @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
|
| 419 | + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 |
| 420 | + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") |
414 | 421 | def test_qat_4w_linear(self):
|
415 | 422 | from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear
|
416 | 423 | from torchao.quantization.GPTQ import WeightOnlyInt4Linear
|
@@ -439,6 +446,8 @@ def test_qat_4w_linear(self):
|
439 | 446 |
|
440 | 447 | @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
|
441 | 448 | @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
|
| 449 | + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 |
| 450 | + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") |
442 | 451 | def test_qat_4w_quantizer(self):
|
443 | 452 | from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
|
444 | 453 | from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
|
|
0 commit comments