Skip to content

Commit 6919f72

Browse files
committed
refactor: Add require_full_compilation in Dynamo
- Add support for full compilation compilation argument in Dynamo paths
1 parent 21d6f98 commit 6919f72

File tree

6 files changed

+164
-13
lines changed

6 files changed

+164
-13
lines changed

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
USE_PYTHON_RUNTIME = False
1313
USE_FAST_PARTITIONER = True
1414
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
15+
REQUIRE_FULL_COMPILATION = False

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
OPTIMIZATION_LEVEL,
1111
PASS_THROUGH_BUILD_FAILURES,
1212
PRECISION,
13+
REQUIRE_FULL_COMPILATION,
1314
TRUNCATE_LONG_AND_DOUBLE,
1415
USE_FAST_PARTITIONER,
1516
USE_PYTHON_RUNTIME,
@@ -54,3 +55,4 @@ class CompilationSettings:
5455
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
5556
use_fast_partitioner: bool = USE_FAST_PARTITIONER
5657
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
58+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION

py/torch_tensorrt/dynamo/compile.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
OPTIMIZATION_LEVEL,
2020
PASS_THROUGH_BUILD_FAILURES,
2121
PRECISION,
22+
REQUIRE_FULL_COMPILATION,
2223
TRUNCATE_LONG_AND_DOUBLE,
2324
USE_FAST_PARTITIONER,
2425
USE_PYTHON_RUNTIME,
@@ -52,7 +53,7 @@ def compile(
5253
dla_global_dram_size: int = 536870912,
5354
calibrator: object = None,
5455
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
55-
require_full_compilation: bool = False,
56+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
5657
min_block_size: int = MIN_BLOCK_SIZE,
5758
torch_executed_ops: Optional[List[str]] = None,
5859
torch_executed_modules: Optional[List[str]] = None,
@@ -75,8 +76,10 @@ def compile(
7576
"The Dynamo backend is an experimental feature, for which only the "
7677
"following arguments are supported: "
7778
"{enabled_precisions, debug, workspace_size, min_block_size, "
78-
"torch_executed_ops, pass_through_build_failures, use_fast_partitioner, "
79-
"enable_experimental_decompositions}"
79+
"max_aux_streams, version_compatible, optimization_level, "
80+
"torch_executed_ops, pass_through_build_failures, "
81+
"use_fast_partitioner, enable_experimental_decompositions, "
82+
"require_full_compilation}"
8083
)
8184

8285
if not isinstance(inputs, collections.abc.Sequence):
@@ -118,6 +121,7 @@ def compile(
118121
"truncate_long_and_double": truncate_long_and_double,
119122
"use_fast_partitioner": use_fast_partitioner,
120123
"enable_experimental_decompositions": enable_experimental_decompositions,
124+
"require_full_compilation": require_full_compilation,
121125
}
122126

123127
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
_SplitterSettingBase,
1313
)
1414
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet
15-
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
15+
from torch_tensorrt.dynamo._defaults import (
16+
DEBUG,
17+
MIN_BLOCK_SIZE,
18+
REQUIRE_FULL_COMPILATION,
19+
)
1620
from torch_tensorrt.dynamo.conversion.converter_registry import (
1721
DYNAMO_CONVERTERS as CONVERTERS,
1822
)
@@ -92,6 +96,7 @@ class TRTPartitioner(_SplitterBase): # type: ignore
9296
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
9397
Generally useful for module-level exclusion ops which are intensive despite being single functions
9498
min_block_size: Minimum number of computational operators per block
99+
require_full_compilation: Require that all computational operators be run in TRT
95100
Returns:
96101
torch.fx.GraphModule
97102
"""
@@ -104,6 +109,7 @@ def __init__(
104109
Collection[str]
105110
] = DEFAULT_SINGLE_NODE_PARTITIONS,
106111
min_block_size: int = MIN_BLOCK_SIZE,
112+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
107113
):
108114
"""
109115
Preprocesses graph before splitting:
@@ -142,6 +148,7 @@ def __init__(
142148

143149
self.num_trt_accelerated_subgraphs: Optional[int] = None
144150
self.allowed_single_node_partition_ops = allowed_single_node_partition_ops
151+
self.require_full_compilation = require_full_compilation
145152

146153
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
147154
"""
@@ -151,12 +158,16 @@ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph
151158
result: List[Subgraph] = []
152159
for subgraph in subgraphs:
153160
if subgraph.is_acc:
154-
if len(subgraph.nodes) >= self.settings.min_acc_module_size or (
155-
self.allowed_single_node_partition_ops is not None
156-
and any(
157-
ConverterRegistry.qualified_name_or_str(node.target)
158-
in self.allowed_single_node_partition_ops
159-
for node in subgraph.nodes
161+
if (
162+
len(subgraph.nodes) >= self.settings.min_acc_module_size
163+
or self.require_full_compilation
164+
or (
165+
self.allowed_single_node_partition_ops is not None
166+
and any(
167+
ConverterRegistry.qualified_name_or_str(node.target)
168+
in self.allowed_single_node_partition_ops
169+
for node in subgraph.nodes
170+
)
160171
)
161172
):
162173
result.append(subgraph)
@@ -185,6 +196,27 @@ def partition_graph(self) -> torch.fx.GraphModule:
185196
# Delegate nodes based on operator coverage
186197
subgraphs = self.put_nodes_into_subgraphs()
187198

199+
# A graph is fully supported if there is a single partition and all operators are supported/convertible
200+
full_support = len([s for s in subgraphs if s.is_acc]) == 1 and not getattr(
201+
self.operator_support, "unsupported_operators", True
202+
)
203+
204+
if not full_support and self.require_full_compilation:
205+
raise AssertionError(
206+
"require_full_compilation=True was specified, but model is not fully supported"
207+
)
208+
209+
if (
210+
full_support
211+
and self.require_full_compilation
212+
and self.settings.min_acc_module_size != MIN_BLOCK_SIZE
213+
):
214+
logger.warning(
215+
"Detected both require_full_compilation and min_block_size compilation "
216+
"arguments were specified. Disregarding min_block_size argument for "
217+
"fully supported model."
218+
)
219+
188220
# Remove segments smaller than the block size (with exceptions)
189221
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
190222

@@ -217,6 +249,7 @@ def partition(
217249
verbose: bool = DEBUG,
218250
min_block_size: int = MIN_BLOCK_SIZE,
219251
torch_executed_ops: Collection[Target] = set(),
252+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
220253
) -> torch.fx.GraphModule:
221254
"""Partition an FX GraphModule with aten ops into TRT engines
222255
Partitioning is based on converter operator support
@@ -226,6 +259,7 @@ def partition(
226259
verbose: Bool representing whether to print operator support
227260
min_block_size: Minimum number of operators per TRT-Engine Block
228261
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
262+
require_full_compilation: Require that all computational operators be run in TRT
229263
Returns:
230264
torch.fx.GraphModule
231265
"""

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from torch.fx.graph_module import GraphModule
66
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
77
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
8-
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
8+
from torch_tensorrt.dynamo._defaults import (
9+
DEBUG,
10+
MIN_BLOCK_SIZE,
11+
REQUIRE_FULL_COMPILATION,
12+
)
913
from torch_tensorrt.dynamo.conversion.converter_registry import (
1014
DYNAMO_CONVERTERS as CONVERTERS,
1115
)
@@ -26,6 +30,7 @@ class TRTPartitioner(CapabilityBasedPartitioner): # type: ignore[misc]
2630
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
2731
Generally useful for module-level exclusion ops which are intensive despite being single functions
2832
min_block_size: Minimum number of computational operators per block
33+
require_full_compilation: Require that all computational operators be run in TRT
2934
Returns:
3035
torch.fx.GraphModule
3136
"""
@@ -40,6 +45,7 @@ def __init__(
4045
Collection[str]
4146
] = DEFAULT_SINGLE_NODE_PARTITIONS,
4247
min_block_size: int = MIN_BLOCK_SIZE,
48+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
4349
) -> None:
4450
super().__init__(
4551
graph_module,
@@ -50,6 +56,7 @@ def __init__(
5056
)
5157

5258
self.min_block_size = min_block_size
59+
self.require_full_compilation = require_full_compilation
5360

5461
def propose_partitions(self) -> List[Partition]:
5562
# Propose partitions using the default, then refine the results
@@ -61,6 +68,22 @@ def propose_partitions(self) -> List[Partition]:
6168
self.operator_support, "unsupported_operators", True
6269
)
6370

71+
if not full_support and self.require_full_compilation:
72+
raise AssertionError(
73+
"require_full_compilation=True was specified, but model is not fully supported"
74+
)
75+
76+
if (
77+
full_support
78+
and self.require_full_compilation
79+
and self.min_block_size != MIN_BLOCK_SIZE
80+
):
81+
logger.warning(
82+
"Detected both require_full_compilation and min_block_size compilation "
83+
"arguments were specified. Disregarding min_block_size argument for "
84+
"fully supported model."
85+
)
86+
6487
# For each partition, determine whether or not the number of computational operators
6588
# exceeds the threshold, and if not, remove that partition
6689
partitions_to_remove = {}
@@ -89,7 +112,7 @@ def propose_partitions(self) -> List[Partition]:
89112
if (
90113
compute_node_count < self.min_block_size
91114
and not exempted_partition
92-
and not full_support
115+
and not (full_support and self.require_full_compilation)
93116
):
94117
partitions_to_remove[id] = compute_node_count
95118

@@ -181,6 +204,7 @@ def partition(
181204
verbose: bool = DEBUG,
182205
min_block_size: int = MIN_BLOCK_SIZE,
183206
torch_executed_ops: Optional[Set[str]] = None,
207+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
184208
) -> torch.fx.GraphModule:
185209
"""Partition an FX GraphModule with aten ops into TRT engines
186210
Partitioning is based on converter operator support
@@ -190,6 +214,7 @@ def partition(
190214
verbose: Bool representing whether to print operator support
191215
min_block_size: Minimum number of operators per TRT-Engine Block
192216
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
217+
require_full_compilation: Whether to require that all operators be run in TRT
193218
Returns:
194219
torch.fx.GraphModule
195220
"""
@@ -198,7 +223,12 @@ def partition(
198223
if torch_executed_ops is not None
199224
else set()
200225
)
201-
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
226+
partitioner = TRTPartitioner(
227+
gm,
228+
supported_ops,
229+
min_block_size=min_block_size,
230+
require_full_compilation=require_full_compilation,
231+
)
202232

203233
# Determine partitions based on user specifications and operator support
204234
# Then, fuse partitions and display overview of supported/unsupported operators

tests/py/dynamo/backend/test_partitioning.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,52 @@ def forward(self, x, y):
3030
"Single operators should not be segmented",
3131
)
3232

33+
def test_partition_fully_supported_one_op_require_full_compilation(self):
34+
class FullySupportedOneOp(torch.nn.Module):
35+
def __init__(self, *args, **kwargs) -> None:
36+
super().__init__(*args, **kwargs)
37+
38+
def forward(self, x, y):
39+
return torch.ops.aten.add.Tensor(x, y)
40+
41+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
42+
partitioned_graph = partitioning.fast_partition(
43+
deepcopy(fx_graph), require_full_compilation=True
44+
)
45+
self.assertEquals(
46+
len(
47+
[
48+
1
49+
for submod in list(partitioned_graph.named_children())
50+
if "_run_on_acc" in submod[0]
51+
]
52+
),
53+
1,
54+
"Single operators can be segmented if full compilation is required",
55+
)
56+
57+
def test_partition_fully_supported_one_op(self):
58+
class FullySupportedOneOp(torch.nn.Module):
59+
def __init__(self, *args, **kwargs) -> None:
60+
super().__init__(*args, **kwargs)
61+
62+
def forward(self, x, y):
63+
return torch.ops.aten.add.Tensor(x, y)
64+
65+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
66+
partitioned_graph = partitioning.fast_partition(deepcopy(fx_graph))
67+
self.assertEquals(
68+
len(
69+
[
70+
1
71+
for submod in list(partitioned_graph.named_children())
72+
if "_run_on_acc" in submod[0]
73+
]
74+
),
75+
0,
76+
"Single operators should not be segmented",
77+
)
78+
3379
def test_partition_fully_supported_multi_op(self):
3480
class FullySupportedMultiOp(torch.nn.Module):
3581
def __init__(self, *args, **kwargs) -> None:
@@ -155,6 +201,40 @@ def forward(self, x, y):
155201

156202

157203
class TestGlobalPartitioning(TestCase):
204+
def test_partition_fully_supported_one_op(self):
205+
class FullySupportedOneOp(torch.nn.Module):
206+
def __init__(self, *args, **kwargs) -> None:
207+
super().__init__(*args, **kwargs)
208+
209+
def forward(self, x, y):
210+
return torch.ops.aten.add.Tensor(x, y)
211+
212+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
213+
partitioned_graph = partitioning.global_partition(deepcopy(fx_graph))
214+
self.assertEquals(
215+
len(list(partitioned_graph.named_children())),
216+
0,
217+
"Single operators should not be segmented",
218+
)
219+
220+
def test_partition_fully_supported_one_op_require_full_compilation(self):
221+
class FullySupportedOneOp(torch.nn.Module):
222+
def __init__(self, *args, **kwargs) -> None:
223+
super().__init__(*args, **kwargs)
224+
225+
def forward(self, x, y):
226+
return torch.ops.aten.add.Tensor(x, y)
227+
228+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
229+
partitioned_graph = partitioning.global_partition(
230+
deepcopy(fx_graph), require_full_compilation=True
231+
)
232+
self.assertEquals(
233+
len(list(partitioned_graph.named_children())),
234+
1,
235+
"Single operators can be segmented if full compilation is required",
236+
)
237+
158238
def test_partition_fully_supported_multi_op(self):
159239
class FullySupportedMultiOp(torch.nn.Module):
160240
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)