Skip to content

Commit a528854

Browse files
committed
Arm backend: Support for identity for TOSA 1.0
Signed-off-by: Per Åstrand <[email protected]> Change-Id: I3f133d55feac211be6196f4d6811feff97ade98c
1 parent 2318ca2 commit a528854

File tree

1 file changed

+38
-4
lines changed

1 file changed

+38
-4
lines changed

backends/arm/operators/ops_identity.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,50 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import torch
1111
import torch.fx
1212

13-
import tosa_tools.v0_80.serializer.tosa_serializer as ts
14-
1513
from executorch.backends.arm.operators.node_visitor import (
1614
NodeVisitor,
1715
register_node_visitor,
1816
)
1917
from executorch.backends.arm.tosa_mapping import TosaArg
2018

2119

20+
def identity_operator_factory_v0_80(identity_target: str):
21+
"""
22+
Creates and registers NodeVisitors for operators that map directly
23+
to a TOSA IDENTITY op.
24+
"""
25+
26+
class IdentityOperatorVisitor(NodeVisitor):
27+
target = identity_target
28+
29+
tosa_specs = NodeVisitor.tosa_specs_0_80
30+
31+
def define_node(
32+
self,
33+
node: torch.fx.Node,
34+
tosa_graph: Any,
35+
inputs: List[TosaArg],
36+
output: TosaArg,
37+
) -> None:
38+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
39+
40+
# Simply add an identityOp
41+
tosa_graph.addOperator(
42+
ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]
43+
)
44+
45+
register_node_visitor(IdentityOperatorVisitor)
46+
47+
48+
identity_operator_factory_v0_80("getitem")
49+
identity_operator_factory_v0_80("aten.alias_copy.default")
50+
51+
2252
def identity_operator_factory(identity_target: str):
2353
"""
2454
Creates and registers NodeVisitors for operators that map directly
@@ -28,13 +58,17 @@ def identity_operator_factory(identity_target: str):
2858
class IdentityOperatorVisitor(NodeVisitor):
2959
target = identity_target
3060

61+
tosa_specs = NodeVisitor.tosa_specs_1_00
62+
3163
def define_node(
3264
self,
3365
node: torch.fx.Node,
34-
tosa_graph: ts.TosaSerializer,
66+
tosa_graph: Any,
3567
inputs: List[TosaArg],
3668
output: TosaArg,
3769
) -> None:
70+
import serializer.tosa_serializer as ts
71+
3872
# Simply add an identityOp
3973
tosa_graph.addOperator(
4074
ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]

0 commit comments

Comments
 (0)