Skip to content

Commit 3a75936

Browse files
committed
Skip int4 QAT tests for nightly for now
int4 tinygemm quantization is currently broken in master and being fixed in #517. Let's skip these tests for now until that is fixed.
1 parent f8789f7 commit 3a75936

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

test/quantization/test_qat.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
get_groupwise_affine_qparams,
2828
groupwise_affine_quantize_tensor,
2929
)
30-
from torchao.utils import TORCH_VERSION_AFTER_2_4
30+
from torchao.utils import (
31+
TORCH_VERSION_AFTER_2_4,
32+
TORCH_VERSION_AFTER_2_5,
33+
)
3134

3235

3336
# TODO: put this in a common test utils file
@@ -366,6 +369,8 @@ def _assert_close_4w(self, val, ref):
366369

367370
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
368371
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
372+
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
373+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
369374
def test_qat_4w_primitives(self):
370375
n_bit = 4
371376
group_size = 32
@@ -411,6 +416,8 @@ def test_qat_4w_primitives(self):
411416

412417
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
413418
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
419+
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
420+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
414421
def test_qat_4w_linear(self):
415422
from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear
416423
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
@@ -439,6 +446,8 @@ def test_qat_4w_linear(self):
439446

440447
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
441448
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
449+
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
450+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
442451
def test_qat_4w_quantizer(self):
443452
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
444453
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

0 commit comments

Comments
 (0)