diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index a6c5bf7e0a..48a171a75f 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -69,8 +69,6 @@ def new_test(self, value=value): class TorchAOBasicTestCase(common_utils.TestCase): - """Basic test case for tensor subclasses - """ COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -142,6 +140,66 @@ def test_linear(self, device, dtype): lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor) self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + +class TorchAOCompileTestCase(common_utils.TestCase): + COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + + TENSOR_SUBCLASS = AffineQuantizedTensor + FACTORY_FN = to_affine_quantized_intx + kwargs = { + "mapping_type": MappingType.ASYMMETRIC, + "block_size": (1, 32), + "target_dtype": torch.uint8, + } + # minimum sqnr for linear operation when the weight is quantized to low precision + # with the above setting + LINEAR_MIN_SQNR = 40 + COMPILE_MIN_SQNR = 50 + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_input_output_tensor_subclass(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): + return tensor + + ref = f(lp_tensor) + f = torch.compile(f) + compiled = f(lp_tensor) + self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) + self.assertEqual(ref.dequantize(), compiled.dequantize()) + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_input_tensor_subclass(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): + return tensor.dequantize() + + ref = f(lp_tensor) + f = torch.compile(f) + compiled = f(lp_tensor) + self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS)) + self.assertEqual(ref, compiled) + + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_output_tensor_subclass(self, device, dtype): + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + def f(hp_tensor): + return self.FACTORY_FN(hp_tensor, **self.kwargs) + + ref = f(hp_tensor) + f = torch.compile(f) + compiled = f(hp_tensor) + self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS)) + # bfloat16 seems to result in much larger numerical differences + if dtype != torch.bfloat16: + self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR) + @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_linear_compile(self, device, dtype): @@ -155,7 +213,10 @@ def test_linear_compile(self, device, dtype): lp_res = torch.compile(l)(hp_act_tensor) self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) +common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) if __name__ == "__main__": unittest.main()