From b03d1bcff31ef4bff39fa2ccdea1414fd1e12489 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 29 May 2025 16:17:09 -0700 Subject: [PATCH 1/7] support for hierarchical adjacency partitioner --- examples/hierarchical_partitioner_example.py | 103 ++ .../partitioning/_hierarchical_partitioner.py | 439 +++++++++ .../dynamo/partitioning/splitter_base.py | 927 ++++++++++++++++++ 3 files changed, 1469 insertions(+) create mode 100644 examples/hierarchical_partitioner_example.py create mode 100644 py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py create mode 100644 py/torch_tensorrt/dynamo/partitioning/splitter_base.py diff --git a/examples/hierarchical_partitioner_example.py b/examples/hierarchical_partitioner_example.py new file mode 100644 index 0000000000..aa2df19516 --- /dev/null +++ b/examples/hierarchical_partitioner_example.py @@ -0,0 +1,103 @@ +# from torch_tensorrt.dynamo.partitioning._global_partitioner import partition +import torch +import torch.nn as nn +import torch_tensorrt +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_ATEN_CONVERTERS, +) +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS as CONVERTERS, +) +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition +from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import ( + hierarchical_partition, +) + + +class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(64) + self.bn2 = nn.BatchNorm2d(128) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = torch.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = torch.relu(x) + return x + + +def main(): + # Create model + model = SimpleModel().cuda() + # model = models.efficientnet_b0(pretrained=True).cuda() + model = model.eval() + + # Create example input + example_input = torch.randn(1, 3, 224, 224).cuda() + + exported_program = torch.export.export(model, (example_input,)) + exported_program = pre_export_lowering(exported_program) + exported_program = exported_program.run_decompositions(get_decompositions()) + + gm = exported_program.module() + + print(gm.graph) + + # Partition the model using the adjacency partitioner + # partitioned_model, op_support = partition( + # gm, + # verbose=True, + # min_block_size=1, + # torch_executed_ops=[ + # torch.ops.aten.relu.default, + # ], + # ) + + partitioned_model, op_support = hierarchical_partition( + gm, + verbose=True, + min_block_size=1, + backend_priority=["mlir", "tensorrt"], # , "inductor"], + backend_support_map={ + "mlir": { + # operator.getitem, + torch.ops.aten.conv2d.default, + torch.ops.aten.convolution.default, + }, + "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), + # "inductor": { + # torch.ops.aten.relu.default, + # }, + }, + torch_executed_ops=[ + torch.ops.aten._native_batch_norm_legit_no_training.default + ], + require_full_compilation=False, + skip_fusion=False, + ) + + print("\nPartitioned Model Structure:") + print(partitioned_model) + + with torch.no_grad(): + output = partitioned_model(example_input) + print("Partitioned output:", output) + print( + "Partitioned output == original output:", + torch.allclose(model(example_input), output, 1e-2, 1e-2), + ) + + +if __name__ == "__main__": + main() diff --git a/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py new file mode 100644 index 0000000000..d5505b1eee --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py @@ -0,0 +1,439 @@ +import logging +from typing import Collection, Dict, List, Optional, Set, Tuple + +import torch +import torch.fx.passes.operator_support as ops +from torch._ops import OpOverload +from torch.fx.node import Target, _get_qualified_name +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeList, NodeSet +from torch_tensorrt.dynamo._defaults import ( + DEBUG, + MIN_BLOCK_SIZE, + REQUIRE_FULL_COMPILATION, +) +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_ATEN_CONVERTERS, + ConverterRegistry, +) +from torch_tensorrt.dynamo.partitioning.splitter_base import ( + FxNetAccFusionsFinder, + FxNetAccNodesFinder, + Subgraph, + _SplitterBase, + _SplitterSettingBase, +) + +logger = logging.getLogger(__name__) + + +class BackendOpSupportTester(ops.OperatorSupportBase): # type: ignore + """Class to determine whether operators are supported by specific backends""" + + def __init__( + self, + backend_support_map: Dict[str, Set[OpOverload]], + backend_priority: List[str], + torch_executed_ops: Collection[Target] = set(), + ) -> None: + super().__init__() + + # Initialize sets of supported/unsupported operators + self.supported_operators: Dict[str, int] = {} + self.unsupported_operators: Dict[str, int] = {} + self.torch_executed_ops = torch_executed_ops + # Map of backend names to sets of supported operators + self.backend_support_map = backend_support_map + # Ordered list of backend names, from highest to lowest priority + self.backend_priority = backend_priority + + def is_node_supported( + self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node + ) -> Tuple[bool, Optional[str]]: + node_name = ConverterRegistry.qualified_name_or_str(node.target) + + for i, backend_name in enumerate(self.backend_priority): + supported_ops = self.backend_support_map.get(backend_name, set()) + supported_ops = {_get_qualified_name(op) for op in supported_ops} + + if ( + (node_name in supported_ops or node.op == "get_attr") + and node_name not in self.torch_executed_ops + and node.target not in self.torch_executed_ops + ): + # If node is a proper, supported computational node, store the operator + if not node.is_impure() and node.op != "get_attr": + if node_name not in self.supported_operators: + self.supported_operators[f"{backend_name}_{node_name}"] = 1 + else: + self.supported_operators[f"{backend_name}_{node_name}"] += 1 + + return True, backend_name + else: + if i == len(self.backend_priority) - 1 and not node.is_impure(): + if node_name not in self.unsupported_operators: + self.unsupported_operators[node_name] = 1 + else: + self.unsupported_operators[node_name] += 1 + + return False, None + + def print_support_overview(self, num_acc_subgraphs: Optional[int] = None) -> None: + if num_acc_subgraphs is not None: + logger.debug( + f"\nNumber of Accelerated Subgraphs Generated: {num_acc_subgraphs}" + ) + + # Reformat support messages for debugger to print node overview as a single string + supported_nodes_str = "\nSupported Nodes:\n" + for node_name, count in self.supported_operators.items(): + supported_nodes_str += f"- {node_name} + Operator Count: {count}\n" + + logger.debug(supported_nodes_str) + + if self.unsupported_operators: + unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n" + for node_name, count in self.unsupported_operators.items(): + unsupported_nodes_str += f"- {node_name} + Operator Count: {count}\n" + + logger.debug(unsupported_nodes_str) + else: + logger.debug("\nAll Nodes Supported\n") + + +class HierarchicalTRTPartitioner(_SplitterBase): + """Hierarchical partitioner to split an FX graph into subgraphs based on backend priority + + This partitioner extends the TRTPartitioner of adjacency_partitioner with backend priority awareness, + allowing different parts of the model to be executed on different backends based on + operator support and priority ordering. + + Args: + module: FX GraphModule to partition + operator_support: OperatorSupport class describing allowed operators + backend_support_map: Dictionary mapping backend names to sets of supported operators + backend_priority: Ordered list of backend names, from highest to lowest priority + allowed_single_node_partition_ops: Nodes which can be included in single-node partitions + min_block_size: Minimum number of computational operators per block + require_full_compilation: Require that all computational operators be run in TRT + Returns: + torch.fx.GraphModule + """ + + def __init__( + self, + module: torch.fx.GraphModule, + operator_support: ops.OperatorSupportBase, + backend_support_map: Dict[str, Set[Target]], + backend_priority: List[str], + allowed_single_node_partition_ops: Optional[Collection[str]] = None, + min_block_size: int = MIN_BLOCK_SIZE, + require_full_compilation: bool = REQUIRE_FULL_COMPILATION, + return_tuple: bool = False, + skip_fusion: bool = False, + ): + """ + Preprocesses graph before splitting with backend priority awareness + """ + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + self.backend_support_map = backend_support_map + self.backend_priority = backend_priority + + self.settings = _SplitterSettingBase( + min_acc_module_size=min_block_size, + allow_non_tensor=True, + skip_fusion=skip_fusion, + ) + self.operator_support = operator_support + + # Get all accelerated nodes based on operator support conditions + self.acc_nodes = FxNetAccNodesFinder( + self.module, self.operator_support, self.settings.allow_non_tensor + )() + + if self.settings.skip_fusion: + self.fusions = {} + else: + self.fusions = FxNetAccFusionsFinder(module, set(self.acc_nodes))() + + # Modify deps to add more deps for fused nodes + self.deps = self.find_deps() + self.update_deps_for_fusions() + + self.non_acc_submodule_name = "_run_on_gpu_" + self._node_submodule_map: Dict[str, str] = {} + + self.num_accelerated_subgraphs: Optional[int] = None + self.allowed_single_node_partition_ops = allowed_single_node_partition_ops + self.require_full_compilation = require_full_compilation + self._return_tuple = return_tuple + + def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: + """ + This pass finds ACC submodules with less than specified size and merges + them with adjacent GPU submodules. + """ + result: List[Subgraph] = [] + for subgraph in subgraphs: + if subgraph.is_acc: + if ( + len(subgraph.nodes) >= self.settings.min_acc_module_size + or self.require_full_compilation + or ( + self.allowed_single_node_partition_ops is not None + and any( + ConverterRegistry.qualified_name_or_str(node.target) + in self.allowed_single_node_partition_ops + for node in subgraph.nodes + ) + ) + ): + result.append(subgraph) + else: + logger.debug( + "Eliminating acc subgraph because it's smaller than the threshold: " + f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" + ) + # if the last subgraph result[-1] is non-acc or has the same backend, merge the current subgraph into it + if result and ( + not result[-1].is_acc or result[-1].backend == subgraph.backend + ): + result[-1].nodes.extend(subgraph.nodes) + else: + # if the last subgraph result[-1] has different backends, then append the current subgraph as non-acc + subgraph.is_acc = False + subgraph.backend = "None" + result.append(subgraph) + else: + if result and not result[-1].is_acc: + result[-1].nodes.extend(subgraph.nodes) + else: + if result: + if result[-1].backend == subgraph.backend: + result[-1].nodes.extend(subgraph.nodes) + else: + result.append(subgraph) + else: + result.append(subgraph) + return result + + def partition_graph(self) -> torch.fx.GraphModule: + """Partitions the GraphModule into subgraphs based on backend priority + + Returns a GraphModule with submodules for each segment + """ + # Delegate nodes based on operator coverage + subgraphs = self.put_nodes_into_subgraphs() + + # A graph is fully supported if there is a single partition and all operators are supported/convertible + full_support = len([s for s in subgraphs if s.is_acc]) == 1 and not getattr( + self.operator_support, "unsupported_operators", True + ) + + if not full_support and self.require_full_compilation: + raise AssertionError( + "require_full_compilation=True was specified, but model is not fully supported or multiple partitions are found" + ) + + if ( + full_support + and self.require_full_compilation + and self.settings.min_acc_module_size != MIN_BLOCK_SIZE + ): + logger.warning( + "Detected both require_full_compilation and min_block_size compilation " + "arguments were specified. Disregarding min_block_size argument for " + "fully supported model." + ) + + # Remove segments smaller than the block size (with exceptions) + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + + # Set the number of accelerated subgraphs to be generated + self.num_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) + + # Tag the accelerated nodes and split the graph accordingly + self.tag(subgraphs) + return self.split() + + def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: + """Generates starter nodes for partitioning + segmentation""" + # Starter accelerated nodes are all callable accelerated ops + starter_acc_nodes = { + node for node in self.acc_nodes if node.op in CALLABLE_NODE_OPS + } + + # Started non-accelerated nodes are the rest of the callable nodes + starter_non_acc_nodes = { + node + for node in self.module.graph.nodes + if (node not in starter_acc_nodes and node.op in CALLABLE_NODE_OPS) + } + return starter_non_acc_nodes, starter_acc_nodes + + def put_nodes_into_subgraphs(self) -> list[Subgraph]: + # We start graph traversal from leaf nodes + current_cpu_nodes, current_acc_nodes = self.starter_nodes() + visited_nodes: NodeSet = set() + + # Determine which subgraph to start from based on which subgraph has + # 0-dep node + acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes) + + current_subgraph_nodes: NodeList = [] + + # Result accumulator + subgraphs: list[Subgraph] = [] + while current_cpu_nodes or current_acc_nodes: + # Find the first node that should belong to the current subgraph and has all dependencies resolved + current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes + node = next( + (n for n in current_nodes if self.deps[n] <= visited_nodes), + None, + ) + + # If no node was found, then it's time to flip the mode and start a new subgraph + if node is None: + if not current_subgraph_nodes: + raise FxNetSplitterInternalError("Subgraph can't be empty") + + subgraphs.append( + Subgraph( + is_acc=acc_subgraph, + nodes=current_subgraph_nodes, + backend=( + current_subgraph_nodes[-1].backend + if acc_subgraph + else "None" + ), + ) + ) + acc_subgraph = not acc_subgraph + current_subgraph_nodes = [] + continue + + # If the backend changed, then it's time to start a new subgraph + if ( + current_subgraph_nodes + and current_subgraph_nodes[-1].backend != node.backend + ): + if not current_subgraph_nodes: + raise FxNetSplitterInternalError("Subgraph can't be empty") + + subgraphs.append( + Subgraph( + is_acc=acc_subgraph, + nodes=current_subgraph_nodes, + backend=current_subgraph_nodes[-1].backend, + ) + ) + current_subgraph_nodes = [] + continue + + current_nodes.remove(node) + visited_nodes.add(node) + current_subgraph_nodes.append(node) + + # Add fusion buddies + if node in self.fusions: + if node in self.acc_nodes: + current_acc_nodes.update(self.fusions[node] - visited_nodes) + else: + current_cpu_nodes.update(self.fusions[node] - visited_nodes) + + # Put depending nodes into the queue + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + # Add downstream nodes + if user in self.acc_nodes: + current_acc_nodes.add(user) + else: + current_cpu_nodes.add(user) + + # Check if the last subgraph was not created + if current_subgraph_nodes: + subgraphs.append( + Subgraph( + is_acc=acc_subgraph, + nodes=current_subgraph_nodes, + backend=( + current_subgraph_nodes[-1].backend if acc_subgraph else "None" + ), + ) + ) + + if not subgraphs: + raise FxNetSplitterInternalError("Couldn't create subgraphs") + + return subgraphs + + +class FxNetSplitterInternalError(Exception): + pass + + +def hierarchical_partition( + gm: torch.fx.GraphModule, + verbose: bool = DEBUG, + min_block_size: int = MIN_BLOCK_SIZE, + torch_executed_ops: Collection[Target] = set(), + backend_support_map: Optional[Dict[str, Set[OpOverload]]] = None, + backend_priority: Optional[List[str]] = None, + require_full_compilation: bool = REQUIRE_FULL_COMPILATION, + skip_fusion: bool = False, +) -> Tuple[torch.fx.GraphModule, BackendOpSupportTester]: + """Partition an FX GraphModule with aten ops into submodules using hierarchical partitioning + based on backend priority and operator support + + Args: + gm: FX GraphModule to partition + verbose: Bool representing whether to print operator support + min_block_size: Minimum number of operators per TRT-Engine Block + backend_support_map: Dictionary mapping backend names to sets of supported operators + backend_priority: Ordered list of backend names, from highest to lowest priority + require_full_compilation: Require that all computational operators be run in TRT + skip_fusion: Skip fusions found by FxNetAccFusionsFinder + Returns: + torch.fx.GraphModule, BackendOpSupportTester + """ + # Ensure graph is clean prior to partitioning + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + # Default backend support map if none provided + if backend_support_map is None: + backend_support_map = { + "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), + "inductor": set(), + } + + # Default backend priority if none provided + if backend_priority is None: + backend_priority = ["tensorrt", "inductor"] + + # Construct BackendOpSupportTester + supported_ops = BackendOpSupportTester( + backend_support_map=backend_support_map, + backend_priority=backend_priority, + torch_executed_ops=torch_executed_ops, + ) + partitioner = HierarchicalTRTPartitioner( + gm, + supported_ops, + backend_support_map=backend_support_map, + backend_priority=backend_priority, + min_block_size=min_block_size, + require_full_compilation=require_full_compilation, + skip_fusion=skip_fusion, + ) + + partitioned_graph = partitioner.partition_graph() + + if verbose: + supported_ops.print_support_overview(partitioner.num_accelerated_subgraphs) + + return partitioned_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/splitter_base.py b/py/torch_tensorrt/dynamo/partitioning/splitter_base.py new file mode 100644 index 0000000000..a7a4f280ab --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/splitter_base.py @@ -0,0 +1,927 @@ +# mypy: allow-untyped-defs +import argparse +import copy +import logging +from collections import defaultdict +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional + +import torch +from torch.fx._compatibility import compatibility +from torch.fx.node import map_arg +from torch.fx.passes.graph_drawer import FxGraphDrawer +from torch.fx.passes.graph_manipulation import get_size_of_node +from torch.fx.passes.operator_support import OperatorSupportBase, get_node_target +from torch.fx.passes.shape_prop import ShapeProp +from torch.fx.passes.split_utils import split_by_tags +from torch.fx.passes.tools_common import ( + CALLABLE_NODE_OPS, + FxNetAccFusionsFinder, + NodeList, + NodeSet, + Tensors, + is_node_output_tensor, +) + +__all__ = [ + "FxNetAccNodesFinder", + "FxNetSplitterInternalError", + "Subgraph", + "SplitResult", + "generate_inputs_for_submodules", +] +_LOGGER = logging.getLogger(__name__) + +DEFAULT_MIN_ACC_MODULE_SIZE = 1 +DEFAULT_SKIP_FUSION = False +DEFAULT_ALLOW_NON_TENSOR = False + + +class _SplitterSettingBase: + def __init__( + self, + min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, + skip_fusion=DEFAULT_SKIP_FUSION, + allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR, + max_acc_splits: int = -1, + ): + parser = argparse.ArgumentParser() + parser.add_argument( + "--min-acc-module-size", + "--min_acc_module_size", + required=False, + type=int, + help="Minimum size limit of an accelerator subgraph.", + ) + parser.add_argument( + "--max-acc-splits", + "--max_acc_splits", + required=False, + type=int, + help="Enforce a maximum number of split subgraphs.", + ) + parser.add_argument( + "--skip-fusion", + "--skip_fusion", + default=False, + action="store_true", + help="If true then no fusion groups. Fusion group is used to " + "enforce no non-tensor data flow between submodules. If we don't " + "have this constrain, setting this to false is recommended as it " + "can reduce overhead.", + ) + parser.add_argument( + "--allow-non-tensor", + "--allow_non_tensor", + default=False, + action="store_true", + help="For some backends non-tensor data flow between cpu and them " + "are not allowed. Therefore, if a node supported by accelerator but " + "it has non-tensor inputs or outputs to a cpu node we would want to " + "consider it as a cpu node during splitting. However, for some backends " + "we might not care about non-tensor data flow and we can set this option " + "to true to disable the functionality that prevent non-tensor data flow.", + ) + args, _unknown = parser.parse_known_args() + + self.min_acc_module_size: int = ( + args.min_acc_module_size + if args.min_acc_module_size + else min_acc_module_size + ) + self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion + self.allow_non_tensor: bool = ( + args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + ) + self.max_acc_splits: int = max_acc_splits + + +@compatibility(is_backward_compatible=False) +class FxNetAccNodesFinder: + """ + Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor + input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. + + I.e. if we have a chain: + + ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 + + where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. + + This behavior can be turned off by passing allow_non_tensor=True. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + operator_support: OperatorSupportBase, + allow_non_tensor: bool, + ): + self.module = module + self.operator_support = operator_support + self.allow_non_tensor = allow_non_tensor + self.acc_nodes: NodeSet = set() + + def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): + """ + Transitively excludes nodes from ACC supported set. + For every node in the worklist: + - removes its downstream ACC nodes from ACC supported set, + - if any downstream ACC node produces non-tensor output, + then it gets added into the worklist. + """ + while cpu_worklist: + node = cpu_worklist.pop(0) + + for user in node.users: + if user in self.acc_nodes: + self.acc_nodes.remove(user) + if not is_node_output_tensor(user): + cpu_worklist.append(user) + + def reduce_acc_nodes_non_tensor_input(self): + """ + Excludes nodes from ACC supported set that have direct + upstream CPU nodes that produce non-tensor outputs. + """ + non_tensor_cpu_nodes: NodeList = [] + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + if node in self.acc_nodes: + continue + if is_node_output_tensor(node): + continue + non_tensor_cpu_nodes.append(node) + + self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) + + def reduce_acc_nodes_non_tensor_output(self): + """ + Excludes nodes from ACC supported set that produce non-tensor + outputs and have downstream CPU nodes. + """ + while True: + new_cpu_nodes: NodeList = [] + + for acc_node in self.acc_nodes: + if is_node_output_tensor(acc_node): + continue + for user in acc_node.users: + if user not in self.acc_nodes: + new_cpu_nodes.append(acc_node) + break + + if not new_cpu_nodes: + break + + for new_cpu_node in new_cpu_nodes: + self.acc_nodes.remove(new_cpu_node) + + self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) + + def __call__(self) -> NodeSet: + submodules = dict(self.module.named_modules()) + for n in self.module.graph.nodes: + n.backend = "None" + if n.op in CALLABLE_NODE_OPS: + is_supported, backend = self.operator_support.is_node_supported( + submodules, n + ) + if is_supported: + n.backend = backend + self.acc_nodes.add(n) + + if not self.allow_non_tensor: + self.reduce_acc_nodes_non_tensor_input() + self.reduce_acc_nodes_non_tensor_output() + + return self.acc_nodes + + +@compatibility(is_backward_compatible=False) +class FxNetSplitterInternalError(Exception): + pass + + +@compatibility(is_backward_compatible=False) +@dataclass +class Subgraph: + is_acc: bool + backend: str + nodes: NodeList + device_ordinal: Optional[int] = None + + +@compatibility(is_backward_compatible=False) +class SplitResult(NamedTuple): + """ + Stores the results of the splitter. + + Attributes: + split_module: root module after splitting. + submodule_inputs: a dict that maps submodule name to its inputs. + non_acc_submodule_prefix: the prefix for non acc submodules. For + acc submodule the prefix is always "_run_on_acc_". + """ + + split_module: torch.fx.GraphModule + submodule_inputs: dict[str, Any] + non_acc_submodule_prefix: str + + +@compatibility(is_backward_compatible=False) +def generate_inputs_for_submodules( + model: torch.nn.Module, + inputs: Sequence[Any], + target_submodules: Iterable[str], + deepcopy: bool = False, +) -> dict[str, Any]: + """ + Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this + function doesn't work. + + Args: + model: root model. + inputs: inputs to the root model. + target_submodules: submodules that we want to generate inputs for. + + Returns: + A dict that maps from submodule name to its inputs. + """ + + handles = [] + results = {} + submodule_to_names = {mod: name for name, mod in model.named_modules()} + + def pre_forward(module, module_inputs): + results[submodule_to_names[module]] = ( + copy.deepcopy(module_inputs) if deepcopy else module_inputs + ) + + for name, mod in model.named_modules(): + if name in target_submodules: + handles.append(mod.register_forward_pre_hook(pre_forward)) + + def clean_up_handles(): + for h in handles: + h.remove() + + try: + with torch.no_grad(): + model(*inputs) + except Exception as e: + clean_up_handles() + raise e + + clean_up_handles() + return results + + +class _SplitterBase: + """ + Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator. + Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible. + Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator. + + Given the following graph: + ==> b ==> + // \\ + a d + \\ // + ==> c ==> + + class SimpleModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.cos(a) + d = b + c + return d + + and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator, + we will get the following split result: + + main: + def forward(self, a): + run_on_acc_0_0 = self._run_on_acc_0_0(a) + getitem = run_on_acc_0_0[0] + getitem_1 = run_on_acc_0_0[1] + run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1) + return run_on_cpu_1_1 + + _run_on_acc_0_0: + def forward(self, a): + sin_1 = torch.sin(a) + cos_1 = torch.cos(a) + return (sin_1, cos_1) + + _run_on_cpu_1_1: + def forward(self, sin_1, cos_1): + add_1 = sin_1 + cos_1 + return add_1 + """ + + # PCIe bandwidth for the backend, default to 100 GB/s + PCIe_BW = 100 * 2**30 + + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Sequence[Any], + operator_support: OperatorSupportBase, + settings: _SplitterSettingBase, + non_acc_submodule_name: str = "_run_on_cpu_", + return_tuple: bool = False, + nodes_finder: Optional[FxNetAccNodesFinder] = None, + ): + """ + Preprocesses graph before splitting: + - finds nodes supported by ACC, + - finds fusion groups for ACC nodes having non-tensor IO, + - builds a graph of direct dependencies, + - builds a map of fused nodes to their fusions. + As a result we get self.acc_nodes, self.deps and self.fusions. + """ + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + ShapeProp(self.module).propagate(*sample_input) + + self.settings = settings + self.operator_support = operator_support + self.sample_input = sample_input + if nodes_finder is None: + nodes_finder = FxNetAccNodesFinder( + self.module, self.operator_support, self.settings.allow_non_tensor + ) + self.acc_nodes = nodes_finder() + + if self.settings.skip_fusion: + self.fusions = {} + else: + self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() + + # Modify deps to add more deps for fused nodes + self.deps = self.find_deps() + self.update_deps_for_fusions() + + self.non_acc_submodule_name = non_acc_submodule_name + self._node_submodule_map: dict[str, str] = {} + self._return_tuple = return_tuple + + self.tags: list[str] = [] + + # =============================================================== + # Helpers for ctor and initial state + # =============================================================== + + def get_node_submodule_map(self) -> dict[str, str]: + """Returns a map from node name to submodule name, e.g. + node: main_module_impl_impl_over_arch_unary_multiple_embedding + _pooling_embedding_pooling_sparse_entity_equivalence_key + _proxy_embedding_bag + maps to submodule name of: _run_on_acc_1 + """ + return self._node_submodule_map + + def find_deps(self) -> dict[torch.fx.Node, NodeSet]: + """ + Builds a graph of node dependencies. Leaf nodes don't have any + dependencies and the "output" node doesn't have nodes depending on it. + + Resulting graph has only direct dependencies, i.e. there are no + transitive dependencies. + """ + deps: dict[torch.fx.Node, NodeSet] = defaultdict(set) + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + for user in node.users: + if user.op != "output": + deps[user].add(node) + return deps + + def update_deps_for_fusions(self): + """ + Updates graph of dependencies so that: + - nodes from the same fusion depend on the same set of outer nodes, + - outer nodes depending on a fusion depend on all nodes in that fusion. + """ + for node in self.fusions: + fusion = self.fusions[node] + for fused_neighbor in fusion: + self.deps[node].update(self.deps[fused_neighbor] - fusion) + + for user in fused_neighbor.users: + if user not in fusion: + self.deps[user].add(node) + + # =============================================================== + # Helpers for preview + # =============================================================== + + def _lower_model_to_backend( + self, mod: torch.fx.GraphModule, inputs: Tensors + ) -> torch.nn.Module: + """ + Lower the model to a backend. + """ + + return mod + + def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str: + """ + When an error occurs during lowering or running the lowered mod, we use this + function to find culprits in the `mod` that causes the error. + """ + + return "Unable to find a culprit because _find_culprit() function is not implemented." + + def _draw_graph_based_on_node_support( + self, mod: torch.fx.GraphModule, supported_nodes: NodeList + ): + color_map = { + "default": "AliceBlue", + "supported": "chartreuse1", + "unsupported": "crimson", + } + + class CustomDrawer(FxGraphDrawer): + def _get_node_style(self, node): + template = super()._get_node_style(node) + if node in supported_nodes: + template["fillcolor"] = color_map["supported"] + elif node.op in CALLABLE_NODE_OPS: + template["fillcolor"] = color_map["unsupported"] + else: + template["fillcolor"] = color_map["default"] + + return template + + drawer = CustomDrawer(mod, "node_support", ignore_getattr=True) + dot_graph = drawer.get_main_dot_graph() + # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. + dot_graph.write_raw("node_support.dot") # type: ignore[attr-defined] + + def node_support_preview(self, dump_graph: bool = False): + submodules = dict(self.module.named_modules()) + + supported_nodes: NodeList = [] + supported_node_types = defaultdict(set) + unsupported_node_types = defaultdict(set) + + def get_dtype(arg): + tensor_meta = arg.meta.get("tensor_meta") + return getattr(tensor_meta, "dtype", None) + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + target = get_node_target(submodules, node) + + # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None. + arg_dtypes = [ + get_dtype(arg) if isinstance(arg, torch.fx.Node) else None + for arg in node.args + ] + + # Find last non-None element. If all elements are None, return max_len. + last_index = len(arg_dtypes) - next( + ( + i + for i, dtype in enumerate(reversed(arg_dtypes)) + if dtype is not None + ), + len(arg_dtypes), + ) + + # Strip None elements at the end. + arg_dtypes_tuple = tuple(arg_dtypes[:last_index]) + kwarg_dtypes_tuple = tuple( + (k, get_dtype(arg)) + for k, arg in node.kwargs.items() + if isinstance(arg, torch.fx.Node) + ) + + if self.operator_support.is_node_supported(submodules, node)[0]: + supported_nodes.append(node) + supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) + else: + unsupported_node_types[target].add( + (arg_dtypes_tuple, kwarg_dtypes_tuple) + ) + + if dump_graph: + self._draw_graph_based_on_node_support(self.module, supported_nodes) + + reports = "\nSupported node types in the model:\n" + for t, dtypes in supported_node_types.items(): + for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: + reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" + + reports += "\nUnsupported node types in the model:\n" + for t, dtypes in unsupported_node_types.items(): + for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: + reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" + + print(reports) + + # Return reports for testing purpose + return reports + + def split_preview(self, dump_graph: bool = False): + reports = "" + subgraphs = self.put_nodes_into_subgraphs() + acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) + cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num + reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" + reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" + + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) + cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num + reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" + reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" + + for i, subgraph in enumerate(subgraphs): + reports += ( + f"_run_on_acc_{i}: " + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{i}: " + ) + reports += f"{len(subgraph.nodes)} node(s)\n" + + self.tag(subgraphs) + split_mod = self.split(remove_tag=True) + split_mod.eval() + + if dump_graph: + drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True) + dot_graphs = drawer.get_all_dot_graphs() + for name, dot_graph in dot_graphs.items(): + # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. + dot_graph.write_raw(f"{name}.dot") # type: ignore[attr-defined] + + max_qps: float = self.PCIe_BW + bottleneck_module = "" + + for node in split_mod.graph.nodes: + if node.op == "call_module" and "acc" in node.target: + reports += f"\nProcessing acc submodule {node.target}\n" + + submod = getattr(split_mod, node.target) + + def get_submod_inputs(main_mod, submod, example_inputs): + sub_inputs = None + + def get_inputs(self, inputs): + nonlocal sub_inputs + sub_inputs = inputs + + handle = submod.register_forward_pre_hook(get_inputs) + main_mod(*example_inputs) + handle.remove() + return sub_inputs + + submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input) + ShapeProp(submod).propagate(*submod_inputs) + + total_input_bytes = 0 + total_output_bytes = 0 + + reports += "Checking inputs...\n" + for n in submod.graph.nodes: + if n.op == "placeholder": + if not is_node_output_tensor(n): + reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" + else: + total_input_bytes += get_size_of_node(submod, n)[0] + if n.op == "output": + output_node = n + + reports += "Checking outputs...\n" + + def get_bytes(node: torch.fx.Node): + nonlocal total_output_bytes + nonlocal reports + if not is_node_output_tensor(node): + reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" + else: + total_output_bytes += get_size_of_node(submod, node)[0] + + map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined] + qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) + reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," + reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" + + if qps < max_qps: + max_qps = qps + bottleneck_module = node.target + + try: + lowered_submod = self._lower_model_to_backend(submod, submod_inputs) + except RuntimeError: + reports += "Run into an error during lowering!\n" + reports += self._find_culprit(submod, submod_inputs) + continue + + try: + lowered_submod(*submod_inputs) + except RuntimeError: + reports += "Run into an error during inference!\n" + reports += self._find_culprit(submod, submod_inputs) + else: + reports += "Lowering and running succeed!\n" + + reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," + reports += f" bottleneck is submodule {bottleneck_module}." + print(reports) + + # return the reports for testing purposes + return reports + + # =============================================================== + # Helpers for extend_acc_subgraph() method + # =============================================================== + + def find_reverse_deps( + self, tag_id: Optional[int] = None + ) -> dict[torch.fx.Node, NodeSet]: + """ + Builds reversed topological node dependencies, if tag_id is specified, + we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. + """ + result: dict[torch.fx.Node, NodeSet] = defaultdict(set) + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id): + result[node].add(user) + + return result + + def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]): + processed_node = set() + + for node, fusion in self.fusions.items(): + if node in processed_node: + continue + + new_dep = set() + + # Create a new dependency set which include all the + # dependencies of the nodes in the fusion group + for n in fusion: + new_dep.update(deps[n]) + + # Exclude nodes in the fusion + new_dep.difference_update(fusion) + + # Update dependency + for n in fusion: + deps[n] = new_dep + + for arg in n.all_input_nodes: + if arg not in fusion: + deps[arg].update(fusion) + + processed_node.add(n) + + def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: + """ + Finds parent nodes of the `tag` subgraph. + + Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph + and is not a placeholder, we consider it as the parent node of the subgraph. + """ + parent_nodes = set() + + for node in self.module.graph.nodes: + if node.op in CALLABLE_NODE_OPS and node.tag == tag: + for arg in node.all_input_nodes: + if arg.op in CALLABLE_NODE_OPS and arg.tag != tag: + parent_nodes.add(arg) + + return parent_nodes + + def extend_acc_subgraph(self, tag: str): + """ + Extend the acc subgraph with `tag` going the reversed topological direction. + """ + # Dict that maps node to its users and ignore users that + # are in the subgraph that has greater tag + deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) + self.update_reverse_deps_for_fusions(deps) + + # Parent nodes of the subgraph + parent_nodes = self.find_parent_nodes_of_subgraph(tag) + + visited_nodes: NodeSet = set() + + while parent_nodes: + node = None + + # Find a acc node that depends on visited nodes only + for n in parent_nodes: + if deps[n] <= visited_nodes and n in self.acc_nodes: + node = n + break + + if node is None: + break + + # Put the node into `tag` subgraph + node.tag = tag # type: ignore[attr-defined] + parent_nodes.remove(node) + visited_nodes.add(node) + + # If node is in a fusion group, add all fusion buddies to parent nodes + if node in self.fusions: + for fusion_node in self.fusions[node]: + if fusion_node not in visited_nodes: + parent_nodes.add(fusion_node) + + # Add inputs of the node to parent nodes + for arg in node.all_input_nodes: + if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes: + parent_nodes.add(arg) + + # =============================================================== + # Helpers for split() method + # =============================================================== + + def starter_nodes(self) -> tuple[NodeSet, NodeSet]: + """ + Finds nodes that consume module inputs or get_attr nodes. + """ + starter_cpu_nodes: NodeSet = set() + starter_acc_nodes: NodeSet = set() + for node in self.module.graph.nodes: + if node.op not in {"placeholder", "get_attr"}: + continue + for user in node.users: + if user in self.acc_nodes: + starter_acc_nodes.add(user) + else: + starter_cpu_nodes.add(user) + return starter_cpu_nodes, starter_acc_nodes + + def put_nodes_into_subgraphs(self) -> list[Subgraph]: + # We start graph traversal from leaf nodes + current_cpu_nodes, current_acc_nodes = self.starter_nodes() + visited_nodes: NodeSet = set() + + # Determine which subgraph to start from based on which subgraph has + # 0-dep node + acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes) + + current_subgraph_nodes: NodeList = [] + + # Result accumulator + subgraphs: list[Subgraph] = [] + while current_cpu_nodes or current_acc_nodes: + # Find the first node that should belong to the current subgraph and has all dependencies resolved + current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes + node = next( + (n for n in current_nodes if self.deps[n] <= visited_nodes), + None, + ) + + # If nothing was found, then it's time to flip the mode and start a new subgraph + if node is None: + if not current_subgraph_nodes: + raise FxNetSplitterInternalError("Subgraph can't be empty") + + subgraphs.append( + Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) + ) + acc_subgraph = not acc_subgraph + current_subgraph_nodes = [] + continue + + current_nodes.remove(node) + visited_nodes.add(node) + current_subgraph_nodes.append(node) + + # Add fusion buddies + if node in self.fusions: + if node in self.acc_nodes: + current_acc_nodes.update(self.fusions[node] - visited_nodes) + else: + current_cpu_nodes.update(self.fusions[node] - visited_nodes) + + # Put depending nodes into the queue + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + # Add downstream nodes + if user in self.acc_nodes: + current_acc_nodes.add(user) + else: + current_cpu_nodes.add(user) + + # Check if the last subgraph was not created + if current_subgraph_nodes: + subgraphs.append( + Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) + ) + + if not subgraphs: + raise FxNetSplitterInternalError("Couldn't create subgraphs") + + return subgraphs + + def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]: + """ + This pass finds ACC submodules with less than specified size and merges + them with adjacent CPU submodules. + """ + result: list[Subgraph] = [] + for subgraph in subgraphs: + if subgraph.is_acc: + if len(subgraph.nodes) >= self.settings.min_acc_module_size: + result.append(subgraph) + else: + print( + "Eliminating acc subgraph because it's smaller than the threshold: " + f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" + ) + if result: + result[-1].nodes.extend(subgraph.nodes) + else: + subgraph.is_acc = False + result.append(subgraph) + else: + if result and not result[-1].is_acc: + result[-1].nodes.extend(subgraph.nodes) + else: + result.append(subgraph) + return result + + def tag(self, subgraphs: list[Subgraph]): + self.tags = [] + for subgraph in subgraphs: + tag = ( + f"_run_on_acc_{subgraph.backend}_{len(self.tags)}" + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{len(self.tags)}" + ) + self.tags.append(tag) + for node in subgraph.nodes: + if hasattr(node, "tag"): + raise FxNetSplitterInternalError(f"Node {node} was already tagged") + + node.tag = tag # type: ignore[attr-defined] + self._node_submodule_map[node.name] = tag + + def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: + split_module = split_by_tags( + self.module, self.tags, return_tuple=self._return_tuple + ) + if remove_tag: + for node in self.module.graph.nodes: + if hasattr(node, "tag"): + del node.tag + return split_module # type: ignore[return-value] + + def __call__(self) -> torch.fx.GraphModule: + subgraphs = self.put_nodes_into_subgraphs() + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) + non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count + print( + f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs" + ) + self.tag(subgraphs) + return self.split() + + def generate_split_results(self) -> SplitResult: + split_module = self() + submodule_names = [] + for name, _mod in split_module.named_children(): + submodule_names.append(name) + if ( + self.settings.max_acc_splits > 0 + and len(submodule_names) > self.settings.max_acc_splits + ): + raise ValueError( + "Cannot fulfill max_acc_splits limit. " + "This may cause split fragmentation and " + "result in performance issues." + ) + + submodule_inputs = generate_inputs_for_submodules( + split_module, self.sample_input, submodule_names + ) + return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) From cf4cf79af390686f5695e3b5823869bdcc64415c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 2 Jun 2025 14:58:50 -0700 Subject: [PATCH 2/7] remove splitter_base file and support inductor compilation --- examples/hierarchical_partitioner_example.py | 37 +- py/torch_tensorrt/dynamo/_compiler.py | 82 +- .../dynamo/partitioning/__init__.py | 1 + .../partitioning/_hierarchical_partitioner.py | 159 ++- .../dynamo/partitioning/splitter_base.py | 927 ------------------ 5 files changed, 242 insertions(+), 964 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/partitioning/splitter_base.py diff --git a/examples/hierarchical_partitioner_example.py b/examples/hierarchical_partitioner_example.py index aa2df19516..4c84cfbe13 100644 --- a/examples/hierarchical_partitioner_example.py +++ b/examples/hierarchical_partitioner_example.py @@ -1,4 +1,3 @@ -# from torch_tensorrt.dynamo.partitioning._global_partitioner import partition import torch import torch.nn as nn import torch_tensorrt @@ -10,12 +9,11 @@ ) from torch_tensorrt.dynamo.lowering import ( get_decompositions, - post_lowering, pre_export_lowering, ) from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import ( - hierarchical_partition, + hierarchical_adjacency_partition, ) @@ -54,6 +52,8 @@ def main(): print(gm.graph) + original_output = model(example_input) + # Partition the model using the adjacency partitioner # partitioned_model, op_support = partition( # gm, @@ -64,21 +64,18 @@ def main(): # ], # ) - partitioned_model, op_support = hierarchical_partition( + partitioned_model, op_support = hierarchical_adjacency_partition( gm, verbose=True, min_block_size=1, - backend_priority=["mlir", "tensorrt"], # , "inductor"], + backend_priority=["inductor", "tensorrt"], backend_support_map={ - "mlir": { + "inductor": { # operator.getitem, torch.ops.aten.conv2d.default, torch.ops.aten.convolution.default, }, "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), - # "inductor": { - # torch.ops.aten.relu.default, - # }, }, torch_executed_ops=[ torch.ops.aten._native_batch_norm_legit_no_training.default @@ -90,12 +87,26 @@ def main(): print("\nPartitioned Model Structure:") print(partitioned_model) + print("0. Original_output:", original_output) + + with torch.no_grad(): + partitioned_output = partitioned_model(example_input) + print("1. Partitioned output:", partitioned_output) + print( + "Partitioned output == Original output:", + torch.allclose(original_output, partitioned_output, 1e-2, 1e-2), + ) + + compiled_model = torch_tensorrt.compile( + model, inputs=[example_input], min_block_size=1 + ) with torch.no_grad(): - output = partitioned_model(example_input) - print("Partitioned output:", output) + compiled_output = compiled_model(example_input) + print("2. Compiled_output:", compiled_output) + print( - "Partitioned output == original output:", - torch.allclose(model(example_input), output, 1e-2, 1e-2), + "Compiled output == Original output:", + torch.allclose(original_output, compiled_output, 1e-2, 1e-2), ) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d7092f1e0f..781820c5cd 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -30,6 +30,9 @@ interpret_module_to_result, repair_double_inputs, ) +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_ATEN_CONVERTERS, +) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) @@ -799,6 +802,18 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments." ) + ############ TODO: testing only ############ + use_hierarchical_partitioner = False + backend_priority = ["inductor", "tensorrt"] + backend_support_map = { + "inductor": { + # operator.getitem, + torch.ops.aten.conv2d.default, + torch.ops.aten.convolution.default, + }, + "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), + } + ############################################# # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False # If specified, try using the fast partitioner and fall back to the global one on failure @@ -845,7 +860,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: submodule_node_dict[node.name] = node # Store TRT replicas of Torch subgraphs - trt_modules = {} + compiled_modules = {} # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those @@ -924,15 +939,45 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: torch.cuda.empty_cache() # Create TRT engines from submodule if not settings.dryrun: - trt_module = convert_module( - submodule, - submodule_inputs, - settings=settings, - name=name, - engine_cache=engine_cache, - ) + if use_hierarchical_partitioner: + # compile submodule with pytorch inductor + if "_run_on_acc_inductor" in name: + sub_inputs = [] + for input in submodule_inputs: + sub_input = ( + torch.randn(input.shape) + .to(dtype.to(input.dtype, t=torch.dtype)) + .cuda() + ) + sub_inputs.append(sub_input) + + compiled_func = torch._inductor.compile( + submodule, + sub_inputs, + ) + # Wrap the compiled function to be a torch.nn.Module + compiled_submodule = FunctionWrapper(compiled_func) + + elif "_run_on_acc_tensorrt" in name: + compiled_submodule = convert_module( + submodule, + submodule_inputs, + settings=settings, + name=name, + engine_cache=engine_cache, + ) + else: + raise ValueError(f"Unknown backend for submodule: {name}") + else: + compiled_submodule = convert_module( + submodule, + submodule_inputs, + settings=settings, + name=name, + engine_cache=engine_cache, + ) - trt_modules[name] = trt_module + compiled_modules[name] = compiled_submodule if _debugger_config: @@ -973,10 +1018,14 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: parse_graph_io(gm, dryrun_tracker) # Replace all FX Modules with TRT Modules - for name, trt_module in trt_modules.items(): - setattr(partitioned_module, name, trt_module) + for name, compiled_module in compiled_modules.items(): + setattr(partitioned_module, name, compiled_module) if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows: - getattr(partitioned_module, name).setup_engine() + if use_hierarchical_partitioner: + if "_run_on_acc_tensorrt" in name: + getattr(partitioned_module, name).setup_engine() + else: + getattr(partitioned_module, name).setup_engine() # Reset settings object to user specification after fallback to global partitioning mode if fast_partitioner_failed: @@ -1322,3 +1371,12 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any: ) return replace_execute_engine_no_op_node(exp_program) + + +class FunctionWrapper(torch.nn.Module): + def __init__(self, func): + super().__init__() + self.func = func + + def forward(self, *args, **kwargs): + return self.func(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py index 25487da065..4ef0c271d1 100644 --- a/py/torch_tensorrt/dynamo/partitioning/__init__.py +++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py @@ -1,5 +1,6 @@ from ._adjacency_partitioner import partition as fast_partition from ._global_partitioner import partition as global_partition +from ._hierarchical_partitioner import hierarchical_adjacency_partition from .common import ( construct_submodule_inputs, get_graph_converter_support, diff --git a/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py index d5505b1eee..e036d0e2be 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py @@ -1,11 +1,23 @@ import logging +from dataclasses import dataclass from typing import Collection, Dict, List, Optional, Set, Tuple import torch import torch.fx.passes.operator_support as ops from torch._ops import OpOverload +from torch.fx._compatibility import compatibility from torch.fx.node import Target, _get_qualified_name -from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeList, NodeSet +from torch.fx.passes.splitter_base import ( + _SplitterBase, + _SplitterSettingBase, +) +from torch.fx.passes.tools_common import ( + CALLABLE_NODE_OPS, + FxNetAccFusionsFinder, + NodeList, + NodeSet, + is_node_output_tensor, +) from torch_tensorrt.dynamo._defaults import ( DEBUG, MIN_BLOCK_SIZE, @@ -15,17 +27,19 @@ DYNAMO_ATEN_CONVERTERS, ConverterRegistry, ) -from torch_tensorrt.dynamo.partitioning.splitter_base import ( - FxNetAccFusionsFinder, - FxNetAccNodesFinder, - Subgraph, - _SplitterBase, - _SplitterSettingBase, -) logger = logging.getLogger(__name__) +@compatibility(is_backward_compatible=False) +@dataclass +class Subgraph: + is_acc: bool + backend: str + nodes: NodeList + device_ordinal: Optional[int] = None + + class BackendOpSupportTester(ops.OperatorSupportBase): # type: ignore """Class to determine whether operators are supported by specific backends""" @@ -100,8 +114,8 @@ def print_support_overview(self, num_acc_subgraphs: Optional[int] = None) -> Non logger.debug("\nAll Nodes Supported\n") -class HierarchicalTRTPartitioner(_SplitterBase): - """Hierarchical partitioner to split an FX graph into subgraphs based on backend priority +class HierarchicalAdjacencyPartitioner(_SplitterBase): # type: ignore + """Hierarchical Adjacency Partitioner to split an FX graph into subgraphs based on backend priority This partitioner extends the TRTPartitioner of adjacency_partitioner with backend priority awareness, allowing different parts of the model to be executed on different backends based on @@ -370,12 +384,133 @@ def put_nodes_into_subgraphs(self) -> list[Subgraph]: return subgraphs + def tag(self, subgraphs: list[Subgraph]): + self.tags = [] + for subgraph in subgraphs: + tag = ( + f"_run_on_acc_{subgraph.backend}_{len(self.tags)}" + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{len(self.tags)}" + ) + self.tags.append(tag) + for node in subgraph.nodes: + if hasattr(node, "tag"): + raise FxNetSplitterInternalError(f"Node {node} was already tagged") + + node.tag = tag # type: ignore[attr-defined] + self._node_submodule_map[node.name] = tag + + +@compatibility(is_backward_compatible=False) +class FxNetAccNodesFinder: + """ + Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor + input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. + + I.e. if we have a chain: + + ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 + + where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. + + This behavior can be turned off by passing allow_non_tensor=True. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + operator_support: ops.OperatorSupportBase, + allow_non_tensor: bool, + ): + self.module = module + self.operator_support = operator_support + self.allow_non_tensor = allow_non_tensor + self.acc_nodes: NodeSet = set() + + def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): + """ + Transitively excludes nodes from ACC supported set. + For every node in the worklist: + - removes its downstream ACC nodes from ACC supported set, + - if any downstream ACC node produces non-tensor output, + then it gets added into the worklist. + """ + while cpu_worklist: + node = cpu_worklist.pop(0) + + for user in node.users: + if user in self.acc_nodes: + self.acc_nodes.remove(user) + if not is_node_output_tensor(user): + cpu_worklist.append(user) + + def reduce_acc_nodes_non_tensor_input(self): + """ + Excludes nodes from ACC supported set that have direct + upstream CPU nodes that produce non-tensor outputs. + """ + non_tensor_cpu_nodes: NodeList = [] + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + if node in self.acc_nodes: + continue + if is_node_output_tensor(node): + continue + non_tensor_cpu_nodes.append(node) + + self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) + + def reduce_acc_nodes_non_tensor_output(self): + """ + Excludes nodes from ACC supported set that produce non-tensor + outputs and have downstream CPU nodes. + """ + while True: + new_cpu_nodes: NodeList = [] + + for acc_node in self.acc_nodes: + if is_node_output_tensor(acc_node): + continue + for user in acc_node.users: + if user not in self.acc_nodes: + new_cpu_nodes.append(acc_node) + break + + if not new_cpu_nodes: + break + + for new_cpu_node in new_cpu_nodes: + self.acc_nodes.remove(new_cpu_node) + + self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) + + def __call__(self) -> NodeSet: + submodules = dict(self.module.named_modules()) + for n in self.module.graph.nodes: + n.backend = "None" + if n.op in CALLABLE_NODE_OPS: + is_supported, backend = self.operator_support.is_node_supported( + submodules, n + ) + if is_supported: + n.backend = backend + self.acc_nodes.add(n) + + if not self.allow_non_tensor: + self.reduce_acc_nodes_non_tensor_input() + self.reduce_acc_nodes_non_tensor_output() + + return self.acc_nodes + +@compatibility(is_backward_compatible=False) class FxNetSplitterInternalError(Exception): pass -def hierarchical_partition( +def hierarchical_adjacency_partition( gm: torch.fx.GraphModule, verbose: bool = DEBUG, min_block_size: int = MIN_BLOCK_SIZE, @@ -421,7 +556,7 @@ def hierarchical_partition( backend_priority=backend_priority, torch_executed_ops=torch_executed_ops, ) - partitioner = HierarchicalTRTPartitioner( + partitioner = HierarchicalAdjacencyPartitioner( gm, supported_ops, backend_support_map=backend_support_map, diff --git a/py/torch_tensorrt/dynamo/partitioning/splitter_base.py b/py/torch_tensorrt/dynamo/partitioning/splitter_base.py deleted file mode 100644 index a7a4f280ab..0000000000 --- a/py/torch_tensorrt/dynamo/partitioning/splitter_base.py +++ /dev/null @@ -1,927 +0,0 @@ -# mypy: allow-untyped-defs -import argparse -import copy -import logging -from collections import defaultdict -from collections.abc import Iterable, Sequence -from dataclasses import dataclass -from typing import Any, NamedTuple, Optional - -import torch -from torch.fx._compatibility import compatibility -from torch.fx.node import map_arg -from torch.fx.passes.graph_drawer import FxGraphDrawer -from torch.fx.passes.graph_manipulation import get_size_of_node -from torch.fx.passes.operator_support import OperatorSupportBase, get_node_target -from torch.fx.passes.shape_prop import ShapeProp -from torch.fx.passes.split_utils import split_by_tags -from torch.fx.passes.tools_common import ( - CALLABLE_NODE_OPS, - FxNetAccFusionsFinder, - NodeList, - NodeSet, - Tensors, - is_node_output_tensor, -) - -__all__ = [ - "FxNetAccNodesFinder", - "FxNetSplitterInternalError", - "Subgraph", - "SplitResult", - "generate_inputs_for_submodules", -] -_LOGGER = logging.getLogger(__name__) - -DEFAULT_MIN_ACC_MODULE_SIZE = 1 -DEFAULT_SKIP_FUSION = False -DEFAULT_ALLOW_NON_TENSOR = False - - -class _SplitterSettingBase: - def __init__( - self, - min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, - skip_fusion=DEFAULT_SKIP_FUSION, - allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR, - max_acc_splits: int = -1, - ): - parser = argparse.ArgumentParser() - parser.add_argument( - "--min-acc-module-size", - "--min_acc_module_size", - required=False, - type=int, - help="Minimum size limit of an accelerator subgraph.", - ) - parser.add_argument( - "--max-acc-splits", - "--max_acc_splits", - required=False, - type=int, - help="Enforce a maximum number of split subgraphs.", - ) - parser.add_argument( - "--skip-fusion", - "--skip_fusion", - default=False, - action="store_true", - help="If true then no fusion groups. Fusion group is used to " - "enforce no non-tensor data flow between submodules. If we don't " - "have this constrain, setting this to false is recommended as it " - "can reduce overhead.", - ) - parser.add_argument( - "--allow-non-tensor", - "--allow_non_tensor", - default=False, - action="store_true", - help="For some backends non-tensor data flow between cpu and them " - "are not allowed. Therefore, if a node supported by accelerator but " - "it has non-tensor inputs or outputs to a cpu node we would want to " - "consider it as a cpu node during splitting. However, for some backends " - "we might not care about non-tensor data flow and we can set this option " - "to true to disable the functionality that prevent non-tensor data flow.", - ) - args, _unknown = parser.parse_known_args() - - self.min_acc_module_size: int = ( - args.min_acc_module_size - if args.min_acc_module_size - else min_acc_module_size - ) - self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion - self.allow_non_tensor: bool = ( - args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor - ) - self.max_acc_splits: int = max_acc_splits - - -@compatibility(is_backward_compatible=False) -class FxNetAccNodesFinder: - """ - Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor - input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. - - I.e. if we have a chain: - - ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 - - where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. - - This behavior can be turned off by passing allow_non_tensor=True. - """ - - def __init__( - self, - module: torch.fx.GraphModule, - operator_support: OperatorSupportBase, - allow_non_tensor: bool, - ): - self.module = module - self.operator_support = operator_support - self.allow_non_tensor = allow_non_tensor - self.acc_nodes: NodeSet = set() - - def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): - """ - Transitively excludes nodes from ACC supported set. - For every node in the worklist: - - removes its downstream ACC nodes from ACC supported set, - - if any downstream ACC node produces non-tensor output, - then it gets added into the worklist. - """ - while cpu_worklist: - node = cpu_worklist.pop(0) - - for user in node.users: - if user in self.acc_nodes: - self.acc_nodes.remove(user) - if not is_node_output_tensor(user): - cpu_worklist.append(user) - - def reduce_acc_nodes_non_tensor_input(self): - """ - Excludes nodes from ACC supported set that have direct - upstream CPU nodes that produce non-tensor outputs. - """ - non_tensor_cpu_nodes: NodeList = [] - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - if node in self.acc_nodes: - continue - if is_node_output_tensor(node): - continue - non_tensor_cpu_nodes.append(node) - - self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) - - def reduce_acc_nodes_non_tensor_output(self): - """ - Excludes nodes from ACC supported set that produce non-tensor - outputs and have downstream CPU nodes. - """ - while True: - new_cpu_nodes: NodeList = [] - - for acc_node in self.acc_nodes: - if is_node_output_tensor(acc_node): - continue - for user in acc_node.users: - if user not in self.acc_nodes: - new_cpu_nodes.append(acc_node) - break - - if not new_cpu_nodes: - break - - for new_cpu_node in new_cpu_nodes: - self.acc_nodes.remove(new_cpu_node) - - self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) - - def __call__(self) -> NodeSet: - submodules = dict(self.module.named_modules()) - for n in self.module.graph.nodes: - n.backend = "None" - if n.op in CALLABLE_NODE_OPS: - is_supported, backend = self.operator_support.is_node_supported( - submodules, n - ) - if is_supported: - n.backend = backend - self.acc_nodes.add(n) - - if not self.allow_non_tensor: - self.reduce_acc_nodes_non_tensor_input() - self.reduce_acc_nodes_non_tensor_output() - - return self.acc_nodes - - -@compatibility(is_backward_compatible=False) -class FxNetSplitterInternalError(Exception): - pass - - -@compatibility(is_backward_compatible=False) -@dataclass -class Subgraph: - is_acc: bool - backend: str - nodes: NodeList - device_ordinal: Optional[int] = None - - -@compatibility(is_backward_compatible=False) -class SplitResult(NamedTuple): - """ - Stores the results of the splitter. - - Attributes: - split_module: root module after splitting. - submodule_inputs: a dict that maps submodule name to its inputs. - non_acc_submodule_prefix: the prefix for non acc submodules. For - acc submodule the prefix is always "_run_on_acc_". - """ - - split_module: torch.fx.GraphModule - submodule_inputs: dict[str, Any] - non_acc_submodule_prefix: str - - -@compatibility(is_backward_compatible=False) -def generate_inputs_for_submodules( - model: torch.nn.Module, - inputs: Sequence[Any], - target_submodules: Iterable[str], - deepcopy: bool = False, -) -> dict[str, Any]: - """ - Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this - function doesn't work. - - Args: - model: root model. - inputs: inputs to the root model. - target_submodules: submodules that we want to generate inputs for. - - Returns: - A dict that maps from submodule name to its inputs. - """ - - handles = [] - results = {} - submodule_to_names = {mod: name for name, mod in model.named_modules()} - - def pre_forward(module, module_inputs): - results[submodule_to_names[module]] = ( - copy.deepcopy(module_inputs) if deepcopy else module_inputs - ) - - for name, mod in model.named_modules(): - if name in target_submodules: - handles.append(mod.register_forward_pre_hook(pre_forward)) - - def clean_up_handles(): - for h in handles: - h.remove() - - try: - with torch.no_grad(): - model(*inputs) - except Exception as e: - clean_up_handles() - raise e - - clean_up_handles() - return results - - -class _SplitterBase: - """ - Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator. - Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible. - Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator. - - Given the following graph: - ==> b ==> - // \\ - a d - \\ // - ==> c ==> - - class SimpleModule(torch.nn.Module): - def forward(self, a): - b = torch.sin(a) - c = torch.cos(a) - d = b + c - return d - - and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator, - we will get the following split result: - - main: - def forward(self, a): - run_on_acc_0_0 = self._run_on_acc_0_0(a) - getitem = run_on_acc_0_0[0] - getitem_1 = run_on_acc_0_0[1] - run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1) - return run_on_cpu_1_1 - - _run_on_acc_0_0: - def forward(self, a): - sin_1 = torch.sin(a) - cos_1 = torch.cos(a) - return (sin_1, cos_1) - - _run_on_cpu_1_1: - def forward(self, sin_1, cos_1): - add_1 = sin_1 + cos_1 - return add_1 - """ - - # PCIe bandwidth for the backend, default to 100 GB/s - PCIe_BW = 100 * 2**30 - - def __init__( - self, - module: torch.fx.GraphModule, - sample_input: Sequence[Any], - operator_support: OperatorSupportBase, - settings: _SplitterSettingBase, - non_acc_submodule_name: str = "_run_on_cpu_", - return_tuple: bool = False, - nodes_finder: Optional[FxNetAccNodesFinder] = None, - ): - """ - Preprocesses graph before splitting: - - finds nodes supported by ACC, - - finds fusion groups for ACC nodes having non-tensor IO, - - builds a graph of direct dependencies, - - builds a map of fused nodes to their fusions. - As a result we get self.acc_nodes, self.deps and self.fusions. - """ - assert isinstance(module, torch.fx.GraphModule) - - self.module = module - ShapeProp(self.module).propagate(*sample_input) - - self.settings = settings - self.operator_support = operator_support - self.sample_input = sample_input - if nodes_finder is None: - nodes_finder = FxNetAccNodesFinder( - self.module, self.operator_support, self.settings.allow_non_tensor - ) - self.acc_nodes = nodes_finder() - - if self.settings.skip_fusion: - self.fusions = {} - else: - self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() - - # Modify deps to add more deps for fused nodes - self.deps = self.find_deps() - self.update_deps_for_fusions() - - self.non_acc_submodule_name = non_acc_submodule_name - self._node_submodule_map: dict[str, str] = {} - self._return_tuple = return_tuple - - self.tags: list[str] = [] - - # =============================================================== - # Helpers for ctor and initial state - # =============================================================== - - def get_node_submodule_map(self) -> dict[str, str]: - """Returns a map from node name to submodule name, e.g. - node: main_module_impl_impl_over_arch_unary_multiple_embedding - _pooling_embedding_pooling_sparse_entity_equivalence_key - _proxy_embedding_bag - maps to submodule name of: _run_on_acc_1 - """ - return self._node_submodule_map - - def find_deps(self) -> dict[torch.fx.Node, NodeSet]: - """ - Builds a graph of node dependencies. Leaf nodes don't have any - dependencies and the "output" node doesn't have nodes depending on it. - - Resulting graph has only direct dependencies, i.e. there are no - transitive dependencies. - """ - deps: dict[torch.fx.Node, NodeSet] = defaultdict(set) - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - for user in node.users: - if user.op != "output": - deps[user].add(node) - return deps - - def update_deps_for_fusions(self): - """ - Updates graph of dependencies so that: - - nodes from the same fusion depend on the same set of outer nodes, - - outer nodes depending on a fusion depend on all nodes in that fusion. - """ - for node in self.fusions: - fusion = self.fusions[node] - for fused_neighbor in fusion: - self.deps[node].update(self.deps[fused_neighbor] - fusion) - - for user in fused_neighbor.users: - if user not in fusion: - self.deps[user].add(node) - - # =============================================================== - # Helpers for preview - # =============================================================== - - def _lower_model_to_backend( - self, mod: torch.fx.GraphModule, inputs: Tensors - ) -> torch.nn.Module: - """ - Lower the model to a backend. - """ - - return mod - - def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str: - """ - When an error occurs during lowering or running the lowered mod, we use this - function to find culprits in the `mod` that causes the error. - """ - - return "Unable to find a culprit because _find_culprit() function is not implemented." - - def _draw_graph_based_on_node_support( - self, mod: torch.fx.GraphModule, supported_nodes: NodeList - ): - color_map = { - "default": "AliceBlue", - "supported": "chartreuse1", - "unsupported": "crimson", - } - - class CustomDrawer(FxGraphDrawer): - def _get_node_style(self, node): - template = super()._get_node_style(node) - if node in supported_nodes: - template["fillcolor"] = color_map["supported"] - elif node.op in CALLABLE_NODE_OPS: - template["fillcolor"] = color_map["unsupported"] - else: - template["fillcolor"] = color_map["default"] - - return template - - drawer = CustomDrawer(mod, "node_support", ignore_getattr=True) - dot_graph = drawer.get_main_dot_graph() - # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. - dot_graph.write_raw("node_support.dot") # type: ignore[attr-defined] - - def node_support_preview(self, dump_graph: bool = False): - submodules = dict(self.module.named_modules()) - - supported_nodes: NodeList = [] - supported_node_types = defaultdict(set) - unsupported_node_types = defaultdict(set) - - def get_dtype(arg): - tensor_meta = arg.meta.get("tensor_meta") - return getattr(tensor_meta, "dtype", None) - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - target = get_node_target(submodules, node) - - # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None. - arg_dtypes = [ - get_dtype(arg) if isinstance(arg, torch.fx.Node) else None - for arg in node.args - ] - - # Find last non-None element. If all elements are None, return max_len. - last_index = len(arg_dtypes) - next( - ( - i - for i, dtype in enumerate(reversed(arg_dtypes)) - if dtype is not None - ), - len(arg_dtypes), - ) - - # Strip None elements at the end. - arg_dtypes_tuple = tuple(arg_dtypes[:last_index]) - kwarg_dtypes_tuple = tuple( - (k, get_dtype(arg)) - for k, arg in node.kwargs.items() - if isinstance(arg, torch.fx.Node) - ) - - if self.operator_support.is_node_supported(submodules, node)[0]: - supported_nodes.append(node) - supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) - else: - unsupported_node_types[target].add( - (arg_dtypes_tuple, kwarg_dtypes_tuple) - ) - - if dump_graph: - self._draw_graph_based_on_node_support(self.module, supported_nodes) - - reports = "\nSupported node types in the model:\n" - for t, dtypes in supported_node_types.items(): - for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: - reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" - - reports += "\nUnsupported node types in the model:\n" - for t, dtypes in unsupported_node_types.items(): - for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: - reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" - - print(reports) - - # Return reports for testing purpose - return reports - - def split_preview(self, dump_graph: bool = False): - reports = "" - subgraphs = self.put_nodes_into_subgraphs() - acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) - cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num - reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" - reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" - - subgraphs = self.remove_small_acc_subgraphs(subgraphs) - acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) - cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num - reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" - reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" - - for i, subgraph in enumerate(subgraphs): - reports += ( - f"_run_on_acc_{i}: " - if subgraph.is_acc - else f"{self.non_acc_submodule_name}{i}: " - ) - reports += f"{len(subgraph.nodes)} node(s)\n" - - self.tag(subgraphs) - split_mod = self.split(remove_tag=True) - split_mod.eval() - - if dump_graph: - drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True) - dot_graphs = drawer.get_all_dot_graphs() - for name, dot_graph in dot_graphs.items(): - # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. - dot_graph.write_raw(f"{name}.dot") # type: ignore[attr-defined] - - max_qps: float = self.PCIe_BW - bottleneck_module = "" - - for node in split_mod.graph.nodes: - if node.op == "call_module" and "acc" in node.target: - reports += f"\nProcessing acc submodule {node.target}\n" - - submod = getattr(split_mod, node.target) - - def get_submod_inputs(main_mod, submod, example_inputs): - sub_inputs = None - - def get_inputs(self, inputs): - nonlocal sub_inputs - sub_inputs = inputs - - handle = submod.register_forward_pre_hook(get_inputs) - main_mod(*example_inputs) - handle.remove() - return sub_inputs - - submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input) - ShapeProp(submod).propagate(*submod_inputs) - - total_input_bytes = 0 - total_output_bytes = 0 - - reports += "Checking inputs...\n" - for n in submod.graph.nodes: - if n.op == "placeholder": - if not is_node_output_tensor(n): - reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" - else: - total_input_bytes += get_size_of_node(submod, n)[0] - if n.op == "output": - output_node = n - - reports += "Checking outputs...\n" - - def get_bytes(node: torch.fx.Node): - nonlocal total_output_bytes - nonlocal reports - if not is_node_output_tensor(node): - reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" - else: - total_output_bytes += get_size_of_node(submod, node)[0] - - map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined] - qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) - reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," - reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" - - if qps < max_qps: - max_qps = qps - bottleneck_module = node.target - - try: - lowered_submod = self._lower_model_to_backend(submod, submod_inputs) - except RuntimeError: - reports += "Run into an error during lowering!\n" - reports += self._find_culprit(submod, submod_inputs) - continue - - try: - lowered_submod(*submod_inputs) - except RuntimeError: - reports += "Run into an error during inference!\n" - reports += self._find_culprit(submod, submod_inputs) - else: - reports += "Lowering and running succeed!\n" - - reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," - reports += f" bottleneck is submodule {bottleneck_module}." - print(reports) - - # return the reports for testing purposes - return reports - - # =============================================================== - # Helpers for extend_acc_subgraph() method - # =============================================================== - - def find_reverse_deps( - self, tag_id: Optional[int] = None - ) -> dict[torch.fx.Node, NodeSet]: - """ - Builds reversed topological node dependencies, if tag_id is specified, - we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. - """ - result: dict[torch.fx.Node, NodeSet] = defaultdict(set) - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - for user in node.users: - if user.op not in CALLABLE_NODE_OPS: - continue - - if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id): - result[node].add(user) - - return result - - def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]): - processed_node = set() - - for node, fusion in self.fusions.items(): - if node in processed_node: - continue - - new_dep = set() - - # Create a new dependency set which include all the - # dependencies of the nodes in the fusion group - for n in fusion: - new_dep.update(deps[n]) - - # Exclude nodes in the fusion - new_dep.difference_update(fusion) - - # Update dependency - for n in fusion: - deps[n] = new_dep - - for arg in n.all_input_nodes: - if arg not in fusion: - deps[arg].update(fusion) - - processed_node.add(n) - - def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: - """ - Finds parent nodes of the `tag` subgraph. - - Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph - and is not a placeholder, we consider it as the parent node of the subgraph. - """ - parent_nodes = set() - - for node in self.module.graph.nodes: - if node.op in CALLABLE_NODE_OPS and node.tag == tag: - for arg in node.all_input_nodes: - if arg.op in CALLABLE_NODE_OPS and arg.tag != tag: - parent_nodes.add(arg) - - return parent_nodes - - def extend_acc_subgraph(self, tag: str): - """ - Extend the acc subgraph with `tag` going the reversed topological direction. - """ - # Dict that maps node to its users and ignore users that - # are in the subgraph that has greater tag - deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) - self.update_reverse_deps_for_fusions(deps) - - # Parent nodes of the subgraph - parent_nodes = self.find_parent_nodes_of_subgraph(tag) - - visited_nodes: NodeSet = set() - - while parent_nodes: - node = None - - # Find a acc node that depends on visited nodes only - for n in parent_nodes: - if deps[n] <= visited_nodes and n in self.acc_nodes: - node = n - break - - if node is None: - break - - # Put the node into `tag` subgraph - node.tag = tag # type: ignore[attr-defined] - parent_nodes.remove(node) - visited_nodes.add(node) - - # If node is in a fusion group, add all fusion buddies to parent nodes - if node in self.fusions: - for fusion_node in self.fusions[node]: - if fusion_node not in visited_nodes: - parent_nodes.add(fusion_node) - - # Add inputs of the node to parent nodes - for arg in node.all_input_nodes: - if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes: - parent_nodes.add(arg) - - # =============================================================== - # Helpers for split() method - # =============================================================== - - def starter_nodes(self) -> tuple[NodeSet, NodeSet]: - """ - Finds nodes that consume module inputs or get_attr nodes. - """ - starter_cpu_nodes: NodeSet = set() - starter_acc_nodes: NodeSet = set() - for node in self.module.graph.nodes: - if node.op not in {"placeholder", "get_attr"}: - continue - for user in node.users: - if user in self.acc_nodes: - starter_acc_nodes.add(user) - else: - starter_cpu_nodes.add(user) - return starter_cpu_nodes, starter_acc_nodes - - def put_nodes_into_subgraphs(self) -> list[Subgraph]: - # We start graph traversal from leaf nodes - current_cpu_nodes, current_acc_nodes = self.starter_nodes() - visited_nodes: NodeSet = set() - - # Determine which subgraph to start from based on which subgraph has - # 0-dep node - acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes) - - current_subgraph_nodes: NodeList = [] - - # Result accumulator - subgraphs: list[Subgraph] = [] - while current_cpu_nodes or current_acc_nodes: - # Find the first node that should belong to the current subgraph and has all dependencies resolved - current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes - node = next( - (n for n in current_nodes if self.deps[n] <= visited_nodes), - None, - ) - - # If nothing was found, then it's time to flip the mode and start a new subgraph - if node is None: - if not current_subgraph_nodes: - raise FxNetSplitterInternalError("Subgraph can't be empty") - - subgraphs.append( - Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) - ) - acc_subgraph = not acc_subgraph - current_subgraph_nodes = [] - continue - - current_nodes.remove(node) - visited_nodes.add(node) - current_subgraph_nodes.append(node) - - # Add fusion buddies - if node in self.fusions: - if node in self.acc_nodes: - current_acc_nodes.update(self.fusions[node] - visited_nodes) - else: - current_cpu_nodes.update(self.fusions[node] - visited_nodes) - - # Put depending nodes into the queue - for user in node.users: - if user.op not in CALLABLE_NODE_OPS: - continue - - # Add downstream nodes - if user in self.acc_nodes: - current_acc_nodes.add(user) - else: - current_cpu_nodes.add(user) - - # Check if the last subgraph was not created - if current_subgraph_nodes: - subgraphs.append( - Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) - ) - - if not subgraphs: - raise FxNetSplitterInternalError("Couldn't create subgraphs") - - return subgraphs - - def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]: - """ - This pass finds ACC submodules with less than specified size and merges - them with adjacent CPU submodules. - """ - result: list[Subgraph] = [] - for subgraph in subgraphs: - if subgraph.is_acc: - if len(subgraph.nodes) >= self.settings.min_acc_module_size: - result.append(subgraph) - else: - print( - "Eliminating acc subgraph because it's smaller than the threshold: " - f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" - ) - if result: - result[-1].nodes.extend(subgraph.nodes) - else: - subgraph.is_acc = False - result.append(subgraph) - else: - if result and not result[-1].is_acc: - result[-1].nodes.extend(subgraph.nodes) - else: - result.append(subgraph) - return result - - def tag(self, subgraphs: list[Subgraph]): - self.tags = [] - for subgraph in subgraphs: - tag = ( - f"_run_on_acc_{subgraph.backend}_{len(self.tags)}" - if subgraph.is_acc - else f"{self.non_acc_submodule_name}{len(self.tags)}" - ) - self.tags.append(tag) - for node in subgraph.nodes: - if hasattr(node, "tag"): - raise FxNetSplitterInternalError(f"Node {node} was already tagged") - - node.tag = tag # type: ignore[attr-defined] - self._node_submodule_map[node.name] = tag - - def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: - split_module = split_by_tags( - self.module, self.tags, return_tuple=self._return_tuple - ) - if remove_tag: - for node in self.module.graph.nodes: - if hasattr(node, "tag"): - del node.tag - return split_module # type: ignore[return-value] - - def __call__(self) -> torch.fx.GraphModule: - subgraphs = self.put_nodes_into_subgraphs() - subgraphs = self.remove_small_acc_subgraphs(subgraphs) - acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) - non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count - print( - f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs" - ) - self.tag(subgraphs) - return self.split() - - def generate_split_results(self) -> SplitResult: - split_module = self() - submodule_names = [] - for name, _mod in split_module.named_children(): - submodule_names.append(name) - if ( - self.settings.max_acc_splits > 0 - and len(submodule_names) > self.settings.max_acc_splits - ): - raise ValueError( - "Cannot fulfill max_acc_splits limit. " - "This may cause split fragmentation and " - "result in performance issues." - ) - - submodule_inputs = generate_inputs_for_submodules( - split_module, self.sample_input, submodule_names - ) - return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) From fc7d71b135b599364b6ffa8bf7846458951afcc2 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 12 Jun 2025 13:21:39 -0700 Subject: [PATCH 3/7] fix bugs and clean codes --- examples/hierarchical_partitioner_example.py | 40 +++++++++--------- py/torch_tensorrt/dynamo/_compiler.py | 41 +++++++++++-------- .../partitioning/_hierarchical_partitioner.py | 39 ++++++++++++------ 3 files changed, 71 insertions(+), 49 deletions(-) diff --git a/examples/hierarchical_partitioner_example.py b/examples/hierarchical_partitioner_example.py index 4c84cfbe13..5875f69df6 100644 --- a/examples/hierarchical_partitioner_example.py +++ b/examples/hierarchical_partitioner_example.py @@ -1,9 +1,6 @@ import torch import torch.nn as nn import torch_tensorrt -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( - DYNAMO_ATEN_CONVERTERS, -) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) @@ -15,6 +12,7 @@ from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import ( hierarchical_adjacency_partition, ) +from torchvision import models class SimpleModel(nn.Module): @@ -50,18 +48,18 @@ def main(): gm = exported_program.module() - print(gm.graph) + print(gm) original_output = model(example_input) - # Partition the model using the adjacency partitioner + # Partition the model using the adjacency partitioner, compared with below # partitioned_model, op_support = partition( # gm, # verbose=True, # min_block_size=1, - # torch_executed_ops=[ - # torch.ops.aten.relu.default, - # ], + # torch_executed_ops={ + # "torch.ops.aten.relu.default", + # }, # ) partitioned_model, op_support = hierarchical_adjacency_partition( @@ -71,21 +69,18 @@ def main(): backend_priority=["inductor", "tensorrt"], backend_support_map={ "inductor": { - # operator.getitem, - torch.ops.aten.conv2d.default, - torch.ops.aten.convolution.default, + "torch.ops.aten.convolution.default", }, - "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), + "tensorrt": CONVERTERS.keys(), + }, + torch_executed_ops={ + "torch.ops.aten._native_batch_norm_legit_no_training.default" }, - torch_executed_ops=[ - torch.ops.aten._native_batch_norm_legit_no_training.default - ], require_full_compilation=False, - skip_fusion=False, + skip_fusion=True, ) - print("\nPartitioned Model Structure:") - print(partitioned_model) + print("\nPartitioned Model Structure:\n", partitioned_model) print("0. Original_output:", original_output) @@ -98,8 +93,15 @@ def main(): ) compiled_model = torch_tensorrt.compile( - model, inputs=[example_input], min_block_size=1 + model, + inputs=[example_input], + min_block_size=1, + torch_executed_ops={ + "torch.ops.aten._native_batch_norm_legit_no_training.default" + }, ) + print("\nCompiled Model Structure:\n", compiled_model) + with torch.no_grad(): compiled_output = compiled_model(example_input) print("2. Compiled_output:", compiled_output) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 781820c5cd..94494a82cc 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,7 +5,17 @@ import os import platform import warnings -from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + Callable, + Collection, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import torch from torch.export import ExportedProgram @@ -30,9 +40,6 @@ interpret_module_to_result, repair_double_inputs, ) -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( - DYNAMO_ATEN_CONVERTERS, -) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) @@ -803,15 +810,13 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: ) ############ TODO: testing only ############ - use_hierarchical_partitioner = False + use_hierarchical_partitioner = True backend_priority = ["inductor", "tensorrt"] backend_support_map = { "inductor": { - # operator.getitem, - torch.ops.aten.conv2d.default, - torch.ops.aten.convolution.default, + "torch.ops.aten.convolution.default", }, - "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), + "tensorrt": CONVERTERS.keys(), } ############################################# # Partition module into components that can be TRT-accelerated @@ -944,11 +949,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: if "_run_on_acc_inductor" in name: sub_inputs = [] for input in submodule_inputs: - sub_input = ( - torch.randn(input.shape) - .to(dtype.to(input.dtype, t=torch.dtype)) - .cuda() - ) + sub_input = input.torch_tensor.to( + dtype.to(input.dtype, t=torch.dtype) + ).cuda() sub_inputs.append(sub_input) compiled_func = torch._inductor.compile( @@ -956,7 +959,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: sub_inputs, ) # Wrap the compiled function to be a torch.nn.Module - compiled_submodule = FunctionWrapper(compiled_func) + compiled_submodule = InductorModule(compiled_func) elif "_run_on_acc_tensorrt" in name: compiled_submodule = convert_module( @@ -1373,10 +1376,12 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any: return replace_execute_engine_no_op_node(exp_program) -class FunctionWrapper(torch.nn.Module): - def __init__(self, func): +class InductorModule(torch.nn.Module): # type: ignore[misc] + """Wrapper module for inductor compiled function.""" + + def __init__(self, func: Callable[..., Any]) -> None: super().__init__() self.func = func - def forward(self, *args, **kwargs): + def forward(self, *args: Any, **kwargs: Any) -> Any: return self.func(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py index e036d0e2be..ee38b8315c 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py @@ -1,12 +1,11 @@ import logging from dataclasses import dataclass -from typing import Collection, Dict, List, Optional, Set, Tuple +from typing import Collection, Dict, List, Optional, Tuple import torch import torch.fx.passes.operator_support as ops -from torch._ops import OpOverload from torch.fx._compatibility import compatibility -from torch.fx.node import Target, _get_qualified_name +from torch.fx.node import Target from torch.fx.passes.splitter_base import ( _SplitterBase, _SplitterSettingBase, @@ -24,12 +23,17 @@ REQUIRE_FULL_COMPILATION, ) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( - DYNAMO_ATEN_CONVERTERS, + DYNAMO_CONVERTERS as CONVERTERS, +) +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( ConverterRegistry, ) logger = logging.getLogger(__name__) +NON_COMPUTE_NODES = {"torch.ops.aten.view", "_operator.getitem"} +NON_ACC_BACKEND_NAME = "None" + @compatibility(is_backward_compatible=False) @dataclass @@ -45,7 +49,7 @@ class BackendOpSupportTester(ops.OperatorSupportBase): # type: ignore def __init__( self, - backend_support_map: Dict[str, Set[OpOverload]], + backend_support_map: Dict[str, Collection[Target]], backend_priority: List[str], torch_executed_ops: Collection[Target] = set(), ) -> None: @@ -62,12 +66,14 @@ def __init__( def is_node_supported( self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node - ) -> Tuple[bool, Optional[str]]: + ) -> Tuple[bool, str]: node_name = ConverterRegistry.qualified_name_or_str(node.target) for i, backend_name in enumerate(self.backend_priority): supported_ops = self.backend_support_map.get(backend_name, set()) - supported_ops = {_get_qualified_name(op) for op in supported_ops} + supported_ops = { + ConverterRegistry.qualified_name_or_str(op) for op in supported_ops + } if ( (node_name in supported_ops or node.op == "get_attr") @@ -89,7 +95,7 @@ def is_node_supported( else: self.unsupported_operators[node_name] += 1 - return False, None + return False, NON_ACC_BACKEND_NAME def print_support_overview(self, num_acc_subgraphs: Optional[int] = None) -> None: if num_acc_subgraphs is not None: @@ -137,7 +143,7 @@ def __init__( self, module: torch.fx.GraphModule, operator_support: ops.OperatorSupportBase, - backend_support_map: Dict[str, Set[Target]], + backend_support_map: Dict[str, Collection[Target]], backend_priority: List[str], allowed_single_node_partition_ops: Optional[Collection[str]] = None, min_block_size: int = MIN_BLOCK_SIZE, @@ -488,8 +494,15 @@ def reduce_acc_nodes_non_tensor_output(self): def __call__(self) -> NodeSet: submodules = dict(self.module.named_modules()) + backend = NON_ACC_BACKEND_NAME for n in self.module.graph.nodes: - n.backend = "None" + # Group non-compute nodes with previous compute nodes + if ConverterRegistry.qualified_name_or_str(n.target) in NON_COMPUTE_NODES: + n.backend = backend + if backend != NON_ACC_BACKEND_NAME: + self.acc_nodes.add(n) + continue + if n.op in CALLABLE_NODE_OPS: is_supported, backend = self.operator_support.is_node_supported( submodules, n @@ -497,6 +510,8 @@ def __call__(self) -> NodeSet: if is_supported: n.backend = backend self.acc_nodes.add(n) + else: + n.backend = NON_ACC_BACKEND_NAME if not self.allow_non_tensor: self.reduce_acc_nodes_non_tensor_input() @@ -515,7 +530,7 @@ def hierarchical_adjacency_partition( verbose: bool = DEBUG, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), - backend_support_map: Optional[Dict[str, Set[OpOverload]]] = None, + backend_support_map: Optional[Dict[str, Collection[Target]]] = None, backend_priority: Optional[List[str]] = None, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, skip_fusion: bool = False, @@ -542,7 +557,7 @@ def hierarchical_adjacency_partition( # Default backend support map if none provided if backend_support_map is None: backend_support_map = { - "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), + "tensorrt": CONVERTERS.keys(), "inductor": set(), } From 2898e5c5573fa71ac97fb0f551b92b928aedee53 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 12 Jun 2025 18:32:39 -0700 Subject: [PATCH 4/7] revert _compiler.py and add details in the example --- examples/hierarchical_partitioner_example.py | 135 ++++++++++++++----- py/torch_tensorrt/dynamo/_compiler.py | 89 ++---------- 2 files changed, 111 insertions(+), 113 deletions(-) diff --git a/examples/hierarchical_partitioner_example.py b/examples/hierarchical_partitioner_example.py index 5875f69df6..206ed945f7 100644 --- a/examples/hierarchical_partitioner_example.py +++ b/examples/hierarchical_partitioner_example.py @@ -1,6 +1,11 @@ +from typing import Any, Callable + import torch import torch.nn as nn import torch_tensorrt +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo._compiler import convert_module from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) @@ -8,13 +13,26 @@ get_decompositions, pre_export_lowering, ) -from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import ( hierarchical_adjacency_partition, ) +from torch_tensorrt.dynamo.utils import ( + get_output_metadata, +) from torchvision import models +class InductorModule(torch.nn.Module): # type: ignore[misc] + """Wrapper module for inductor compiled function.""" + + def __init__(self, func: Callable[..., Any]) -> None: + super().__init__() + self.func = func + + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.func(*args, **kwargs) + + class SimpleModel(nn.Module): def __init__(self): super().__init__() @@ -48,20 +66,11 @@ def main(): gm = exported_program.module() - print(gm) + print("Original Model Structure:\n", gm) original_output = model(example_input) - # Partition the model using the adjacency partitioner, compared with below - # partitioned_model, op_support = partition( - # gm, - # verbose=True, - # min_block_size=1, - # torch_executed_ops={ - # "torch.ops.aten.relu.default", - # }, - # ) - + # 1. Partition the model into blocks that can be executed by different backends partitioned_model, op_support = hierarchical_adjacency_partition( gm, verbose=True, @@ -80,35 +89,87 @@ def main(): skip_fusion=True, ) - print("\nPartitioned Model Structure:\n", partitioned_model) - - print("0. Original_output:", original_output) + print("1. Partitioned Model Structure:\n", partitioned_model) + + # 2. Compile each submodule with the corresponding backend + submodule_node_dict = {} + for node in partitioned_model.graph.nodes: + if "_run_on_acc" not in node.name: + continue + submodule_node_dict[node.name] = node + + # Store compiled replicas of Torch subgraphs + compiled_modules = {} + + for name, _ in partitioned_model.named_children(): + submodule = getattr(partitioned_model, name) + if not isinstance(submodule, torch.fx.graph_module.GraphModule): + continue + + if "_run_on_acc" not in name: + submodule.to("cuda") + continue + + if name not in submodule_node_dict: + raise ValueError( + f"node_name: {name} does not exist in the submodule node dictionary" + ) + + # set the submodule metadata back to the parent module_node + metadata_list = get_output_metadata(submodule) + assert len(metadata_list) > 0 + metadata_keys = ["val", "tensor_meta"] + for key in metadata_keys: + if key not in submodule_node_dict[name].meta: + meta_val_list = [ + metadata[key] for metadata in metadata_list if key in metadata + ] + submodule_node_dict[name].meta[key] = meta_val_list + break + + # Get the submodule inputs for min, opt, max shapes of the graph inputs + submodule_inputs = partitioning.construct_submodule_inputs(submodule) + assert submodule_inputs is not None + + # compile submodule with pytorch inductor backend + if "_run_on_acc_inductor" in name: + sub_inputs = [] + for input in submodule_inputs: + sub_input = input.torch_tensor.to( + dtype.to(input.dtype, t=torch.dtype) + ).cuda() + sub_inputs.append(sub_input) + + compiled_func = torch._inductor.compile( + submodule, + sub_inputs, + ) + # Wrap the compiled function to be a torch.nn.Module + compiled_submodule = InductorModule(compiled_func) + + # compile submodule with tensorrt backend + elif "_run_on_acc_tensorrt" in name: + compiled_submodule = convert_module( + submodule, + submodule_inputs, + name=name, + ) + else: + raise ValueError(f"Unknown backend for submodule: {name}") + + compiled_modules[name] = compiled_submodule + + # Replace all FX Modules with compiled Modules + for name, compiled_module in compiled_modules.items(): + setattr(partitioned_model, name, compiled_module) + + print("2. Compiled Model Structure:\n", partitioned_model) with torch.no_grad(): partitioned_output = partitioned_model(example_input) - print("1. Partitioned output:", partitioned_output) - print( - "Partitioned output == Original output:", - torch.allclose(original_output, partitioned_output, 1e-2, 1e-2), - ) - - compiled_model = torch_tensorrt.compile( - model, - inputs=[example_input], - min_block_size=1, - torch_executed_ops={ - "torch.ops.aten._native_batch_norm_legit_no_training.default" - }, - ) - print("\nCompiled Model Structure:\n", compiled_model) - - with torch.no_grad(): - compiled_output = compiled_model(example_input) - print("2. Compiled_output:", compiled_output) - print( - "Compiled output == Original output:", - torch.allclose(original_output, compiled_output, 1e-2, 1e-2), + "3. Verify that Partitioned output == Original output:", + torch.allclose(partitioned_output, original_output, 1e-2, 1e-2), ) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 94494a82cc..d7092f1e0f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,17 +5,7 @@ import os import platform import warnings -from typing import ( - Any, - Callable, - Collection, - List, - Optional, - Sequence, - Set, - Tuple, - Union, -) +from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch from torch.export import ExportedProgram @@ -809,16 +799,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments." ) - ############ TODO: testing only ############ - use_hierarchical_partitioner = True - backend_priority = ["inductor", "tensorrt"] - backend_support_map = { - "inductor": { - "torch.ops.aten.convolution.default", - }, - "tensorrt": CONVERTERS.keys(), - } - ############################################# # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False # If specified, try using the fast partitioner and fall back to the global one on failure @@ -865,7 +845,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: submodule_node_dict[node.name] = node # Store TRT replicas of Torch subgraphs - compiled_modules = {} + trt_modules = {} # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those @@ -944,43 +924,15 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: torch.cuda.empty_cache() # Create TRT engines from submodule if not settings.dryrun: - if use_hierarchical_partitioner: - # compile submodule with pytorch inductor - if "_run_on_acc_inductor" in name: - sub_inputs = [] - for input in submodule_inputs: - sub_input = input.torch_tensor.to( - dtype.to(input.dtype, t=torch.dtype) - ).cuda() - sub_inputs.append(sub_input) - - compiled_func = torch._inductor.compile( - submodule, - sub_inputs, - ) - # Wrap the compiled function to be a torch.nn.Module - compiled_submodule = InductorModule(compiled_func) - - elif "_run_on_acc_tensorrt" in name: - compiled_submodule = convert_module( - submodule, - submodule_inputs, - settings=settings, - name=name, - engine_cache=engine_cache, - ) - else: - raise ValueError(f"Unknown backend for submodule: {name}") - else: - compiled_submodule = convert_module( - submodule, - submodule_inputs, - settings=settings, - name=name, - engine_cache=engine_cache, - ) + trt_module = convert_module( + submodule, + submodule_inputs, + settings=settings, + name=name, + engine_cache=engine_cache, + ) - compiled_modules[name] = compiled_submodule + trt_modules[name] = trt_module if _debugger_config: @@ -1021,14 +973,10 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: parse_graph_io(gm, dryrun_tracker) # Replace all FX Modules with TRT Modules - for name, compiled_module in compiled_modules.items(): - setattr(partitioned_module, name, compiled_module) + for name, trt_module in trt_modules.items(): + setattr(partitioned_module, name, trt_module) if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows: - if use_hierarchical_partitioner: - if "_run_on_acc_tensorrt" in name: - getattr(partitioned_module, name).setup_engine() - else: - getattr(partitioned_module, name).setup_engine() + getattr(partitioned_module, name).setup_engine() # Reset settings object to user specification after fallback to global partitioning mode if fast_partitioner_failed: @@ -1374,14 +1322,3 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any: ) return replace_execute_engine_no_op_node(exp_program) - - -class InductorModule(torch.nn.Module): # type: ignore[misc] - """Wrapper module for inductor compiled function.""" - - def __init__(self, func: Callable[..., Any]) -> None: - super().__init__() - self.func = func - - def forward(self, *args: Any, **kwargs: Any) -> Any: - return self.func(*args, **kwargs) From 0d88e8d3f13f596bbd8de48944862cf066df3643 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 13 Jun 2025 09:56:57 -0700 Subject: [PATCH 5/7] update example folder --- examples/{ => dynamo}/hierarchical_partitioner_example.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{ => dynamo}/hierarchical_partitioner_example.py (100%) diff --git a/examples/hierarchical_partitioner_example.py b/examples/dynamo/hierarchical_partitioner_example.py similarity index 100% rename from examples/hierarchical_partitioner_example.py rename to examples/dynamo/hierarchical_partitioner_example.py From d8509fa21d0ba16eff148148cdda40aa2c59c367 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 13 Jun 2025 14:21:05 -0700 Subject: [PATCH 6/7] add in contributors doc --- docsrc/contributors/partitioning.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docsrc/contributors/partitioning.rst b/docsrc/contributors/partitioning.rst index 8c83ddcadc..77880cef6a 100644 --- a/docsrc/contributors/partitioning.rst +++ b/docsrc/contributors/partitioning.rst @@ -239,3 +239,16 @@ In this example we will collect the arithmetic ops in a TensorRT segment and the In some cases this approach may create adjacent segments in the partition which have the same target. As a clean-up step we can consolidate these adjacent segments to further reduce the number of segments in the final partition. The merge segments step identifies a list of segments that are adjacent in the graph, have the same target, and are not marked as `do_not_merge`. The nodes from these segments will be combined into a single new segment that will replace the merged segments in the partition. The `do_not_merge` marking is used to prevent merging of segments created for conditional nodes and loops that are handled as special cases in graph stitching and should not be merged with adjacent segments of the same type. + + +Hierarchical Partitioner for Dynamo +=================================== + +The Hierarchical Partitioner is an extension to the standard TensorRT partitioner that allows for more sophisticated partitioning strategies by considering backend priority and operator support. This is particularly useful when you want to distribute different parts of your model across multiple backends based on their capabilities and priorities. + +We currently support hierarchical adjacency partitioner, which extends the standard adjacency partitioner with the following capabilities: + +1. **Backend priority ordering**: Assign operators to backends based on a priority order, ensuring that operators are assigned to the highest-priority backend that supports them. +2. **Multi-backend support**: Distribute model execution across multiple backends based on operator support. + +Please refer to `hierarchical_partitioner_example `_ for more details. From ae4d55830184b58582e74865e9994cc19385b7c3 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 17 Jun 2025 13:06:20 -0700 Subject: [PATCH 7/7] add tests and rebase --- .../hierarchical_partitioner_example.py | 1 - .../partitioning/_adjacency_partitioner.py | 1 - .../partitioning/_global_partitioner.py | 1 - .../partitioning/_hierarchical_partitioner.py | 18 +- .../test_hierarchical_partitioning.py | 303 ++++++++++++++++++ 5 files changed, 310 insertions(+), 14 deletions(-) create mode 100644 tests/py/dynamo/partitioning/test_hierarchical_partitioning.py diff --git a/examples/dynamo/hierarchical_partitioner_example.py b/examples/dynamo/hierarchical_partitioner_example.py index 206ed945f7..4cae68ebac 100644 --- a/examples/dynamo/hierarchical_partitioner_example.py +++ b/examples/dynamo/hierarchical_partitioner_example.py @@ -73,7 +73,6 @@ def main(): # 1. Partition the model into blocks that can be executed by different backends partitioned_model, op_support = hierarchical_adjacency_partition( gm, - verbose=True, min_block_size=1, backend_priority=["inductor", "tensorrt"], backend_support_map={ diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 2cb7fe43f5..e2f544c2a7 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -261,7 +261,6 @@ def partition( Args: gm: FX GraphModule to partition - verbose: Bool representing whether to print operator support min_block_size: Minimum number of operators per TRT-Engine Block torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage require_full_compilation: Require that all computational operators be run in TRT diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 3279db00cf..707497b227 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -210,7 +210,6 @@ def partition( Args: gm: FX GraphModule to partition - verbose: Bool representing whether to print operator support min_block_size: Minimum number of operators per TRT-Engine Block torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage require_full_compilation: Whether to require that all operators be run in TRT diff --git a/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py index ee38b8315c..12a1a091ea 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py @@ -18,7 +18,6 @@ is_node_output_tensor, ) from torch_tensorrt.dynamo._defaults import ( - DEBUG, MIN_BLOCK_SIZE, REQUIRE_FULL_COMPILATION, ) @@ -390,8 +389,8 @@ def put_nodes_into_subgraphs(self) -> list[Subgraph]: return subgraphs - def tag(self, subgraphs: list[Subgraph]): - self.tags = [] + def tag(self, subgraphs: list[Subgraph]) -> None: + self.tags: list[str] = [] for subgraph in subgraphs: tag = ( f"_run_on_acc_{subgraph.backend}_{len(self.tags)}" @@ -403,7 +402,7 @@ def tag(self, subgraphs: list[Subgraph]): if hasattr(node, "tag"): raise FxNetSplitterInternalError(f"Node {node} was already tagged") - node.tag = tag # type: ignore[attr-defined] + node.tag = tag self._node_submodule_map[node.name] = tag @@ -433,7 +432,7 @@ def __init__( self.allow_non_tensor = allow_non_tensor self.acc_nodes: NodeSet = set() - def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): + def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList) -> None: """ Transitively excludes nodes from ACC supported set. For every node in the worklist: @@ -450,7 +449,7 @@ def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): if not is_node_output_tensor(user): cpu_worklist.append(user) - def reduce_acc_nodes_non_tensor_input(self): + def reduce_acc_nodes_non_tensor_input(self) -> None: """ Excludes nodes from ACC supported set that have direct upstream CPU nodes that produce non-tensor outputs. @@ -468,7 +467,7 @@ def reduce_acc_nodes_non_tensor_input(self): self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) - def reduce_acc_nodes_non_tensor_output(self): + def reduce_acc_nodes_non_tensor_output(self) -> None: """ Excludes nodes from ACC supported set that produce non-tensor outputs and have downstream CPU nodes. @@ -527,7 +526,6 @@ class FxNetSplitterInternalError(Exception): def hierarchical_adjacency_partition( gm: torch.fx.GraphModule, - verbose: bool = DEBUG, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), backend_support_map: Optional[Dict[str, Collection[Target]]] = None, @@ -540,7 +538,6 @@ def hierarchical_adjacency_partition( Args: gm: FX GraphModule to partition - verbose: Bool representing whether to print operator support min_block_size: Minimum number of operators per TRT-Engine Block backend_support_map: Dictionary mapping backend names to sets of supported operators backend_priority: Ordered list of backend names, from highest to lowest priority @@ -583,7 +580,6 @@ def hierarchical_adjacency_partition( partitioned_graph = partitioner.partition_graph() - if verbose: - supported_ops.print_support_overview(partitioner.num_accelerated_subgraphs) + supported_ops.print_support_overview(partitioner.num_accelerated_subgraphs) return partitioned_graph, supported_ops diff --git a/tests/py/dynamo/partitioning/test_hierarchical_partitioning.py b/tests/py/dynamo/partitioning/test_hierarchical_partitioning.py new file mode 100644 index 0000000000..0553fb4b45 --- /dev/null +++ b/tests/py/dynamo/partitioning/test_hierarchical_partitioning.py @@ -0,0 +1,303 @@ +from copy import deepcopy + +import numpy as np +import torch +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import partitioning + + +class TestHierarchicalAdjacencyPartitioning(TestCase): + def test_hierarchical_adjacency_partition_fully_supported_one_op(self): + class FullySupportedOneOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + return torch.ops.aten.add.Tensor(x, y) + + fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) + partitioned_graph, _ = partitioning.hierarchical_adjacency_partition( + deepcopy(fx_graph), + ) + self.assertEqual( + len( + [ + 1 + for submod in list(partitioned_graph.named_children()) + if "_run_on_acc" in submod[0] + ] + ), + 0, + "Single operators should not be segmented", + ) + + def test_hierarchical_adjacency_partition_fully_supported_one_op_require_full_compilation( + self, + ): + class FullySupportedOneOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + return torch.ops.aten.add.Tensor(x, y) + + fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) + partitioned_graph, _ = partitioning.hierarchical_adjacency_partition( + deepcopy(fx_graph), require_full_compilation=True + ) + self.assertEqual( + len( + [ + 1 + for submod in list(partitioned_graph.named_children()) + if "_run_on_acc" in submod[0] + ] + ), + 1, + "Single operators can be segmented if full compilation is required", + ) + + def test_hierarchical_adjacency_partition_fully_supported_multi_op(self): + class FullySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_ = torch.ops.aten.sub.Tensor(x, y) + concat_ = torch.ops.aten.cat.default(x, sum_) + relu_ = torch.ops.aten.relu.default(concat_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) + partitioned_graph, _ = partitioning.hierarchical_adjacency_partition( + deepcopy(fx_graph), min_block_size=2 + ) + self.assertEqual( + len( + [ + 1 + for submod in list(partitioned_graph.named_children()) + if "_run_on_acc" in submod[0] + ] + ), + 1, + "All operators are supported, there should be one segment", + ) + + def test_hierarchical_adjacency_partition_partially_supported_multi_op(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_1 = torch.ops.aten.add.Tensor(x, y) + sum_2 = torch.ops.aten.add.Tensor(x, sum_1) + sum_ = np.sum(sum_1) + np.sum(sum_2) + relu_ = torch.ops.aten.relu.default(sum_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + partitioned_graph, _ = partitioning.hierarchical_adjacency_partition( + deepcopy(fx_graph), min_block_size=2 + ) + self.assertEqual( + len( + [ + 1 + for submod in list(partitioned_graph.named_children()) + if "_run_on_acc" in submod[0] + ] + ), + 2, + "Unsupported operators interleave supported ones, expected 2 segments", + ) + + def test_hierarchical_adjacency_partition_partially_supported_with_torch_executed_ops( + self, + ): + class PartiallySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_1 = torch.ops.aten.add.Tensor(x, y) + sum_2 = torch.ops.aten.add.Tensor(x, sum_1) + sum_ = torch.ops.aten.add.Tensor(sum_1, sum_2) + relu_ = torch.ops.aten.relu.default(sum_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + torch_executed_ops = {torch.ops.aten.add.Tensor} + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + partitioned_graph, _ = partitioning.hierarchical_adjacency_partition( + deepcopy(fx_graph), + min_block_size=1, + torch_executed_ops=torch_executed_ops, + ) + + unexpected_ops = torch_executed_ops + expected_ops = {torch.ops.aten.relu.default, torch.ops.aten.pow.Tensor_Scalar} + + unexpected_ops_seen = set() + expected_ops_seen = set() + + for name, gm in partitioned_graph.named_children(): + if "_run_on_acc" in name: + for node in gm.graph.nodes: + if node.op == "call_function": + if node.target in unexpected_ops: + unexpected_ops_seen.add(node.target) + elif node.target in expected_ops: + expected_ops_seen.add(node.target) + + expected_ops_unseen = expected_ops.difference(expected_ops_seen) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1) + self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1) + self.bn1 = torch.nn.BatchNorm2d(64) + self.bn2 = torch.nn.BatchNorm2d(128) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = torch.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = torch.relu(x) + return x + + def test_hierarchical_adjacency_partition_with_two_backends(self): + from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS as CONVERTERS, + ) + from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + pre_export_lowering, + ) + + model = self.SimpleModel().cuda().eval() + example_input = torch.randn(1, 3, 224, 224).cuda() + + exported_program = torch.export.export(model, (example_input,)) + exported_program = pre_export_lowering(exported_program) + exported_program = exported_program.run_decompositions(get_decompositions()) + gm = exported_program.module() + + partitioned_graph, _ = partitioning.hierarchical_adjacency_partition( + gm, + min_block_size=1, + backend_priority=["inductor", "tensorrt"], + backend_support_map={ + "inductor": { + "torch.ops.aten.convolution.default", + }, + "tensorrt": CONVERTERS.keys(), + }, + ) + + inductor_subgraphs_num = 0 + tensorrt_subgraphs_num = 0 + + for name, gm in partitioned_graph.named_children(): + if "_run_on_acc_inductor" in name: + inductor_subgraphs_num += 1 + elif "_run_on_acc_tensorrt" in name: + tensorrt_subgraphs_num += 1 + else: + raise ValueError(f"Unknown backend: {name}") + + self.assertEqual( + inductor_subgraphs_num, + 2, + "There should be 2 subgraphs running on inductor backend", + ) + self.assertEqual( + tensorrt_subgraphs_num, + 2, + "There should be 2 subgraph running on tensorrt backend", + ) + + def test_hierarchical_adjacency_partition_with_two_backends_with_torch_executed_ops( + self, + ): + from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS as CONVERTERS, + ) + from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + pre_export_lowering, + ) + + model = self.SimpleModel().cuda().eval() + example_input = torch.randn(1, 3, 224, 224).cuda() + + exported_program = torch.export.export(model, (example_input,)) + exported_program = pre_export_lowering(exported_program) + exported_program = exported_program.run_decompositions(get_decompositions()) + gm = exported_program.module() + + partitioned_graph, _ = partitioning.hierarchical_adjacency_partition( + gm, + min_block_size=1, + backend_priority=["inductor", "tensorrt"], + backend_support_map={ + "inductor": { + "torch.ops.aten.convolution.default", + }, + "tensorrt": CONVERTERS.keys(), + }, + torch_executed_ops={ + "torch.ops.aten._native_batch_norm_legit_no_training.default" + }, + ) + + inductor_subgraphs_num = 0 + tensorrt_subgraphs_num = 0 + torch_gpu_subgraphs_num = 0 + + for name, gm in partitioned_graph.named_children(): + if "_run_on_acc_inductor" in name: + inductor_subgraphs_num += 1 + elif "_run_on_acc_tensorrt" in name: + tensorrt_subgraphs_num += 1 + elif "_run_on_gpu" in name: + torch_gpu_subgraphs_num += 1 + else: + raise ValueError(f"Unknown backend: {name}") + + self.assertEqual( + torch_gpu_subgraphs_num, + 2, + "There should be 2 subgraphs running on torch gpu backend", + ) + self.assertEqual( + inductor_subgraphs_num, + 2, + "There should be 2 subgraphs running on inductor backend", + ) + self.assertEqual( + tensorrt_subgraphs_num, + 2, + "There should be 2 subgraph running on tensorrt backend", + ) + + +if __name__ == "__main__": + run_tests()