diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 1aec147cd67..aaf7f051b09 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -863,6 +863,76 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return result +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class FuseMulTensorIntoQuantPass(ExportPass): + """ + Looks for the pattern where aten.mul.Tensor is followed by quant node. + If found, updates the quant scale to reflect the multiplication and + removes the mul node. + """ + + def attempt_fusion( + self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node + ) -> None: + full_nodes = [ + arg + for arg in mul_node.args + if isinstance(arg, torch.fx.Node) + and arg.target == exir_ops.edge.aten.full.default + ] + + if len(full_nodes) != 1 or len(mul_node.users) != 1: + return + + full_node = full_nodes[0] + mul_user = list(mul_node.users.keys())[0] + + if mul_user.target not in { + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + }: + return + + quant_node = mul_user + + # Calculate the new scale value. + prev_scale = quant_node.args[1] + assert isinstance(prev_scale, (int, float)) + mul_scalar = full_node.args[1] + assert isinstance(mul_scalar, (int, float)) + new_scale = float(prev_scale) * float(mul_scalar) + + logging.debug( + f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}" + ) + + # Replace the input first + quant_node.replace_input_with( + cast(torch.fx.Node, quant_node.args[0]), + cast(torch.fx.Node, mul_node.args[0]), + ) + + # Now update the scale in the args + new_quant_args = list(quant_node.args) + new_quant_args[1] = new_scale + quant_node.args = tuple(new_quant_args) + + # Clean up the mul_node + mul_node.args = () + mul_node.users = {} + + graph_module.graph.erase_node(mul_node) + graph_module.graph.erase_node(full_node) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mul.Tensor + ): + self.attempt_fusion(graph_module, node) + graph_module.graph.eliminate_dead_code() + return super().call(graph_module) + + @register_cadence_pass(CadencePassAttribute(opt_level=1)) class FuseMulTensorIntoDequantPass(ExportPass): """ diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index d01e2e57859..30ea91bafb5 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -20,6 +20,7 @@ FuseMMWithAdd, FuseMulScalarIntoDequantPass, FuseMulTensorIntoDequantPass, + FuseMulTensorIntoQuantPass, FuseQuantDequantToRequantizePass, FuseTransposeOrPermuteOpPairsPass, ) @@ -587,6 +588,48 @@ def test_fuse_mul_scalar_into_dequant(self): deq_scale = node.args[1] self.assertEqual(deq_scale, dequant_scale * mul_value) + def test_fuse_mul_into_quant(self): + quant_scale = 1.5 + mul_value = 10 + + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 32, dtype=torch.float32)) + full = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], mul_value), + ) + mul = builder.call_operator( + op=exir_ops.edge.aten.mul.Tensor, + args=(x, full), + ) + quant = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(mul, quant_scale, 0, 0, 255, torch.uint8), + ) + builder.output(quant) + graph_module = FuseMulTensorIntoQuantPass()( + builder.get_graph_module() + ).graph_module + + # verify that the mul and full ops were removed + self.check_op_counts( + graph_module, + expected_op_counts={ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, + exir_ops.edge.aten.full.default: 0, + exir_ops.edge.aten.mul.Tensor: 0, + }, + ) + + # verify that the quant scale value was updated correctly + for node in graph_module.graph.nodes: + if ( + node.target + == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ): + deq_scale = node.args[1] + self.assertEqual(deq_scale, quant_scale * mul_value) + def test_fuse_then_transpose_pass(self): # Create a graph with full -> transpose. builder = GraphBuilder()