12
12
_SplitterSettingBase ,
13
13
)
14
14
from torch .fx .passes .tools_common import CALLABLE_NODE_OPS , NodeSet
15
- from torch_tensorrt .dynamo ._defaults import DEBUG , MIN_BLOCK_SIZE
15
+ from torch_tensorrt .dynamo ._defaults import (
16
+ DEBUG ,
17
+ MIN_BLOCK_SIZE ,
18
+ REQUIRE_FULL_COMPILATION ,
19
+ )
16
20
from torch_tensorrt .dynamo .conversion .converter_registry import (
17
21
DYNAMO_CONVERTERS as CONVERTERS ,
18
22
)
@@ -92,6 +96,7 @@ class TRTPartitioner(_SplitterBase): # type: ignore
92
96
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
93
97
Generally useful for module-level exclusion ops which are intensive despite being single functions
94
98
min_block_size: Minimum number of computational operators per block
99
+ require_full_compilation: Require that all computational operators be run in TRT
95
100
Returns:
96
101
torch.fx.GraphModule
97
102
"""
@@ -104,6 +109,7 @@ def __init__(
104
109
Collection [str ]
105
110
] = DEFAULT_SINGLE_NODE_PARTITIONS ,
106
111
min_block_size : int = MIN_BLOCK_SIZE ,
112
+ require_full_compilation : bool = REQUIRE_FULL_COMPILATION ,
107
113
):
108
114
"""
109
115
Preprocesses graph before splitting:
@@ -142,6 +148,7 @@ def __init__(
142
148
143
149
self .num_trt_accelerated_subgraphs : Optional [int ] = None
144
150
self .allowed_single_node_partition_ops = allowed_single_node_partition_ops
151
+ self .require_full_compilation = require_full_compilation
145
152
146
153
def remove_small_acc_subgraphs (self , subgraphs : List [Subgraph ]) -> List [Subgraph ]:
147
154
"""
@@ -151,12 +158,16 @@ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph
151
158
result : List [Subgraph ] = []
152
159
for subgraph in subgraphs :
153
160
if subgraph .is_acc :
154
- if len (subgraph .nodes ) >= self .settings .min_acc_module_size or (
155
- self .allowed_single_node_partition_ops is not None
156
- and any (
157
- ConverterRegistry .qualified_name_or_str (node .target )
158
- in self .allowed_single_node_partition_ops
159
- for node in subgraph .nodes
161
+ if (
162
+ len (subgraph .nodes ) >= self .settings .min_acc_module_size
163
+ or self .require_full_compilation
164
+ or (
165
+ self .allowed_single_node_partition_ops is not None
166
+ and any (
167
+ ConverterRegistry .qualified_name_or_str (node .target )
168
+ in self .allowed_single_node_partition_ops
169
+ for node in subgraph .nodes
170
+ )
160
171
)
161
172
):
162
173
result .append (subgraph )
@@ -185,6 +196,27 @@ def partition_graph(self) -> torch.fx.GraphModule:
185
196
# Delegate nodes based on operator coverage
186
197
subgraphs = self .put_nodes_into_subgraphs ()
187
198
199
+ # A graph is fully supported if there is a single partition and all operators are supported/convertible
200
+ full_support = len ([s for s in subgraphs if s .is_acc ]) == 1 and not getattr (
201
+ self .operator_support , "unsupported_operators" , True
202
+ )
203
+
204
+ if not full_support and self .require_full_compilation :
205
+ raise AssertionError (
206
+ "require_full_compilation=True was specified, but model is not fully supported"
207
+ )
208
+
209
+ if (
210
+ full_support
211
+ and self .require_full_compilation
212
+ and self .settings .min_acc_module_size != MIN_BLOCK_SIZE
213
+ ):
214
+ logger .warning (
215
+ "Detected both require_full_compilation and min_block_size compilation "
216
+ "arguments were specified. Disregarding min_block_size argument for "
217
+ "fully supported model."
218
+ )
219
+
188
220
# Remove segments smaller than the block size (with exceptions)
189
221
subgraphs = self .remove_small_acc_subgraphs (subgraphs )
190
222
@@ -217,6 +249,7 @@ def partition(
217
249
verbose : bool = DEBUG ,
218
250
min_block_size : int = MIN_BLOCK_SIZE ,
219
251
torch_executed_ops : Collection [Target ] = set (),
252
+ require_full_compilation : bool = REQUIRE_FULL_COMPILATION ,
220
253
) -> torch .fx .GraphModule :
221
254
"""Partition an FX GraphModule with aten ops into TRT engines
222
255
Partitioning is based on converter operator support
@@ -226,6 +259,7 @@ def partition(
226
259
verbose: Bool representing whether to print operator support
227
260
min_block_size: Minimum number of operators per TRT-Engine Block
228
261
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
262
+ require_full_compilation: Require that all computational operators be run in TRT
229
263
Returns:
230
264
torch.fx.GraphModule
231
265
"""
@@ -236,7 +270,12 @@ def partition(
236
270
237
271
# Construct
238
272
supported_ops = OpSupportTester (torch_executed_ops = torch_executed_ops )
239
- partitioner = TRTPartitioner (gm , supported_ops , min_block_size = min_block_size )
273
+ partitioner = TRTPartitioner (
274
+ gm ,
275
+ supported_ops ,
276
+ min_block_size = min_block_size ,
277
+ require_full_compilation = require_full_compilation ,
278
+ )
240
279
241
280
partitioned_graph = partitioner .partition_graph ()
242
281
0 commit comments