5
5
6
6
# pyre-unsafe
7
7
8
- from typing import List
8
+ from typing import Any , List
9
9
10
10
import torch
11
11
import torch .fx
12
12
13
- import tosa_tools .v0_80 .serializer .tosa_serializer as ts
14
-
15
13
from executorch .backends .arm .operators .node_visitor import (
16
14
NodeVisitor ,
17
15
register_node_visitor ,
18
16
)
19
17
from executorch .backends .arm .tosa_mapping import TosaArg
20
18
21
19
20
+ def identity_operator_factory_v0_80 (identity_target : str ):
21
+ """
22
+ Creates and registers NodeVisitors for operators that map directly
23
+ to a TOSA IDENTITY op.
24
+ """
25
+
26
+ class IdentityOperatorVisitor (NodeVisitor ):
27
+ target = identity_target
28
+
29
+ tosa_specs = NodeVisitor .tosa_specs_0_80
30
+
31
+ def define_node (
32
+ self ,
33
+ node : torch .fx .Node ,
34
+ tosa_graph : Any ,
35
+ inputs : List [TosaArg ],
36
+ output : TosaArg ,
37
+ ) -> None :
38
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts
39
+
40
+ # Simply add an identityOp
41
+ tosa_graph .addOperator (
42
+ ts .TosaOp .Op ().IDENTITY , [inputs [0 ].name ], [output .name ]
43
+ )
44
+
45
+ register_node_visitor (IdentityOperatorVisitor )
46
+
47
+
48
+ identity_operator_factory_v0_80 ("getitem" )
49
+ identity_operator_factory_v0_80 ("aten.alias_copy.default" )
50
+
51
+
22
52
def identity_operator_factory (identity_target : str ):
23
53
"""
24
54
Creates and registers NodeVisitors for operators that map directly
@@ -28,13 +58,17 @@ def identity_operator_factory(identity_target: str):
28
58
class IdentityOperatorVisitor (NodeVisitor ):
29
59
target = identity_target
30
60
61
+ tosa_specs = NodeVisitor .tosa_specs_1_00
62
+
31
63
def define_node (
32
64
self ,
33
65
node : torch .fx .Node ,
34
- tosa_graph : ts . TosaSerializer ,
66
+ tosa_graph : Any ,
35
67
inputs : List [TosaArg ],
36
68
output : TosaArg ,
37
69
) -> None :
70
+ import serializer .tosa_serializer as ts
71
+
38
72
# Simply add an identityOp
39
73
tosa_graph .addOperator (
40
74
ts .TosaOp .Op ().IDENTITY , [inputs [0 ].name ], [output .name ]
0 commit comments