From 9661a3bf4720ace4820f16d0448fe387c1e48201 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Fri, 28 Feb 2025 14:14:09 +0000 Subject: [PATCH] Arm backend: Add test pipeline for run_transform_for_annotation_pipeline - Added eager mode testing for pipeline --- backends/arm/test/ops/test_scalars.py | 274 +++++++++++++--------- backends/arm/test/tester/arm_tester.py | 74 +++++- backends/arm/test/tester/test_pipeline.py | 48 ++++ 3 files changed, 276 insertions(+), 120 deletions(-) diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index 2ab420bd59e..17dcd6f1d27 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -5,11 +5,16 @@ import unittest +from typing import Tuple + +import common import torch -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from parameterized import parameterized +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineBI, + TosaPipelineMI, + TransformAnnotationPassPipeline, +) """ Summary of non-working cases. @@ -24,6 +29,7 @@ # MLETORCH-408 Sub or inplace-sub with an integer input. """ +input_t1 = Tuple[torch.Tensor, torch.scalar_tensor] # Input x, Input y class TestScalars(unittest.TestCase): @@ -92,112 +98,160 @@ def forward(self, x): x -= 10 return x - # Inplace ops end with '_' (from aten naming) - ops = [ - ("Add", Add()), - ("Sub", Sub()), - ("Mul", Mul()), - ("Div", Div()), - ("Add_", AddInplace()), - ("Sub_", SubInplace()), - ("Mul_", MulInplace()), - ("Div_", DivInplace()), - ("MulScalar", MulScalar()), - ("DivScalar", DivScalar()), - ("AddScalar", AddScalar()), - ("SubScalar", SubScalar()), - ] - - const_ops = [("Add", AddConst())] - - dtypes = [("int", 3), ("float", 3.0)] - sizes = [("r1", (1)), ("r4", (2, 4, 5, 3))] - - # Create combinations of tests - tensor_scalar_tests = [] - for op in ops: - for dtype in dtypes: - for size in sizes: - test_name = f"{op[0]}_{dtype[0]}_{size[0]}" - tensor = torch.rand(size[1]) - scalar = dtype[1] - tensor_scalar_tests.append((test_name + "_ts", op[1], tensor, scalar)) - - # Don't add (scalar, tensor) test case for .Scalar ops. - if op[0][-6:] == "Scalar": - continue - - tensor_scalar_tests.append((test_name + "_st", op[1], scalar, tensor)) - - tensor_const_tests = [] - for op in const_ops: + +# Inplace ops end with '_' (from aten naming) +ops = [ + ("Add", TestScalars.Add()), + ("Sub", TestScalars.Sub()), + ("Mul", TestScalars.Mul()), + ("Div", TestScalars.Div()), + ("Add_", TestScalars.AddInplace()), + ("Sub_", TestScalars.SubInplace()), + ("Mul_", TestScalars.MulInplace()), + ("Div_", TestScalars.DivInplace()), + ("MulScalar", TestScalars.MulScalar()), + ("DivScalar", TestScalars.DivScalar()), + ("AddScalar", TestScalars.AddScalar()), + ("SubScalar", TestScalars.SubScalar()), +] + +const_ops = [("Add", TestScalars.AddConst())] + +dtypes = [("int", 3), ("float", 3.0)] +sizes = [("r1", (1)), ("r4", (2, 4, 5, 3))] + +# Create combinations of tests +tensor_scalar_tests = {} +for op in ops: + for dtype in dtypes: for size in sizes: - test_name = f"{op[0]}_{size[0]}" + test_name = f"{op[0]}_{dtype[0]}_{size[0]}" tensor = torch.rand(size[1]) - tensor_const_tests.append((test_name, op[1], tensor)) - - def _test_add_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: tuple): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .to_edge() - .partition() - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .to_edge() - .partition() - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - @parameterized.expand(tensor_scalar_tests) - def test_MI(self, test_name: str, op: torch.nn.Module, x, y): - expected_exception = None - if any(token in test_name for token in ("Sub_int", "Sub__int")): - expected_exception = AssertionError - if test_name.endswith("_st"): - expected_exception = AttributeError - - if expected_exception: - with self.assertRaises( - expected_exception, msg=f"Test {test_name} is expected to fail." - ): - self._test_add_tosa_MI_pipeline(op, (x, y)) - return - - self._test_add_tosa_MI_pipeline(op, (x, y)) - - # op(Scalar float, tensor) works if the scalar is constant. - @parameterized.expand(tensor_const_tests) - def test_MI_const(self, test_name: str, op: torch.nn.Module, x): - self._test_add_tosa_MI_pipeline(op, (x,)) - - @parameterized.expand(tensor_scalar_tests) - def test_BI(self, test_name: str, op: torch.nn.Module, x, y): - self._test_add_tosa_BI_pipeline(op, (x, y)) - - # op(Scalar float, tensor) works if the scalar is constant. - @parameterized.expand(tensor_const_tests) - def test_BI_const(self, test_name: str, op: torch.nn.Module, x): - self._test_add_tosa_BI_pipeline(op, (x,)) - - def test_shift_sub_inplace_tosa_MI(self): - self._test_add_tosa_MI_pipeline(self.ShiftInplaceSub(), (torch.IntTensor(5),)) - - def test_shift_sub_inplace_tosa_BI(self): - self._test_add_tosa_BI_pipeline(self.ShiftInplaceSub(), (torch.IntTensor(5),)) + scalar = dtype[1] + tensor_scalar_tests[test_name + "_ts"] = (op[1], tensor, scalar) + # Don't add (scalar, tensor) test case for .Scalar ops. + if op[0][-6:] == "Scalar": + continue + + tensor_scalar_tests[test_name + "_st"] = (op[1], scalar, tensor) + +tensor_const_tests = {} +for op in const_ops: + for size in sizes: + test_name = f"{op[0]}_{size[0]}" + tensor = torch.rand(size[1]) + tensor_const_tests[test_name] = (op[1], tensor) + + +def _test_add_tosa_MI_pipeline(module: torch.nn.Module, test_data: tuple): + pipeline = TosaPipelineMI[input_t1](module, test_data, aten_op=[], exir_op=[]) + pipeline.run() + + +def _test_add_tosa_BI_pipeline( + module: torch.nn.Module, test_data: tuple, check_quant_nodes=True +): + pipeline = TosaPipelineBI[input_t1](module, test_data, aten_op=[], exir_op=[]) + if not check_quant_nodes: + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +fail_str = "MLETORCH-408: Arithmetic ops can't handle scalars first for MI" +MI_xfails = { + "Add_int_r1_st": fail_str, + "Add_int_r4_st": fail_str, + "Add_float_r1_st": fail_str, + "Add_float_r4_st": fail_str, + "Sub_int_r1_ts": fail_str, + "Sub_int_r1_st": fail_str, + "Sub_int_r4_ts": fail_str, + "Sub_int_r4_st": fail_str, + "Sub_float_r1_st": fail_str, + "Sub_float_r4_st": fail_str, + "Mul_int_r1_st": fail_str, + "Mul_int_r4_st": fail_str, + "Mul_float_r1_st": fail_str, + "Mul_float_r4_st": fail_str, + "Div_int_r1_st": fail_str, + "Div_int_r4_st": fail_str, + "Div_float_r1_st": fail_str, + "Div_float_r4_st": fail_str, + "Add__int_r1_st": fail_str, + "Add__float_r1_st": fail_str, + "Add__float_r4_st": fail_str, + "Add__int_r4_st": fail_str, + "Sub__int_r1_ts": fail_str, + "Sub__int_r1_st": fail_str, + "Sub__int_r4_ts": fail_str, + "Sub__int_r4_st": fail_str, + "Sub__float_r1_st": fail_str, + "Sub__float_r4_st": fail_str, + "Mul__int_r1_st": fail_str, + "Mul__int_r4_st": fail_str, + "Mul__float_r1_st": fail_str, + "Mul__float_r4_st": fail_str, + "Div__int_r1_st": fail_str, + "Div__int_r4_st": fail_str, + "Div__float_r1_st": fail_str, + "Div__float_r4_st": fail_str, +} + + +@common.parametrize("tensor_scalar_tests", tensor_scalar_tests, MI_xfails) +def test_MI(tensor_scalar_tests: list): + op, x, y = tensor_scalar_tests + _test_add_tosa_MI_pipeline(op, (x, y)) + + +def _test_passes_tosa_BI_pipeline(module: torch.nn.Module, test_data: tuple): + pipeline = TransformAnnotationPassPipeline[input_t1]( + module, test_data, tosa_version="TOSA-0.80+BI" + ) + pipeline.run() + + +fail_str = "MLETORCH-770: Numerical issues on Div Scalar." +passes_xfails = { + "Div__int_r1_ts": fail_str, + "Div__int_r4_ts": fail_str, + "Div__float_r1_ts": fail_str, + "Div__float_r4_ts": fail_str, +} + + +@common.parametrize("tensor_scalar_tests", tensor_scalar_tests, passes_xfails) +def test_passes_BI(tensor_scalar_tests: list): + op, x, y = tensor_scalar_tests + _test_passes_tosa_BI_pipeline(op, (x, y)) + + +# op(Scalar float, tensor) works if the scalar is constant. +@common.parametrize("tensor_const_tests", tensor_const_tests) +def test_MI_const(tensor_const_tests: list): + op, x = tensor_const_tests + _test_add_tosa_MI_pipeline(op, (x,)) + + +@common.parametrize("tensor_scalar_tests", tensor_scalar_tests) +def test_BI(tensor_scalar_tests: list): + op, x, y = tensor_scalar_tests + _test_add_tosa_BI_pipeline(op, (x, y)) + + +# op(Scalar float, tensor) works if the scalar is constant. +@common.parametrize("tensor_const_tests", tensor_const_tests) +def test_BI_const(tensor_const_tests: list): + op, x = tensor_const_tests + _test_add_tosa_BI_pipeline(op, (x,)) + + +def test_shift_sub_inplace_tosa_MI(): + _test_add_tosa_MI_pipeline(TestScalars.ShiftInplaceSub(), (torch.IntTensor(5),)) + + +# Do not check for quant nodes in the graph for rshift. +def test_shift_sub_inplace_tosa_BI(): + _test_add_tosa_BI_pipeline( + TestScalars.ShiftInplaceSub(), (torch.IntTensor(5),), check_quant_nodes=False + ) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index a6da2accd1d..0c538e0d599 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -16,6 +16,7 @@ import torch.fx import torch.utils._pytree as pytree +from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager from executorch.backends.arm.arm_backend import ( get_intermediate_path, @@ -59,7 +60,10 @@ from executorch.exir.backend.partitioner import Partitioner from executorch.exir.lowered_backend_module import LoweredBackendModule from executorch.exir.pass_base import ExportPass -from executorch.exir.program._program import _update_exported_program_graph_module +from executorch.exir.program._program import ( + _copy_module, + _update_exported_program_graph_module, +) from tabulate import tabulate from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec @@ -181,6 +185,7 @@ def __init__( """Passes are run in the order they are passed: first pass_list, second pass_functions, and lastly passes_with_exported_program.""" self.pass_with_exported_program = passes_with_exported_program + super().__init__(pass_list, pass_functions) def run( @@ -347,6 +352,7 @@ def run_method_and_compare_outputs( rtol=1e-03, qtol=0, error_callbacks=None, + run_eager_mode=False, ): """ Compares the run_artifact output of 'stage' with the output of a reference stage. @@ -362,12 +368,23 @@ def run_method_and_compare_outputs( inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data. The default is random data. """ - edge_stage = self.stages[self.stage_name(tester.ToEdge)] - if edge_stage is None: - edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)] - assert ( - edge_stage is not None - ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run." + + if not run_eager_mode: + edge_stage = self.stages[self.stage_name(tester.ToEdge)] + if edge_stage is None: + edge_stage = self.stages[ + self.stage_name(tester.ToEdgeTransformAndLower) + ] + assert ( + edge_stage is not None + ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run." + else: + # Run models in eager mode. We do this when we want to check that the passes + # are numerically accurate and the exported graph is correct. + export_stage = self.stages[self.stage_name(tester.Export)] + assert ( + export_stage is not None + ), "To compare outputs in eager mode, the model must be at Export stage" stage = stage or self.cur test_stage = self.stages[stage] @@ -380,6 +397,7 @@ def run_method_and_compare_outputs( exported_program = self.stages[self.stage_name(tester.Export)].artifact output_nodes = get_output_nodes(exported_program) + output_qparams = get_output_quantization_params(output_nodes) quantization_scales = [] @@ -404,9 +422,19 @@ def run_method_and_compare_outputs( reference_outputs, _ = pytree.tree_flatten( reference_stage.run_artifact(reference_input) ) - test_outputs, _ = pytree.tree_flatten( - test_stage.run_artifact(reference_input) - ) + + if run_eager_mode: + # Run exported module directly + test_outputs, _ = pytree.tree_flatten( + self._calculate_reference_output( + exported_program.module(), reference_input + ) + ) + else: + # Run lowered model with target + test_outputs, _ = pytree.tree_flatten( + test_stage.run_artifact(reference_input) + ) for reference_output, test_output, quantization_scale in zip( reference_outputs, test_outputs, quantization_scales @@ -533,6 +561,32 @@ def dump_dtype_distribution( _dump_str(to_print, path_to_dump) return self + def run_transform_for_annotation_pipeline( + self, stage: str | None = None + ) -> torch.fx.GraphModule: + """Run transform_for_annotation_pipeline on exported program to ensure + passes do not break the initial model before quantization. + + There are caveats to this however. As we register buffers to the graph modules + the resulting exported graph can fail. Use this only to compare numerical correctness + in eager mode. + + Returns exported program with passes applied. + """ + + if stage is None: + stage = self.cur + # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run. + artifact = self.get_artifact(stage) + if self.cur == self.stage_name(tester.Export): + new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type] + graph_module=artifact.graph_module + ) + else: + raise RuntimeError("Can only run passes on Export stage.") + _copy_module(artifact.graph_module, new_gm) + return artifact + @staticmethod def _calculate_reference_output( module: Union[torch.fx.GraphModule, torch.nn.Module], inputs diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 62d0b633224..d50b4a9eb56 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -580,6 +580,54 @@ def __init__( self.add_stage(self.tester.run_method_and_compare_outputs) +class TransformAnnotationPassPipeline(BasePipelineMaker, Generic[T]): + """ + Runs transform_for_annotation_pipeline passes directly on an exported program and checks output. + + Attributes: + module: The module which the pipeline is applied to. + test_data: Data used for testing the module. + tosa_version: The TOSA-version which to test for. + + custom_path : Path to dump intermediate artifacts such as tosa and pte to. + + """ + + def __init__( + self, + module: torch.nn.Module, + test_data: T, + tosa_version: str, + custom_path: str = None, + ): + compile_spec = common.get_tosa_compile_spec( + tosa_version, custom_path=custom_path + ) + super().__init__( + module, + test_data, + None, + compile_spec, + None, + use_to_edge_transform_and_lower=True, + ) + self.add_stage_after( + "export", self.tester.run_transform_for_annotation_pipeline + ) + + # Delete most of the pipeline + self.pop_stage("check_not.exir") + self.pop_stage("check_count.exir") + self.pop_stage("to_executorch") + self.pop_stage("to_edge_transform_and_lower") + self.pop_stage("check.aten") + self.add_stage( + self.tester.run_method_and_compare_outputs, + inputs=test_data, + run_eager_mode=True, + ) + + class OpNotSupportedPipeline(BasePipelineMaker, Generic[T]): """ Runs the partitioner on a module and checks that ops are not delegated to test