From 83a290a410dbf013ff564e943ee44ecf45cb2b9c Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 30 Aug 2023 17:27:17 -0700 Subject: [PATCH 1/3] feat: Add ATen lowering pass system - Add documentation, testing, and lowering pass management systems for ATen lowering passes --- docsrc/index.rst | 1 + examples/dynamo/README.rst | 1 + .../dynamo/dynamo_aten_lowering_passes.py | 97 +++++++++++++++++++ py/torch_tensorrt/dynamo/backend/backends.py | 49 +--------- py/torch_tensorrt/dynamo/lowering/__init__.py | 1 + .../dynamo/lowering/passes/__init__.py | 27 ++++++ .../lowering/passes/constant_folding.py | 56 +++++++++++ .../lowering/passes/repair_input_as_output.py | 53 ++++++++++ setup.py | 2 + .../lowering/test_aten_lowering_passes.py | 59 +++++++++++ tests/py/dynamo/testing_utilities.py | 10 +- 11 files changed, 306 insertions(+), 50 deletions(-) create mode 100644 examples/dynamo/dynamo_aten_lowering_passes.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/__init__.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py create mode 100644 tests/py/dynamo/lowering/test_aten_lowering_passes.py diff --git a/docsrc/index.rst b/docsrc/index.rst index eee62bc2f7..c8d8ede907 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -73,6 +73,7 @@ Tutorials tutorials/_rendered_examples/dynamo/torch_compile_resnet_example tutorials/_rendered_examples/dynamo/torch_compile_transformers_example tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage + tutorials/_rendered_examples/dynamo/dynamo_aten_lowering_passes Python API Documenation ------------------------ diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index fa863952e7..f4eb16fbba 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -9,3 +9,4 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` * :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` * :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API +:ref:`dynamo_aten_lowering_passes`: Custom modifications of a graph of ATen operators via lowering passes diff --git a/examples/dynamo/dynamo_aten_lowering_passes.py b/examples/dynamo/dynamo_aten_lowering_passes.py new file mode 100644 index 0000000000..5b0a795399 --- /dev/null +++ b/examples/dynamo/dynamo_aten_lowering_passes.py @@ -0,0 +1,97 @@ +""" +.. _dynamo_aten_lowering_passes: + +Dynamo ATen Lowering Passes +====================================================== + +This interactive script is intended as an overview of the process by which ATen lowering passes are written and used.""" + +# %% +# 1. Lowering Pass Function +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# An ATen lowering pass function in Torch-TRT must satisfy two requirements: +# - The function must take as input a single `torch.fx.GraphModule` and return the lowered +# `torch.fx.GraphModule` +# - The function must leave the graph in a valid and invoke-able state, including performing any +# necessary linting and recompilation +# +# See below for an example of a lowering pass which repairs graphs that have inputs which are +# also outputs, a disallowed configuration for TRT Engines. + +# %% +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Repair scenarios where inputs are also outputs of the graph + + TRT does not allow such cases, so we insert a clone (identity) layer + """ + modified_graph = False + + # Extract graph placeholder Tensors + placeholders = [ + node + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.Tensor) + ) + ] + + for placeholder in placeholders: + # If any placeholder has any users which are direct graph outputs + if len(placeholder.users) >= 1 and any( + user.op == "output" for user in placeholder.users + ): + modified_graph = True + + # Get direct graph outputs which are direct uses of placeholders + direct_outputs = [user for user in placeholder.users if user.op == "output"] + + # Insert clone node for placeholder to ensure placeholder is not a direct output + with gm.graph.inserting_after(placeholder): + cloned_placeholder = gm.graph.call_function( + torch.ops.aten.clone.default, + args=(placeholder,), + ) + + # Replace placeholder as output with cloned version + for output in direct_outputs: + output.replace_input_with(placeholder, cloned_placeholder) + + # If the graph was modified, clean up the graph and ensure it is up-to-date + if modified_graph: + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}") + + return gm + + +# %% +# 2. Lowering Pass Registration +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# To add a lowering pass, use the convenience function `add_lowering_pass` in the module +# `torch_tensorrt.dynamo.lowering.passes`. See below for an example: + +# %% +from torch_tensorrt.dynamo.lowering.passes import add_lowering_pass + +add_lowering_pass(repair_input_as_output) + +# %% +# 3. Apply Available Lowering Passes +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# To apply all lowering passes to a graph, the convenience function `apply_lowering_passes` in the module +# `torch_tensorrt.dynamo.lowering.passes` can be used. This function is automatically invoked in the Torch-TRT Dynamo +# paths. Additionally, the graph after each modifying pass is logged in the debug logs for Torch-TRT runs. diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 7fde8bbb41..022f3b193d 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -10,24 +10,12 @@ from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import _aot_export_function from torch._ops import OpOverload -from torch_tensorrt._utils import sanitized_torch_version from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.compile import compile_module -from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level -from packaging import version - -# Modify import location of utilities based on Torch version -if version.parse(sanitized_torch_version()) < version.parse("2.1.1"): - from torch._inductor.freezing import ConstantFolder, replace_node_with_constant -else: - from torch._inductor.constant_folding import ( - ConstantFolder, - replace_node_with_constant, - ) - logger = logging.getLogger(__name__) @@ -84,7 +72,7 @@ def _pretraced_backend( fake_mode, "allow_non_fake_inputs", True ), fake_mode: # Invoke AOTAutograd to translate operators to aten - graph_module = aot_export_for_compile( + gm = aot_export_for_compile( gm, sample_inputs, decompositions=get_decompositions( @@ -94,10 +82,10 @@ def _pretraced_backend( logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) - constant_fold(graph_module) + gm = apply_lowering_passes(gm) trt_compiled = compile_module( - graph_module, + gm, sample_inputs, settings=settings, ) @@ -121,35 +109,6 @@ def _pretraced_backend( raise -@torch.utils._python_dispatch._disable_current_modes() # type: ignore -def constant_fold(gm: torch.fx.GraphModule) -> Any: - """Adapted from: - https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197 - - Folds constants in the graph module, not skipping constructors - - Modifies the graph in-place and replaces node with constants - """ - cf = ConstantFolder(gm, skip_constructors=False) - cf.run() - - for node, constant in cf.node_replacements.items(): - replace_node_with_constant(gm, node, constant) - - erased_params = [] - for node in gm.graph.nodes: - if node.op == "get_attr" and len(node.users) == 0: - delattr(gm, node.target) - erased_params.append(node) - - for node in erased_params: - gm.graph.erase_node(node) - - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - - def aot_export_for_compile( func: torch.fx.GraphModule, args: Sequence[torch.Tensor], diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 6eda61a6fd..c83cf4665c 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -2,4 +2,5 @@ from ._fusers import * # noqa: F401 from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401 from ._pre_aot_lowering import register_substitution # noqa: F401 +from .passes import add_lowering_pass, apply_lowering_passes from .substitutions import * # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py new file mode 100644 index 0000000000..445e42cfc9 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py @@ -0,0 +1,27 @@ +from typing import Callable + +import torch +from torch.fx.passes.pass_manager import PassManager + +from .constant_folding import constant_fold +from .repair_input_as_output import repair_input_as_output + +ATEN_LOWERING_PASSES = PassManager.build_from_passlist( + [ + constant_fold, + repair_input_as_output, + ] +) + + +def add_lowering_pass( + lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule] +) -> None: + """Adds a lowering pass to the registry""" + ATEN_LOWERING_PASSES.add_pass(lowering_pass) + return + + +def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Applies the lowering passes to a graph module, returns the modified GraphModule""" + return ATEN_LOWERING_PASSES(gm) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py new file mode 100644 index 0000000000..d17d0a2528 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -0,0 +1,56 @@ +import logging + +import torch +from torch_tensorrt._utils import sanitized_torch_version + +from packaging import version + +# Modify import location of utilities based on Torch version +if version.parse(sanitized_torch_version()) < version.parse("2.1.1"): + from torch._inductor.freezing import ConstantFolder, replace_node_with_constant +else: + from torch._inductor.constant_folding import ( + ConstantFolder, + replace_node_with_constant, + ) + +logger = logging.getLogger(__name__) + + +@torch.utils._python_dispatch._disable_current_modes() # type: ignore +def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Adapted from: + https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197 + + Folds constants in the graph module, not skipping constructors + + Modifies the graph in-place and replaces node with constants + """ + cf = ConstantFolder(gm, skip_constructors=False) + cf.run() + + for node, constant in cf.node_replacements.items(): + replace_node_with_constant(gm, node, constant) + + erased_params = [] + for node in gm.graph.nodes: + # If get_attr node has no users, mark it for deletion + if node.op == "get_attr" and len(node.users) == 0: + # If the node's parameter is not a parameter of any other node, remove it + if not any( + other.target == node.target for other in gm.graph.nodes if other != node + ): + delattr(gm, node.target) + erased_params.append(node) + + # Remove unused nodes from the graph + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + logger.debug(f"Graph after constant folding:\n{gm.graph}") + + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py b/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py new file mode 100644 index 0000000000..6ce846637d --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py @@ -0,0 +1,53 @@ +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Repair scenarios where inputs are also outputs of the graph + + TRT does not allow such cases, so we insert a clone (identity) layer + """ + modified_graph = False + + # Extract graph placeholder Tensors + placeholders = [ + node + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.Tensor) + ) + ] + + for placeholder in placeholders: + # If any placeholder has any users which are direct graph outputs + if len(placeholder.users) >= 1 and any( + user.op == "output" for user in placeholder.users + ): + modified_graph = True + + # Get direct graph outputs which are direct uses of placeholders + direct_outputs = [user for user in placeholder.users if user.op == "output"] + + # Insert clone node for placeholder to ensure placeholder is not a direct output + with gm.graph.inserting_after(placeholder): + cloned_placeholder = gm.graph.call_function( + torch.ops.aten.clone.default, + args=(placeholder,), + ) + + # Replace placeholder as output with cloned version + for output in direct_outputs: + output.replace_input_with(placeholder, cloned_placeholder) + + if modified_graph: + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}") + + return gm diff --git a/setup.py b/setup.py index 6b013daf9e..d02adfb678 100644 --- a/setup.py +++ b/setup.py @@ -392,6 +392,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.unary", "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.substitutions", + "torch_tensorrt.dynamo.lowering.passes", "torch_tensorrt.dynamo.partitioning", "torch_tensorrt.dynamo.runtime", "torch_tensorrt.dynamo.tools", @@ -419,6 +420,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary", "torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering", "torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions", + "torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes", "torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning", "torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime", "torch_tensorrt.dynamo.tools": "py/torch_tensorrt/dynamo/tools", diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py new file mode 100644 index 0000000000..9c547867e2 --- /dev/null +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -0,0 +1,59 @@ +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests + +from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing + + +class TestInputAsOutput(TestCase): + def test_input_as_output(self): + class InputAsOutput(torch.nn.Module): + def forward(self, x, y): + y_new = y + x + 1 + y_new = y_new * 7 + return (y_new, x, y) + + inputs = [ + torch.rand( + 5, + 7, + ).cuda(), + torch.rand( + 5, + 7, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(InputAsOutput()) + lower_graph_testing(fx_graph, inputs, min_block_size=1) + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"InputAsOutput TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index f311f2db2b..af5336813f 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -6,8 +6,8 @@ import torch from torch._dynamo.utils import detect_fake_mode from torch_tensorrt.dynamo import partitioning -from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile, constant_fold -from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions +from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions DECIMALS_OF_AGREEMENT = 4 @@ -40,16 +40,16 @@ def fx_dynamo_testing_backend( fake_mode, "allow_non_fake_inputs", True ), fake_mode: # Invoke AOTAutograd to translate operators to aten - graph_module = aot_export_for_compile( + gm = aot_export_for_compile( gm, sample_inputs, decompositions=get_decompositions(), ) - constant_fold(graph_module) + gm = apply_lowering_passes(gm) trt_compiled = custom_backend( - graph_module, + gm, sample_inputs, ) return trt_compiled From 1a03fc51de7750af5841727ad4886edce7c799e8 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:58:41 -0700 Subject: [PATCH 2/3] feat: Refactor pass manager and utilities - Improve logging and pass manager utilities - Add testing of new utilities --- .../dynamo/dynamo_aten_lowering_passes.py | 16 +++++++ .../dynamo/lowering/passes/__init__.py | 40 +++++++++++++++--- .../dynamo/lowering/passes/pass_manager.py | 42 +++++++++++++++++++ .../lowering/test_aten_lowering_passes.py | 38 +++++++++++++++++ 4 files changed, 130 insertions(+), 6 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py diff --git a/examples/dynamo/dynamo_aten_lowering_passes.py b/examples/dynamo/dynamo_aten_lowering_passes.py index 5b0a795399..ffe7083576 100644 --- a/examples/dynamo/dynamo_aten_lowering_passes.py +++ b/examples/dynamo/dynamo_aten_lowering_passes.py @@ -86,8 +86,24 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: # %% from torch_tensorrt.dynamo.lowering.passes import add_lowering_pass +# Adds the lowering pass at the end of the pass list add_lowering_pass(repair_input_as_output) +# Alternatively, specify an index to insert the lowering pass at a specific location +add_lowering_pass(repair_input_as_output, 1) + +# To remove a lowering pass, specify the index of the pass to remove: +from torch_tensorrt.dynamo.lowering.passes import remove_lowering_pass + +# Removes the lowering pass at index 1 +remove_lowering_pass(1) + + +# To view all lowering passes, in the order they will be run, use the following +from torch_tensorrt.dynamo.lowering.passes import dump_lowering_passes + +print(dump_lowering_passes()) + # %% # 3. Apply Available Lowering Passes # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py index 445e42cfc9..ff549b73a8 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py @@ -1,27 +1,55 @@ -from typing import Callable +import logging +from typing import Callable, Optional import torch -from torch.fx.passes.pass_manager import PassManager +# Import and order lowering passes and pass manager from .constant_folding import constant_fold +from .pass_manager import DynamoPassManager from .repair_input_as_output import repair_input_as_output -ATEN_LOWERING_PASSES = PassManager.build_from_passlist( +ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ constant_fold, repair_input_as_output, ] ) +logger = logging.getLogger(__name__) + def add_lowering_pass( - lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule] + lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule], + index: Optional[int] = None, ) -> None: - """Adds a lowering pass to the registry""" - ATEN_LOWERING_PASSES.add_pass(lowering_pass) + """Adds a lowering pass to the registry, at a specified index if desired + + If no index is specified, the lowering pass is inserted at the end of the list + """ + ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) + logger.debug( + f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}" + ) + return + + +def remove_lowering_pass(index: int) -> None: + """Removes a lowering pass at a specific index from the registry""" + ATEN_LOWERING_PASSES.remove_pass_with_index(index) + logger.debug( + f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}" + ) return def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """Applies the lowering passes to a graph module, returns the modified GraphModule""" + logging.debug( + f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}" + ) return ATEN_LOWERING_PASSES(gm) + + +def dump_lowering_passes() -> str: + """Returns a string containing the lowering passes""" + return str(ATEN_LOWERING_PASSES) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py new file mode 100644 index 0000000000..51e2584364 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py @@ -0,0 +1,42 @@ +from typing import Any, Callable, List, Optional + +import torch +from torch.fx.passes.pass_manager import PassManager + + +class DynamoPassManager(PassManager): # type: ignore[misc] + def __init__( + self, + passes: Optional[ + List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]] + ] = None, + ): + super().__init__(passes) + + @classmethod + def build_from_passlist( + cls, + passes: Optional[List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]], + ) -> Any: + pm = DynamoPassManager(passes) + return pm + + def add_pass_with_index( + self, + lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule], + index: Optional[int] = None, + ) -> None: + if index is None: + self.passes.append(lowering_pass) + index = -1 + else: + self.passes.insert(index, lowering_pass) + + def remove_pass_with_index(self, index: int) -> None: + del self.passes[index] + + def __call__(self, source: Any) -> Any: + return super().__call__(source) + + def __str__(self) -> str: + return str(self.passes) diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 9c547867e2..c7dd0ecdb2 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -55,5 +55,43 @@ def forward(self, x, y): torch._dynamo.reset() +class TestLoweringPassMembership(TestCase): + def insert_at_end(self): + from torch_tensorrt.dynamo.lowering.passes import ( + ATEN_LOWERING_PASSES, + add_lowering_pass, + remove_lowering_pass, + ) + + def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + return gm + + add_lowering_pass(identity_pass) + + self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[-1]) + + remove_lowering_pass(-1) + + self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes) + + def insert_at_index(self): + from torch_tensorrt.dynamo.lowering.passes import ( + ATEN_LOWERING_PASSES, + add_lowering_pass, + remove_lowering_pass, + ) + + def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + return gm + + add_lowering_pass(identity_pass, 0) + + self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[0]) + + remove_lowering_pass(0) + + self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes) + + if __name__ == "__main__": run_tests() From 7fa0a0cc395a2fad14a26a6d870f7276004e85bd Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 19 Sep 2023 19:27:02 -0700 Subject: [PATCH 3/3] fix: Address review comments and add upgrades --- .../writing_dynamo_aten_lowering_passes.rst | 109 +++++++++++++++++ docsrc/index.rst | 3 +- examples/dynamo/README.rst | 1 - .../dynamo/dynamo_aten_lowering_passes.py | 113 ------------------ py/torch_tensorrt/dynamo/aten_tracer.py | 5 +- py/torch_tensorrt/dynamo/lowering/__init__.py | 2 +- .../dynamo/lowering/passes/__init__.py | 56 +-------- .../lowering/passes/_aten_lowering_pass.py | 76 ++++++++++++ .../lowering/test_aten_lowering_passes.py | 18 ++- 9 files changed, 199 insertions(+), 184 deletions(-) create mode 100644 docsrc/contributors/writing_dynamo_aten_lowering_passes.rst delete mode 100644 examples/dynamo/dynamo_aten_lowering_passes.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py diff --git a/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst b/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst new file mode 100644 index 0000000000..d64f81d4aa --- /dev/null +++ b/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst @@ -0,0 +1,109 @@ +.. _writing_dynamo_aten_lowering_passes: + +Writing Dynamo ATen Lowering Passes +=================== + +Basics of a Lowering Pass +------------ + +ATen lowering passes are Python functions which take as input a graph of ATen operators, apply some desired modification such as operator coalescing/fusion, operator replacement, subgraph rewriting, custom operator insertion, or other operation on a `torch.fx.GraphModule`, then return the modified graph to the caller. These lowering passes generally modify the graph in-place and return the same input object. + +Lowering Pass Requirements +------------ + +An ATen lowering pass function in Torch-TRT must satisfy two requirements: +- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule` +- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation + +See this link for information on `Graph Manipulations `_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines. + +Example Lowering Pass +------------ + +.. code-block:: python + + def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Repair scenarios where inputs are also outputs of the graph + + TRT does not allow such cases, so we insert a clone (identity) layer + """ + modified_graph = False + + # Extract graph placeholder Tensors + placeholders = [ + node + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.Tensor) + ) + ] + + for placeholder in placeholders: + # If any placeholder has any users which are direct graph outputs + if len(placeholder.users) >= 1 and any( + user.op == "output" for user in placeholder.users + ): + modified_graph = True + + # Get direct graph outputs which are direct uses of placeholders + direct_outputs = [user for user in placeholder.users if user.op == "output"] + + # Insert clone node for placeholder to ensure + # placeholder is not a direct output + with gm.graph.inserting_after(placeholder): + cloned_placeholder = gm.graph.call_function( + torch.ops.aten.clone.default, + args=(placeholder,), + ) + + # Replace placeholder as output with cloned version + for output in direct_outputs: + output.replace_input_with(placeholder, cloned_placeholder) + + # If the graph was modified, clean up the graph and ensure it is up-to-date + if modified_graph: + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}") + + return gm + + +Registering Lowering Passes +---------------------- + +Lowering passes are currently registered in `py/torch_tensorrt/dynamo/lowering/passes/__init__.py`, using the `torch.fx.passes.pass_manager.PassManager` utility to assemble the list of passes in a desired order. New passes added directly to that list will be applied to graphs in the Torch-TensorRT `torch.compile` backend. Currently, we offer an ATen lowering pass registration decorator for convenience, which can be invoked either directly, or with the optional `index` keyword argument which controls where in the pass list the lowering pass will be inserted. + +For instance, to insert the pass at the default location (end of the list), the following code can be used: + +.. code-block:: python + + @_aten_lowering_pass + def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + ... + +Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used: + +.. code-block:: python + + @_aten_lowering_pass(index=0) + def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + ... + +There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index. + +.. code-block:: python + + # Print all lowering passes in the list + print(dump_lowering_passes()) + + # Apply lowering passes to a GraphModule + apply_lowering_passes(graph_module) + + # Remove the lowering pass at index 1 + _remove_lowering_pass(index=1) + +**Note:** The above APIs are subject to change, as the lowering pass system evolves. diff --git a/docsrc/index.rst b/docsrc/index.rst index c8d8ede907..ded3b99c9d 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -73,7 +73,6 @@ Tutorials tutorials/_rendered_examples/dynamo/torch_compile_resnet_example tutorials/_rendered_examples/dynamo/torch_compile_transformers_example tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage - tutorials/_rendered_examples/dynamo/dynamo_aten_lowering_passes Python API Documenation ------------------------ @@ -129,6 +128,7 @@ Contributor Documentation -------------------------------- * :ref:`system_overview` * :ref:`writing_converters` +* :ref:`writing_dynamo_aten_lowering_passes` * :ref:`useful_links` .. toctree:: @@ -138,6 +138,7 @@ Contributor Documentation contributors/system_overview contributors/writing_converters + contributors/writing_dynamo_aten_lowering_passes contributors/useful_links Indices diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index f4eb16fbba..fa863952e7 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -9,4 +9,3 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` * :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` * :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API -:ref:`dynamo_aten_lowering_passes`: Custom modifications of a graph of ATen operators via lowering passes diff --git a/examples/dynamo/dynamo_aten_lowering_passes.py b/examples/dynamo/dynamo_aten_lowering_passes.py deleted file mode 100644 index ffe7083576..0000000000 --- a/examples/dynamo/dynamo_aten_lowering_passes.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -.. _dynamo_aten_lowering_passes: - -Dynamo ATen Lowering Passes -====================================================== - -This interactive script is intended as an overview of the process by which ATen lowering passes are written and used.""" - -# %% -# 1. Lowering Pass Function -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# An ATen lowering pass function in Torch-TRT must satisfy two requirements: -# - The function must take as input a single `torch.fx.GraphModule` and return the lowered -# `torch.fx.GraphModule` -# - The function must leave the graph in a valid and invoke-able state, including performing any -# necessary linting and recompilation -# -# See below for an example of a lowering pass which repairs graphs that have inputs which are -# also outputs, a disallowed configuration for TRT Engines. - -# %% -import logging - -import torch - -logger = logging.getLogger(__name__) - - -def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - """Repair scenarios where inputs are also outputs of the graph - - TRT does not allow such cases, so we insert a clone (identity) layer - """ - modified_graph = False - - # Extract graph placeholder Tensors - placeholders = [ - node - for node in gm.graph.nodes - if ( - node.op == "placeholder" - and isinstance(node.type, type) - and issubclass(node.type, torch.Tensor) - ) - ] - - for placeholder in placeholders: - # If any placeholder has any users which are direct graph outputs - if len(placeholder.users) >= 1 and any( - user.op == "output" for user in placeholder.users - ): - modified_graph = True - - # Get direct graph outputs which are direct uses of placeholders - direct_outputs = [user for user in placeholder.users if user.op == "output"] - - # Insert clone node for placeholder to ensure placeholder is not a direct output - with gm.graph.inserting_after(placeholder): - cloned_placeholder = gm.graph.call_function( - torch.ops.aten.clone.default, - args=(placeholder,), - ) - - # Replace placeholder as output with cloned version - for output in direct_outputs: - output.replace_input_with(placeholder, cloned_placeholder) - - # If the graph was modified, clean up the graph and ensure it is up-to-date - if modified_graph: - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}") - - return gm - - -# %% -# 2. Lowering Pass Registration -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# To add a lowering pass, use the convenience function `add_lowering_pass` in the module -# `torch_tensorrt.dynamo.lowering.passes`. See below for an example: - -# %% -from torch_tensorrt.dynamo.lowering.passes import add_lowering_pass - -# Adds the lowering pass at the end of the pass list -add_lowering_pass(repair_input_as_output) - -# Alternatively, specify an index to insert the lowering pass at a specific location -add_lowering_pass(repair_input_as_output, 1) - -# To remove a lowering pass, specify the index of the pass to remove: -from torch_tensorrt.dynamo.lowering.passes import remove_lowering_pass - -# Removes the lowering pass at index 1 -remove_lowering_pass(1) - - -# To view all lowering passes, in the order they will be run, use the following -from torch_tensorrt.dynamo.lowering.passes import dump_lowering_passes - -print(dump_lowering_passes()) - -# %% -# 3. Apply Available Lowering Passes -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# To apply all lowering passes to a graph, the convenience function `apply_lowering_passes` in the module -# `torch_tensorrt.dynamo.lowering.passes` can be used. This function is automatically invoked in the Torch-TRT Dynamo -# paths. Additionally, the graph after each modifying pass is logged in the debug logs for Torch-TRT runs. diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py index 32225e79fc..b271d0d6fb 100644 --- a/py/torch_tensorrt/dynamo/aten_tracer.py +++ b/py/torch_tensorrt/dynamo/aten_tracer.py @@ -6,8 +6,7 @@ import torch from torch._export import export -from torch_tensorrt.dynamo.backend.backends import constant_fold -from torch_tensorrt.dynamo.lowering import get_decompositions +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.utils import set_log_level logger = logging.getLogger(__name__) @@ -29,6 +28,6 @@ def trace( "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) ): graph_module = export(model, tuple(inputs)).module() - constant_fold(graph_module) + graph_module = apply_lowering_passes(graph_module) logger.debug("Post export graph: " + str(graph_module.graph)) return graph_module diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index c83cf4665c..34faa1d11b 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -2,5 +2,5 @@ from ._fusers import * # noqa: F401 from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401 from ._pre_aot_lowering import register_substitution # noqa: F401 -from .passes import add_lowering_pass, apply_lowering_passes +from .passes import apply_lowering_passes from .substitutions import * # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py index ff549b73a8..ea393fab14 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py @@ -1,55 +1 @@ -import logging -from typing import Callable, Optional - -import torch - -# Import and order lowering passes and pass manager -from .constant_folding import constant_fold -from .pass_manager import DynamoPassManager -from .repair_input_as_output import repair_input_as_output - -ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( - [ - constant_fold, - repair_input_as_output, - ] -) - -logger = logging.getLogger(__name__) - - -def add_lowering_pass( - lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule], - index: Optional[int] = None, -) -> None: - """Adds a lowering pass to the registry, at a specified index if desired - - If no index is specified, the lowering pass is inserted at the end of the list - """ - ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) - logger.debug( - f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}" - ) - return - - -def remove_lowering_pass(index: int) -> None: - """Removes a lowering pass at a specific index from the registry""" - ATEN_LOWERING_PASSES.remove_pass_with_index(index) - logger.debug( - f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}" - ) - return - - -def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - """Applies the lowering passes to a graph module, returns the modified GraphModule""" - logging.debug( - f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}" - ) - return ATEN_LOWERING_PASSES(gm) - - -def dump_lowering_passes() -> str: - """Returns a string containing the lowering passes""" - return str(ATEN_LOWERING_PASSES) +from ._aten_lowering_pass import * diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py new file mode 100644 index 0000000000..a4c7fad607 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -0,0 +1,76 @@ +import logging +from typing import Callable, Optional + +import torch + +from .constant_folding import constant_fold +from .pass_manager import DynamoPassManager +from .repair_input_as_output import repair_input_as_output + +ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( + [ + constant_fold, + repair_input_as_output, + ] +) + +logger = logging.getLogger(__name__) + + +LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule] + + +def _aten_lowering_pass( + *args: LoweringPassSignature, + index: Optional[int] = None, +) -> LoweringPassSignature: + """Adds a lowering pass to the registry, at a specified index if desired + + If no index is specified, the lowering pass is inserted at the end of the list + """ + + def add_lowering_pass( + lowering_pass: LoweringPassSignature, + ) -> LoweringPassSignature: + ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) + logger.debug( + f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}" + ) + return lowering_pass + + # If there are arguments specified, the decorator may have been called as-is + if args: + # The decorator may only be called with the lowering pass + # The index must be specified as a keyword argument + if len(args) == 1 and callable(args[0]): + return add_lowering_pass(args[0]) + else: + raise AssertionError( + f"aten_lowering_pass decorator called with invalid arguments {args} " + "To specify an index to insert the pass, use the keyword 'index='" + ) + # If no arguments are specified, the decorator was called with an index keyword + else: + return add_lowering_pass + + +def _remove_lowering_pass(*, index: int) -> None: + """Removes a lowering pass at a specific index from the registry""" + ATEN_LOWERING_PASSES.remove_pass_with_index(index) + logger.debug( + f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}" + ) + return + + +def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Applies the lowering passes to a graph module, returns the modified GraphModule""" + logging.debug( + f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}" + ) + return ATEN_LOWERING_PASSES(gm) + + +def dump_lowering_passes() -> str: + """Returns a string containing the lowering passes""" + return str(ATEN_LOWERING_PASSES) diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index c7dd0ecdb2..a63c5e3439 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -59,36 +59,34 @@ class TestLoweringPassMembership(TestCase): def insert_at_end(self): from torch_tensorrt.dynamo.lowering.passes import ( ATEN_LOWERING_PASSES, - add_lowering_pass, - remove_lowering_pass, + _aten_lowering_pass, + _remove_lowering_pass, ) + @_aten_lowering_pass def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: return gm - add_lowering_pass(identity_pass) - self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[-1]) - remove_lowering_pass(-1) + _remove_lowering_pass(-1) self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes) def insert_at_index(self): from torch_tensorrt.dynamo.lowering.passes import ( ATEN_LOWERING_PASSES, - add_lowering_pass, - remove_lowering_pass, + _aten_lowering_pass, + _remove_lowering_pass, ) + @_aten_lowering_pass(index=0) def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: return gm - add_lowering_pass(identity_pass, 0) - self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[0]) - remove_lowering_pass(0) + _remove_lowering_pass(0) self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes)