Skip to content

Commit 889448b

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
update SqueezeInt4LinearInputs to process relu/gelu inputs too (#8601)
Summary: Update/rename SqueezeInt4LinearInputs pass so it wraps gelu/relu with squeeze/unsqueeze view ops too Differential Revision: D69673068
1 parent 254eeca commit 889448b

File tree

5 files changed

+40
-11
lines changed

5 files changed

+40
-11
lines changed

backends/transforms/fuse_view_copy.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,24 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph:
4040
return graph
4141

4242

43+
def remove_noop_view_copy(graph: torch.fx.Graph) -> torch.fx.Graph:
44+
"""
45+
Remove view_copy nodes that are no-ops.
46+
"""
47+
ops = exir_ops.edge
48+
view_op = ops.aten.view_copy.default
49+
for node in graph.nodes:
50+
if node.op == "call_function" and node.target == view_op:
51+
input_shape = list(node.args[0].meta["val"].shape)
52+
target_shape = node.args[1]
53+
if input_shape == target_shape:
54+
node.replace_all_uses_with(node.args[0])
55+
graph.eliminate_dead_code()
56+
return graph
57+
58+
4359
class FuseViewCopyTransform(ExportPass):
4460
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4561
graph_module.graph = merge_view_copy_chains(graph_module.graph)
62+
graph_module.graph = remove_noop_view_copy(graph_module.graph)
4663
return PassResult(graph_module, True)

backends/vulkan/_passes/TARGETS

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,15 @@ runtime.python_library(
3131
)
3232

3333
runtime.python_library(
34-
name = "squeeze_int4_linear_inputs",
34+
name = "squeeze_unsqueeze_inputs",
3535
srcs = [
36-
"squeeze_int4_linear_inputs.py",
36+
"squeeze_unsqueeze_inputs.py",
3737
],
3838
visibility = [
3939
"//executorch/backends/...",
4040
],
4141
deps = [
42+
"//caffe2:torch",
4243
"//executorch/backends/vulkan:custom_ops_lib",
4344
"//executorch/exir:pass_base",
4445
"//executorch/exir/dialects:lib",
@@ -114,7 +115,7 @@ runtime.python_library(
114115
":remove_asserts",
115116
":remove_local_scalar_dense",
116117
":remove_redundant_ops",
117-
":squeeze_int4_linear_inputs",
118+
":squeeze_unsqueeze_inputs",
118119
":tag_memory_meta_pass",
119120
]
120121
)

backends/vulkan/_passes/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from executorch.backends.vulkan._passes.remove_redundant_ops import (
2121
RemoveRedundantOpsTransform,
2222
)
23-
from executorch.backends.vulkan._passes.squeeze_int4_linear_inputs import (
24-
SqueezeInt4LinearInputs,
23+
from executorch.backends.vulkan._passes.squeeze_unsqueeze_inputs import (
24+
SqueezeUnsqueezeInputs,
2525
)
2626
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
2727

@@ -32,6 +32,6 @@
3232
"RemoveAssertsTransform",
3333
"RemoveLocalScalarDenseOpsTransform",
3434
"RemoveRedundantOpsTransform",
35-
"SqueezeInt4LinearInputs",
35+
"SqueezeUnsqueezeInputs",
3636
"TagMemoryMetaPass",
3737
]

backends/vulkan/_passes/squeeze_int4_linear_inputs.py renamed to backends/vulkan/_passes/squeeze_unsqueeze_inputs.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,27 @@
66

77
# pyre-strict
88

9-
from typing import Dict, List, Tuple
9+
from typing import Dict, List, Set, Tuple, Union
1010

1111
import executorch.backends.vulkan.custom_ops_lib # noqa: needed to access vk op
1212
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1314
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
1415

16+
from torch._ops import OpOverload
17+
1518
from torch.fx.node import Argument
1619

20+
OpType = Union[str, OpOverload, EdgeOpOverload]
21+
22+
23+
class SqueezeUnsqueezeInputs(ExportPass):
24+
_squeezable_ops: Set[OpType] = {
25+
exir_ops.edge.et_vk.linear_weight_int4.default,
26+
exir_ops.edge.aten.relu.default,
27+
exir_ops.edge.aten.gelu.default,
28+
}
1729

18-
class SqueezeInt4LinearInputs(ExportPass):
1930
def call_operator(
2031
self,
2132
op, # pyre-ignore
@@ -26,7 +37,7 @@ def call_operator(
2637
def _squeezable(shape: List[int]) -> bool:
2738
return len(shape) > 2 and 1 in shape
2839

29-
if op != exir_ops.edge.et_vk.linear_weight_int4.default:
40+
if op not in self._squeezable_ops:
3041
return super().call_operator(op, args, kwargs, meta)
3142

3243
# pyre-ignore[16]: `None` has no attribute `node`

backends/vulkan/vulkan_preprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
insert_prepack_nodes,
2727
RemoveLocalScalarDenseOpsTransform,
2828
RemoveRedundantOpsTransform,
29-
SqueezeInt4LinearInputs,
29+
SqueezeUnsqueezeInputs,
3030
TagMemoryMetaPass,
3131
)
3232

@@ -153,7 +153,7 @@ def preprocess( # noqa: C901
153153
RemoveRedundantOpsTransform(),
154154
AddmmToLinearTransform(),
155155
FuseDequantLinearPass(),
156-
SqueezeInt4LinearInputs(),
156+
SqueezeUnsqueezeInputs(),
157157
FuseViewCopyTransform(),
158158
ViewCopyToSqueezeUnsqueezePass(),
159159
FuseBatchNormWithConvPass(program),

0 commit comments

Comments
 (0)