diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 25db510d3c..37ad31826a 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -9,7 +9,7 @@ get_decompositions, ) from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( - pre_aot_module_replacement, + pre_aot_substitutions, ) from torch_tensorrt.dynamo.backend.lowering._partition import ( partition, @@ -49,7 +49,7 @@ def aot_torch_tensorrt_aten_backend( ) # Perform Pre-AOT Lowering for Module-Level Replacement - gm = pre_aot_module_replacement(gm) + gm = pre_aot_substitutions(gm) # Invoke AOTAutograd to translate operators to aten return aot_module_simplified( diff --git a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py index 1a0cbab2df..dd55d2c83c 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py @@ -2,8 +2,8 @@ get_decompositions, ) from ._pre_aot_lowering import ( - MODULE_SUBSTITUTION_REGISTRY, - module_substitution, + SUBSTITUTION_REGISTRY, + register_substitution, ) from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS -from .module_substitutions import * +from .substitutions import * diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 496c91a089..4d82bf4be5 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -4,7 +4,7 @@ import torch from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE -from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY +from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.graph_module import GraphModule from torch.fx.node import _get_qualified_name @@ -16,8 +16,8 @@ logger = logging.getLogger(__name__) DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set( - _get_qualified_name(module.new_operator) - for module in MODULE_SUBSTITUTION_REGISTRY.values() + _get_qualified_name(to_replace.new_operator) + for to_replace in SUBSTITUTION_REGISTRY.values() ) diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py index 34430be9f0..8a47fc04d2 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Callable, Dict, Type +from typing import Any, Callable, Dict, Optional, Type, Union import torch import logging @@ -8,59 +8,62 @@ @dataclass(frozen=True) -class ModuleReplacement: +class Substitution: """Class to store key functionality for module replacement""" # torch.ops.___ name for replacement function for module new_operator: torch._ops.OpOverload - # Function taking a containing graph, a submodule, and a 'call_module' node and returning - # a replacement node, with type 'call_function', or raising an Error if incompatibility is detected + # Function taking a containing graph, a node, and optionally a submodule (if replacing a module) + # and returning a replacement node, with type 'call_function', or raising an Error if + # incompatibility is detected # Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph subgraph_insertion_fn: Callable[ - [torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node + [torch.fx.GraphModule, torch.fx.Node, Optional[torch.nn.Module]], torch.fx.Node ] -# Dictionary mapping module to ModuleReplacement instance -MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict() +# Dictionary mapping module to Substitution instance +SUBSTITUTION_REGISTRY: Dict[ + Union[Type[torch.nn.Module], Callable], Substitution +] = dict() -def module_substitution( - module_to_replace: Type[torch.nn.Module], +def register_substitution( + module_or_function_to_replace: Union[Type[torch.nn.Module], Callable], new_operator: torch._ops.OpOverload, enabled: bool = True, ) -> Callable[[Any], Any]: """Decorator to register subgraph insertion functions Args: - module_to_replace: nn.Module to replace + module_or_function_to_replace: nn.Module or node target Callable to replace new_operator: Custom torch operator to replace with enabled: Whether the substitution is enabled or disabled Returns: torch.fx.GraphModule """ - def register_substitution(subgraph_insertion_fn): + def enable_substitution(subgraph_insertion_fn): """Function for use if substitution is enabled""" - module_replacement = ModuleReplacement( + replacement = Substitution( new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn ) - MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement + SUBSTITUTION_REGISTRY[module_or_function_to_replace] = replacement return subgraph_insertion_fn def disable_substitution(subgraph_insertion_fn): """Function for use if substitution is disabled""" return subgraph_insertion_fn - return register_substitution if enabled else disable_substitution + return enable_substitution if enabled else disable_substitution -def pre_aot_module_replacement(gm: torch.fx.GraphModule): - """Perform module-level graph replacement prior to AOT tracing +def pre_aot_substitutions(gm: torch.fx.GraphModule): + """Perform graph substitutions prior to AOT tracing Args: - gm: FX GraphModule to perform module replacement on + gm: FX GraphModule to perform substitution on Returns: torch.fx.GraphModule @@ -73,48 +76,58 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule): # Iterate over graph nodes, extracting module calls, to check for interceptions for n in gm.graph.nodes: + exists_in_registry = False + to_replace = None + if n.op == "call_module": - # Extract submodule from graph + # Extract submodule from graph, validate in registry submodule = gm.get_submodule(n.target) - - # If submodule is a member of the substitution registry, replace it - if type(submodule) in MODULE_SUBSTITUTION_REGISTRY: - - try: - replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)] - op, insertion_fn = ( - replacement.new_operator, - replacement.subgraph_insertion_fn, - ) - logger.debug( - f"Replacing module of type {type(submodule)} with {op}" + to_replace = type(submodule) + exists_in_registry = to_replace in SUBSTITUTION_REGISTRY + elif n.op == "call_function": + # Extract function from graph, validate in registry + to_replace = n.target + exists_in_registry = n.target in SUBSTITUTION_REGISTRY + + # If submodule/function is a member of the substitution registry, replace it + if exists_in_registry: + try: + replacement = SUBSTITUTION_REGISTRY[to_replace] + op, insertion_fn = ( + replacement.new_operator, + replacement.subgraph_insertion_fn, + ) + logger.debug(f"Replacing node of type {to_replace} with {op}") + + # Insert new node prior to older node + with gm.graph.inserting_before(n): + new_node = insertion_fn( + gm, n, submodule if n.op == "call_module" else None ) - # Insert new node prior to older node - with gm.graph.inserting_before(n): - new_node = insertion_fn(gm, submodule, n) - - # If submodule is not a native torch.nn module, it must be manually excluded - # from Dynamo tracing - if not type(submodule).__module__.startswith("torch.nn"): - torch._dynamo.allowed_functions._allowed_function_ids.add( - id(type(submodule)) - ) - - # Replace all original node uses and clean up graph - n.replace_all_uses_with(new_node) - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - - # A module replacement can fail in the event that the specific instance of the submodule cannot - # be replaced - except Exception: - logger.debug( - f"Encountered error while replacing {type(submodule)}", - exc_info=True, + # If submodule is not a native torch.nn module, it must be manually excluded + # from Dynamo tracing + if n.op == "call_module" and not type(submodule).__module__.startswith( + "torch.nn" + ): + torch._dynamo.allowed_functions._allowed_function_ids.add( + id(to_replace) ) - continue + + # Replace all original node uses and clean up graph + n.replace_all_uses_with(new_node) + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + # A replacement can fail in the event that the specific instance of the submodule/function + # cannot be replaced + except Exception: + logger.debug( + f"Encountered error while replacing {to_replace}", + exc_info=True, + ) + continue # Perform cleanup and recompilation before returning module gm.graph.eliminate_dead_code() diff --git a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/__init__.py similarity index 53% rename from py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py rename to py/torch_tensorrt/dynamo/backend/lowering/substitutions/__init__.py index 4b8ba88e34..8d3acc8874 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/__init__.py @@ -1 +1,2 @@ from .maxpool1d import * +from .einsum import * diff --git a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py new file mode 100644 index 0000000000..57c4a93e62 --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py @@ -0,0 +1,80 @@ +from typing import Dict, Tuple +import torch +from torch._custom_op.impl import custom_op +from torch.fx.node import Argument, Target + +from torch_tensorrt.fx.converter_registry import tensorrt_converter +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +from torch_tensorrt.dynamo.backend.lowering import register_substitution + + +@custom_op( + qualname="tensorrt::einsum", + manual_schema="(str equation, Tensor[] tensors) -> Tensor", +) +def einsum(equation, tensors): + # Defines operator schema, name, namespace, and function header + ... + + +@einsum.impl("cpu") +@einsum.impl("cuda") +@einsum.impl_abstract() +def einsum_generic( + *args, + **kwargs, +): + # Defines a converter implementation for AOT Autograd to use for shape analysis/propagation + return torch.einsum( + *args, + **kwargs, + ) + + +@tensorrt_converter(torch.ops.tensorrt.einsum.default) +def aten_ops_einsum( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + # Defines converter replacing the default operator for this function + for input_trt in args[1]: + if not isinstance(input_trt, TRTTensor): + raise RuntimeError(f"Einsum received non-TRTTensor input: {input_trt}") + + einsum_layer = network.add_einsum(inputs=args[1], equation=args[0]) + + set_layer_name(einsum_layer, target, name) + return einsum_layer.get_output(0) + + +@register_substitution(torch.einsum, torch.ops.tensorrt.einsum) +def einsum_insertion_fn( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + _unused: None = None, +) -> torch.fx.Node: + equation = node.args[0] + + # Ensure inputs is a list of (Tensor) arguments + if isinstance(node.args[1], (tuple, list)): + inputs = node.args[1] + else: + inputs = node.args[1:] + + assert ( + 1 <= len(inputs) <= 2 + ), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors" + + # Ensure the input is formatted as an equation and + new_node = gm.graph.call_function( + torch.ops.tensorrt.einsum, + args=(equation, inputs), + kwargs=node.kwargs, + ) + + return new_node diff --git a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py similarity index 95% rename from py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py rename to py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py index 38481aadcc..020d3a0ca9 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py @@ -7,7 +7,7 @@ from torch_tensorrt.fx.converters import acc_ops_converters from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from torch_tensorrt.dynamo.backend.lowering import module_substitution +from torch_tensorrt.dynamo.backend.lowering import register_substitution # This file serves as an example and a tutorial for excluding custom modules from @@ -71,9 +71,11 @@ def maxpool1d_generic( # "bias": bias, # ... # -@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) +@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) def maxpool1d_insertion_fn( - gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node + gm: torch.fx.GraphModule, + node: torch.fx.Node, + submodule: torch.nn.Module, ) -> torch.fx.Node: # Defines insertion function for new node new_node = gm.graph.call_function( diff --git a/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py index 2fa65bfabc..da44d6e826 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py @@ -51,5 +51,51 @@ def forward(self, x): ) +class TestEinsum(TestCase): + def test_pre_aot_lowering_einsum(self): + class Einsum(torch.nn.Module): + def forward(self, x, y): + return torch.einsum("ij,ji->ij", x, y) + + # Operations expected to be included in the traced graph after decompositions + expected_ops = {torch.ops.tensorrt.einsum.default} + + inputs = [ + torch.rand( + 16, + 16, + ).cuda(), + torch.rand( + 16, + 16, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(Einsum()) + _, expected_ops_unseen = lower_graph_testing( + fx_graph, inputs, expected_ops=expected_ops, min_block_size=1 + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) + self.assertAlmostEqual( + max_diff, 0, f"Einsum TRT outputs don't match with the original model." + ) + + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/dynamo/backend/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py index e7dc435ac4..7c679b7d4d 100644 --- a/py/torch_tensorrt/dynamo/backend/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -9,7 +9,7 @@ partition, ) from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( - pre_aot_module_replacement, + pre_aot_substitutions, ) from torch._dynamo.backends.common import fake_tensor_unsupported @@ -34,7 +34,7 @@ def fx_dynamo_testing_backend( torch_executed_ops=torch_executed_ops, ) - gm = pre_aot_module_replacement(gm) + gm = pre_aot_substitutions(gm) # Invoke AOTAutograd to translate operators to aten return aot_module_simplified(