From ee05b59fb6d3a56174ebbd83fda7342f42d9c835 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 12 May 2023 10:22:07 -0700 Subject: [PATCH 1/4] feat: Prototype Module-Acceleration in Dynamo - Add support for excluding entire Torch modules from tracing in Dynamo using Torch custom operators - Develop new dataclass to store required replacement functions and operators in a streamlined way - Add new registry to store mapping between replacement operators and their corresponding dataclass - Add documentation for easy additions of new module-level exclusion operators --- .circleci/config.yml | 2 +- py/torch_tensorrt/dynamo/backend/__init__.py | 3 + py/torch_tensorrt/dynamo/backend/backends.py | 12 ++ .../backend/lowering/_pre_aot_lowering.py | 164 ++++++++++++++++++ 4 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 5422b31a5a..070e4cc544 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -258,7 +258,7 @@ commands: name: Set up python environment command: | pip3 install --upgrade pip - pip3 install wheel setuptools + pip3 install wheel setuptools pyyaml pip3 install nvidia-pyindex pip3 install tabulate pip3 install tensorrt==<< parameters.trt-version-long >> nvidia-cudnn-cu11==<< parameters.cudnn-version-long >> diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 3743b263db..aef1538693 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -50,6 +50,9 @@ def compile( if debug: logger.setLevel(logging.DEBUG) + if debug: + logger.setLevel(logging.DEBUG) + logger.warn( "The Dynamo backend is an experimental feature, for which only the " + "following arguments are supported: " diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 8f6408492a..ea0d398a44 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -8,6 +8,9 @@ from torch_tensorrt.dynamo.backend.lowering._decompositions import ( get_decompositions, ) +from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( + pre_aot_module_replacement, +) from torch_tensorrt.dynamo.backend.lowering._partition import ( partition, get_submod_inputs, @@ -46,6 +49,13 @@ def aot_torch_tensorrt_aten_backend( settings=settings, ) + logger.debug("Pre-module replacement graph:\n" + str(gm.graph)) + + # Enable Pre-AOT Lowering for Module-Level Replacement + gm = pre_aot_module_replacement(gm) + + logger.debug("Post-module replacement graph:\n" + str(gm.graph)) + # Invoke AOTAutograd to translate operators to aten return aot_module_simplified( gm, @@ -71,6 +81,8 @@ def _pretraced_backend( Compiled FX GraphModule """ try: + logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) + trt_compiled = _compile_module( gm, sample_inputs, diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py new file mode 100644 index 0000000000..e2f3653121 --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py @@ -0,0 +1,164 @@ +from dataclasses import dataclass +import traceback +from typing import Callable, Dict, Tuple +import torch +from torch._custom_op import custom_op +from torch.fx.node import Argument, Target +import logging + +from torch_tensorrt.fx.converter_registry import tensorrt_converter +from torch_tensorrt.fx.converters import acc_ops_converters +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +logger = logging.getLogger(__name__) + + +@custom_op( + "(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor", + ns="tensorrt", +) +def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): + # Defines operator schema, name, namespace, and function header + ... + + +@maxpool1d.impl("cpu") +@maxpool1d.impl("cuda") +def maxpool1d_generic( + *args, + **kwargs, +): + # Defines a converter implementation for Autograd to use for shape analysis/propagation + return torch.nn.functional.max_pool1d( + *args, + **kwargs, + ) + + +def maxpool1d_insertion_fn( + gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node +) -> torch.fx.Node: + # Defines insertion function for new node + new_node = gm.graph.call_function( + torch.ops.tensorrt.maxpool1d, + args=node.args, + kwargs={ + "kernel_size": submodule.kernel_size, + "stride": submodule.stride, + "padding": submodule.padding, + "dilation": submodule.dilation, + "ceil_mode": submodule.ceil_mode, + }, + ) + + return new_node + + +@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) +def aten_ops_maxpool1d( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + # Defines converter replacing the default operator for this function + kwargs_new = { + "input": args[0], + "kernel_size": args[1], + "stride": args[2], + "padding": args[3], + "dilation": args[4], + "ceil_mode": False if len(args) < 6 else args[5], + } + + return acc_ops_converters.acc_ops_max_pool1d( + network, target, None, kwargs_new, name + ) + + +@dataclass(frozen=True) +class ModuleReplacement: + """Class to store key functionality for module replacement""" + + # torch.ops.___ name for replacement function for module + new_operator: torch._ops.OpOverload + + # Function taking a containing graph, a submodule, and a 'call_module' node and returning + # a replacement node, with type 'call_function', or raising an Error if incompatibility is detected + # Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph + subgraph_insertion_fn: Callable[ + [torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node + ] + + +# Dictionary mapping module to ModuleReplacement instance +MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = { + torch.nn.MaxPool1d: ModuleReplacement( + new_operator=torch.ops.tensorrt.maxpool1d, + subgraph_insertion_fn=maxpool1d_insertion_fn, + ), +} + + +def pre_aot_module_replacement(gm: torch.fx.GraphModule): + """Perform module-level graph replacement prior to AOT tracing + + Args: + gm: FX GraphModule to perform module replacement on + Returns: + torch.fx.GraphModule + + """ + # Ensure all parameters are in inference mode + for param in gm.parameters(): + param.requires_grad = False + + # Iterate over graph nodes, extracting module calls, to check for interceptions + for n in gm.graph.nodes: + if n.op == "call_module": + # Extract submodule from graph + submodule = gm.get_submodule(n.target) + + # If submodule is a member of the substitution registry, replace it + if type(submodule) in MODULE_SUBSTITUTION_REGISTRY: + + try: + replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)] + op, insertion_fn = ( + replacement.new_operator, + replacement.subgraph_insertion_fn, + ) + logger.debug( + f"Replacing module of type {type(submodule)} with {op}" + ) + + # Insert new node prior to older node + with gm.graph.inserting_before(n): + new_node = insertion_fn(gm, submodule, n) + + # If submodule is not a native torch.nn module, it must be manually excluded + # from Dynamo tracing + if not type(submodule).__module__.startswith("torch.nn"): + torch._dynamo.allowed_functions._allowed_function_ids.add( + id(type(submodule)) + ) + + # Replace all original node uses and delete node + n.replace_all_uses_with(new_node) + gm.graph.eliminate_dead_code() + gm.recompile() + + # A module replacement can fail in the event that the specific instance of the submodule cannot + # be replaced + except Exception: + logger.debug( + f"Encountered the following error while replacing {type(submodule)}" + ) + logger.debug(traceback.format_exc()) + continue + + # Perform cleanup and recompilation before returning module + gm.graph.eliminate_dead_code() + gm.recompile() + return gm From 60df50e4d1d91be65349a5e91025db29a257bb9b Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 24 May 2023 21:02:33 -0700 Subject: [PATCH 2/4] fix: Refactor code and add testing --- py/torch_tensorrt/dynamo/backend/__init__.py | 3 - .../dynamo/backend/lowering/__init__.py | 10 +- .../dynamo/backend/lowering/_partition.py | 12 +- .../backend/lowering/_pre_aot_lowering.py | 115 ++++++------------ .../lowering/module_substitutions/__init__.py | 1 + .../module_substitutions/maxpool1d.py | 75 ++++++++++++ .../backend/test/test_pre_aot_lowering.py | 55 +++++++++ .../dynamo/backend/test/utils.py | 5 + 8 files changed, 187 insertions(+), 89 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py create mode 100644 py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py create mode 100644 py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index aef1538693..3743b263db 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -50,9 +50,6 @@ def compile( if debug: logger.setLevel(logging.DEBUG) - if debug: - logger.setLevel(logging.DEBUG) - logger.warn( "The Dynamo backend is an experimental feature, for which only the " + "following arguments are supported: " diff --git a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py index 01b20cef6d..1a0cbab2df 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py @@ -1,7 +1,9 @@ -from torch_tensorrt.dynamo.backend.lowering._decompositions import ( +from ._decompositions import ( get_decompositions, ) -from torch_tensorrt.dynamo.backend.lowering._partition import ( - partition, - get_submod_inputs, +from ._pre_aot_lowering import ( + MODULE_SUBSTITUTION_REGISTRY, + module_substitution, ) +from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS +from .module_substitutions import * diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 5cd83d768c..431f08be86 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -1,9 +1,10 @@ import logging -from typing import Dict, List, Optional, Sequence +from typing import Dict, List, Optional, Sequence, Set import torch from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE +from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.graph_module import GraphModule from torch.fx.node import _get_qualified_name @@ -14,6 +15,11 @@ logger = logging.getLogger(__name__) +DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set( + "torch.ops." + str(module.new_operator) + for module in MODULE_SUBSTITUTION_REGISTRY.values() +) + class TRTPartitioner(CapabilityBasedPartitioner): """Partitioner to split an FX graph into subgraphs based on operator support @@ -35,7 +41,9 @@ def __init__( operator_support: OperatorSupport, *, non_compute_ops: Optional[Sequence[str]] = None, - allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[ + Sequence[str] + ] = DEFAULT_SINGLE_NODE_PARTITIONS, min_block_size=MIN_BLOCK_SIZE, ) -> None: super().__init__( diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py index e2f3653121..2c331824de 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py @@ -1,82 +1,12 @@ from dataclasses import dataclass -import traceback -from typing import Callable, Dict, Tuple +from typing import Any, Callable, Dict import torch -from torch._custom_op import custom_op -from torch.fx.node import Argument, Target import logging -from torch_tensorrt.fx.converter_registry import tensorrt_converter -from torch_tensorrt.fx.converters import acc_ops_converters -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor logger = logging.getLogger(__name__) -@custom_op( - "(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor", - ns="tensorrt", -) -def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): - # Defines operator schema, name, namespace, and function header - ... - - -@maxpool1d.impl("cpu") -@maxpool1d.impl("cuda") -def maxpool1d_generic( - *args, - **kwargs, -): - # Defines a converter implementation for Autograd to use for shape analysis/propagation - return torch.nn.functional.max_pool1d( - *args, - **kwargs, - ) - - -def maxpool1d_insertion_fn( - gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node -) -> torch.fx.Node: - # Defines insertion function for new node - new_node = gm.graph.call_function( - torch.ops.tensorrt.maxpool1d, - args=node.args, - kwargs={ - "kernel_size": submodule.kernel_size, - "stride": submodule.stride, - "padding": submodule.padding, - "dilation": submodule.dilation, - "ceil_mode": submodule.ceil_mode, - }, - ) - - return new_node - - -@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) -def aten_ops_maxpool1d( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> TRTTensor: - # Defines converter replacing the default operator for this function - kwargs_new = { - "input": args[0], - "kernel_size": args[1], - "stride": args[2], - "padding": args[3], - "dilation": args[4], - "ceil_mode": False if len(args) < 6 else args[5], - } - - return acc_ops_converters.acc_ops_max_pool1d( - network, target, None, kwargs_new, name - ) - - @dataclass(frozen=True) class ModuleReplacement: """Class to store key functionality for module replacement""" @@ -93,12 +23,37 @@ class ModuleReplacement: # Dictionary mapping module to ModuleReplacement instance -MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = { - torch.nn.MaxPool1d: ModuleReplacement( - new_operator=torch.ops.tensorrt.maxpool1d, - subgraph_insertion_fn=maxpool1d_insertion_fn, - ), -} +MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = dict() + + +def module_substitution( + module_to_replace: torch.nn.Module, + new_operator: torch._ops.OpOverload, + enabled: bool = True, +) -> Callable[[Any], Any]: + """Decorator to register subgraph insertion functions + + Args: + module_to_replace: nn.Module to replace + new_operator: Custom torch operator to replace with + enabled: Whether the substitution is enabled or disabled + Returns: + torch.fx.GraphModule + """ + + def register_substitution(subgraph_insertion_fn): + """Function for use if substitution is enabled""" + module_replacement = ModuleReplacement( + new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn + ) + MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement + return subgraph_insertion_fn + + def disable_substitution(subgraph_insertion_fn): + """Function for use if substitution is disabled""" + return subgraph_insertion_fn + + return register_substitution if enabled else disable_substitution def pre_aot_module_replacement(gm: torch.fx.GraphModule): @@ -144,7 +99,7 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule): id(type(submodule)) ) - # Replace all original node uses and delete node + # Replace all original node uses and clean up graph n.replace_all_uses_with(new_node) gm.graph.eliminate_dead_code() gm.recompile() @@ -153,9 +108,9 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule): # be replaced except Exception: logger.debug( - f"Encountered the following error while replacing {type(submodule)}" + f"Encountered error while replacing {type(submodule)}", + exc_info=True, ) - logger.debug(traceback.format_exc()) continue # Perform cleanup and recompilation before returning module diff --git a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py b/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py new file mode 100644 index 0000000000..4b8ba88e34 --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py @@ -0,0 +1 @@ +from .maxpool1d import * diff --git a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py b/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py new file mode 100644 index 0000000000..090df11835 --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py @@ -0,0 +1,75 @@ +from typing import Dict, Tuple +import torch +from torch._custom_op import custom_op +from torch.fx.node import Argument, Target + +from torch_tensorrt.fx.converter_registry import tensorrt_converter +from torch_tensorrt.fx.converters import acc_ops_converters +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +from torch_tensorrt.dynamo.backend.lowering import module_substitution + + +@custom_op( + "(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor", + ns="tensorrt", +) +def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): + # Defines operator schema, name, namespace, and function header + ... + + +@maxpool1d.impl("cpu") +@maxpool1d.impl("cuda") +def maxpool1d_generic( + *args, + **kwargs, +): + # Defines a converter implementation for AOT Autograd to use for shape analysis/propagation + return torch.nn.functional.max_pool1d( + *args, + **kwargs, + ) + + +@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) +def aten_ops_maxpool1d( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + # Defines converter replacing the default operator for this function + kwargs_new = { + "input": args[0], + "kernel_size": args[1], + "stride": args[2], + "padding": args[3], + "dilation": args[4], + "ceil_mode": False if len(args) < 6 else args[5], + } + + return acc_ops_converters.acc_ops_max_pool1d( + network, target, None, kwargs_new, name + ) + + +@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) +def maxpool1d_insertion_fn( + gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node +) -> torch.fx.Node: + # Defines insertion function for new node + new_node = gm.graph.call_function( + torch.ops.tensorrt.maxpool1d, + args=node.args, + kwargs={ + "kernel_size": submodule.kernel_size, + "stride": submodule.stride, + "padding": submodule.padding, + "dilation": submodule.dilation, + "ceil_mode": submodule.ceil_mode, + }, + ) + + return new_node diff --git a/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py new file mode 100644 index 0000000000..2fa65bfabc --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py @@ -0,0 +1,55 @@ +import torch +from utils import lower_graph_testing +from torch.testing._internal.common_utils import run_tests, TestCase +from torch_tensorrt.dynamo import compile + + +class TestMaxPool1D(TestCase): + def test_pre_aot_lowering_maxpool1d(self): + class MaxPool1D(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.maxpool = torch.nn.MaxPool1d(2) + + def forward(self, x): + return self.maxpool(x) + + # Operations expected to be included in the traced graph after decompositions + expected_ops = {torch.ops.tensorrt.maxpool1d.default} + + inputs = [ + torch.rand( + 9, + 16, + 2, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(MaxPool1D()) + _, expected_ops_unseen = lower_graph_testing( + fx_graph, inputs, expected_ops=expected_ops, min_block_size=1 + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) + self.assertAlmostEqual( + max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model." + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/backend/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py index 48f6443e32..e7dc435ac4 100644 --- a/py/torch_tensorrt/dynamo/backend/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -8,6 +8,9 @@ from torch_tensorrt.dynamo.backend.lowering._partition import ( partition, ) +from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( + pre_aot_module_replacement, +) from torch._dynamo.backends.common import fake_tensor_unsupported @@ -31,6 +34,8 @@ def fx_dynamo_testing_backend( torch_executed_ops=torch_executed_ops, ) + gm = pre_aot_module_replacement(gm) + # Invoke AOTAutograd to translate operators to aten return aot_module_simplified( gm, From af53282dd2d042c64c52fd73dd8be8f26c82412d Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 31 May 2023 12:41:18 -0700 Subject: [PATCH 3/4] fix: Address review comments - Fix typing issues, add depedencies to `setup.py`, add qualified name checking for module registry --- .circleci/config.yml | 2 +- py/setup.py | 2 ++ py/torch_tensorrt/dynamo/backend/lowering/_partition.py | 2 +- .../dynamo/backend/lowering/_pre_aot_lowering.py | 8 +++++--- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 070e4cc544..5422b31a5a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -258,7 +258,7 @@ commands: name: Set up python environment command: | pip3 install --upgrade pip - pip3 install wheel setuptools pyyaml + pip3 install wheel setuptools pip3 install nvidia-pyindex pip3 install tabulate pip3 install tensorrt==<< parameters.trt-version-long >> nvidia-cudnn-cu11==<< parameters.cudnn-version-long >> diff --git a/py/setup.py b/py/setup.py index b870560ae5..eb382559f8 100644 --- a/py/setup.py +++ b/py/setup.py @@ -427,6 +427,8 @@ def run(self): ext_modules=ext_modules, install_requires=[ "torch >=2.1.dev,<2.2" if not LEGACY else "torch >=1.13.0,<2.0", + "pyyaml", + "packaging", ], setup_requires=[], cmdclass={ diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 431f08be86..496c91a089 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set( - "torch.ops." + str(module.new_operator) + _get_qualified_name(module.new_operator) for module in MODULE_SUBSTITUTION_REGISTRY.values() ) diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py index 2c331824de..738a398a51 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Type import torch import logging @@ -23,11 +23,11 @@ class ModuleReplacement: # Dictionary mapping module to ModuleReplacement instance -MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = dict() +MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict() def module_substitution( - module_to_replace: torch.nn.Module, + module_to_replace: Type[torch.nn.Module], new_operator: torch._ops.OpOverload, enabled: bool = True, ) -> Callable[[Any], Any]: @@ -102,6 +102,7 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule): # Replace all original node uses and clean up graph n.replace_all_uses_with(new_node) gm.graph.eliminate_dead_code() + gm.graph.lint() gm.recompile() # A module replacement can fail in the event that the specific instance of the submodule cannot @@ -115,5 +116,6 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule): # Perform cleanup and recompilation before returning module gm.graph.eliminate_dead_code() + gm.graph.lint() gm.recompile() return gm From 356ed0a970f347dc4671b6e5cf8c1bd1e19df8f0 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 31 May 2023 13:40:26 -0700 Subject: [PATCH 4/4] fix: Add support for general-purpose exclusion - Add functionality for advanced exclusion of both function and module-type nodes in Torch-TRT - Add sample exclusion for `torch.einsum` function which can be accelerated as a single unit via TRT - Add utilities and improve module and function-level exclusion mechanisms - Add test cases for new exclusion mechanism --- py/torch_tensorrt/dynamo/backend/backends.py | 4 +- .../dynamo/backend/lowering/__init__.py | 6 +- .../dynamo/backend/lowering/_partition.py | 6 +- .../backend/lowering/_pre_aot_lowering.py | 121 ++++++++++-------- .../__init__.py | 1 + .../backend/lowering/substitutions/einsum.py | 79 ++++++++++++ .../maxpool1d.py | 8 +- .../backend/test/test_pre_aot_lowering.py | 46 +++++++ .../dynamo/backend/test/utils.py | 4 +- 9 files changed, 208 insertions(+), 67 deletions(-) rename py/torch_tensorrt/dynamo/backend/lowering/{module_substitutions => substitutions}/__init__.py (53%) create mode 100644 py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py rename py/torch_tensorrt/dynamo/backend/lowering/{module_substitutions => substitutions}/maxpool1d.py (89%) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ea0d398a44..4493f7d007 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -9,7 +9,7 @@ get_decompositions, ) from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( - pre_aot_module_replacement, + pre_aot_substitutions, ) from torch_tensorrt.dynamo.backend.lowering._partition import ( partition, @@ -52,7 +52,7 @@ def aot_torch_tensorrt_aten_backend( logger.debug("Pre-module replacement graph:\n" + str(gm.graph)) # Enable Pre-AOT Lowering for Module-Level Replacement - gm = pre_aot_module_replacement(gm) + gm = pre_aot_substitutions(gm) logger.debug("Post-module replacement graph:\n" + str(gm.graph)) diff --git a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py index 1a0cbab2df..dd55d2c83c 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py @@ -2,8 +2,8 @@ get_decompositions, ) from ._pre_aot_lowering import ( - MODULE_SUBSTITUTION_REGISTRY, - module_substitution, + SUBSTITUTION_REGISTRY, + register_substitution, ) from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS -from .module_substitutions import * +from .substitutions import * diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 496c91a089..4d82bf4be5 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -4,7 +4,7 @@ import torch from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE -from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY +from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.graph_module import GraphModule from torch.fx.node import _get_qualified_name @@ -16,8 +16,8 @@ logger = logging.getLogger(__name__) DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set( - _get_qualified_name(module.new_operator) - for module in MODULE_SUBSTITUTION_REGISTRY.values() + _get_qualified_name(to_replace.new_operator) + for to_replace in SUBSTITUTION_REGISTRY.values() ) diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py index 738a398a51..9cca38bea4 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Callable, Dict, Type +from typing import Any, Callable, Dict, Optional, Type, Union import torch import logging @@ -8,59 +8,62 @@ @dataclass(frozen=True) -class ModuleReplacement: +class Substitution: """Class to store key functionality for module replacement""" # torch.ops.___ name for replacement function for module new_operator: torch._ops.OpOverload - # Function taking a containing graph, a submodule, and a 'call_module' node and returning - # a replacement node, with type 'call_function', or raising an Error if incompatibility is detected + # Function taking a containing graph, a node, and optionally a submodule (if replacing a module) + # and returning a replacement node, with type 'call_function', or raising an Error if + # incompatibility is detected # Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph subgraph_insertion_fn: Callable[ - [torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node + [torch.fx.GraphModule, torch.fx.Node, Optional[torch.nn.Module]], torch.fx.Node ] -# Dictionary mapping module to ModuleReplacement instance -MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict() +# Dictionary mapping module to Substitution instance +SUBSTITUTION_REGISTRY: Dict[ + Union[Type[torch.nn.Module], Callable], Substitution +] = dict() -def module_substitution( - module_to_replace: Type[torch.nn.Module], +def register_substitution( + module_or_function_to_replace: Union[Type[torch.nn.Module], Callable], new_operator: torch._ops.OpOverload, enabled: bool = True, ) -> Callable[[Any], Any]: """Decorator to register subgraph insertion functions Args: - module_to_replace: nn.Module to replace + module_or_function_to_replace: nn.Module or node target Callable to replace new_operator: Custom torch operator to replace with enabled: Whether the substitution is enabled or disabled Returns: torch.fx.GraphModule """ - def register_substitution(subgraph_insertion_fn): + def enable_substitution(subgraph_insertion_fn): """Function for use if substitution is enabled""" - module_replacement = ModuleReplacement( + replacement = Substitution( new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn ) - MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement + SUBSTITUTION_REGISTRY[module_or_function_to_replace] = replacement return subgraph_insertion_fn def disable_substitution(subgraph_insertion_fn): """Function for use if substitution is disabled""" return subgraph_insertion_fn - return register_substitution if enabled else disable_substitution + return enable_substitution if enabled else disable_substitution -def pre_aot_module_replacement(gm: torch.fx.GraphModule): - """Perform module-level graph replacement prior to AOT tracing +def pre_aot_substitutions(gm: torch.fx.GraphModule): + """Perform graph substitutions prior to AOT tracing Args: - gm: FX GraphModule to perform module replacement on + gm: FX GraphModule to perform substitution on Returns: torch.fx.GraphModule @@ -71,48 +74,58 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule): # Iterate over graph nodes, extracting module calls, to check for interceptions for n in gm.graph.nodes: + exists_in_registry = False + to_replace = None + if n.op == "call_module": - # Extract submodule from graph + # Extract submodule from graph, validate in registry submodule = gm.get_submodule(n.target) - - # If submodule is a member of the substitution registry, replace it - if type(submodule) in MODULE_SUBSTITUTION_REGISTRY: - - try: - replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)] - op, insertion_fn = ( - replacement.new_operator, - replacement.subgraph_insertion_fn, - ) - logger.debug( - f"Replacing module of type {type(submodule)} with {op}" + to_replace = type(submodule) + exists_in_registry = to_replace in SUBSTITUTION_REGISTRY + elif n.op == "call_function": + # Extract function from graph, validate in registry + to_replace = n.target + exists_in_registry = n.target in SUBSTITUTION_REGISTRY + + # If submodule/function is a member of the substitution registry, replace it + if exists_in_registry: + try: + replacement = SUBSTITUTION_REGISTRY[to_replace] + op, insertion_fn = ( + replacement.new_operator, + replacement.subgraph_insertion_fn, + ) + logger.debug(f"Replacing node of type {to_replace} with {op}") + + # Insert new node prior to older node + with gm.graph.inserting_before(n): + new_node = insertion_fn( + gm, n, submodule if n.op == "call_module" else None ) - # Insert new node prior to older node - with gm.graph.inserting_before(n): - new_node = insertion_fn(gm, submodule, n) - - # If submodule is not a native torch.nn module, it must be manually excluded - # from Dynamo tracing - if not type(submodule).__module__.startswith("torch.nn"): - torch._dynamo.allowed_functions._allowed_function_ids.add( - id(type(submodule)) - ) - - # Replace all original node uses and clean up graph - n.replace_all_uses_with(new_node) - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - - # A module replacement can fail in the event that the specific instance of the submodule cannot - # be replaced - except Exception: - logger.debug( - f"Encountered error while replacing {type(submodule)}", - exc_info=True, + # If submodule is not a native torch.nn module, it must be manually excluded + # from Dynamo tracing + if n.op == "call_module" and not type(submodule).__module__.startswith( + "torch.nn" + ): + torch._dynamo.allowed_functions._allowed_function_ids.add( + id(to_replace) ) - continue + + # Replace all original node uses and clean up graph + n.replace_all_uses_with(new_node) + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + # A replacement can fail in the event that the specific instance of the submodule/function + # cannot be replaced + except Exception: + logger.debug( + f"Encountered error while replacing {to_replace}", + exc_info=True, + ) + continue # Perform cleanup and recompilation before returning module gm.graph.eliminate_dead_code() diff --git a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/__init__.py similarity index 53% rename from py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py rename to py/torch_tensorrt/dynamo/backend/lowering/substitutions/__init__.py index 4b8ba88e34..8d3acc8874 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/__init__.py @@ -1 +1,2 @@ from .maxpool1d import * +from .einsum import * diff --git a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py new file mode 100644 index 0000000000..b65117f5ac --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py @@ -0,0 +1,79 @@ +from typing import Dict, Tuple +import torch +from torch._custom_op import custom_op +from torch.fx.node import Argument, Target + +from torch_tensorrt.fx.converter_registry import tensorrt_converter +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +from torch_tensorrt.dynamo.backend.lowering import register_substitution + + +@custom_op( + "(str equation, Tensor[] tensors) -> Tensor", + ns="tensorrt", +) +def einsum(equation, tensors): + # Defines operator schema, name, namespace, and function header + ... + + +@einsum.impl("cpu") +@einsum.impl("cuda") +def einsum_generic( + *args, + **kwargs, +): + # Defines a converter implementation for AOT Autograd to use for shape analysis/propagation + return torch.einsum( + *args, + **kwargs, + ) + + +@tensorrt_converter(torch.ops.tensorrt.einsum.default) +def aten_ops_einsum( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + # Defines converter replacing the default operator for this function + for input_trt in args[1]: + if not isinstance(input_trt, TRTTensor): + raise RuntimeError(f"Einsum received non-TRTTensor input: {input_trt}") + + einsum_layer = network.add_einsum(inputs=args[1], equation=args[0]) + + set_layer_name(einsum_layer, target, name) + return einsum_layer.get_output(0) + + +@register_substitution(torch.einsum, torch.ops.tensorrt.einsum) +def einsum_insertion_fn( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + _unused: None = None, +) -> torch.fx.Node: + equation = node.args[0] + + # Ensure inputs is a list of (Tensor) arguments + if isinstance(node.args[1], (tuple, list)): + inputs = node.args[1] + else: + inputs = node.args[1:] + + assert ( + 1 <= len(inputs) <= 2 + ), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors" + + # Ensure the input is formatted as an equation and + new_node = gm.graph.call_function( + torch.ops.tensorrt.einsum, + args=(equation, inputs), + kwargs=node.kwargs, + ) + + return new_node diff --git a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py similarity index 89% rename from py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py rename to py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py index 090df11835..61c7a1c48f 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py @@ -7,7 +7,7 @@ from torch_tensorrt.fx.converters import acc_ops_converters from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from torch_tensorrt.dynamo.backend.lowering import module_substitution +from torch_tensorrt.dynamo.backend.lowering import register_substitution @custom_op( @@ -55,9 +55,11 @@ def aten_ops_maxpool1d( ) -@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) +@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) def maxpool1d_insertion_fn( - gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node + gm: torch.fx.GraphModule, + node: torch.fx.Node, + submodule: torch.nn.Module, ) -> torch.fx.Node: # Defines insertion function for new node new_node = gm.graph.call_function( diff --git a/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py index 2fa65bfabc..da44d6e826 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py @@ -51,5 +51,51 @@ def forward(self, x): ) +class TestEinsum(TestCase): + def test_pre_aot_lowering_einsum(self): + class Einsum(torch.nn.Module): + def forward(self, x, y): + return torch.einsum("ij,ji->ij", x, y) + + # Operations expected to be included in the traced graph after decompositions + expected_ops = {torch.ops.tensorrt.einsum.default} + + inputs = [ + torch.rand( + 16, + 16, + ).cuda(), + torch.rand( + 16, + 16, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(Einsum()) + _, expected_ops_unseen = lower_graph_testing( + fx_graph, inputs, expected_ops=expected_ops, min_block_size=1 + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) + self.assertAlmostEqual( + max_diff, 0, f"Einsum TRT outputs don't match with the original model." + ) + + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/dynamo/backend/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py index e7dc435ac4..7c679b7d4d 100644 --- a/py/torch_tensorrt/dynamo/backend/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -9,7 +9,7 @@ partition, ) from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( - pre_aot_module_replacement, + pre_aot_substitutions, ) from torch._dynamo.backends.common import fake_tensor_unsupported @@ -34,7 +34,7 @@ def fx_dynamo_testing_backend( torch_executed_ops=torch_executed_ops, ) - gm = pre_aot_module_replacement(gm) + gm = pre_aot_substitutions(gm) # Invoke AOTAutograd to translate operators to aten return aot_module_simplified(