Skip to content

Commit ffed9d6

Browse files
committed
feat: Inform user if no valid partitions
- Fall back to global partitioner if fast partitioner fails
1 parent f3b34ac commit ffed9d6

File tree

5 files changed

+114
-39
lines changed

5 files changed

+114
-39
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,47 @@ def _compile_module(
108108
Returns:
109109
Compiled FX GraphModule
110110
"""
111-
# Partition module into components that can be TRT-accelerated
112-
if settings.use_fast_partitioner:
113-
partitioned_module = partitioning.fast_partition(
114-
gm,
115-
verbose=settings.debug,
116-
min_block_size=settings.min_block_size,
117-
torch_executed_ops=settings.torch_executed_ops,
111+
# Check the number of supported operations in the graph
112+
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
113+
gm, settings.debug, settings.torch_executed_ops
114+
)
115+
116+
# If the number of supported operations is 0 or less than the block size, skip the subgraph
117+
# TODO: Add condition to second expression below when require_full_compilation is added
118+
if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size):
119+
logger.warning(
120+
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
121+
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
118122
)
123+
return gm
119124
else:
125+
logger.debug(
126+
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
127+
)
128+
129+
# Partition module into components that can be TRT-accelerated
130+
fast_partitioner_failed = False
131+
132+
# If specified, try using the fast partitioner and fall back to the global one on failure
133+
if settings.use_fast_partitioner:
134+
try:
135+
partitioned_module = partitioning.fast_partition(
136+
gm,
137+
verbose=settings.debug,
138+
min_block_size=settings.min_block_size,
139+
torch_executed_ops=settings.torch_executed_ops,
140+
)
141+
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
142+
logger.error(
143+
"Partitioning failed on the subgraph with fast partition. See trace above. "
144+
+ "Retrying with global partition.",
145+
exc_info=True,
146+
)
147+
148+
fast_partitioner_failed = True
149+
settings.use_fast_partitioner = False
150+
151+
if not settings.use_fast_partitioner:
120152
partitioned_module = partitioning.global_partition(
121153
gm,
122154
verbose=settings.debug,
@@ -162,4 +194,8 @@ def _compile_module(
162194
for name, trt_mod in trt_modules.items():
163195
setattr(partitioned_module, name, trt_mod)
164196

197+
# Reset settings object to user specification after fallback to global partitioning mode
198+
if fast_partitioner_failed:
199+
settings.use_fast_partitioner = True
200+
165201
return partitioned_module
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from ._adjacency_partitioner import partition as fast_partition
22
from ._global_partitioner import partition as global_partition
3-
from .common import get_submod_inputs
3+
from .common import get_graph_converter_support, get_submod_inputs

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
_SplitterSettingBase,
1313
)
1414
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet
15-
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
16-
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
15+
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
16+
from torch_tensorrt.dynamo.conversion.converter_registry import (
17+
DYNAMO_CONVERTERS as CONVERTERS,
18+
)
1719
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
1820

1921
from .common import DEFAULT_SINGLE_NODE_PARTITIONS
@@ -212,7 +214,7 @@ def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
212214

213215
def partition(
214216
gm: torch.fx.GraphModule,
215-
verbose: bool = True,
217+
verbose: bool = DEBUG,
216218
min_block_size: int = MIN_BLOCK_SIZE,
217219
torch_executed_ops: Collection[Target] = set(),
218220
) -> torch.fx.GraphModule:

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
import torch
55
from torch.fx.graph_module import GraphModule
6-
from torch.fx.node import _get_qualified_name
76
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
87
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
9-
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
8+
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
109
from torch_tensorrt.dynamo.conversion.converter_registry import (
1110
DYNAMO_CONVERTERS as CONVERTERS,
1211
)
12+
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
1313

1414
from .common import DEFAULT_SINGLE_NODE_PARTITIONS
1515

@@ -69,14 +69,15 @@ def propose_partitions(self) -> List[Partition]:
6969
# Partitions are exempted from min_block_size if they contain an allowed single-node op
7070
if (
7171
node.op == "call_function"
72-
and _get_qualified_name(node.target)
72+
and ConverterRegistry.qualified_name_or_str(node.target)
7373
in self.allowed_single_node_partition_ops
7474
):
7575
exempted_partition = True
7676
break
7777
elif (
7878
node.op == "call_function"
79-
and _get_qualified_name(node.target) not in non_compute_ops
79+
and ConverterRegistry.qualified_name_or_str(node.target)
80+
not in non_compute_ops
8081
):
8182
compute_node_count += 1
8283

@@ -118,11 +119,7 @@ def __init__(
118119
def is_node_supported(
119120
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
120121
) -> bool:
121-
node_name = (
122-
_get_qualified_name(node.target)
123-
if not isinstance(node.target, str)
124-
else node.target
125-
)
122+
node_name = ConverterRegistry.qualified_name_or_str(node.target)
126123

127124
if node in CONVERTERS and node_name not in self.torch_executed_ops:
128125
# If node is a proper, supported computational node, store the operator
@@ -142,32 +139,37 @@ def is_node_supported(
142139

143140
return False
144141

145-
def print_support_overview(self, num_trt_blocks: Optional[int] = None) -> None:
142+
def print_support_overview(
143+
self, num_trt_blocks: Optional[int] = None, print_node_support: bool = False
144+
) -> None:
146145
if num_trt_blocks is not None:
147146
logger.debug(
148147
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
149148
)
150149

151-
# Reformat support messages for debugger to print node overview as a single string
152-
supported_nodes_str = "\nSupported Nodes:\n"
153-
for node_name, count in self.supported_operators.items():
154-
supported_nodes_str += f"- {node_name} + Operator Count: {count}\n"
150+
if print_node_support:
151+
# Reformat support messages for debugger to print node overview as a single string
152+
supported_nodes_str = "\nSupported Nodes:\n"
153+
for node_name, count in self.supported_operators.items():
154+
supported_nodes_str += f"- {node_name} + Operator Count: {count}\n"
155155

156-
logger.debug(supported_nodes_str)
156+
logger.debug(supported_nodes_str)
157157

158-
if self.unsupported_operators:
159-
unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n"
160-
for node_name, count in self.unsupported_operators.items():
161-
unsupported_nodes_str += f"- {node_name} + Operator Count: {count}\n"
158+
if self.unsupported_operators:
159+
unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n"
160+
for node_name, count in self.unsupported_operators.items():
161+
unsupported_nodes_str += (
162+
f"- {node_name} + Operator Count: {count}\n"
163+
)
162164

163-
logger.debug(unsupported_nodes_str)
164-
else:
165-
logger.debug("\nAll Nodes Supported\n")
165+
logger.debug(unsupported_nodes_str)
166+
else:
167+
logger.debug("\nAll Nodes Supported\n")
166168

167169

168170
def partition(
169171
gm: torch.fx.GraphModule,
170-
verbose: bool = True,
172+
verbose: bool = DEBUG,
171173
min_block_size: int = MIN_BLOCK_SIZE,
172174
torch_executed_ops: Optional[Set[str]] = None,
173175
) -> torch.fx.GraphModule:

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import logging
2-
from typing import Any, Optional, Sequence, Set
1+
from typing import Any, Optional, Sequence, Set, Tuple
32

43
import torch
54
from torch.fx.node import _get_qualified_name
5+
from torch_tensorrt.dynamo._defaults import DEBUG
66
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
77

88
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = {
@@ -11,9 +11,6 @@
1111
}
1212

1313

14-
logger = logging.getLogger(__name__)
15-
16-
1714
def get_submod_inputs(
1815
mod: torch.fx.GraphModule,
1916
submod: torch.fx.GraphModule,
@@ -39,3 +36,41 @@ def get_input(self: Any, inputs: Sequence[torch.Tensor]) -> None:
3936
mod(*inputs)
4037
handle.remove()
4138
return acc_inputs
39+
40+
41+
def get_graph_converter_support(
42+
graph_module: torch.fx.GraphModule,
43+
verbose: bool = DEBUG,
44+
torch_executed_ops: Optional[Set[str]] = None,
45+
) -> Tuple[int, int]:
46+
"""Helper function to get converter support overview pre-partitioning
47+
48+
Args:
49+
graph_module: FX GraphModule to determine support for
50+
verbose: Bool representing whether to print operator support
51+
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
52+
Returns:
53+
The number of supported call_function nodes in the graph
54+
"""
55+
from ._global_partitioner import TorchTensorRTOperatorSupport
56+
57+
# Instantiate operator support object and module dictionary
58+
op_support = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops)
59+
module_dict = dict(graph_module.named_modules())
60+
61+
number_of_supported_nodes = 0
62+
total_functional_nodes = 0
63+
64+
# Iterate over all nodes in the graph, enumerating call_function nodes
65+
for node in graph_module.graph.nodes:
66+
if node.op == "call_function":
67+
total_functional_nodes += 1
68+
69+
if op_support.is_node_supported(module_dict, node):
70+
number_of_supported_nodes += 1
71+
72+
# Print node support overview prior to partitioning
73+
if verbose:
74+
op_support.print_support_overview(print_node_support=True)
75+
76+
return number_of_supported_nodes, total_functional_nodes

0 commit comments

Comments
 (0)