Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .decompose_col_im import DecomposeColIm
from .decompose_einsum import DecomposeEinsum
from .decompose_expm1 import DecomposeExpM1
from .decompose_glu import DecomposeGlu
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
from .decompose_minmaxdim import DecomposeMinMaxDim
from .decompose_roll import DecomposeRoll
Expand Down Expand Up @@ -57,6 +58,7 @@
DecomposeColIm,
DecomposeEinsum,
DecomposeExpM1,
DecomposeGlu,
DecomposeLinalgVectorNorm,
DecomposeMinMaxDim,
DecomposeRoll,
Expand Down
8 changes: 8 additions & 0 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
QCOM_SCALE,
QCOM_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import get_quant_attrs
Expand All @@ -38,6 +39,9 @@ def __init__(
super(AnnotateQuantAttrs, self).__init__()
self.edge_program = edge_program
self.skip_advanced_requant = skip_advanced_requant
self.skip_requant_allowlist = {
exir_ops.edge.aten.sigmoid.default,
}

def _annotate_source_nodes(
self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any]
Expand Down Expand Up @@ -80,6 +84,10 @@ def _annotate_requant(self, n):
# node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
# We store {node2: quant_attr in dq_int32} in node1.meta
if n.target in q_ops and n.args[0].target not in dq_ops:
# for some fixed scale op, there is no need to requantize it
if n.args[0].target in self.skip_requant_allowlist:
return

dq_nodes = self._find_last_dq_nodes(n)
q_attrs = get_quant_attrs(self.edge_program, n)
for dq_node in dq_nodes:
Expand Down
28 changes: 8 additions & 20 deletions backends/qualcomm/_passes/decompose_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from executorch.exir import to_edge
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import merge_decomposed_graph


class Any(torch.nn.Module):
def __init__(self, dim, keepdim):
Expand Down Expand Up @@ -49,26 +51,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": node.args[0]}

for decomposed_node in decomposed_module.graph.nodes:
# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

merge_decomposed_graph(
remap=remap,
target_node=node,
target_graph=graph,
decomposed_graph_module=decomposed_module,
)
graph.erase_node(node)

graph.eliminate_dead_code()
Expand Down
28 changes: 8 additions & 20 deletions backends/qualcomm/_passes/decompose_cdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import merge_decomposed_graph


class CDist(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -54,26 +56,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": node.args[0], "y": node.args[1]}

for decomposed_node in decomposed_module.graph.nodes:
# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

merge_decomposed_graph(
remap=remap,
target_node=node,
target_graph=graph,
decomposed_graph_module=decomposed_module,
)
graph.erase_node(node)

graph.eliminate_dead_code()
Expand Down
33 changes: 8 additions & 25 deletions backends/qualcomm/_passes/decompose_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.experimental.proxy_tensor import make_fx

from .utils import copy_nn_module_stack
from .utils import merge_decomposed_graph


class DecomposeEinsum(ExportPass):
Expand Down Expand Up @@ -37,30 +37,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for i, arg in enumerate(node.args[1]):
remap[f"arg1_{i+1}"] = arg

for decomposed_node in decomposed_module.graph.nodes:
copy_nn_module_stack(node, decomposed_node)
# This is the arg[0] equation string, which is not required anymore after decomposition
if "arg0" in decomposed_node.name:
continue

# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

merge_decomposed_graph(
remap=remap,
target_node=node,
target_graph=graph,
decomposed_graph_module=decomposed_module,
predicate=lambda decomp_node: "arg0" not in decomp_node.name,
)
graph.erase_node(node)

graph.eliminate_dead_code()
Expand Down
55 changes: 55 additions & 0 deletions backends/qualcomm/_passes/decompose_glu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import merge_decomposed_graph


# this wrapper is required for IO name mapping with decomposed graph
class Glu(torch.nn.Module):
def __init__(self, dim=-1):
super().__init__()
self.glu = torch.nn.GLU(dim=dim)

def forward(self, x):
return self.glu(x)


class DecomposeGlu(ExportPass):
"""
Decompose glu for quantization annotation to work properly.
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if node.target == torch.ops.aten.glu.default:
ep = torch.export.export(
Glu(dim=-1 if len(node.args) < 2 else node.args[1]),
(node.args[0].meta["val"],),
)
decomposed_module = ep.run_decompositions().graph_module

with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": node.args[0]}
merge_decomposed_graph(
remap=remap,
target_node=node,
target_graph=graph,
decomposed_graph_module=decomposed_module,
)
graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
29 changes: 7 additions & 22 deletions backends/qualcomm/_passes/decompose_linalg_vector_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from executorch.exir import to_edge
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_nn_module_stack
from .utils import merge_decomposed_graph


class LinalgVectorNorm(torch.nn.Module):
Expand Down Expand Up @@ -62,27 +62,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": node.args[0]}

for decomposed_node in decomposed_module.graph.nodes:
copy_nn_module_stack(node, decomposed_node)
# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

merge_decomposed_graph(
remap=remap,
target_node=node,
target_graph=graph,
decomposed_graph_module=decomposed_module,
)
graph.erase_node(node)

graph.eliminate_dead_code()
Expand Down
29 changes: 7 additions & 22 deletions backends/qualcomm/_passes/decompose_roll.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_nn_module_stack
from .utils import merge_decomposed_graph


class SliceCopy(torch.nn.Module):
Expand Down Expand Up @@ -65,27 +65,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": input_node}

for decomposed_node in decomposed_module.graph.nodes:
copy_nn_module_stack(node, decomposed_node)
# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

merge_decomposed_graph(
remap=remap,
target_node=node,
target_graph=graph,
decomposed_graph_module=decomposed_module,
)
graph.erase_node(node)

graph.eliminate_dead_code()
Expand Down
27 changes: 9 additions & 18 deletions backends/qualcomm/_passes/decompose_wrap_with_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_nn_module_stack
from .utils import merge_decomposed_graph


class DecomposeWrapWithAutocast(ExportPass):
Expand Down Expand Up @@ -52,7 +52,7 @@ def _replace(self, gm: torch.fx.GraphModule) -> None:
graph = gm.graph
for node in graph.nodes:
if isinstance(node.target, torch._higher_order_ops.wrap.WrapWithAutocast):
submod, submod_name = self._get_submod(gm, node)
submod, _ = self._get_submod(gm, node)
n_args = node.args
input_submod = n_args[4]
decomposed_module = submod
Expand All @@ -61,22 +61,13 @@ def _replace(self, gm: torch.fx.GraphModule) -> None:
# which ensures that reference to nodes are correctly updated in the new graph
# remap = {"expand_1": node.args[5], "to_4": node.args[6]}
remap = {n_args[i].name: n_args[i] for i in range(5, len(n_args))}

for decomposed_node in decomposed_module.graph.nodes:
copy_nn_module_stack(node, decomposed_node)
# no need to copy existent 'output'
if decomposed_node.op == "output":
self._replace_output(node, decomposed_node, remap)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

merge_decomposed_graph(
remap=remap,
target_node=node,
target_graph=graph,
decomposed_graph_module=decomposed_module,
output_processor=self._replace_output,
)
graph.erase_node(node)

graph.erase_node(input_submod)
Expand Down
Loading
Loading