From 2579ccdaaf6448869a7ce283d38737748e02f89e Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 1 May 2023 14:14:53 -0700 Subject: [PATCH 1/2] fix: Improve partitioning + lowering systems - Improve torch.compile Dynamo partitioning system by incorporating key arguments including `min_block_size` and `torch_executed_ops`, which are available for use in TorchScript - Improve torch.compile lowering system by adding key new decompositions to improve coverage and reduce the number of unique operators requiring implementation - Update testing framework to use utilities, reducing code replication - Add extensive testing of new partitioning system and lowering phases --- py/torch_tensorrt/dynamo/backend/__init__.py | 16 +- py/torch_tensorrt/dynamo/backend/_defaults.py | 2 +- py/torch_tensorrt/dynamo/backend/_settings.py | 8 +- py/torch_tensorrt/dynamo/backend/backends.py | 5 +- .../backend/lowering/_decompositions.py | 15 ++ .../dynamo/backend/lowering/_partition.py | 151 ++++++++++++++---- .../dynamo/backend/test/test_lowering.py | 106 +++++++++--- .../dynamo/backend/test/test_partitioning.py | 64 +++++++- .../dynamo/backend/test/utils.py | 94 ++++++++++- 9 files changed, 390 insertions(+), 71 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index eba389ecec..0846dec144 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -4,7 +4,7 @@ import torch_tensorrt from functools import partial -from typing import Any +from typing import Any, Sequence from torch_tensorrt import EngineCapability, Device from torch_tensorrt.fx.utils import LowerPrecision @@ -15,7 +15,7 @@ PRECISION, DEBUG, MAX_WORKSPACE_SIZE, - MAX_NUM_TRT_ENGINES, + MIN_BLOCK_SIZE, ) @@ -41,7 +41,7 @@ def compile( calibrator=None, truncate_long_and_double=False, require_full_compilation=False, - min_block_size=3, + min_block_size=MIN_BLOCK_SIZE, torch_executed_ops=[], torch_executed_modules=[], **kwargs, @@ -50,7 +50,7 @@ def compile( logger.warn( "The Dynamo backend is an experimental feature, for which only the " + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, max_num_trt_engines}" + + "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}" ) if not isinstance(inputs, collections.abc.Sequence): @@ -80,6 +80,8 @@ def compile( precision=lower_precision, debug=debug, workspace_size=workspace_size, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, **kwargs, ) @@ -100,7 +102,8 @@ def create_backend( precision: LowerPrecision = PRECISION, debug: bool = DEBUG, workspace_size: int = MAX_WORKSPACE_SIZE, - max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, + min_block_size: int = MIN_BLOCK_SIZE, + torch_executed_ops: Sequence[str] = set(), **kwargs, ): """Create torch.compile backend given specified arguments @@ -117,7 +120,8 @@ def create_backend( debug=debug, precision=precision, workspace_size=workspace_size, - max_num_trt_engines=max_num_trt_engines, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, ) return partial( diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py index 814331e158..06c4efa5fc 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -4,4 +4,4 @@ PRECISION = LowerPrecision.FP32 DEBUG = False MAX_WORKSPACE_SIZE = 20 << 30 -MAX_NUM_TRT_ENGINES = 10 +MIN_BLOCK_SIZE = 3 diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/backend/_settings.py index 7677b1bd57..8c1a807343 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/backend/_settings.py @@ -1,11 +1,12 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Sequence from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt.dynamo.backend._defaults import ( PRECISION, DEBUG, MAX_WORKSPACE_SIZE, - MAX_NUM_TRT_ENGINES, + MIN_BLOCK_SIZE, ) @@ -14,4 +15,5 @@ class CompilationSettings: precision: LowerPrecision = PRECISION debug: bool = DEBUG workspace_size: int = MAX_WORKSPACE_SIZE - max_num_trt_engines: int = MAX_NUM_TRT_ENGINES + min_block_size: int = MIN_BLOCK_SIZE + torch_executed_ops: Sequence[str] = field(default_factory=set) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 9df3f1c686..962cbe8eba 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -100,7 +100,10 @@ def _compile_module( """ # Partition module into components that can be TRT-accelerated partitioned_module = partition( - gm, verbose=settings.debug, max_num_trt_engines=settings.max_num_trt_engines + gm, + verbose=settings.debug, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, ) # Iterate over all components that can be accelerated diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py index 7aff1a79d1..d0bd5ed3b8 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py @@ -41,5 +41,20 @@ def inplace_op(*args, **kwargs): replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) +@register_decomposition(aten.std, registry=DECOMPOSITIONS) +def std_replacement(*args, **kwargs) -> torch.Tensor: + return torch.sqrt(torch.var(*args, **kwargs)) + + +@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS) +def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: + return torch.reciprocal(torch.sqrt(*args, **kwargs)) + + +@register_decomposition(aten.alias, registry=DECOMPOSITIONS) +def alias_replacement(x: torch.Tensor) -> torch.Tensor: + return x + + def get_decompositions(): return DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 1885d18705..b4d1b18db9 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -1,62 +1,159 @@ -from typing import Dict, Optional, Sequence +import logging +from typing import Dict, List, Optional, Sequence import torch -from torch_tensorrt.dynamo.backend._defaults import MAX_NUM_TRT_ENGINES -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition +from torch.fx.graph_module import GraphModule +from torch.fx.node import _get_qualified_name from torch.fx.passes.operator_support import OperatorSupport from torch_tensorrt.fx.converter_registry import CONVERTERS +logger = logging.getLogger(__name__) + + +class TRTPartitioner(CapabilityBasedPartitioner): + """Partitioner to split an FX graph into subgraphs based on operator support + + Args: + graph_module: FX GraphModule to partition + operator_support: OperatorSupport class describing allowed operators + non_compute_ops: Operators which are not considered computational (e.g. getattr) + allowed_single_node_partition_ops: Nodes which can be included in single-node partitons. + Generally useful for module-level exclusion ops which are intensive despite being single functions + min_block_size: Minimum number of computational operators per block + Returns: + torch.fx.GraphModule + """ + + def __init__( + self, + graph_module: GraphModule, + operator_support: OperatorSupport, + *, + non_compute_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + min_block_size=MIN_BLOCK_SIZE, + ) -> None: + super().__init__( + graph_module, + operator_support, + allows_single_node_partition=True, + non_compute_ops=non_compute_ops, + allowed_single_node_partition_ops=allowed_single_node_partition_ops, + ) + + self.min_block_size = min_block_size + + def propose_partitions(self) -> List[Partition]: + # Propose partitions using the default, then refine the results + initial_proposed_partitions = super().propose_partitions() + partitions = {i: part for i, part in enumerate(initial_proposed_partitions)} + + # For each partition, determine whether or not the number of computational operators + # exceeds the threshold, and if not, remove that partition + partitions_to_remove = {} + for id, partition in partitions.items(): + default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} + non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) + exempted_partition = False + + compute_node_count = 0 + for node in partition.nodes: + # Partitions are exempted from min_block_size if they contain an allowed single-node op + if ( + node.op == "call_function" + and _get_qualified_name(node.target) + in self.allowed_single_node_partition_ops + ): + exempted_partition = True + break + elif ( + node.op == "call_function" + and _get_qualified_name(node.target) not in non_compute_ops + ): + compute_node_count += 1 + + if compute_node_count < self.min_block_size and not exempted_partition: + partitions_to_remove[id] = compute_node_count + + # Remove any nodes violating the criteria specified by the user + for id, count in partitions_to_remove.items(): + logger.debug( + f"Removing partition which has {count} < {self.min_block_size} computational operators" + ) + del partitions[id] + + return [partitions[k] for k in sorted(partitions.keys())] + + def partition_and_fuse(self) -> GraphModule: + partitions = self.propose_partitions() + fused_gm = self.fuse_partitions(partitions) + return fused_gm + + class TorchTensorRTOperatorSupport(OperatorSupport): """Class to determine whether operators within a module are supported""" - def __init__(self, support_dict=None): + def __init__(self, support_dict=None, torch_executed_ops=set()): super().__init__(support_dict) # Initialize sets of supported/unsupported operators self.supported_operators = set() self.unsupported_operators = set() + self.torch_executed_ops = torch_executed_ops def is_node_supported( self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node ) -> bool: - if node.target in CONVERTERS.keys(): - # If node is a proper computational node, store the operator + node_name = ( + _get_qualified_name(node.target) + if not isinstance(node.target, str) + else node.target + ) + + if ( + node.target in CONVERTERS.keys() + and node_name not in self.torch_executed_ops + ): + # If node is a proper, supported computational node, store the operator if not node.is_impure(): - node_name = node._pretty_print_target(node.target) self.supported_operators.add(node_name) return True else: if not node.is_impure(): - node_name = node._pretty_print_target(node.target) self.unsupported_operators.add(node_name) return False def print_support_overview(self, num_trt_blocks: Optional[int] = None): if num_trt_blocks is not None: - print(f"\nNumber of TensorRT-Accelerated Subgraphs: {num_trt_blocks}") + logger.debug( + f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}" + ) - print("\nSupported Nodes:") + logger.debug("\nSupported Nodes:") for node_name in self.supported_operators: - print("-", node_name) + logger.debug("-", node_name) if len(self.unsupported_operators) != 0: - print("\nUnsupported Nodes:") + logger.debug("\nUnsupported or Excluded Nodes:") for node_name in self.unsupported_operators: - print("-", node_name) - print("\n") + logger.debug("-", node_name) + logger.debug("\n") else: - print("\nAll Nodes Supported\n") + logger.debug("\nAll Nodes Supported\n") def partition( gm: torch.fx.GraphModule, verbose: bool = True, - max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, + min_block_size: int = MIN_BLOCK_SIZE, + torch_executed_ops: Sequence[str] = set(), ) -> torch.fx.GraphModule: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -64,29 +161,21 @@ def partition( Args: gm: FX GraphModule to partition verbose: Bool representing whether to print operator support - max_num_trt_engines: Maximum number of allowed TRT engines in partitioning + min_block_size: Minimum number of operators per TRT-Engine Block + torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage Returns: torch.fx.GraphModule """ - supported_ops = TorchTensorRTOperatorSupport() - partitioner = CapabilityBasedPartitioner(gm, supported_ops) + supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops) + partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size) - # Determine partitions, and raise error if the degree of partitioning - # exceeds a specified threshold + # Determine partitions based on user specifications and operator support + # Then, fuse partitions and display overview of supported/unsupported operators partitions = partitioner.propose_partitions() - num_blocks = len(partitions) - if num_blocks > max_num_trt_engines: - raise AssertionError( - f"The graph module has {num_blocks} TRT Engines which is larger than the " - + f"threshold={max_num_trt_engines}. Falling back to non-TRT module." - ) - - # Fuse partitions and display overview of supported/unsupported operators fused_graph = partitioner.fuse_partitions(partitions) - num_blocks = len(partitions) if verbose: - supported_ops.print_support_overview(num_blocks) + supported_ops.print_support_overview(len(partitions)) return fused_graph diff --git a/py/torch_tensorrt/dynamo/backend/test/test_lowering.py b/py/torch_tensorrt/dynamo/backend/test/test_lowering.py index d14acb815b..6b7651957f 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_lowering.py @@ -1,12 +1,12 @@ from functools import partial -from utils import fx_dynamo_testing_backend +from utils import lower_graph_testing from torch.testing._internal.common_utils import run_tests, TestCase import torch class TestLowering(TestCase): def test_lowering_inplace_op(self): - class FullySupported(torch.nn.Module): + class InPlace(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -18,35 +18,95 @@ def forward(self, x, y): # Operations expected to be included in the traced graph after decompositions expected_ops = {torch.ops.aten.add.Tensor, torch.ops.aten.relu.default} - # Trace module and set up custom backend to track intermediate graphs - fx_graph = torch.fx.symbolic_trace(FullySupported()) - partitioned_graphs = [] - custom_backend = partial( - fx_dynamo_testing_backend, - store_intermediate_graphs=partitioned_graphs, - ) - - # Invoke compilation - compiled_graph = torch.compile(fx_graph, backend=custom_backend) - compiled_graph( + inputs = [ torch.rand( 5, - ).cuda(), + ), torch.rand( 5, - ).cuda(), + ), + ] + + fx_graph = torch.fx.symbolic_trace(InPlace()) + _, expected_ops_unseen = lower_graph_testing( + fx_graph, inputs, expected_ops=expected_ops, min_block_size=2 ) - # Iterate over intermediate graphs, attempt to match nodes - for fx_module in partitioned_graphs: - for _, submodule in fx_module.named_children(): - for node in submodule.graph.nodes: + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + def test_lowering_alias_replacement(self): + class Alias(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - if node.op == "call_function" and node.target in expected_ops: - expected_ops.remove(node.target) + def forward(self, x): + y = torch.ops.aten.alias.default(x) + return y + + # Operations expected to be removed in the traced graph after decompositions + unexpected_ops = {torch.ops.aten.alias.default} + + inputs = [ + torch.rand( + 5, + ), + ] + + fx_graph = torch.fx.symbolic_trace(Alias()) + unexpected_ops_seen, _ = lower_graph_testing( + fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1 + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + def test_lowering_rsqrt(self): + class Rsqrt(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x): + y = torch.ops.aten.rsqrt.default(x) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.sqrt.default, torch.ops.aten.reciprocal.default} + unexpected_ops = {torch.ops.aten.rsqrt.default} + + inputs = [ + torch.randint( + 1, + 10, + (5,), + ), + ] + + fx_graph = torch.fx.symbolic_trace(Rsqrt()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) - self.assertEqual( - len(expected_ops), 0, "All operators should have been decomposed" + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", ) diff --git a/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py index fccdd3c32e..edc9ea7ee1 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py @@ -1,5 +1,6 @@ from torch_tensorrt.dynamo.backend.lowering import partition from torch.testing._internal.common_utils import run_tests, TestCase +from utils import lower_graph_testing import torch from copy import deepcopy import numpy as np @@ -16,7 +17,7 @@ def forward(self, x, y): fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) partitioned_graph = partition(deepcopy(fx_graph)) - self.assertEqual( + self.assertEquals( len(list(partitioned_graph.named_children())), 0, "Single operators should not be segmented", @@ -36,7 +37,7 @@ def forward(self, x, y): fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) partitioned_graph = partition(deepcopy(fx_graph)) - self.assertEqual( + self.assertEquals( len(list(partitioned_graph.named_children())), 1, "All operators are supported, there should be one segment", @@ -56,13 +57,68 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - partitioned_graph = partition(deepcopy(fx_graph)) - self.assertEqual( + partitioned_graph = partition(deepcopy(fx_graph), min_block_size=2) + self.assertEquals( len(list(partitioned_graph.named_children())), 2, "Unsupported operators interleave supported ones, expected 2 segments", ) + def test_partition_partially_supported_with_torch_executed_ops(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_1 = torch.ops.aten.add.Tensor(x, y) + sum_2 = torch.ops.aten.add.Tensor(x, sum_1) + sum_ = torch.ops.aten.add.Tensor(sum_1, sum_2) + relu_ = torch.ops.aten.relu.default(sum_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + unexpected_ops = {torch.ops.aten.add.Tensor} + + inputs = [ + torch.randint( + 1, + 10, + (5,), + ), + torch.randint( + 1, + 10, + (5,), + ), + ] + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + (unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing( + fx_graph, + inputs, + unexpected_ops=unexpected_ops, + min_block_size=2, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, + testing_partitioning=True, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(partitioned_graphs), + 1, + "Without control flow breaks, there should only be a single graph", + ) + self.assertEquals( + len(list(partitioned_graphs[0].named_children())), + 1, + "Certain operators are set to run in Torch, expected 1 segment", + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/dynamo/backend/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py index 466a600db8..d59b710faf 100644 --- a/py/torch_tensorrt/dynamo/backend/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -1,6 +1,6 @@ from copy import deepcopy from functools import partial -from typing import List, Sequence +from typing import Any, List, Sequence, Set import torch from torch_tensorrt.dynamo.backend.lowering._decompositions import ( get_decompositions, @@ -20,11 +20,15 @@ def fx_dynamo_testing_backend( sample_inputs: Sequence[torch.Tensor], *, store_intermediate_graphs: List, + min_block_size: int = 3, + torch_executed_ops: Sequence[str] = set(), ): """Helper Dynamo backend exclusively for testing""" custom_backend = partial( compile_module_testing, store_intermediate_graphs=store_intermediate_graphs, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, ) # Invoke AOTAutograd to translate operators to aten @@ -41,9 +45,13 @@ def compile_module_testing( example_inputs: Sequence[torch.Tensor], *, store_intermediate_graphs: List, + min_block_size: int = 3, + torch_executed_ops: Sequence[str] = str(), ) -> torch.fx.GraphModule: """Helper compiler exclusively for testing""" - partitioned_module = partition(gm) + partitioned_module = partition( + gm, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops + ) # Store intermediate graph from partitioned module store_intermediate_graphs.append(deepcopy(partitioned_module)) @@ -52,6 +60,18 @@ def compile_module_testing( def same_output_format(trt_output, torch_output, enforce_tensor_type=True): + """Determines whether two objects containing Tensors have the same format + + ((Tensor, Tensor), Tensor) and (Tensor (Tensor, Tensor)) do not + have the same format, for example. + + Args: + trt_output: TensorRT output + torch_output: Torch output + enforce_tensor_type: Whether to enforce Tensor type equivalence + Returns: + bool: True if the outputs have the same format + """ # For each encountered collection type, ensure the torch and trt outputs agree # on type and size, checking recursively through all member elements. if isinstance(trt_output, tuple): @@ -92,3 +112,73 @@ def same_output_format(trt_output, torch_output, enforce_tensor_type=True): return type(trt_output) is type(torch_output) else: return True + + +def lower_graph_testing( + fx_graph: torch.fx.GraphModule, + inputs: Any, + *, + expected_ops: Set = set(), + unexpected_ops: Set = set(), + min_block_size: int = 3, + torch_executed_ops: Sequence[str] = set(), + testing_partitioning: bool = False, +): + """Helper function to assist with graph lowering for testing of Dynamo torch_compile + + Args: + fx_graph: Graph to lower + inputs: Input values to the FX graph + expected_ops: Operations to be expected in the lowered graph + unexpected_ops: Operations not to be expected in the lowered graph + min_block_size: Minimum number of operators per TRT-Engine Block + torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage + testing_partitioning: Whether partitioning is being tested (to analyze only TRT-supported ops) + Returns: + If testing_partitioning: + List[torch.fx.GraphModule], Set, Set: List of partitioned graph outputs, unexpected ops seen, expected ops unseen + Else: + Set, Set: unexpected ops seen and expected ops unseen (If the run was successful, both sets should be empty) + """ + # Trace module and set up custom backend to track intermediate graphs + partitioned_graphs = [] + custom_backend = partial( + fx_dynamo_testing_backend, + store_intermediate_graphs=partitioned_graphs, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + ) + + # Invoke compilation + compiled_graph = torch.compile(fx_graph, backend=custom_backend) + compiled_graph(*inputs) + + unexpected_ops_seen = set() + expected_ops_seen = set() + + def classify_node(node: torch.fx.Node): + if node.target in unexpected_ops: + unexpected_ops_seen.add(node.target) + elif node.target in expected_ops: + expected_ops_seen.add(node.target) + + # Iterate over intermediate graphs, attempt to match nodes + # If an unexpected or expected op is encountered, register it + for fx_module in partitioned_graphs: + # For each function call in the set of graph nodes, classify the node + for top_level_node in fx_module.graph.nodes: + if top_level_node.op == "call_function" and not testing_partitioning: + classify_node(top_level_node) + elif top_level_node.op == "call_module": + for node in fx_module.get_submodule(top_level_node.target).graph.nodes: + classify_node(node) + + # Return unexpected ops seen and expected ops unseen + # If the run was successful, both sets should be empty + expected_ops_unseen = expected_ops.difference(expected_ops_seen) + + if testing_partitioning: + return unexpected_ops_seen, expected_ops_unseen, partitioned_graphs + + else: + return unexpected_ops_seen, expected_ops_unseen From 357d18dc0d525eb585bfedf139c5b10290cb1455 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 18 May 2023 13:24:23 -0700 Subject: [PATCH 2/2] fix: Remove remnants of max_num_trt_engines - Update default `min_block_size` --- py/torch_tensorrt/dynamo/backend/_defaults.py | 2 +- py/torch_tensorrt/dynamo/backend/test/test_partitioning.py | 2 +- py/torch_tensorrt/dynamo/test/test_dynamo_backend.py | 4 ---- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py index 06c4efa5fc..b1ee62dfa3 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -4,4 +4,4 @@ PRECISION = LowerPrecision.FP32 DEBUG = False MAX_WORKSPACE_SIZE = 20 << 30 -MIN_BLOCK_SIZE = 3 +MIN_BLOCK_SIZE = 5 diff --git a/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py index edc9ea7ee1..fb5430b384 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py @@ -36,7 +36,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partition(deepcopy(fx_graph)) + partitioned_graph = partition(deepcopy(fx_graph), min_block_size=2) self.assertEquals( len(list(partitioned_graph.named_children())), 1, diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index 531d0cc317..b86817df56 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -24,7 +24,6 @@ def test_resnet18(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, - "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -55,7 +54,6 @@ def test_mobilenet_v2(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, - "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -86,7 +84,6 @@ def test_efficientnet_b0(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, - "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -126,7 +123,6 @@ def test_bert_base_uncased(ir): "enabled_precisions": {torch.float}, "truncate_long_and_double": True, "ir": ir, - "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec)