Skip to content

Commit 21d6f98

Browse files
committed
feat: Add support for exempting full-support blocks
- When a graph is fully supported, we can ignore the minimum block size argument, which is primarily helpful in reducing segmentation. If the minimum block size is above the number of total operators in the graph, and we support all of those, the whole graph will run in Torch regardless. As a result, we can exempt fully supported graphs from the min block size requirement - Alternatively, if preferable, we can display a warning in such a case, but still respect the minimum block size argument
1 parent 148b3ba commit 21d6f98

File tree

2 files changed

+10
-17
lines changed

2 files changed

+10
-17
lines changed

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def propose_partitions(self) -> List[Partition]:
5656
initial_proposed_partitions = super().propose_partitions()
5757
partitions = dict(enumerate(initial_proposed_partitions))
5858

59+
# A graph is fully supported if there is a single partition and all operators are supported/convertible
60+
full_support = len(partitions) == 1 and not getattr(
61+
self.operator_support, "unsupported_operators", True
62+
)
63+
5964
# For each partition, determine whether or not the number of computational operators
6065
# exceeds the threshold, and if not, remove that partition
6166
partitions_to_remove = {}
@@ -81,7 +86,11 @@ def propose_partitions(self) -> List[Partition]:
8186
):
8287
compute_node_count += 1
8388

84-
if compute_node_count < self.min_block_size and not exempted_partition:
89+
if (
90+
compute_node_count < self.min_block_size
91+
and not exempted_partition
92+
and not full_support
93+
):
8594
partitions_to_remove[id] = compute_node_count
8695

8796
# Remove any nodes violating the criteria specified by the user

tests/py/dynamo/backend/test_partitioning.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -155,22 +155,6 @@ def forward(self, x, y):
155155

156156

157157
class TestGlobalPartitioning(TestCase):
158-
def test_partition_fully_supported_one_op(self):
159-
class FullySupportedOneOp(torch.nn.Module):
160-
def __init__(self, *args, **kwargs) -> None:
161-
super().__init__(*args, **kwargs)
162-
163-
def forward(self, x, y):
164-
return torch.ops.aten.add.Tensor(x, y)
165-
166-
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
167-
partitioned_graph = partitioning.global_partition(deepcopy(fx_graph))
168-
self.assertEquals(
169-
len(list(partitioned_graph.named_children())),
170-
0,
171-
"Single operators should not be segmented",
172-
)
173-
174158
def test_partition_fully_supported_multi_op(self):
175159
class FullySupportedMultiOp(torch.nn.Module):
176160
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)