Skip to content

Commit ee1f65d

Browse files
committed
feat: Add support for exempting full-support blocks
- When a graph is fully supported, we can ignore the minimum block size argument, which is primarily helpful in reducing segmentation. If the minimum block size is above the number of total operators in the graph, and we support all of those, the whole graph will run in Torch regardless. As a result, we can exempt fully supported graphs from the min block size requirement - Alternatively, if preferable, we can display a warning in such a case, but still respect the minimum block size argument
1 parent a765b72 commit ee1f65d

File tree

2 files changed

+10
-17
lines changed

2 files changed

+10
-17
lines changed

py/torch_tensorrt/dynamo/lowering/_partition.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ def propose_partitions(self) -> List[Partition]:
6161
initial_proposed_partitions = super().propose_partitions()
6262
partitions = {i: part for i, part in enumerate(initial_proposed_partitions)}
6363

64+
# A graph is fully supported if there is a single partition and all operators are supported/convertible
65+
full_support = len(partitions) == 1 and not getattr(
66+
self.operator_support, "unsupported_operators", True
67+
)
68+
6469
# For each partition, determine whether or not the number of computational operators
6570
# exceeds the threshold, and if not, remove that partition
6671
partitions_to_remove = {}
@@ -85,7 +90,11 @@ def propose_partitions(self) -> List[Partition]:
8590
):
8691
compute_node_count += 1
8792

88-
if compute_node_count < self.min_block_size and not exempted_partition:
93+
if (
94+
compute_node_count < self.min_block_size
95+
and not exempted_partition
96+
and not full_support
97+
):
8998
partitions_to_remove[id] = compute_node_count
9099

91100
# Remove any nodes violating the criteria specified by the user

tests/py/dynamo/backend/test_partitioning.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,6 @@
77

88

99
class TestPartitioning(TestCase):
10-
def test_partition_fully_supported_one_op(self):
11-
class FullySupportedOneOp(torch.nn.Module):
12-
def __init__(self, *args, **kwargs) -> None:
13-
super().__init__(*args, **kwargs)
14-
15-
def forward(self, x, y):
16-
return torch.ops.aten.add.Tensor(x, y)
17-
18-
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
19-
partitioned_graph = partition(deepcopy(fx_graph))
20-
self.assertEquals(
21-
len(list(partitioned_graph.named_children())),
22-
0,
23-
"Single operators should not be segmented",
24-
)
25-
2610
def test_partition_fully_supported_multi_op(self):
2711
class FullySupportedMultiOp(torch.nn.Module):
2812
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)