diff --git a/backends/arm/ethosu_backend.py b/backends/arm/ethosu_backend.py index 9b14a7a72b8..402f3fed42b 100644 --- a/backends/arm/ethosu_backend.py +++ b/backends/arm/ethosu_backend.py @@ -23,7 +23,6 @@ # debug functionality logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) @final diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 377f8c17c4c..f9b77e28493 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -14,11 +14,7 @@ from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_utils import ( - get_node_debug_info, - getNodeArgs, - tosa_shape, -) +from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape from torch.export.exported_program import ExportedProgram @@ -36,7 +32,7 @@ def process_call_function( output = TosaArg(node) except ValueError as e: raise ValueError( - f"Failed processing call_function:\n{get_node_debug_info(node)}" + f"Failed processing call_function: {node.name}. " "Is the original torch function supported?" ) from e tosa_graph.currRegion.currBasicBlock.addTensor( @@ -74,7 +70,7 @@ def process_inputs( tosa_arg = TosaArg(node) except ValueError as e: raise ValueError( - f"Failed processing input placeholder:\n{get_node_debug_info(node)}" + f"Failed processing input placeholder: {node.name}. " "Is the original torch function supported?" ) from e input_shape = tosa_arg.shape @@ -100,7 +96,7 @@ def process_inputs_to_parameters( tosa_arg = TosaArg(node) except ValueError as e: raise ValueError( - f"Failed processing parameter placeholder:\n{get_node_debug_info(node)}" + f"Failed processing parameter placeholder: {node.name}. " "Is the original torch function supported?" ) from e parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name] @@ -129,7 +125,7 @@ def process_inputs_to_buffers( tosa_arg = TosaArg(node) except ValueError as e: raise ValueError( - f"Failed processing buffer placeholder:\n{get_node_debug_info(node)}" + f"Failed processing buffer placeholder: {node.name}. " "Is the original torch function supported?" ) from e buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name] @@ -157,7 +153,7 @@ def process_inputs_to_lifted_tensor_constants( tosa_arg = TosaArg(node) except ValueError as e: raise ValueError( - f"Failed processing lifted tensor constant placeholder:\n{get_node_debug_info(node)}" + f"Failed processing lifted tensor constant placeholder: {node.name}. " "Is the original torch function supported?" ) from e tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[ diff --git a/backends/arm/tosa_backend.py b/backends/arm/tosa_backend.py index 0eb0757e262..f64b3681098 100644 --- a/backends/arm/tosa_backend.py +++ b/backends/arm/tosa_backend.py @@ -34,7 +34,6 @@ # TOSA backend debug functionality logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" if TOSA_DBG_VERBOSE: logging.basicConfig(level=logging.INFO) @@ -101,18 +100,22 @@ def preprocess( # noqa: C901 input_count = 0 for node in graph_module.graph.nodes: node = cast(Node, node) - if node.op == "call_function": - process_call_function(node, tosa_graph, node_visitors, tosa_spec) - elif node.op == "placeholder": - process_placeholder(node, tosa_graph, edge_program, tosa_spec) - if node.name in edge_program.graph_signature.user_inputs: - input_count += 1 - elif node.op == "output": - process_output(node, tosa_graph) - else: - # This will only happen if an unpartitioned graph is passed without - # any checking of compatibility. - dbg_fail(node, tosa_graph, artifact_path) + try: + if node.op == "call_function": + process_call_function(node, tosa_graph, node_visitors, tosa_spec) + elif node.op == "placeholder": + process_placeholder(node, tosa_graph, edge_program, tosa_spec) + if node.name in edge_program.graph_signature.user_inputs: + input_count += 1 + elif node.op == "output": + process_output(node, tosa_graph) + else: + # This will only happen if an unpartitioned graph is passed without + # any checking of compatibility. + raise RuntimeError(f"{node.name} is unsupported op {node.op}") + except (AssertionError, RuntimeError, ValueError): + dbg_fail(node, graph_module, tosa_graph, artifact_path) + raise if len(input_order) > 0: if input_count != len(input_order): diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 45473a496e1..788ebf39696 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -7,31 +7,32 @@ import logging import os -from typing import Any, Tuple +from typing import Any, Optional, Tuple import serializer.tosa_serializer as ts # type: ignore import torch from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.print_program import inspect_node from serializer.tosa_serializer import TosaOp from torch.fx import Node logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" if TOSA_DBG_VERBOSE: logging.basicConfig(level=logging.INFO) logger.setLevel(logging.INFO) -def dbg_node(node: torch.fx.Node): +def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): # Debug output of node information - logger.info(get_node_debug_info(node)) + logger.info(get_node_debug_info(node, graph_module)) -def get_node_debug_info(node: torch.fx.Node) -> str: +def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str: output = ( + f" {inspect_node(graph=graph_module.graph, node=node)}\n" "-- NODE DEBUG INFO --\n" f" Op is {node.op}\n" f" Name is {node.name}\n" @@ -71,21 +72,24 @@ def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON" -def dbg_fail(node, tosa_graph, path): - dbg_tosa_dump(tosa_graph, path) +def dbg_fail( + node, + graph_module, + tosa_graph: Optional[ts.TosaSerializer] = None, + path: Optional[str] = None, +): logger.warning("Internal error due to poorly handled node:") - dbg_node(node) - logger.warning(f"Debug output captured in '{path}'.") - raise RuntimeError("TOSA Internal Error on node, enable logging for further info.") + if tosa_graph is not None and path is not None: + dbg_tosa_dump(tosa_graph, path) + logger.warning(f"Debug output captured in '{path}'.") + dbg_node(node, graph_module) def getNodeArgs(node: Node) -> list[TosaArg]: try: return [TosaArg(arg) for arg in node.args] except ValueError as e: - raise ValueError( - f"Failed processing args to op:\n{get_node_debug_info(node)}" - ) from e + raise ValueError(f"Failed processing args to op:\n{node}") from e def get_output_node(node: Node) -> Node: