From 1e98d08c380f18cbb5d9f9a7ea5fc3efe0a6cf3b Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 18 Mar 2025 14:32:57 -0700 Subject: [PATCH] Revert "Arm backend: Use dbg_fail when node visitors raise exceptions (#9268)" This reverts commit 622b79e6138ba95906f3d54b66d134ba9e0d22ed. --- backends/arm/ethosu_backend.py | 1 + backends/arm/process_node.py | 16 ++++++++++------ backends/arm/tosa_backend.py | 29 +++++++++++++---------------- backends/arm/tosa_utils.py | 30 +++++++++++++----------------- 4 files changed, 37 insertions(+), 39 deletions(-) diff --git a/backends/arm/ethosu_backend.py b/backends/arm/ethosu_backend.py index 402f3fed42b..9b14a7a72b8 100644 --- a/backends/arm/ethosu_backend.py +++ b/backends/arm/ethosu_backend.py @@ -23,6 +23,7 @@ # debug functionality logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) @final diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index f9b77e28493..377f8c17c4c 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -14,7 +14,11 @@ from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape +from executorch.backends.arm.tosa_utils import ( + get_node_debug_info, + getNodeArgs, + tosa_shape, +) from torch.export.exported_program import ExportedProgram @@ -32,7 +36,7 @@ def process_call_function( output = TosaArg(node) except ValueError as e: raise ValueError( - f"Failed processing call_function: {node.name}. " + f"Failed processing call_function:\n{get_node_debug_info(node)}" "Is the original torch function supported?" ) from e tosa_graph.currRegion.currBasicBlock.addTensor( @@ -70,7 +74,7 @@ def process_inputs( tosa_arg = TosaArg(node) except ValueError as e: raise ValueError( - f"Failed processing input placeholder: {node.name}. " + f"Failed processing input placeholder:\n{get_node_debug_info(node)}" "Is the original torch function supported?" ) from e input_shape = tosa_arg.shape @@ -96,7 +100,7 @@ def process_inputs_to_parameters( tosa_arg = TosaArg(node) except ValueError as e: raise ValueError( - f"Failed processing parameter placeholder: {node.name}. " + f"Failed processing parameter placeholder:\n{get_node_debug_info(node)}" "Is the original torch function supported?" ) from e parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name] @@ -125,7 +129,7 @@ def process_inputs_to_buffers( tosa_arg = TosaArg(node) except ValueError as e: raise ValueError( - f"Failed processing buffer placeholder: {node.name}. " + f"Failed processing buffer placeholder:\n{get_node_debug_info(node)}" "Is the original torch function supported?" ) from e buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name] @@ -153,7 +157,7 @@ def process_inputs_to_lifted_tensor_constants( tosa_arg = TosaArg(node) except ValueError as e: raise ValueError( - f"Failed processing lifted tensor constant placeholder: {node.name}. " + f"Failed processing lifted tensor constant placeholder:\n{get_node_debug_info(node)}" "Is the original torch function supported?" ) from e tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[ diff --git a/backends/arm/tosa_backend.py b/backends/arm/tosa_backend.py index 18c39d133b7..6030b4e8bef 100644 --- a/backends/arm/tosa_backend.py +++ b/backends/arm/tosa_backend.py @@ -35,6 +35,7 @@ # TOSA backend debug functionality logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" if TOSA_DBG_VERBOSE: logging.basicConfig(level=logging.INFO) @@ -101,22 +102,18 @@ def preprocess( # noqa: C901 input_count = 0 for node in graph_module.graph.nodes: node = cast(Node, node) - try: - if node.op == "call_function": - process_call_function(node, tosa_graph, node_visitors, tosa_spec) - elif node.op == "placeholder": - process_placeholder(node, tosa_graph, edge_program, tosa_spec) - if node.name in edge_program.graph_signature.user_inputs: - input_count += 1 - elif node.op == "output": - process_output(node, tosa_graph) - else: - # This will only happen if an unpartitioned graph is passed without - # any checking of compatibility. - raise RuntimeError(f"{node.name} is unsupported op {node.op}") - except (AssertionError, RuntimeError, ValueError): - dbg_fail(node, graph_module, tosa_graph, artifact_path) - raise + if node.op == "call_function": + process_call_function(node, tosa_graph, node_visitors, tosa_spec) + elif node.op == "placeholder": + process_placeholder(node, tosa_graph, edge_program, tosa_spec) + if node.name in edge_program.graph_signature.user_inputs: + input_count += 1 + elif node.op == "output": + process_output(node, tosa_graph) + else: + # This will only happen if an unpartitioned graph is passed without + # any checking of compatibility. + dbg_fail(node, tosa_graph, artifact_path) if len(input_order) > 0: if input_count != len(input_order): diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 788ebf39696..45473a496e1 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -7,32 +7,31 @@ import logging import os -from typing import Any, Optional, Tuple +from typing import Any, Tuple import serializer.tosa_serializer as ts # type: ignore import torch from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.print_program import inspect_node from serializer.tosa_serializer import TosaOp from torch.fx import Node logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" if TOSA_DBG_VERBOSE: logging.basicConfig(level=logging.INFO) logger.setLevel(logging.INFO) -def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): +def dbg_node(node: torch.fx.Node): # Debug output of node information - logger.info(get_node_debug_info(node, graph_module)) + logger.info(get_node_debug_info(node)) -def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str: +def get_node_debug_info(node: torch.fx.Node) -> str: output = ( - f" {inspect_node(graph=graph_module.graph, node=node)}\n" "-- NODE DEBUG INFO --\n" f" Op is {node.op}\n" f" Name is {node.name}\n" @@ -72,24 +71,21 @@ def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON" -def dbg_fail( - node, - graph_module, - tosa_graph: Optional[ts.TosaSerializer] = None, - path: Optional[str] = None, -): +def dbg_fail(node, tosa_graph, path): + dbg_tosa_dump(tosa_graph, path) logger.warning("Internal error due to poorly handled node:") - if tosa_graph is not None and path is not None: - dbg_tosa_dump(tosa_graph, path) - logger.warning(f"Debug output captured in '{path}'.") - dbg_node(node, graph_module) + dbg_node(node) + logger.warning(f"Debug output captured in '{path}'.") + raise RuntimeError("TOSA Internal Error on node, enable logging for further info.") def getNodeArgs(node: Node) -> list[TosaArg]: try: return [TosaArg(arg) for arg in node.args] except ValueError as e: - raise ValueError(f"Failed processing args to op:\n{node}") from e + raise ValueError( + f"Failed processing args to op:\n{get_node_debug_info(node)}" + ) from e def get_output_node(node: Node) -> Node: