Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2318ca2

Browse files
committedApr 23, 2025
Arm backend: Support for to/to_dim_order for TOSA 1.0
Signed-off-by: Per Åstrand <[email protected]> Change-Id: Ia67431491d5c287e8ee1fe6674ebf5fe3903b612
1 parent c95bde4 commit 2318ca2

File tree

3 files changed

+68
-13
lines changed

3 files changed

+68
-13
lines changed
 

‎backends/arm/operator_support/to_copy_support.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def is_node_tosa_supported(
7777
) -> bool:
7878
assert node.target in self.targets
7979

80-
assert tosa_spec.support_integer()
8180
supported_dtypes = (
8281
self.ALL_SUPPORTED_TYPES
8382
if tosa_spec.support_float()

‎backends/arm/operators/op_to_copy.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,44 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12-
import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore
13-
1411
from executorch.backends.arm.operators.node_visitor import (
1512
NodeVisitor,
1613
register_node_visitor,
1714
)
1815
from executorch.backends.arm.tosa_mapping import TosaArg
1916

2017

18+
@register_node_visitor
19+
class ToCopyVisitor_0_80(NodeVisitor):
20+
"""
21+
Implement the type cast functionality of _to_copy.
22+
23+
Other features like setting of the memory_format or moving a tensor to a
24+
different device are not supported.
25+
26+
Also note that the node should not be quantized.
27+
"""
28+
29+
target = "aten._to_copy.default"
30+
31+
tosa_specs = NodeVisitor.tosa_specs_0_80
32+
33+
def define_node(
34+
self,
35+
node: torch.fx.Node,
36+
tosa_graph: Any,
37+
inputs: List[TosaArg],
38+
output: TosaArg,
39+
) -> None:
40+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
41+
42+
tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name])
43+
44+
2145
@register_node_visitor
2246
class ToCopyVisitor(NodeVisitor):
2347
"""
@@ -31,11 +55,15 @@ class ToCopyVisitor(NodeVisitor):
3155

3256
target = "aten._to_copy.default"
3357

58+
tosa_specs = NodeVisitor.tosa_specs_1_00
59+
3460
def define_node(
3561
self,
3662
node: torch.fx.Node,
37-
tosa_graph: ts.TosaSerializer,
63+
tosa_graph: Any,
3864
inputs: List[TosaArg],
3965
output: TosaArg,
4066
) -> None:
41-
tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name])
67+
import serializer.tosa_serializer as ts # type: ignore
68+
69+
tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name])

‎backends/arm/operators/op_to_dim_order_copy.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,44 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12-
import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore
13-
1411
from executorch.backends.arm.operators.node_visitor import (
1512
NodeVisitor,
1613
register_node_visitor,
1714
)
1815
from executorch.backends.arm.tosa_mapping import TosaArg
1916

2017

18+
@register_node_visitor
19+
class ToDimOrderCopyVisitor_0_80(NodeVisitor):
20+
"""
21+
Implement the type cast functionality of _to_dim_order_copy.
22+
23+
Other features like setting of the dim_order or moving a tensor to a
24+
different device are not supported.
25+
26+
Also note that the node should not be quantized.
27+
"""
28+
29+
target = "dim_order_ops._to_dim_order_copy.default"
30+
31+
tosa_specs = NodeVisitor.tosa_specs_0_80
32+
33+
def define_node(
34+
self,
35+
node: torch.fx.Node,
36+
tosa_graph: Any,
37+
inputs: List[TosaArg],
38+
output: TosaArg,
39+
) -> None:
40+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
41+
42+
tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name])
43+
44+
2145
@register_node_visitor
2246
class ToDimOrderCopyVisitor(NodeVisitor):
2347
"""
@@ -31,11 +55,15 @@ class ToDimOrderCopyVisitor(NodeVisitor):
3155

3256
target = "dim_order_ops._to_dim_order_copy.default"
3357

58+
tosa_specs = NodeVisitor.tosa_specs_1_00
59+
3460
def define_node(
3561
self,
3662
node: torch.fx.Node,
37-
tosa_graph: ts.TosaSerializer,
63+
tosa_graph: Any,
3864
inputs: List[TosaArg],
3965
output: TosaArg,
4066
) -> None:
41-
tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name])
67+
import serializer.tosa_serializer as ts # type: ignore
68+
69+
tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name])

0 commit comments

Comments
 (0)
Please sign in to comment.