Skip to content

feat: Hierarchical Partitioner to support multi-backends #3539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docsrc/contributors/partitioning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,16 @@ In this example we will collect the arithmetic ops in a TensorRT segment and the
In some cases this approach may create adjacent segments in the partition which have the same target. As a clean-up step we can consolidate these adjacent segments to further reduce the number of segments in the final partition.
The merge segments step identifies a list of segments that are adjacent in the graph, have the same target, and are not marked as `do_not_merge`. The nodes from these segments will be combined into a single new segment that will replace the merged segments in the partition.
The `do_not_merge` marking is used to prevent merging of segments created for conditional nodes and loops that are handled as special cases in graph stitching and should not be merged with adjacent segments of the same type.


Hierarchical Partitioner for Dynamo
===================================

The Hierarchical Partitioner is an extension to the standard TensorRT partitioner that allows for more sophisticated partitioning strategies by considering backend priority and operator support. This is particularly useful when you want to distribute different parts of your model across multiple backends based on their capabilities and priorities.

We currently support hierarchical adjacency partitioner, which extends the standard adjacency partitioner with the following capabilities:

1. **Backend priority ordering**: Assign operators to backends based on a priority order, ensuring that operators are assigned to the highest-priority backend that supports them.
2. **Multi-backend support**: Distribute model execution across multiple backends based on operator support.

Please refer to `hierarchical_partitioner_example <https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/hierarchical_partitioner_example.py>`_ for more details.
176 changes: 176 additions & 0 deletions examples/dynamo/hierarchical_partitioner_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from typing import Any, Callable

import torch
import torch.nn as nn
import torch_tensorrt
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo import partitioning
from torch_tensorrt.dynamo._compiler import convert_module
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
pre_export_lowering,
)
from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import (
hierarchical_adjacency_partition,
)
from torch_tensorrt.dynamo.utils import (
get_output_metadata,
)
from torchvision import models


class InductorModule(torch.nn.Module): # type: ignore[misc]
"""Wrapper module for inductor compiled function."""

def __init__(self, func: Callable[..., Any]) -> None:
super().__init__()
self.func = func

def forward(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)


class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(128)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = torch.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = torch.relu(x)
return x


def main():
# Create model
model = SimpleModel().cuda()
# model = models.efficientnet_b0(pretrained=True).cuda()
model = model.eval()

# Create example input
example_input = torch.randn(1, 3, 224, 224).cuda()

exported_program = torch.export.export(model, (example_input,))
exported_program = pre_export_lowering(exported_program)
exported_program = exported_program.run_decompositions(get_decompositions())

gm = exported_program.module()

print("Original Model Structure:\n", gm)

original_output = model(example_input)

# 1. Partition the model into blocks that can be executed by different backends
partitioned_model, op_support = hierarchical_adjacency_partition(
gm,
min_block_size=1,
backend_priority=["inductor", "tensorrt"],
backend_support_map={
"inductor": {
"torch.ops.aten.convolution.default",
},
"tensorrt": CONVERTERS.keys(),
},
torch_executed_ops={
"torch.ops.aten._native_batch_norm_legit_no_training.default"
},
require_full_compilation=False,
skip_fusion=True,
)

print("1. Partitioned Model Structure:\n", partitioned_model)

# 2. Compile each submodule with the corresponding backend
submodule_node_dict = {}
for node in partitioned_model.graph.nodes:
if "_run_on_acc" not in node.name:
continue
submodule_node_dict[node.name] = node

# Store compiled replicas of Torch subgraphs
compiled_modules = {}

for name, _ in partitioned_model.named_children():
submodule = getattr(partitioned_model, name)
if not isinstance(submodule, torch.fx.graph_module.GraphModule):
continue

if "_run_on_acc" not in name:
submodule.to("cuda")
continue

if name not in submodule_node_dict:
raise ValueError(
f"node_name: {name} does not exist in the submodule node dictionary"
)

# set the submodule metadata back to the parent module_node
metadata_list = get_output_metadata(submodule)
assert len(metadata_list) > 0
metadata_keys = ["val", "tensor_meta"]
for key in metadata_keys:
if key not in submodule_node_dict[name].meta:
meta_val_list = [
metadata[key] for metadata in metadata_list if key in metadata
]
submodule_node_dict[name].meta[key] = meta_val_list
break

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
assert submodule_inputs is not None

# compile submodule with pytorch inductor backend
if "_run_on_acc_inductor" in name:
sub_inputs = []
for input in submodule_inputs:
sub_input = input.torch_tensor.to(
dtype.to(input.dtype, t=torch.dtype)
).cuda()
sub_inputs.append(sub_input)

compiled_func = torch._inductor.compile(
submodule,
sub_inputs,
)
# Wrap the compiled function to be a torch.nn.Module
compiled_submodule = InductorModule(compiled_func)

# compile submodule with tensorrt backend
elif "_run_on_acc_tensorrt" in name:
compiled_submodule = convert_module(
submodule,
submodule_inputs,
name=name,
)
else:
raise ValueError(f"Unknown backend for submodule: {name}")

compiled_modules[name] = compiled_submodule

# Replace all FX Modules with compiled Modules
for name, compiled_module in compiled_modules.items():
setattr(partitioned_model, name, compiled_module)

print("2. Compiled Model Structure:\n", partitioned_model)

with torch.no_grad():
partitioned_output = partitioned_model(example_input)
print(
"3. Verify that Partitioned output == Original output:",
torch.allclose(partitioned_output, original_output, 1e-2, 1e-2),
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/partitioning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._adjacency_partitioner import partition as fast_partition
from ._global_partitioner import partition as global_partition
from ._hierarchical_partitioner import hierarchical_adjacency_partition
from .common import (
construct_submodule_inputs,
get_graph_converter_support,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def partition(

Args:
gm: FX GraphModule to partition
verbose: Bool representing whether to print operator support
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
require_full_compilation: Require that all computational operators be run in TRT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def partition(

Args:
gm: FX GraphModule to partition
verbose: Bool representing whether to print operator support
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
require_full_compilation: Whether to require that all operators be run in TRT
Expand Down
Loading
Loading