Skip to content

Commit 5cc98bc

Browse files
Arm backend: Use dbg_fail when node visitors raise exceptions (#9391)
Adds a try-expect around the node_visitor call to be able to call dbg_fail() when an error/exception is raised. Signed-off-by: Oscar Andersson <[email protected]> Co-authored-by: Digant Desai <[email protected]>
1 parent dc93fde commit 5cc98bc

File tree

4 files changed

+39
-37
lines changed

4 files changed

+39
-37
lines changed

backends/arm/ethosu_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
# debug functionality
2525
logger = logging.getLogger(__name__)
26-
logger.setLevel(logging.WARNING)
2726

2827

2928
@final

backends/arm/process_node.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
from executorch.backends.arm.operators.node_visitor import NodeVisitor
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_specification import TosaSpecification
17-
from executorch.backends.arm.tosa_utils import (
18-
get_node_debug_info,
19-
getNodeArgs,
20-
tosa_shape,
21-
)
17+
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
2218
from torch.export.exported_program import ExportedProgram
2319

2420

@@ -36,7 +32,7 @@ def process_call_function(
3632
output = TosaArg(node)
3733
except ValueError as e:
3834
raise ValueError(
39-
f"Failed processing call_function:\n{get_node_debug_info(node)}"
35+
f"Failed processing call_function: {node.name}. "
4036
"Is the original torch function supported?"
4137
) from e
4238
tosa_graph.currRegion.currBasicBlock.addTensor(
@@ -74,7 +70,7 @@ def process_inputs(
7470
tosa_arg = TosaArg(node)
7571
except ValueError as e:
7672
raise ValueError(
77-
f"Failed processing input placeholder:\n{get_node_debug_info(node)}"
73+
f"Failed processing input placeholder: {node.name}. "
7874
"Is the original torch function supported?"
7975
) from e
8076
input_shape = tosa_arg.shape
@@ -100,7 +96,7 @@ def process_inputs_to_parameters(
10096
tosa_arg = TosaArg(node)
10197
except ValueError as e:
10298
raise ValueError(
103-
f"Failed processing parameter placeholder:\n{get_node_debug_info(node)}"
99+
f"Failed processing parameter placeholder: {node.name}. "
104100
"Is the original torch function supported?"
105101
) from e
106102
parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name]
@@ -129,7 +125,7 @@ def process_inputs_to_buffers(
129125
tosa_arg = TosaArg(node)
130126
except ValueError as e:
131127
raise ValueError(
132-
f"Failed processing buffer placeholder:\n{get_node_debug_info(node)}"
128+
f"Failed processing buffer placeholder: {node.name}. "
133129
"Is the original torch function supported?"
134130
) from e
135131
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
@@ -157,7 +153,7 @@ def process_inputs_to_lifted_tensor_constants(
157153
tosa_arg = TosaArg(node)
158154
except ValueError as e:
159155
raise ValueError(
160-
f"Failed processing lifted tensor constant placeholder:\n{get_node_debug_info(node)}"
156+
f"Failed processing lifted tensor constant placeholder: {node.name}. "
161157
"Is the original torch function supported?"
162158
) from e
163159
tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[

backends/arm/tosa_backend.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535
# TOSA backend debug functionality
3636
logger = logging.getLogger(__name__)
37-
logger.setLevel(logging.WARNING)
3837
TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
3938
if TOSA_DBG_VERBOSE:
4039
logging.basicConfig(level=logging.INFO)
@@ -101,18 +100,22 @@ def preprocess( # noqa: C901
101100
input_count = 0
102101
for node in graph_module.graph.nodes:
103102
node = cast(Node, node)
104-
if node.op == "call_function":
105-
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
106-
elif node.op == "placeholder":
107-
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
108-
if node.name in edge_program.graph_signature.user_inputs:
109-
input_count += 1
110-
elif node.op == "output":
111-
process_output(node, tosa_graph)
112-
else:
113-
# This will only happen if an unpartitioned graph is passed without
114-
# any checking of compatibility.
115-
dbg_fail(node, tosa_graph, artifact_path)
103+
try:
104+
if node.op == "call_function":
105+
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
106+
elif node.op == "placeholder":
107+
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
108+
if node.name in edge_program.graph_signature.user_inputs:
109+
input_count += 1
110+
elif node.op == "output":
111+
process_output(node, tosa_graph)
112+
else:
113+
# This will only happen if an unpartitioned graph is passed without
114+
# any checking of compatibility.
115+
raise RuntimeError(f"{node.name} is unsupported op {node.op}")
116+
except (AssertionError, RuntimeError, ValueError):
117+
dbg_fail(node, graph_module, tosa_graph, artifact_path)
118+
raise
116119

117120
if len(input_order) > 0:
118121
if input_count != len(input_order):

backends/arm/tosa_utils.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,32 @@
77

88
import logging
99
import os
10-
from typing import Any, Tuple
10+
from typing import Any, Optional, Tuple
1111

1212
import serializer.tosa_serializer as ts # type: ignore
1313
import torch
1414
from executorch.backends.arm.tosa_mapping import TosaArg
1515

1616
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.print_program import inspect_node
1718
from serializer.tosa_serializer import TosaOp
1819
from torch.fx import Node
1920

2021
logger = logging.getLogger(__name__)
21-
logger.setLevel(logging.WARNING)
2222
TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
2323
if TOSA_DBG_VERBOSE:
2424
logging.basicConfig(level=logging.INFO)
2525
logger.setLevel(logging.INFO)
2626

2727

28-
def dbg_node(node: torch.fx.Node):
28+
def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule):
2929
# Debug output of node information
30-
logger.info(get_node_debug_info(node))
30+
logger.info(get_node_debug_info(node, graph_module))
3131

3232

33-
def get_node_debug_info(node: torch.fx.Node) -> str:
33+
def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str:
3434
output = (
35+
f" {inspect_node(graph=graph_module.graph, node=node)}\n"
3536
"-- NODE DEBUG INFO --\n"
3637
f" Op is {node.op}\n"
3738
f" Name is {node.name}\n"
@@ -71,21 +72,24 @@ def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
7172
assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON"
7273

7374

74-
def dbg_fail(node, tosa_graph, path):
75-
dbg_tosa_dump(tosa_graph, path)
75+
def dbg_fail(
76+
node,
77+
graph_module,
78+
tosa_graph: Optional[ts.TosaSerializer] = None,
79+
path: Optional[str] = None,
80+
):
7681
logger.warning("Internal error due to poorly handled node:")
77-
dbg_node(node)
78-
logger.warning(f"Debug output captured in '{path}'.")
79-
raise RuntimeError("TOSA Internal Error on node, enable logging for further info.")
82+
if tosa_graph is not None and path is not None:
83+
dbg_tosa_dump(tosa_graph, path)
84+
logger.warning(f"Debug output captured in '{path}'.")
85+
dbg_node(node, graph_module)
8086

8187

8288
def getNodeArgs(node: Node) -> list[TosaArg]:
8389
try:
8490
return [TosaArg(arg) for arg in node.args]
8591
except ValueError as e:
86-
raise ValueError(
87-
f"Failed processing args to op:\n{get_node_debug_info(node)}"
88-
) from e
92+
raise ValueError(f"Failed processing args to op:\n{node}") from e
8993

9094

9195
def get_output_node(node: Node) -> Node:

0 commit comments

Comments
 (0)