diff --git a/.mypy.ini b/.mypy.ini index 8c1c9dbcadc..5ee07ddb2bf 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -80,6 +80,9 @@ ignore_missing_imports = True [mypy-serializer.*] ignore_missing_imports = True +[mypy-tosa_tools.*] +ignore_missing_imports = True + [mypy-setuptools.*] ignore_missing_imports = True diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index f2c7ce9f9ce..72fb58f582c 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -7,8 +7,9 @@ from typing import Dict, List -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.export import ExportedProgram diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index 886a96fd520..648edde04f4 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -9,15 +9,13 @@ import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification - -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -70,7 +68,7 @@ def define_node( # Do the INT32 Abs tosa_graph.addOperator( - TosaOp.Op().ABS, + ts.TosaOp.Op().ABS, [ rescaled_inputs[0].name, ], @@ -126,7 +124,7 @@ def define_node( # MI lowering tosa_graph.addOperator( - TosaOp.Op().ABS, + ts.TosaOp.Op().ABS, [inputs[0].name], [output.name], None, diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 1be4a218232..904a2405047 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -10,14 +10,13 @@ import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -82,7 +81,7 @@ def define_node( # Do the INT32 Add tosa_graph.addOperator( - TosaOp.Op().ADD, + ts.TosaOp.Op().ADD, [input1.name, input2.name], [add_output.name], None, @@ -135,7 +134,7 @@ def define_node( # MI lowering tosa_graph.addOperator( - TosaOp.Op().ADD, + ts.TosaOp.Op().ADD, [input1.name, input2.name], [output.name], None, diff --git a/backends/arm/operators/op_amax.py b/backends/arm/operators/op_amax.py index 7347648c454..059f6c1e553 100644 --- a/backends/arm/operators/op_amax.py +++ b/backends/arm/operators/op_amax.py @@ -4,14 +4,13 @@ # LICENSE file in the root directory of this source tree. from typing import List -import serializer.tosa_serializer as ts +import tosa_tools.v0_80.serializer.tosa_serializer as ts from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -48,5 +47,5 @@ def define_node( attr.AxisAttribute(input.dim_order.index(dim)) tosa_graph.addOperator( - TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr + ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr ) diff --git a/backends/arm/operators/op_amin.py b/backends/arm/operators/op_amin.py index 37625cfcc52..85e43b76c4c 100644 --- a/backends/arm/operators/op_amin.py +++ b/backends/arm/operators/op_amin.py @@ -4,14 +4,13 @@ # LICENSE file in the root directory of this source tree. from typing import List -import serializer.tosa_serializer as ts +import tosa_tools.v0_80.serializer.tosa_serializer as ts from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -48,5 +47,5 @@ def define_node( attr.AxisAttribute(input.dim_order.index(dim)) tosa_graph.addOperator( - TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr + ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr ) diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index ffb2e8a3c5d..b65ebb2ac5d 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -6,14 +6,13 @@ # pyre-unsafe from typing import cast, List -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( # type: ignore NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -49,5 +48,5 @@ def define_node( attr.AxisAttribute(inputs[0].dim_order.index(dim)) tosa_graph.addOperator( - TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr + ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 772f8353565..bdd3425fda5 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -6,9 +6,10 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index af02fc30dd8..6dc0ec8002d 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -7,9 +7,10 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -20,7 +21,6 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import build_rescale -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -64,7 +64,7 @@ def define_node( attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp) tosa_graph.addOperator( - TosaOp.Op().MATMUL, + ts.TosaOp.Op().MATMUL, [inputs[0].name, inputs[1].name], [bmm_output_name], attr, diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index f786395cc39..6b1710301b1 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -7,13 +7,12 @@ from typing import List -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -42,5 +41,8 @@ def define_node( attr.AxisAttribute(dim) tosa_graph.addOperator( - TosaOp.Op().CONCAT, [tensor.name for tensor in tensors], [output.name], attr + ts.TosaOp.Op().CONCAT, + [tensor.name for tensor in tensors], + [output.name], + attr, ) diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index 7c4ad8682fa..67fff8b8a60 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -8,9 +8,9 @@ from typing import Any, List, Tuple -import serializer.tosa_serializer as ts # type: ignore - import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -18,7 +18,6 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -51,7 +50,7 @@ def _create_clamp_node( min_fp32, max_fp32, ) - tosa_graph.addOperator(TosaOp.Op().CLAMP, [input_name], [output_name], attr) + tosa_graph.addOperator(ts.TosaOp.Op().CLAMP, [input_name], [output_name], attr) def _get_min_max_arguments( self, node: Node, dtype_min: int | float, dtype_max: int | float diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 73f6d2751c5..b2c31df96ab 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -7,9 +7,10 @@ from typing import List -import serializer.tosa_serializer as ts import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, ) @@ -18,7 +19,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -71,4 +71,6 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.PadAttribute(tosa_graph.builder, output_pad, pad_const_qs, pad_const_fp) - tosa_graph.addOperator(TosaOp.Op().PAD, [inputs[0].name], [output.name], attr) + tosa_graph.addOperator( + ts.TosaOp.Op().PAD, [inputs[0].name], [output.name], attr + ) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 2fe00b6758f..90475af1476 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -6,9 +6,10 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index de49b267641..7f87fb5a81d 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -9,13 +9,12 @@ import executorch.backends.arm.tosa_quant_utils as tqutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -53,7 +52,7 @@ def define_node( # Do the equal comparison tosa_graph.addOperator( - TosaOp.Op().EQUAL, + ts.TosaOp.Op().EQUAL, [input_nodes[0].name, input_nodes[1].name], output.name, None, diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index d0dc2af572f..23850c0241d 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -5,15 +5,15 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch.fx + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -41,4 +41,4 @@ def define_node( if not (inputs[0].dtype == ts.DType.FP32): raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}") # MI lowering - tosa_graph.addOperator(TosaOp.Op().ERF, [inputs[0].name], [output.name]) + tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index d9f0e4f5197..ca067b3b8be 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -6,15 +6,13 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification - -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -46,4 +44,4 @@ def define_node( f"{inputs[0].dtype} and output dtype: {output.dtype}" ) - tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name]) + tosa_graph.addOperator(ts.TosaOp.Op().EXP, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index f06b9873e63..e68610a37a0 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -8,7 +8,7 @@ import numpy as np -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index d18156e1a50..b2193a2e7ed 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -9,13 +9,12 @@ import executorch.backends.arm.tosa_quant_utils as tqutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -52,7 +51,7 @@ def define_node( input_nodes = rescaled_inputs tosa_graph.addOperator( - TosaOp.Op().GREATER_EQUAL, + ts.TosaOp.Op().GREATER_EQUAL, [input_nodes[0].name, input_nodes[1].name], [output.name], None, diff --git a/backends/arm/operators/op_get_item.py b/backends/arm/operators/op_get_item.py index 577a8c8d2ea..0e1192b3bef 100644 --- a/backends/arm/operators/op_get_item.py +++ b/backends/arm/operators/op_get_item.py @@ -6,14 +6,14 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -32,4 +32,4 @@ def define_node( ) -> None: item_name = inputs[0].name ## Simply add an identityOp - tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name]) + tosa_graph.addOperator(ts.TosaOp.Op().IDENTITY, [item_name], [output.name]) diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 25ff2463c5c..06f29e4505c 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -9,13 +9,12 @@ import executorch.backends.arm.tosa_quant_utils as tqutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -52,7 +51,7 @@ def define_node( input_nodes = rescaled_inputs tosa_graph.addOperator( - TosaOp.Op().GREATER, + ts.TosaOp.Op().GREATER, [input_nodes[0].name, input_nodes[1].name], [output.name], None, diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index 8ed5539034b..fadf4848359 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -9,13 +9,12 @@ import executorch.backends.arm.tosa_quant_utils as tqutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -52,7 +51,7 @@ def define_node( input_nodes = rescaled_inputs tosa_graph.addOperator( - TosaOp.Op().GREATER_EQUAL, + ts.TosaOp.Op().GREATER_EQUAL, [input_nodes[1].name, input_nodes[0].name], [output.name], None, diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index d8a136e37f8..34911075065 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -6,15 +6,13 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification - -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -38,4 +36,4 @@ def define_node( assert len(node.all_input_nodes) == 1 assert inputs[0].dtype == output.dtype == ts.DType.FP32 - tosa_graph.addOperator(TosaOp.Op().LOG, [inputs[0].name], [output.name]) + tosa_graph.addOperator(ts.TosaOp.Op().LOG, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index 3bb71f611a5..a261cd2db9f 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -9,13 +9,12 @@ import executorch.backends.arm.tosa_quant_utils as tqutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -52,7 +51,7 @@ def define_node( input_nodes = rescaled_inputs tosa_graph.addOperator( - TosaOp.Op().GREATER, + ts.TosaOp.Op().GREATER, [input_nodes[1].name, input_nodes[0].name], [output.name], None, diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 9dd627a3e4f..fcf2636977d 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -6,9 +6,10 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -18,7 +19,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -75,7 +75,7 @@ def define_node( ) tosa_graph.addOperator( - TosaOp.Op().MAX_POOL2D, + ts.TosaOp.Op().MAX_POOL2D, [input_tensor.name], [output.name], attr, diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index 4eb7e47fac8..fdee4e61855 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -8,7 +8,7 @@ from typing import List import executorch.backends.arm.tosa_quant_utils as tqutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, @@ -19,8 +19,6 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import tosa_shape - -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -62,7 +60,7 @@ def define_node( operand_inputs = inputs tosa_graph.addOperator( - TosaOp.Op().MAXIMUM, + ts.TosaOp.Op().MAXIMUM, [ operand_inputs[0].name, operand_inputs[1].name, diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index 1b8c1960411..dd3afe90cb5 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -9,7 +9,7 @@ import executorch.backends.arm.tosa_quant_utils as tqutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, @@ -20,8 +20,6 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import tosa_shape - -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -63,7 +61,7 @@ def define_node( operand_inputs = inputs tosa_graph.addOperator( - TosaOp.Op().MINIMUM, + ts.TosaOp.Op().MINIMUM, [ operand_inputs[0].name, operand_inputs[1].name, diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 2f6c7e7130c..dcceb36b0ab 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -9,10 +9,10 @@ import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils - -import serializer.tosa_serializer as ts # type: ignore import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, ) @@ -24,7 +24,6 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import reshape_for_broadcast -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -87,7 +86,7 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.MulAttribute(shift=0) tosa_graph.addOperator( - TosaOp.Op().MUL, + ts.TosaOp.Op().MUL, [input1.name, input2.name], [mul_output.name], attr, @@ -119,5 +118,5 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.MulAttribute(shift=0) tosa_graph.addOperator( - TosaOp.Op().MUL, [input1.name, input2.name], [output.name], attr + ts.TosaOp.Op().MUL, [input1.name, input2.name], [output.name], attr ) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index e659918baf2..c92a008a281 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -7,14 +7,14 @@ from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor: @@ -117,5 +117,5 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.TransposeAttribute(permutation_vector) tosa_graph.addOperator( - TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr + ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index 0f251a8aa6d..d3b92feff12 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -7,14 +7,13 @@ from typing import List -import serializer.tosa_serializer as ts +import tosa_tools.v0_80.serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -47,7 +46,7 @@ def define_node( ) tosa_graph.addOperator( - TosaOp.Op().POW, + ts.TosaOp.Op().POW, [ inputs[0].name, inputs[1].name, diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 5410e1dd99a..11d2cbc2cc1 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -6,15 +6,15 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -35,4 +35,6 @@ def define_node( output: TosaArg, ) -> None: assert inputs[0].dtype == output.dtype == ts.DType.FP32 - tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]) + tosa_graph.addOperator( + ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index b97d7023ef0..142ccb1d25a 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -5,15 +5,14 @@ # pyre-unsafe -import serializer.tosa_serializer as ts # type: ignore import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import tosa_shape -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -35,4 +34,6 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.TileAttribute(tosa_shape(multiples, output.dim_order)) - tosa_graph.addOperator(TosaOp.Op().TILE, [inputs[0].name], [output.name], attr) + tosa_graph.addOperator( + ts.TosaOp.Op().TILE, [inputs[0].name], [output.name], attr + ) diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py index 098fbeccce1..c59015dcc14 100644 --- a/backends/arm/operators/op_rescale.py +++ b/backends/arm/operators/op_rescale.py @@ -8,10 +8,10 @@ from typing import cast, List import executorch.backends.arm.tosa_quant_utils as tosa_quant_utils -import serializer.tosa_serializer as ts # type: ignore import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore -import tosa.Op as TosaOp # type: ignore +import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 8ea0343faaa..125f5493a29 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -7,15 +7,15 @@ from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import Tosa_0_80 -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -39,7 +39,7 @@ def define_node( attr.ArithmeticRightShiftAttribute(round=round) tosa_graph.addOperator( - TosaOp.Op().ARITHMETIC_RIGHT_SHIFT, + ts.TosaOp.Op().ARITHMETIC_RIGHT_SHIFT, [inputs[0].name, inputs[1].name], [output.name], attr, diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 0fbb203b081..52bcc937c96 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -6,15 +6,15 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -35,4 +35,4 @@ def define_node( output: TosaArg, ) -> None: assert inputs[0].dtype == output.dtype == ts.DType.FP32 - tosa_graph.addOperator(TosaOp.Op().RSQRT, [inputs[0].name], [output.name]) + tosa_graph.addOperator(ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index abf60bf747f..9a002036fee 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -6,15 +6,13 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification - -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -46,4 +44,4 @@ def define_node( f"{inputs[0].dtype} and output_dtype: {output.dtype}" ) - tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name]) + tosa_graph.addOperator(ts.TosaOp.Op().SIGMOID, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index a3ce80c5b24..a9e1a6cdbe2 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -7,13 +7,12 @@ from typing import List -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -58,5 +57,5 @@ def define_node( attr.SliceAttribute(start_attr, size_attr) tosa_graph.addOperator( - TosaOp.Op().SLICE, [input_node.name], [output.name], attr + ts.TosaOp.Op().SLICE, [input_node.name], [output.name], attr ) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 6cd422095ab..1f701f29b1e 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -10,14 +10,13 @@ import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -63,7 +62,7 @@ def define_node( # Do the INT32 Sub tosa_graph.addOperator( - TosaOp.Op().SUB, + ts.TosaOp.Op().SUB, [ rescaled_inputs[0].name, rescaled_inputs[1].name, @@ -110,7 +109,7 @@ def define_node( # MI lowering tosa_graph.addOperator( - TosaOp.Op().SUB, + ts.TosaOp.Op().SUB, [inputs[0].name, inputs[1].name], [output.name], None, diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index b5b388b3352..135566e48ac 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -10,14 +10,13 @@ import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -69,7 +68,7 @@ def define_node( ) tosa_graph.addOperator( - TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr + ts.TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr ) prev_node = next_node @@ -120,7 +119,7 @@ def define_node( ).name tosa_graph.addOperator( - TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr + ts.TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr ) input_name = output_name diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py index 40214b265f0..6a2053bea0d 100644 --- a/backends/arm/operators/op_table.py +++ b/backends/arm/operators/op_table.py @@ -8,15 +8,14 @@ from typing import List import numpy as np - -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -49,5 +48,5 @@ def define_node( table_attr.TableAttribute(np.array(table)) tosa_graph.addOperator( - TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr + ts.TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr ) diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 89dd15c97d6..51cf1ee786b 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -6,14 +6,13 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -44,4 +43,4 @@ def define_node( f"{inputs[0].dtype} and output_dtype: {output.dtype}" ) - tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name]) + tosa_graph.addOperator(ts.TosaOp.Op().TANH, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py index feaec3a41e9..90485b71d50 100644 --- a/backends/arm/operators/op_to_copy.py +++ b/backends/arm/operators/op_to_copy.py @@ -6,9 +6,10 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch -import tosa.Op as TosaOp # type: ignore + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, diff --git a/backends/arm/operators/op_to_dim_order_copy.py b/backends/arm/operators/op_to_dim_order_copy.py index 397979a439d..f144beba29f 100644 --- a/backends/arm/operators/op_to_dim_order_copy.py +++ b/backends/arm/operators/op_to_dim_order_copy.py @@ -6,9 +6,10 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch -import tosa.Op as TosaOp # type: ignore + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_transpose.py index 54a79297dd6..b909aef2ac9 100644 --- a/backends/arm/operators/op_transpose.py +++ b/backends/arm/operators/op_transpose.py @@ -7,14 +7,14 @@ from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp @register_node_visitor @@ -39,5 +39,5 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.TransposeAttribute(perms) tosa_graph.addOperator( - TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr + ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py index 38e4087d38d..23d24b78339 100644 --- a/backends/arm/operators/op_upsample_nearest2d.py +++ b/backends/arm/operators/op_upsample_nearest2d.py @@ -6,17 +6,17 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape -from serializer.tosa_serializer import TosaOp -from tosa.ResizeMode import ResizeMode # type: ignore +from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore @register_node_visitor @@ -65,5 +65,5 @@ def in_int16_range(x): ) tosa_graph.addOperator( - TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr + ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index 119e32fa58f..e063b8e39ec 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -6,9 +6,10 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch -import tosa.Op as TosaOp # type: ignore + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index c8b35e831d4..c45d787ef38 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -5,7 +5,7 @@ from typing import List, Sequence -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -13,7 +13,6 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -40,7 +39,7 @@ def _add_node_to_tosa_graph( ) tosa_graph.addOperator( - TosaOp.Op().SELECT, + ts.TosaOp.Op().SELECT, [inputs[0].name, inputs[1].name, inputs[2].name], [output.name], None, diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 307710e38e9..a17da41f767 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -7,16 +7,16 @@ from typing import List -import serializer.tosa_serializer as ts import torch import torch.fx +import tosa_tools.v0_80.serializer.tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp def binary_operator_factory(bw_target: str, tosa_op): @@ -46,12 +46,12 @@ def define_node( register_node_visitor(BinaryOperator) -binary_operator_factory("aten.bitwise_and.Tensor", TosaOp.Op().BITWISE_AND) -binary_operator_factory("aten.bitwise_xor.Tensor", TosaOp.Op().BITWISE_XOR) -binary_operator_factory("aten.bitwise_or.Tensor", TosaOp.Op().BITWISE_OR) -binary_operator_factory("aten.logical_and.default", TosaOp.Op().LOGICAL_AND) -binary_operator_factory("aten.logical_xor.default", TosaOp.Op().LOGICAL_XOR) -binary_operator_factory("aten.logical_or.default", TosaOp.Op().LOGICAL_OR) +binary_operator_factory("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND) +binary_operator_factory("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR) +binary_operator_factory("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR) +binary_operator_factory("aten.logical_and.default", ts.TosaOp.Op().LOGICAL_AND) +binary_operator_factory("aten.logical_xor.default", ts.TosaOp.Op().LOGICAL_XOR) +binary_operator_factory("aten.logical_or.default", ts.TosaOp.Op().LOGICAL_OR) binary_operator_factory( - "aten.bitwise_left_shift.Tensor", TosaOp.Op().LOGICAL_LEFT_SHIFT + "aten.bitwise_left_shift.Tensor", ts.TosaOp.Op().LOGICAL_LEFT_SHIFT ) diff --git a/backends/arm/operators/ops_unary.py b/backends/arm/operators/ops_unary.py index 0a7d45ffe98..3f713e086e6 100644 --- a/backends/arm/operators/ops_unary.py +++ b/backends/arm/operators/ops_unary.py @@ -6,15 +6,15 @@ # pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore import torch.fx + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp def unary_operator_factory(unary_target: str, tosa_op): @@ -53,6 +53,6 @@ def define_node( register_node_visitor(UnaryOperator) -unary_operator_factory("aten.ceil.default", TosaOp.Op().CEIL) -unary_operator_factory("aten.floor.default", TosaOp.Op().FLOOR) -unary_operator_factory("aten.logical_not.default", TosaOp.Op().LOGICAL_NOT) +unary_operator_factory("aten.ceil.default", ts.TosaOp.Op().CEIL) +unary_operator_factory("aten.floor.default", ts.TosaOp.Op().FLOOR) +unary_operator_factory("aten.logical_not.default", ts.TosaOp.Op().LOGICAL_NOT) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index f9b77e28493..07bbfaa37b9 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -8,9 +8,9 @@ from typing import cast, Dict import numpy as np -import serializer.tosa_serializer as ts # type: ignore import torch import torch.fx +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification diff --git a/backends/arm/scripts/install_reference_model.sh b/backends/arm/scripts/install_reference_model.sh new file mode 100755 index 00000000000..796a1ed418e --- /dev/null +++ b/backends/arm/scripts/install_reference_model.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -euo pipefail + +# Installation script to manage transition to 1.0 + +# TOSA reference model +tosa_reference_model_url="https://git.gitlab.arm.com/tosa/tosa-reference-model.git" +tosa_reference_model_0_80_branch="v0.80" +tosa_reference_model_0_80_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a" +tosa_serialization_lib_0_80_rev="v0.80.1" +tosa_reference_model_1_0_rev="v1.0" + +script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) + +source ${script_dir}/utils.sh + + +function setup_tosa_reference_model() { + local work_dir="$1" + + if [[ -z "$work_dir" ]]; then + echo "Error: work_dir parameter is required." + return 1 + fi + + mkdir -p "$work_dir" + pushd "$work_dir" || exit 1 + + # Install a patched version of TOSA reference model v0.80.1 to make it co-exist with 1.0 during the transition period + if [[ ! -d "reference_model" ]]; then + git clone --recurse-submodules --branch ${tosa_reference_model_0_80_branch} "$tosa_reference_model_url" reference_model + fi + + patches_dir=${script_dir}/../third-party/reference_model/patches/v0.80 + patch_repo reference_model ${tosa_reference_model_0_80_rev} ${patches_dir} + patch_repo reference_model/thirdparty/serialization_lib ${tosa_serialization_lib_0_80_rev} ${patches_dir} + + pushd reference_model + rm -rf build + # reference_model flatbuffers version clashes with Vela. + # go with Vela's since it newer. + # Vela's flatbuffer requirement is expected to loosen, then remove this. MLETORCH-565 + CMAKE_POLICY_VERSION_MINIMUM=3.5 pip install . --no-dependencies flatbuffers + popd +} + +setup_tosa_reference_model $1 diff --git a/backends/arm/scripts/utils.sh b/backends/arm/scripts/utils.sh index e3ed04ffa22..8b4c8d4f96f 100644 --- a/backends/arm/scripts/utils.sh +++ b/backends/arm/scripts/utils.sh @@ -46,7 +46,7 @@ function patch_repo() { local patch_dir="${3}/$name" echo -e "[${FUNCNAME[0]}] Patching ${name} repo_dir:${repo_dir} base_rev:${base_rev} patch_dir:${patch_dir}" - cd $repo_dir + pushd $repo_dir git fetch git reset --hard ${base_rev} @@ -54,4 +54,5 @@ function patch_repo() { git am -3 ${patch_dir}/*.patch echo -e "[${FUNCNAME[0]}] Patched ${name} @ $(git describe --all --long 2> /dev/null) in ${repo_dir} dir.\n" + popd } diff --git a/backends/arm/test/conftest.py b/backends/arm/test/conftest.py index e5d7783fea3..12220acbae9 100644 --- a/backends/arm/test/conftest.py +++ b/backends/arm/test/conftest.py @@ -15,7 +15,7 @@ import pytest try: - import tosa_reference_model + import tosa_tools.v0_80.tosa_reference_model as tosa_reference_model except ImportError: logging.warning("tosa_reference_model not found, can't run reference model tests") tosa_reference_model = None diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 28bbee052f9..7fc4db2273a 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) try: - import tosa_reference_model + import tosa_tools.v0_80.tosa_reference_model as tosa_reference_model except ImportError: tosa_reference_model = None from executorch.backends.arm.arm_backend import get_tosa_spec, is_tosa @@ -34,7 +34,7 @@ from torch.fx.node import Node from torch.overrides import TorchFunctionMode -from tosa import TosaGraph +from tosa_tools.v0_80.tosa import TosaGraph logger = logging.getLogger(__name__) logger.setLevel(logging.CRITICAL) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 7b74603cfb2..6346a53edef 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -14,10 +14,10 @@ import executorch.backends.xnnpack.test.tester.tester as tester -import serializer.tosa_serializer as ts # type: ignore[import-untyped] - import torch.fx import torch.utils._pytree as pytree + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore[import-untyped] from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager from executorch.backends.arm.arm_backend import ( diff --git a/backends/arm/third-party/reference_model/patches/v0.80/reference_model/0001-Move-tosa-tools-to-be-namespaced-into-tosa-tools.v0_.patch b/backends/arm/third-party/reference_model/patches/v0.80/reference_model/0001-Move-tosa-tools-to-be-namespaced-into-tosa-tools.v0_.patch new file mode 100644 index 00000000000..512c105bda2 --- /dev/null +++ b/backends/arm/third-party/reference_model/patches/v0.80/reference_model/0001-Move-tosa-tools-to-be-namespaced-into-tosa-tools.v0_.patch @@ -0,0 +1,154 @@ +From 20c2059723d5c6952cecfb7fcde92601639ef825 Mon Sep 17 00:00:00 2001 +From: =?UTF-8?q?Per=20=C3=85strand?= <per.astrand@arm.com> +Date: Wed, 5 Feb 2025 12:31:47 +0100 +Subject: [PATCH 1/2] Move tosa-tools to be namespaced into tosa-tools.v0_80 + +--- + CMakeLists.txt | 4 ++- + pyproject.toml | 3 ++- + setup.cfg | 70 +++++++++++++++++++++++++------------------------- + setup.py | 3 ++- + 4 files changed, 42 insertions(+), 38 deletions(-) + +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 68e8d8a..34becd0 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -1,4 +1,6 @@ +-cmake_minimum_required (VERSION 3.4) ++cmake_minimum_required (VERSION 3.19) ++ ++cmake_policy(SET CMP0077 NEW) + + set(CMAKE_INSTALL_PREFIX ".") + project(tosa_tools LANGUAGES CXX) +diff --git a/pyproject.toml b/pyproject.toml +index 7565f93..60448e7 100644 +--- a/pyproject.toml ++++ b/pyproject.toml +@@ -6,7 +6,8 @@ requires = [ + "setuptools>=42", + "wheel", + "setuptools_scm[toml]>=6.0", +- "cmake" ++ "cmake", ++ "ninja", + ] + build-backend = "setuptools.build_meta" + +diff --git a/setup.cfg b/setup.cfg +index 82ec9b8..c1bd1a8 100644 +--- a/setup.cfg ++++ b/setup.cfg +@@ -2,7 +2,7 @@ + # SPDX-License-Identifier: Apache-2.0 + + [metadata] +-name = tosa-tools ++name = tosa-tools-v0.80 + # version = done by setuptools_scm in pyproject.toml + author = Arm Limited + #author_email = +@@ -25,44 +25,44 @@ install_requires = + python_requires = >=3.6 + include_package_data = True + packages = +- runner +- generator +- checker +- frameworks +- tests +- conformance +- xunit +- json2fbbin +- json2numpy +- schemavalidation +- convert2conformance +- tosa +- serializer +- tosa_reference_model ++ tosa_tools.v0_80.verif.runner ++ tosa_tools.v0_80.verif.generator ++ tosa_tools.v0_80.verif.checker ++ tosa_tools.v0_80.verif.frameworks ++ tosa_tools.v0_80.verif.tests ++ tosa_tools.v0_80.verif.conformance ++ tosa_tools.v0_80.xunit ++ tosa_tools.v0_80.json2fbbin ++ tosa_tools.v0_80.json2numpy ++ tosa_tools.v0_80.schemavalidation ++ tosa_tools.v0_80.convert2conformance ++ tosa_tools.v0_80.tosa ++ tosa_tools.v0_80.serializer ++ tosa_tools.v0_80.tosa_reference_model + package_dir = +- = verif +- xunit = scripts/xunit +- json2fbbin = scripts/json2fbbin +- json2numpy = scripts/json2numpy +- convert2conformance = scripts/convert2conformance +- tosa = thirdparty/serialization_lib/python/tosa +- serializer = thirdparty/serialization_lib/python/serializer +- tosa_reference_model = py_package +- schemavalidation = scripts/schemavalidation ++ tosa_tools.v0_80.verif = verif ++ tosa_tools.v0_80.xunit = scripts/xunit ++ tosa_tools.v0_80.json2fbbin = scripts/json2fbbin ++ tosa_tools.v0_80.json2numpy = scripts/json2numpy ++ tosa_tools.v0_80.convert2conformance = scripts/convert2conformance ++ tosa_tools.v0_80.tosa = thirdparty/serialization_lib/python/tosa ++ tosa_tools.v0_80.serializer = thirdparty/serialization_lib/python/serializer ++ tosa_tools.v0_80.tosa_reference_model = py_package ++ tosa_tools.v0_80.schemavalidation = scripts/schemavalidation + + [options.entry_points] + console_scripts = +- tosa_verif_run_ref = runner.tosa_verif_run_tests:main +- tosa_verif_run_tests = runner.tosa_verif_run_tests:main +- tosa_verif_build_tests = generator.tosa_verif_build_tests:main +- tosa_json2numpy = json2numpy.json2numpy:main +- tosa_json2fbbin = json2fbbin.json2fbbin:main +- tosa_verif_result_check = checker.tosa_result_checker:main +- tosa_convert2conformance = convert2conformance.convert2conformance:main +- tosa_verif_framework_generator = frameworks.tosa_verif_framework_generator:main +- tosa_verif_framework_compiler_runner = frameworks.tosa_verif_framework_compiler_runner:main +- tosa_verif_conformance_generator = conformance.tosa_verif_conformance_generator:main +- tosa_schemavalidation = schemavalidation.schemavalidation:main ++ tosa_verif_run_ref = tosa_tools.v0_80.verif.runner.tosa_verif_run_tests:main ++ tosa_verif_run_tests = tosa_tools.v0_80.verif.runner.tosa_verif_run_tests:main ++ tosa_verif_build_tests = tosa_tools.v0_80.verif.generator.tosa_verif_build_tests:main ++ tosa_json2numpy = tosa_tools.v0_80.verif.json2numpy.json2numpy:main ++ tosa_json2fbbin = tosa_tools.v0_80.verif.json2fbbin.json2fbbin:main ++ tosa_verif_result_check = tosa_tools.v0_80.verif.checker.tosa_result_checker:main ++ tosa_convert2conformance = tosa_tools.v0_80.verif.convert2conformance.convert2conformance:main ++ tosa_verif_framework_generator = tosa_tools.v0_80.verif.frameworks.tosa_verif_framework_generator:main ++ tosa_verif_framework_compiler_runner = tosa_tools.v0_80.verif.frameworks.tosa_verif_framework_compiler_runner:main ++ tosa_verif_conformance_generator = tosa_tools.v0_80.verif.conformance.tosa_verif_conformance_generator:main ++ tosa_schemavalidation = tosa_tools.v0_80.verif.schemavalidation.schemavalidation:main + + [options.package_data] + schemavalidation= +diff --git a/setup.py b/setup.py +index 8c6b4cd..95896ad 100644 +--- a/setup.py ++++ b/setup.py +@@ -20,7 +20,7 @@ class CMakeBuild(build_py): + root_dir = Path(__file__).parent + build_dir = root_dir / "build" + build_dir.mkdir(exist_ok=True) +- package_dir = root_dir / "py_package" ++ package_dir = root_dir / "build/lib/tosa_tools/v0_80/tosa_reference_model/" + + cmake_cmd = [ + "cmake", +@@ -90,6 +90,7 @@ class CMakeBuild(build_py): + # Python will know which one to import + copied_so = False + so_dir = build_dir / "reference_model" ++ package_dir.mkdir(parents=True, exist_ok=True) + print(f"copying .so files from '{so_dir}' to '{package_dir}'") + for so_file in so_dir.glob("tosa_reference_model.*.so"): + shutil.copy(so_file, package_dir) +-- +2.39.5 (Apple Git-154) + diff --git a/backends/arm/third-party/reference_model/patches/v0.80/serialization_lib/0001-Make-TOSA-serializer-lib-to-be-self-contained.patch b/backends/arm/third-party/reference_model/patches/v0.80/serialization_lib/0001-Make-TOSA-serializer-lib-to-be-self-contained.patch new file mode 100644 index 00000000000..cc9cbc4edad --- /dev/null +++ b/backends/arm/third-party/reference_model/patches/v0.80/serialization_lib/0001-Make-TOSA-serializer-lib-to-be-self-contained.patch @@ -0,0 +1,283 @@ +From b3c8c3f779a7e051826f317598fb831fa9cfe923 Mon Sep 17 00:00:00 2001 +From: =?UTF-8?q?Per=20=C3=85strand?= <per.astrand@arm.com> +Date: Wed, 5 Feb 2025 12:30:09 +0100 +Subject: [PATCH] Make TOSA serializer lib to be self contained + +--- + CMakeLists.txt | 4 ++ + python/serializer/tosa_serializer.py | 57 ++++++++++++++-------------- + 2 files changed, 32 insertions(+), 29 deletions(-) + +diff --git a/CMakeLists.txt b/CMakeLists.txt +index ac34b75..5e191aa 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -19,6 +19,8 @@ + cmake_minimum_required(VERSION 3.13.4) + project(TosaSerialization) + ++cmake_policy(SET CMP0077 NEW) ++ + set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ standard to conform to") + set(CMAKE_CXX_STANDARD_REQUIRED YES) + +@@ -27,6 +29,8 @@ set(CMAKE_VERBOSE_MAKEFILE ON) + option(BUILD_TESTS "Build test applications" ON) + option(FLATBUFFERS_ROOT "Location where the flatbuffers 'include' and 'lib' folders to be found" Off) + ++message(STATUS "FLATBUFFERS_ROOT set to: ${FLATBUFFERS_ROOT}") ++ + include_directories(${PROJECT_SOURCE_DIR}/third_party/half/include) + + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py +index 7bc75f0..d191997 100644 +--- a/python/serializer/tosa_serializer.py ++++ b/python/serializer/tosa_serializer.py +@@ -14,12 +14,11 @@ + + import os + import struct +-import serializer.tosa_serializer as ts + import json + import flatbuffers + import numpy as np + from enum import IntEnum, unique +-from tosa import ( ++from ..tosa import ( + TosaGraph, + TosaRegion, + TosaBasicBlock, +@@ -27,8 +26,8 @@ from tosa import ( + TosaOperator, + Version, + ) +-import tosa.DType as TosaDType +-import tosa.Op as TosaOp ++from ..tosa import DType as TosaDType ++from ..tosa import Op as TosaOp + + # Keep version number in sync with the version default value with schema/tosa.fbs + TOSA_VERSION_MAJOR = 0 +@@ -159,7 +158,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + output_zp, + accum_dtype, + ): +- from tosa import PoolAttribute as a, Attribute ++ from ..tosa import PoolAttribute as a, Attribute + + self.utype = Attribute.Attribute().PoolAttribute + +@@ -172,7 +171,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.ints.append((a.AddAccumDtype, accum_dtype)) + + def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, local_bound): +- from tosa import ConvAttribute as a, Attribute ++ from ..tosa import ConvAttribute as a, Attribute + + self.utype = Attribute.Attribute().ConvAttribute + self.optFcns = (a.Start, a.End) +@@ -187,7 +186,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + def TransposeConvAttribute( + self, outpad, stride, output_shape, input_zp, weight_zp, local_bound + ): +- from tosa import TransposeConvAttribute as a, Attribute ++ from ..tosa import TransposeConvAttribute as a, Attribute + + self.utype = Attribute.Attribute().TransposeConvAttribute + self.optFcns = (a.Start, a.End) +@@ -200,7 +199,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.bools.append((a.AddLocalBound, local_bound)) + + def PadAttribute(self, serializer_builder, padding, pad_const_int, pad_const_fp): +- from tosa import PadAttribute as a, Attribute ++ from ..tosa import PadAttribute as a, Attribute + + self.utype = Attribute.Attribute().PadAttribute + self.optFcns = (a.Start, a.End) +@@ -210,14 +209,14 @@ class TosaSerializerAttribute(TosaSerializerUnion): + + # pad_const_fp attribute serialized as uint8 vector + pad_const_float_as_bytes = struct.pack("<f", pad_const_fp) +- serialized_pad_const_fp = ts.TosaSerializer.serializeUint8Vec( ++ serialized_pad_const_fp = TosaSerializer.serializeUint8Vec( + serializer_builder, pad_const_float_as_bytes + ) + + self.floats.append((a.AddPadConstFp, serialized_pad_const_fp)) + + def AxisAttribute(self, axis): +- from tosa import AxisAttribute as a, Attribute ++ from ..tosa import AxisAttribute as a, Attribute + + self.utype = Attribute.Attribute().AxisAttribute + self.optFcns = (a.Start, a.End) +@@ -225,7 +224,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.ints.append((a.AddAxis, axis)) + + def ReshapeAttribute(self, new_shape): +- from tosa import ReshapeAttribute as a, Attribute ++ from ..tosa import ReshapeAttribute as a, Attribute + + self.utype = Attribute.Attribute().ReshapeAttribute + self.optFcns = (a.Start, a.End) +@@ -233,7 +232,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.intvecs.append((a.AddNewShape, new_shape)) + + def SliceAttribute(self, start, size): +- from tosa import SliceAttribute as a, Attribute ++ from ..tosa import SliceAttribute as a, Attribute + + self.utype = Attribute.Attribute().SliceAttribute + self.optFcns = (a.Start, a.End) +@@ -242,7 +241,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.intvecs.append((a.AddSize, size)) + + def TileAttribute(self, multiples): +- from tosa import TileAttribute as a, Attribute ++ from ..tosa import TileAttribute as a, Attribute + + self.utype = Attribute.Attribute().TileAttribute + self.optFcns = (a.Start, a.End) +@@ -250,7 +249,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.intvecs.append((a.AddMultiples, multiples)) + + def ResizeAttribute(self, scale, offset, border, mode): +- from tosa import ResizeAttribute as a, Attribute ++ from ..tosa import ResizeAttribute as a, Attribute + + self.utype = Attribute.Attribute().ResizeAttribute + self.optFcns = (a.Start, a.End) +@@ -261,7 +260,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.ints.append((a.AddMode, mode)) + + def ClampAttribute(self, serializer_builder, minint, maxint, minfp, maxfp): +- from tosa import ClampAttribute as a, Attribute ++ from ..tosa import ClampAttribute as a, Attribute + + self.utype = Attribute.Attribute().ClampAttribute + self.optFcns = (a.Start, a.End) +@@ -272,10 +271,10 @@ class TosaSerializerAttribute(TosaSerializerUnion): + # min/max float attributes serialized as uint8 vectors + minfp_bytes = struct.pack("<f", minfp) + maxfp_bytes = struct.pack("<f", maxfp) +- serialized_minfp_bytes = ts.TosaSerializer.serializeUint8Vec( ++ serialized_minfp_bytes = TosaSerializer.serializeUint8Vec( + serializer_builder, minfp_bytes + ) +- serialized_maxfp_bytes = ts.TosaSerializer.serializeUint8Vec( ++ serialized_maxfp_bytes = TosaSerializer.serializeUint8Vec( + serializer_builder, maxfp_bytes + ) + +@@ -294,7 +293,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + input_unsigned, + output_unsigned, + ): +- from tosa import RescaleAttribute as a, Attribute ++ from ..tosa import RescaleAttribute as a, Attribute + + self.utype = Attribute.Attribute().RescaleAttribute + self.optFcns = (a.Start, a.End) +@@ -310,7 +309,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.bools.append((a.AddOutputUnsigned, output_unsigned)) + + def MulAttribute(self, shift): +- from tosa import MulAttribute as a, Attribute ++ from ..tosa import MulAttribute as a, Attribute + + self.utype = Attribute.Attribute().MulAttribute + self.optFcns = (a.Start, a.End) +@@ -318,7 +317,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.ints.append((a.AddShift, shift)) + + def ArithmeticRightShiftAttribute(self, round): +- from tosa import ArithmeticRightShiftAttribute as a, Attribute ++ from ..tosa import ArithmeticRightShiftAttribute as a, Attribute + + self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute + self.optFcns = ( +@@ -329,7 +328,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.bools.append((a.AddRound, round)) + + def CondIfAttribute(self, then_branch, else_branch): +- from tosa import CondIfAttribute as a, Attribute ++ from ..tosa import CondIfAttribute as a, Attribute + + self.utype = Attribute.Attribute().CondIfAttribute + self.optFcns = (a.Start, a.End) +@@ -338,7 +337,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.strings.append((a.AddElseBranch, else_branch)) + + def WhileLoopAttribute(self, cond_branch, body_branch): +- from tosa import WhileLoopAttribute as a, Attribute ++ from ..tosa import WhileLoopAttribute as a, Attribute + + self.utype = Attribute.Attribute().WhileLoopAttribute + self.optFcns = (a.Start, a.End) +@@ -347,7 +346,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.strings.append((a.AddBodyBranch, body_branch)) + + def TransposeAttribute(self, perms): +- from tosa import TransposeAttribute as a, Attribute ++ from ..tosa import TransposeAttribute as a, Attribute + + self.utype = Attribute.Attribute().TransposeAttribute + self.optFcns = (a.Start, a.End) +@@ -355,7 +354,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.intvecs.append((a.AddPerms, perms)) + + def TableAttribute(self, table): +- from tosa import TableAttribute as a, Attribute ++ from ..tosa import TableAttribute as a, Attribute + + self.utype = Attribute.Attribute().TableAttribute + self.optFcns = (a.Start, a.End) +@@ -363,7 +362,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.int16vecs.append((a.AddTable, table)) + + def MatMulAttribute(self, A_zp, B_zp): +- from tosa import MatMulAttribute as a, Attribute ++ from ..tosa import MatMulAttribute as a, Attribute + + self.utype = Attribute.Attribute().MatMulAttribute + self.optFcns = (a.Start, a.End) +@@ -372,7 +371,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.ints.append((a.AddBZp, B_zp)) + + def FullyConnectedAttribute(self, input_zp, weight_zp): +- from tosa import FullyConnectedAttribute as a, Attribute ++ from ..tosa import FullyConnectedAttribute as a, Attribute + + self.utype = Attribute.Attribute().FullyConnectedAttribute + self.optFcns = (a.Start, a.End) +@@ -381,7 +380,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.ints.append((a.AddWeightZp, weight_zp)) + + def NegateAttribute(self, input1_zp, output_zp): +- from tosa import NegateAttribute as a, Attribute ++ from ..tosa import NegateAttribute as a, Attribute + + self.utype = Attribute.Attribute().NegateAttribute + self.optFcns = (a.Start, a.End) +@@ -390,7 +389,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.ints.append((a.AddOutputZp, output_zp)) + + def FFTAttribute(self, inverse, local_bound): +- from tosa import FFTAttribute as a, Attribute ++ from ..tosa import FFTAttribute as a, Attribute + + self.utype = Attribute.Attribute().FFTAttribute + self.optFcns = (a.Start, a.End) +@@ -399,7 +398,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): + self.bools.append((a.AddLocalBound, local_bound)) + + def RFFTAttribute(self, local_bound): +- from tosa import RFFTAttribute as a, Attribute ++ from ..tosa import RFFTAttribute as a, Attribute + + self.utype = Attribute.Attribute().RFFTAttribute + self.optFcns = (a.Start, a.End) +-- +2.39.5 (Apple Git-154) + diff --git a/backends/arm/tosa_backend.py b/backends/arm/tosa_backend.py index 314f4c7d291..adb4fba1fc8 100644 --- a/backends/arm/tosa_backend.py +++ b/backends/arm/tosa_backend.py @@ -13,7 +13,7 @@ import logging from typing import cast, final, List -import serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.arm_backend import get_tosa_spec from executorch.backends.arm.operators.node_visitor import get_node_visitors diff --git a/backends/arm/tosa_mapping.py b/backends/arm/tosa_mapping.py index b75f0e88fde..26441cbfb02 100644 --- a/backends/arm/tosa_mapping.py +++ b/backends/arm/tosa_mapping.py @@ -13,9 +13,10 @@ from typing import Any, Sequence -import serializer.tosa_serializer as ts # type: ignore import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + UNSUPPORTED_DTYPES = ( torch.float64, diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 70a2dd4281b..0cfa19eb453 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -12,13 +12,13 @@ import executorch.backends.arm.tosa_mapping -import serializer.tosa_serializer as ts # type: ignore import torch.fx import torch.fx.node -import tosa.Op as TosaOp # type: ignore + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops -from serializer.tosa_serializer import TosaSerializer, TosaSerializerTensor from torch import Tensor from torch.fx import Node @@ -30,7 +30,7 @@ def insert_rescale_ops_to_int32( tosa_graph: ts.TosaSerializer, inputs: list[TosaArg], node: Node -) -> tuple[list[TosaSerializerTensor], float]: +) -> tuple[list[ts.TosaSerializerTensor], float]: """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. The scales are adjusted using the smallest scale of all 'nodes'. @@ -61,7 +61,7 @@ def insert_rescale_ops_to_int32( min_scale = min([qarg.scale for qarg in qargs]) scales = [qarg.scale / min_scale for qarg in qargs] - rescaled_nodes: list[TosaSerializerTensor] = [] + rescaled_nodes: list[ts.TosaSerializerTensor] = [] for tensor, qarg, scale in zip(tensors, qargs, scales): rescaled_nodes.append( build_rescale_to_int32( @@ -198,9 +198,9 @@ def compute_multiplier_and_shift( def build_rescale( - tosa_fb: TosaSerializer, + tosa_fb: ts.TosaSerializer, scale: list[float], - input_node: TosaSerializerTensor, + input_node: ts.TosaSerializerTensor, output_name: str, output_type: ts.DType, output_shape: List[int], @@ -233,14 +233,14 @@ def build_rescale( def build_rescale_to_int32( - tosa_fb: TosaSerializer, + tosa_fb: ts.TosaSerializer, input_arg: executorch.backends.arm.tosa_mapping.TosaArg, input_zp: int, rescale_scale: list[float], is_scale32: bool = True, is_double_round: bool = False, per_channel: bool = False, -) -> TosaSerializerTensor: +) -> ts.TosaSerializerTensor: multipliers, shifts = compute_multiplier_and_shift(rescale_scale) attr_rescale = ts.TosaSerializerAttribute() attr_rescale.RescaleAttribute( @@ -266,7 +266,7 @@ def build_rescale_to_int32( def build_rescale_from_int32( - tosa_fb: TosaSerializer, + tosa_fb: ts.TosaSerializer, input_name: str, output_name: str, output_zp: int, @@ -300,8 +300,8 @@ def build_rescale_from_int32( def build_rescale_conv_output( - tosa_fb: TosaSerializer, - op: TosaSerializerTensor, + tosa_fb: ts.TosaSerializer, + op: ts.TosaSerializerTensor, output_name: str, output_type: ts.DType, input_scale: list[float], diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 5fa603ea683..4d0f33003bc 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -9,14 +9,15 @@ import os from typing import Any, Optional, Tuple -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.print_program import inspect_node -from serializer.tosa_serializer import TosaOp from torch.fx import Node +from tosa_tools.v0_80.serializer.tosa_serializer import TosaOp logger = logging.getLogger(__name__) diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 016e03f04a2..8d77eabce0f 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -55,10 +55,6 @@ else echo "[main] Error: only x86-64 & aarch64/arm64 architecture is supported for now!"; exit 1; fi -# tosa reference model -tosa_reference_model_url="https://git.gitlab.arm.com/tosa/tosa-reference-model.git" -tosa_reference_model_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a" - # vela vela_repo_url="https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela" vela_rev="425541302c7e4b6fbeca7c0061286b131ee507c3" @@ -156,14 +152,6 @@ function setup_toolchain() { tar xf "${toolchain_dir}.tar.xz" } -function setup_tosa_reference_model() { - # reference_model flatbuffers version clashes with Vela. - # go with Vela's since it newer. - # Vela's flatbuffer requirement is expected to loosen, then remove this. MLETORCH-565 - CMAKE_POLICY_VERSION_MINIMUM=3.5 pip install tosa-tools@git+${tosa_reference_model_url}@${tosa_reference_model_rev} --no-dependencies flatbuffers - -} - function setup_vela() { pip install ethos-u-vela@git+${vela_repo_url}@${vela_rev} } @@ -233,7 +221,7 @@ if [[ $is_script_sourced -eq 0 ]] create_setup_path # Setup the tosa_reference_model - setup_tosa_reference_model + $et_dir/backends/arm/scripts/install_reference_model.sh ${root_dir} # Setup vela and patch in codegen fixes setup_vela