From 72c4a43dea4ffa132e5221ebe46d6935fa1ae46e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 19 Sep 2024 12:15:21 -0700 Subject: [PATCH 1/3] Add compile tests to test suite Summary: This is a follow up PR addressing https://github.com/pytorch/ao/pull/839#discussion_r1750720771 We can add more compiler related tests in the future. Next * refactor a bit to use quantize_ API directly * use the test suite in existing API tests Test Plan: python torchao/testing/utils.py Reviewers: Subscribers: Tasks: Tags: --- torchao/testing/utils.py | 42 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index a6c5bf7e0a..19e99c30f9 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -69,8 +69,6 @@ def new_test(self, value=value): class TorchAOBasicTestCase(common_utils.TestCase): - """Basic test case for tensor subclasses - """ COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -142,6 +140,43 @@ def test_linear(self, device, dtype): lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor) self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + +class TorchAOCompileTestCase(common_utils.TestCase): + COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + + TENSOR_SUBCLASS = AffineQuantizedTensor + FACTORY_FN = to_affine_quantized_intx + kwargs = { + "mapping_type": MappingType.ASYMMETRIC, + "block_size": (1, 32), + "target_dtype": torch.uint8, + } + # minimum sqnr for linear operation when the weight is quantized to low precision + # with the above setting + LINEAR_MIN_SQNR = 40 + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_input_output(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): + return tensor.t() + + f = torch.compile(f) + self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_input_output(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + def f(hp_tensor): + return self.FACTORY_FN(hp_tensor, **self.kwargs) + + f = torch.compile(f) + self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS)) + @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_linear_compile(self, device, dtype): @@ -155,7 +190,10 @@ def test_linear_compile(self, device, dtype): lp_res = torch.compile(l)(hp_act_tensor) self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) +common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) if __name__ == "__main__": unittest.main() From a180182756a1298519e272dcea3ad555395eb7bd Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 19 Sep 2024 12:25:36 -0700 Subject: [PATCH 2/3] rename --- torchao/testing/utils.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 19e99c30f9..6165b1214c 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -158,18 +158,29 @@ class TorchAOCompileTestCase(common_utils.TestCase): @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) - def test_input_output(self, device, dtype): + def test_input_output_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) def f(tensor): - return tensor.t() + return tensor f = torch.compile(f) self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) - def test_input_output(self, device, dtype): + def test_input_tensor_subclass(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): + return tensor.dequantize() + + f = torch.compile(f) + self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_output_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) def f(hp_tensor): return self.FACTORY_FN(hp_tensor, **self.kwargs) From 7d3ceb78f1d932666417bb337eced1dad214ff6c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 25 Sep 2024 17:58:41 -0700 Subject: [PATCH 3/3] add result check --- torchao/testing/utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 6165b1214c..48a171a75f 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -155,6 +155,7 @@ class TorchAOCompileTestCase(common_utils.TestCase): # minimum sqnr for linear operation when the weight is quantized to low precision # with the above setting LINEAR_MIN_SQNR = 40 + COMPILE_MIN_SQNR = 50 @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -164,8 +165,11 @@ def test_input_output_tensor_subclass(self, device, dtype): def f(tensor): return tensor + ref = f(lp_tensor) f = torch.compile(f) + compiled = f(lp_tensor) self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) + self.assertEqual(ref.dequantize(), compiled.dequantize()) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -175,8 +179,11 @@ def test_input_tensor_subclass(self, device, dtype): def f(tensor): return tensor.dequantize() + ref = f(lp_tensor) f = torch.compile(f) + compiled = f(lp_tensor) self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) + self.assertEqual(ref, compiled) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -185,8 +192,13 @@ def test_output_tensor_subclass(self, device, dtype): def f(hp_tensor): return self.FACTORY_FN(hp_tensor, **self.kwargs) + ref = f(hp_tensor) f = torch.compile(f) + compiled = f(hp_tensor) self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS)) + # bfloat16 seems to result in much larger numerical differences + if dtype != torch.bfloat16: + self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES)