Skip to content

Commit 67a1622

Browse files
committed
Qualcomm AI Engine Direct - XR model enablement pipe_clean
summary - support linalg_vector_norm, instance_norm, any, ne - expand coverage of quantization annotator - test cases - small refactor for _pass importing
1 parent 433e30b commit 67a1622

29 files changed

+865
-139
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,55 @@
22
from .annotate_decomposed import AnnotateDecomposed
33
from .annotate_quant_attrs import AnnotateQuantAttrs
44
from .constant_i64_to_i32 import ConstantI64toI32
5+
from .convert_binary_op_with_scalar import ConvertBinaryOpsWithScalar
56
from .convert_bmm_to_matmul import ConvertBmmToMatmul
67
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
78
from .convert_prelu import ConvertPReLU
89
from .convert_to_linear import ConvertToLinear
10+
from .decompose_any import DecomposeAny
11+
from .decompose_einsum import DecomposeEinsum
12+
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
13+
from .decompose_silu import DecomposeSilu
914
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
1015
from .fold_qdq import FoldQDQ
16+
from .fuse_consecutive_transpose import FuseConsecutiveTranspose
17+
from .insert_io_qdq import InsertIOQDQ
18+
from .insert_requantize import InsertRequantize
1119
from .layout_transform import LayoutTransform
1220
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
1321
from .recompose_rms_norm import RecomposeRmsNorm
22+
from .reduce_dynamic_range import ReduceDynamicRange
1423
from .remove_redundancy import RemoveRedundancy
1524
from .replace_index_put_input import ReplaceIndexPutInput
25+
from .replace_inf_buffer import ReplaceInfBuffer
1626
from .tensor_i64_to_i32 import TensorI64toI32
1727

1828

1929
__all__ = [
2030
AnnotateAndQuantScalar,
2131
AnnotateDecomposed,
2232
AnnotateQuantAttrs,
33+
ConstantI64toI32,
2334
ConvertBmmToMatmul,
35+
ConvertBinaryOpsWithScalar,
2436
ConvertInterpolateWithUpsample2D,
2537
ConvertPReLU,
2638
ConvertToLinear,
39+
DecomposeAny,
40+
DecomposeEinsum,
41+
DecomposeLinalgVectorNorm,
42+
DecomposeSilu,
2743
ExpandBroadcastTensorShape,
2844
FoldQDQ,
29-
ConstantI64toI32,
30-
TensorI64toI32,
45+
FuseConsecutiveTranspose,
46+
InsertIOQDQ,
47+
InsertRequantize,
3148
LayoutTransform,
3249
RecomposePixelUnshuffle,
3350
RecomposeRmsNorm,
51+
ReduceDynamicRange,
3452
RemoveRedundancy,
3553
ReplaceIndexPutInput,
54+
ReplaceInfBuffer,
55+
TensorI64toI32,
3656
]

backends/qualcomm/_passes/convert_to_linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class ConvertToLinear(ExportPass):
3939
mm = exir_ops.edge.aten.mm.default
4040

4141
addmm_patterns = [
42+
{view_copy: 1, permute_copy: 1, addmm: 1},
4243
{view_copy: 2, permute_copy: 1, addmm: 1},
4344
{permute_copy: 1, addmm: 1},
4445
]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir import to_edge
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class Any(torch.nn.Module):
13+
def __init__(self, dim, keepdim):
14+
super().__init__()
15+
self.dim = tuple(dim) if isinstance(dim, list) else dim
16+
self.keepdim = keepdim
17+
18+
def forward(self, x):
19+
if self.dim is None:
20+
x = torch.flatten(x)
21+
self.dim = 0
22+
23+
x = x.to(torch.bool).to(torch.int32)
24+
x = torch.sum(x, dim=self.dim, keepdim=self.keepdim, dtype=torch.int32)
25+
return torch.not_equal(x, torch.zeros(1, dtype=torch.int32))
26+
27+
28+
class DecomposeAny(ExportPass):
29+
"""
30+
Decompose for math equivalent op.
31+
"""
32+
33+
def __init__(self, quantization_capture=False) -> None:
34+
super().__init__()
35+
36+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
37+
graph = graph_module.graph
38+
for node in graph.nodes:
39+
if "any.dim" in str(node.target):
40+
dim = node.args[1] if len(node.args) > 1 else None
41+
keepdim = node.args[2] if len(node.args) > 2 else False
42+
model = Any(dim, keepdim)
43+
edge_mgr = to_edge(
44+
torch.export.export(model, (node.args[0].meta["val"],))
45+
)
46+
decomposed_module = edge_mgr.exported_program()
47+
48+
with graph.inserting_before(node):
49+
# remap is used to map original node values to new node values,
50+
# which ensures that reference to nodes are correctly updated in the new graph
51+
remap = {"x": node.args[0]}
52+
53+
for decomposed_node in decomposed_module.graph.nodes:
54+
# no need to copy existent 'output'
55+
if decomposed_node.op == "output":
56+
for user in node.users.copy():
57+
# remap
58+
user.replace_input_with(
59+
node,
60+
remap[decomposed_node.args[0][0]],
61+
)
62+
# no need to copy existent placeholders
63+
elif decomposed_node.op == "placeholder":
64+
# replace node map from string to graph node
65+
remap[decomposed_node] = remap.pop(decomposed_node.name)
66+
else:
67+
remap[decomposed_node] = graph.node_copy(
68+
decomposed_node,
69+
arg_transform=lambda x, remap=remap: remap[x],
70+
)
71+
72+
graph.erase_node(node)
73+
74+
graph.eliminate_dead_code()
75+
graph_module.recompile()
76+
return PassResult(graph_module, True)

backends/qualcomm/_passes/decompose_einsum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2828

2929
with graph.inserting_before(node):
3030
# remap is used to map original node values to new node values,
31-
# which ensures that reference to nodes are correclty updated in the new graph
31+
# which ensures that reference to nodes are correctly updated in the new graph
3232
remap = {}
3333
# Different from other nodes, einsum args[0] is the einsum equation,
3434
# while input nodes are stored in args[1]
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir import to_edge
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class LinalgVectorNorm(torch.nn.Module):
13+
def __init__(self, exp, dim, keepdim):
14+
super().__init__()
15+
self.exp = exp
16+
self.dim = tuple(dim) if dim is not None else None
17+
self.keepdim = keepdim
18+
19+
def forward(self, x):
20+
if self.dim is None:
21+
x = torch.flatten(x)
22+
self.dim = 0
23+
24+
x = torch.abs(x)
25+
x = torch.pow(x, self.exp)
26+
x = torch.sum(x, dim=self.dim, keepdim=self.keepdim)
27+
return torch.pow(x, 1.0 / self.exp)
28+
29+
30+
class DecomposeLinalgVectorNorm(ExportPass):
31+
"""
32+
Decompose for math equivalent op.
33+
"""
34+
35+
def __init__(self, quantization_capture=False) -> None:
36+
super().__init__()
37+
self.quantization_capture = quantization_capture
38+
39+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
40+
graph = graph_module.graph
41+
for node in graph.nodes:
42+
if "linalg_vector_norm" in str(node.target):
43+
ord = node.args[1] if len(node.args) > 1 else 2.0
44+
dim = node.args[2] if len(node.args) > 2 else None
45+
keepdim = node.args[3] if len(node.args) > 3 else False
46+
model = LinalgVectorNorm(ord, dim, keepdim)
47+
if self.quantization_capture:
48+
decomposed_module = torch.export.export(
49+
model, (node.args[0].meta["val"],)
50+
).module()
51+
else:
52+
edge_mgr = to_edge(
53+
torch.export.export(model, (node.args[0].meta["val"],))
54+
)
55+
decomposed_module = edge_mgr.exported_program()
56+
57+
with graph.inserting_before(node):
58+
# remap is used to map original node values to new node values,
59+
# which ensures that reference to nodes are correctly updated in the new graph
60+
remap = {"x": node.args[0]}
61+
62+
for decomposed_node in decomposed_module.graph.nodes:
63+
# no need to copy existent 'output'
64+
if decomposed_node.op == "output":
65+
for user in node.users.copy():
66+
# remap
67+
user.replace_input_with(
68+
node,
69+
remap[decomposed_node.args[0][0]],
70+
)
71+
# no need to copy existent placeholders
72+
elif decomposed_node.op == "placeholder":
73+
# replace node map from string to graph node
74+
remap[decomposed_node] = remap.pop(decomposed_node.name)
75+
else:
76+
remap[decomposed_node] = graph.node_copy(
77+
decomposed_node,
78+
arg_transform=lambda x, remap=remap: remap[x],
79+
)
80+
81+
graph.erase_node(node)
82+
83+
graph.eliminate_dead_code()
84+
graph_module.recompile()
85+
return PassResult(graph_module, True)

backends/qualcomm/_passes/layout_transform.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ class LayoutTransform(ExportPass):
3333
exir_ops.edge.aten.adaptive_avg_pool2d.default,
3434
exir_ops.edge.aten.avg_pool2d.default,
3535
exir_ops.edge.aten.convolution.default,
36+
exir_ops.edge.aten.instance_norm.default,
3637
exir_ops.edge.aten.max_pool2d_with_indices.default,
3738
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
39+
exir_ops.edge.aten._native_batch_norm_legit.no_stats,
3840
exir_ops.edge.aten.native_group_norm.default,
3941
exir_ops.edge.aten.pixel_shuffle.default,
4042
exir_ops.edge.aten.pixel_unshuffle.default,
@@ -54,6 +56,7 @@ class LayoutTransform(ExportPass):
5456
exir_ops.edge.aten.eq.Scalar,
5557
exir_ops.edge.aten.eq.Tensor,
5658
exir_ops.edge.aten.full.default,
59+
exir_ops.edge.aten.full_like.default,
5760
exir_ops.edge.aten.ge.Scalar,
5861
exir_ops.edge.aten.ge.Tensor,
5962
exir_ops.edge.aten.gelu.default,
@@ -75,6 +78,8 @@ class LayoutTransform(ExportPass):
7578
exir_ops.edge.aten.mean.dim,
7679
exir_ops.edge.aten.minimum.default,
7780
exir_ops.edge.aten.mul.Tensor,
81+
exir_ops.edge.aten.ne.Scalar,
82+
exir_ops.edge.aten.ne.Tensor,
7883
exir_ops.edge.aten.neg.default,
7984
exir_ops.edge.aten.pow.Tensor_Scalar,
8085
exir_ops.edge.aten.prelu.default,

backends/qualcomm/_passes/tensor_i64_to_i32.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _cast_to_int32(self, core_ep: ExirExportedProgram):
3939
for n in core_ep.exported_program.graph.nodes:
4040
# Keep track of original output dtype so we ensure the dtype of the graph is consistent with nn.Module
4141
if is_graph_output(n):
42-
if isinstance(n.meta["val"], tuple):
42+
if isinstance(n.meta["val"], (tuple, list)):
4343
dtype_list = [tensor.dtype for tensor in n.meta["val"]]
4444
n.meta[QCOM_ORIG_DTYPE] = dtype_list
4545
else:
@@ -76,7 +76,7 @@ def _preserve_output_dtype(
7676
copy_op = exir_ops.edge.aten._to_copy.default
7777
for n in graph_module.graph.nodes:
7878
if is_graph_output(n) and QCOM_ORIG_DTYPE in n.meta:
79-
if isinstance(n.meta["val"], tuple):
79+
if isinstance(n.meta["val"], (tuple, list)):
8080
for i, dtype in enumerate(n.meta[QCOM_ORIG_DTYPE]):
8181
# TODO: Enable this in future to support OP such as topK
8282
if n.meta["val"][i].dtype != dtype:

backends/qualcomm/_passes/utils.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def get_passes_dependency_for_capture_program():
6565
ConvertInterpolateWithUpsample2D,
6666
ConvertPReLU,
6767
ConvertToLinear,
68+
DecomposeAny,
69+
DecomposeLinalgVectorNorm,
6870
ExpandBroadcastTensorShape,
6971
FoldQDQ,
7072
LayoutTransform,
@@ -76,14 +78,10 @@ def get_passes_dependency_for_capture_program():
7678
)
7779

7880
return {
79-
RecomposePixelUnshuffle: [RemoveRedundancy],
80-
RecomposeRmsNorm: [RemoveRedundancy],
81-
ConvertToLinear: [RecomposePixelUnshuffle],
82-
ConvertPReLU: [RemoveRedundancy],
83-
ConvertBmmToMatmul: [ConvertToLinear],
84-
ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
85-
ConstantI64toI32: [RemoveRedundancy],
86-
TensorI64toI32: [RemoveRedundancy],
81+
AnnotateAndQuantScalar: [
82+
AnnotateQuantAttrs,
83+
],
84+
AnnotateDecomposed: [RemoveRedundancy],
8785
AnnotateQuantAttrs: [
8886
RecomposePixelUnshuffle,
8987
RecomposeRmsNorm,
@@ -92,16 +90,22 @@ def get_passes_dependency_for_capture_program():
9290
ConvertBmmToMatmul,
9391
ConvertInterpolateWithUpsample2D,
9492
],
95-
AnnotateAndQuantScalar: [
96-
AnnotateQuantAttrs,
97-
],
98-
AnnotateDecomposed: [RemoveRedundancy],
99-
FoldQDQ: [AnnotateQuantAttrs, AnnotateAndQuantScalar, AnnotateDecomposed],
93+
ConstantI64toI32: [ConvertInterpolateWithUpsample2D],
94+
ConvertBmmToMatmul: [ConvertToLinear],
95+
ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
96+
ConvertPReLU: [RemoveRedundancy],
97+
ConvertToLinear: [RecomposePixelUnshuffle],
98+
DecomposeAny: [RemoveRedundancy],
99+
DecomposeLinalgVectorNorm: [RemoveRedundancy],
100100
ExpandBroadcastTensorShape: [RemoveRedundancy],
101+
FoldQDQ: [AnnotateQuantAttrs, AnnotateAndQuantScalar, AnnotateDecomposed],
101102
LayoutTransform: [
102103
AnnotateQuantAttrs,
103104
AnnotateAndQuantScalar,
104105
ExpandBroadcastTensorShape,
105106
],
107+
RecomposePixelUnshuffle: [RemoveRedundancy],
108+
RecomposeRmsNorm: [RemoveRedundancy],
106109
ReplaceIndexPutInput: [LayoutTransform],
110+
TensorI64toI32: [RemoveRedundancy],
107111
}

backends/qualcomm/builders/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
op_embedding,
2626
op_eq,
2727
op_expand,
28+
op_full,
2829
op_full_like,
2930
op_ge,
3031
op_gelu,
@@ -35,6 +36,7 @@
3536
op_hardtanh,
3637
op_index,
3738
op_index_put,
39+
op_instance_norm,
3840
op_layer_norm,
3941
op_le,
4042
op_linear,
@@ -48,6 +50,7 @@
4850
op_mean_dim,
4951
op_min,
5052
op_mul,
53+
op_ne,
5154
op_neg,
5255
op_pad,
5356
op_pow,
@@ -101,6 +104,7 @@
101104
op_embedding,
102105
op_eq,
103106
op_expand,
107+
op_full,
104108
op_full_like,
105109
op_ge,
106110
op_gelu,
@@ -111,6 +115,7 @@
111115
op_hardsigmoid,
112116
op_index,
113117
op_index_put,
118+
op_instance_norm,
114119
op_layer_norm,
115120
op_le,
116121
op_linear,
@@ -125,6 +130,7 @@
125130
op_min,
126131
op_mul,
127132
op_neg,
133+
op_ne,
128134
op_pad,
129135
op_pow,
130136
op_prelu,

backends/qualcomm/builders/op_arange.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ def define_node(
2828
step = node.args[2] if len(node.args) > 2 else 1
2929
out_tensor = torch.arange(start, end, step)
3030

31+
# since we can derive the constant value of current op in AoT stage
32+
# we only build static tensor here for consumers of current node
33+
# to correctly reference the data
3134
self.define_tensor(
3235
node,
3336
node,

0 commit comments

Comments
 (0)