diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ac24623eef..4738ea80be 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: rev: 'v1.4.1' hooks: - id: mypy - exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools|^docs|noxfile.py|setup.py|versions.py" + exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.0.278 diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index b30da1ffb8..2467fdd6ae 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -15,7 +15,6 @@ get_decompositions, repair_input_aliasing, ) -from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions from torch_tensorrt.dynamo.utils import ( parse_dynamo_kwargs, prepare_inputs, @@ -68,9 +67,6 @@ def _pretraced_backend( try: logger.debug("Pre-AOT Autograd graph:\n" + str(gm.graph)) - # Perform Pre-AOT Lowering for Module-Level Replacement - gm = pre_aot_substitutions(gm) - fake_mode = detect_fake_mode(sample_inputs) # Place backend tracing within FakeTensor context allowing nonfake Tensors diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 2b67ef0c91..1fbe0cd120 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -1,7 +1,4 @@ from ._decompositions import get_decompositions # noqa: F401 from ._fusers import * # noqa: F401 -from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401 -from ._pre_aot_lowering import register_substitution # noqa: F401 from ._repair_input_aliasing import repair_input_aliasing from .passes import apply_lowering_passes -from .substitutions import * # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py deleted file mode 100644 index 70cc5424af..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py +++ /dev/null @@ -1,145 +0,0 @@ -from __future__ import annotations - -import logging -from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, Type - -import torch -from torch._ops import OpOverload -from torch.fx import GraphModule, Node -from typing_extensions import TypeAlias - -logger = logging.getLogger(__name__) - -SubgraphInsertionFnType: TypeAlias = Callable[ - [GraphModule, Node, Optional[torch.nn.Module]], Node -] - - -@dataclass(frozen=True) -class Substitution: - """Class to store key functionality for module replacement""" - - # torch.ops.___ name for replacement function for module - new_operator: torch._ops.OpOverload - - # Function taking a containing graph, a node, and optionally a submodule (if replacing a module) - # and returning a replacement node, with type 'call_function', or raising an Error if - # incompatibility is detected - # Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph - subgraph_insertion_fn: SubgraphInsertionFnType - - -# Dictionary mapping module to Substitution instance -SUBSTITUTION_REGISTRY: Dict[ - (Type[torch.nn.Module] | Callable[..., Any]), Substitution -] = dict() - - -def register_substitution( - module_or_function_to_replace: (Type[torch.nn.Module] | Callable[..., Any]), - new_operator: OpOverload, - enabled: bool = True, -) -> Callable[[SubgraphInsertionFnType], SubgraphInsertionFnType]: - """Decorator to register subgraph insertion functions - - Args: - module_or_function_to_replace: nn.Module or node target Callable to replace - new_operator: Custom torch operator to replace with - enabled: Whether the substitution is enabled or disabled - Returns: - torch.fx.GraphModule - """ - - def enable_substitution( - subgraph_insertion_fn: SubgraphInsertionFnType, - ) -> SubgraphInsertionFnType: - """Function for use if substitution is enabled""" - replacement = Substitution( - new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn - ) - SUBSTITUTION_REGISTRY[module_or_function_to_replace] = replacement - return subgraph_insertion_fn - - def disable_substitution( - subgraph_insertion_fn: SubgraphInsertionFnType, - ) -> SubgraphInsertionFnType: - """Function for use if substitution is disabled""" - return subgraph_insertion_fn - - return enable_substitution if enabled else disable_substitution - - -def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - """Perform graph substitutions prior to AOT tracing - - Args: - gm: FX GraphModule to perform substitution on - Returns: - torch.fx.GraphModule - - """ - logger.debug("Pre-module replacement graph:\n" + str(gm.graph)) - - # Iterate over graph nodes, extracting module calls, to check for interceptions - for n in gm.graph.nodes: - exists_in_registry = False - to_replace = None - - if n.op == "call_module": - # Extract submodule from graph, validate in registry - submodule = gm.get_submodule(n.target) - to_replace = type(submodule) - exists_in_registry = to_replace in SUBSTITUTION_REGISTRY - elif n.op == "call_function": - # Extract function from graph, validate in registry - to_replace = n.target - exists_in_registry = n.target in SUBSTITUTION_REGISTRY - - # If submodule/function is a member of the substitution registry, replace it - if exists_in_registry: - try: - assert to_replace is not None - replacement = SUBSTITUTION_REGISTRY[to_replace] - op, insertion_fn = ( - replacement.new_operator, - replacement.subgraph_insertion_fn, - ) - logger.debug(f"Replacing node of type {to_replace} with {op}") - - # Insert new node prior to older node - with gm.graph.inserting_before(n): - new_node = insertion_fn( - gm, n, submodule if n.op == "call_module" else None - ) - - # If submodule is not a native torch.nn module, it must be manually excluded - # from Dynamo tracing - if n.op == "call_module" and not type(submodule).__module__.startswith( - "torch.nn" - ): - torch._dynamo.allowed_functions._allowed_function_ids.add( - id(to_replace) - ) - - # Replace all original node uses and clean up graph - n.replace_all_uses_with(new_node) - gm.graph.lint() - gm.recompile() - - # A replacement can fail in the event that the specific instance of the submodule/function - # cannot be replaced - except Exception: - logger.debug( - f"Encountered error while replacing {to_replace}", - exc_info=True, - ) - continue - - # Perform cleanup and recompilation before returning module - gm.graph.lint() - gm.recompile() - - logger.debug("Post-module replacement graph:\n" + str(gm.graph)) - - return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index ffbe1c7f44..a883018c5e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -9,6 +9,7 @@ from .pass_manager import DynamoPassManager from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output +from .replace_max_pool_with_indices import replace_max_pool_with_indices ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ @@ -17,6 +18,7 @@ repair_input_as_output, lower_efficient_attention, fuse_prims_broadcast, + replace_max_pool_with_indices, ] ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py new file mode 100644 index 0000000000..75395d6435 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py @@ -0,0 +1,60 @@ +import logging +import operator +from typing import Sequence + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def replace_max_pool_with_indices( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Replace MaxPool nodes which return unused indices""" + replacement_dict = { + torch.ops.aten.max_pool1d_with_indices.default: torch.ops.aten.max_pool1d.default, + torch.ops.aten.max_pool2d_with_indices.default: torch.ops.aten.max_pool2d.default, + torch.ops.aten.max_pool3d_with_indices.default: torch.ops.aten.max_pool3d.default, + } + + modified_graph = False + + for node in gm.graph.nodes: + # If the node is a placeholder and its only user is a clone node + # it was modified by the input alias-fixing pass, and the change + # needs to be undone + if ( + node.target in replacement_dict + and len(node.users) == 1 + and list(node.users)[0].target == operator.getitem + and list(node.users)[0].args[1] == 0 + ): + modified_graph = True + + # Replace all uses of the clone with the placholder, delete the clone + getitem_node = list(node.users)[0] + + with gm.graph.inserting_after(getitem_node): + maxpool_fused = gm.graph.call_function( + replacement_dict[node.target], + args=node.args, + kwargs=node.kwargs, + ) + + logger.debug( + f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} " + f"is the only user of placeholder {node} and was inserted by the compiler." + ) + + getitem_node.replace_all_uses_with(maxpool_fused) + gm.graph.erase_node(getitem_node) + gm.graph.erase_node(node) + + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after fusing maxpool operators with indices:\n{gm.graph}") + + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py b/py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py deleted file mode 100644 index bd348b3e47..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .einsum import * # noqa: F403 -from .maxpool1d import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py b/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py deleted file mode 100644 index ea44a88be5..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Any, Dict, Optional, Sequence, Tuple - -import torch -import torch._custom_ops as library -from torch.fx.node import Argument, Target -from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution -from torch_tensorrt.fx.converter_registry import tensorrt_converter -from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor - -library.custom_op( - "tensorrt::einsum", - "(str equation, Tensor[] tensors) -> Tensor", -) - - -@library.impl("tensorrt::einsum") # type: ignore[misc] -@library.impl_abstract("tensorrt::einsum") # type: ignore[misc] -def einsum_generic( - *args: Any, - **kwargs: Any, -) -> Any: - # Defines a converter implementation for AOT Autograd to use for shape analysis/propagation - return torch.einsum( - *args, - **kwargs, - ) - - -# TODO: @gs-olive Port to dynamo converter -@tensorrt_converter(torch.ops.tensorrt.einsum.default) # type: ignore[misc] -def aten_ops_einsum( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> TRTTensor: - # Defines converter replacing the default operator for this function - assert isinstance(args[1], Sequence) - for input_trt in args[1]: - if not isinstance(input_trt, TRTTensor): - raise RuntimeError(f"Einsum received non-TRTTensor input: {input_trt}") - - einsum_layer = network.add_einsum(inputs=args[1], equation=args[0]) - - set_layer_name(einsum_layer, target, name) - return einsum_layer.get_output(0) - - -@register_substitution(torch.einsum, torch.ops.tensorrt.einsum) # type: ignore[misc] -def einsum_insertion_fn( - gm: torch.fx.GraphModule, - node: torch.fx.Node, - submodule: Optional[torch.nn.Module] = None, -) -> torch.fx.Node: - equation = node.args[0] - - # Ensure inputs is a list of (Tensor) arguments - if isinstance(node.args[1], (tuple, list)): - inputs = node.args[1] - else: - inputs = node.args[1:] - - assert ( - 1 <= len(inputs) <= 2 - ), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors" - - # Ensure the input is formatted as an equation and - new_node: torch.fx.Node = gm.graph.call_function( - torch.ops.tensorrt.einsum, - args=(equation, inputs), - kwargs=node.kwargs, - ) - - return new_node diff --git a/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py b/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py deleted file mode 100644 index 8bb44ac8e0..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py +++ /dev/null @@ -1,124 +0,0 @@ -from typing import Any, Dict, Optional, Tuple - -import torch -import torch._custom_ops as library -from torch.fx.node import Argument, Target -from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution -from torch_tensorrt.fx.converter_registry import tensorrt_converter -from torch_tensorrt.fx.converters import acc_ops_converters -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor - -# This file serves as an example and a tutorial for excluding custom modules from -# torch.compile tracing. Each required step is labeled with a number indicating the -# preferable implementation order. - - -# 1. The Placeholder -# -# Specify the schema and namespace of the operator, as well as a placeholder function -# representing the schema. The schema should be in torch JIT syntax, indicating input and output -# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op -# Then, create a placeholder function with no operations, but having the same schema and naming as that -# used in the decorator -library.custom_op( - "tensorrt::maxpool1d", - "(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor", -) - - -# 2. The Generic Implementation -# -# Define the default implementation of the operator in torch syntax. This is used for autograd -# and other tracing functionality. Generally, the torch.nn.functional analog of the operator to replace -# is desirable. If the operator to replace is a custom module you've written, then add its Torch -# implementation here. Note that the function header to the generic function can have specific arguments -# as in the above placeholder -@library.impl("tensorrt::maxpool1d") # type: ignore[misc] -@library.impl_abstract("tensorrt::maxpool1d") # type: ignore[misc] -def maxpool1d_generic( - *args: Any, - **kwargs: Any, -) -> Any: - # Defines an implementation for AOT Autograd to use for shape analysis/propagation - return torch.nn.functional.max_pool1d( - *args, - **kwargs, - ) - - -# 3. The Module Substitution Function -# -# Define a function which can intercept a node of the kind to be replaced, extract -# the relevant data from that node/submodule, and then re-package the information -# for use by an accelerated implementation (to be implemented in step 4). This function -# should use the operator defined in step 1 (for example torch.ops.tensorrt.maxpool1d). -# It should refactor the args and kwargs as is needed by the accelerated implementation. -# -# If the submodule has weights or other Tensor fields which the accelerated implementation -# needs, the function should insert the necessary nodes to access those weights. For example, -# if the weight Tensor of a submodule is needed, one could write: -# -# weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor) -# bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor) -# ... -# kwargs={"weight": weights, -# "bias": bias, -# ... -# -@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) # type: ignore[misc] -def maxpool1d_insertion_fn( - gm: torch.fx.GraphModule, - node: torch.fx.Node, - submodule: Optional[torch.nn.Module], -) -> torch.fx.Node: - # Defines insertion function for new node - assert submodule is not None - new_node: torch.fx.Node = gm.graph.call_function( - torch.ops.tensorrt.maxpool1d, - args=node.args, - kwargs={ - "kernel_size": submodule.kernel_size, - "stride": submodule.stride, - "padding": submodule.padding, - "dilation": submodule.dilation, - "ceil_mode": submodule.ceil_mode, - }, - ) - - return new_node - - -# 4. The Accelerated Implementation -# -# Define an accelerated implementation of the operator, and register it as necessary. -# This accelerated implementation should consume the args/kwargs specified in step 3. -# One should expect that torch.compile will compress all kwargs into the args field in -# the order specified in the schema written in step 1. -@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) # type: ignore[misc] -def tensorrt_maxpool1d( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> TRTTensor: - # Defines converter replacing the default operator for this function - kwargs_new = { - "input": args[0], - "kernel_size": args[1], - "stride": args[2], - "padding": args[3], - "dilation": args[4], - "ceil_mode": False if len(args) < 6 else args[5], - } - - return acc_ops_converters.acc_ops_max_pool1d( - network, target, None, kwargs_new, name - ) - - -# 5. Add Imports -# -# Add your accelerated module file to the __init__.py in this directory, to ensure -# all registrations are run. For instance, if the new module file is called new_mod.py, -# one should add `from .new_mod import *` to the __init__.py diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 5399bc5d6f..3f20c7efc8 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -22,8 +22,6 @@ ) from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry -from .common import DEFAULT_SINGLE_NODE_PARTITIONS - logger = logging.getLogger(__name__) @@ -107,9 +105,7 @@ def __init__( self, module: torch.fx.GraphModule, operator_support: ops.OperatorSupportBase, - allowed_single_node_partition_ops: Optional[ - Collection[str] - ] = DEFAULT_SINGLE_NODE_PARTITIONS, + allowed_single_node_partition_ops: Optional[Collection[str]] = None, min_block_size: int = MIN_BLOCK_SIZE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, ): diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 19fccfc73f..49ee02b4cf 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -15,8 +15,6 @@ ) from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry -from .common import DEFAULT_SINGLE_NODE_PARTITIONS - logger = logging.getLogger(__name__) @@ -41,9 +39,7 @@ def __init__( operator_support: OperatorSupport, *, non_compute_ops: Optional[Sequence[str]] = None, - allowed_single_node_partition_ops: Optional[ - Collection[str] - ] = DEFAULT_SINGLE_NODE_PARTITIONS, + allowed_single_node_partition_ops: Optional[Collection[str]] = None, min_block_size: int = MIN_BLOCK_SIZE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, ) -> None: @@ -95,8 +91,8 @@ def propose_partitions(self) -> List[Partition]: compute_node_count = 0 for node in partition.nodes: # Partitions are exempted from min_block_size if they contain an allowed single-node op - if ( - node.op == "call_function" + if node.op == "call_function" and ( + self.allowed_single_node_partition_ops is not None and ConverterRegistry.qualified_name_or_str(node.target) in self.allowed_single_node_partition_ops ): diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 14c068260f..bbc6f92af1 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -2,19 +2,12 @@ from typing import Any, Optional, Sequence, Set, Tuple import torch -from torch.fx.node import _get_qualified_name from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import DEBUG -from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY from torch_tensorrt.dynamo.utils import get_torch_inputs, input_is_dynamic logger = logging.getLogger(__name__) -DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = { - _get_qualified_name(to_replace.new_operator) - for to_replace in SUBSTITUTION_REGISTRY.values() -} - def get_submod_inputs( mod: torch.fx.GraphModule, diff --git a/setup.py b/setup.py index d02adfb678..82f1ac42f7 100644 --- a/setup.py +++ b/setup.py @@ -391,7 +391,6 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.slice", "torch_tensorrt.dynamo.conversion.impl.unary", "torch_tensorrt.dynamo.lowering", - "torch_tensorrt.dynamo.lowering.substitutions", "torch_tensorrt.dynamo.lowering.passes", "torch_tensorrt.dynamo.partitioning", "torch_tensorrt.dynamo.runtime", @@ -419,7 +418,6 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.slice": "py/torch_tensorrt/dynamo/conversion/impl/slice", "torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary", "torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering", - "torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions", "torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes", "torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning", "torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime", diff --git a/tests/py/dynamo/backend/test_pre_aot_lowering.py b/tests/py/dynamo/backend/test_pre_aot_lowering.py deleted file mode 100644 index 2ea02691bd..0000000000 --- a/tests/py/dynamo/backend/test_pre_aot_lowering.py +++ /dev/null @@ -1,110 +0,0 @@ -import torch -import torch_tensorrt -from torch.testing._internal.common_utils import TestCase, run_tests - -from ..testing_utilities import lower_graph_testing - - -class TestMaxPool1D(TestCase): - def test_pre_aot_lowering_maxpool1d(self): - class MaxPool1D(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.maxpool = torch.nn.MaxPool1d(2) - - def forward(self, x): - return self.maxpool(x) - - # Operations expected to be included in the traced graph after decompositions - expected_ops = {torch.ops.tensorrt.maxpool1d.default} - - inputs = [ - torch.rand( - 9, - 16, - 2, - ).cuda(), - ] - - fx_graph = torch.fx.symbolic_trace(MaxPool1D()) - _, expected_ops_unseen = lower_graph_testing( - fx_graph, inputs, expected_ops=expected_ops, min_block_size=1 - ) - - self.assertEquals( - len(expected_ops_unseen), - 0, - f"The following expected ops were not encountered: {expected_ops_unseen}", - ) - - torch._dynamo.reset() - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - ) - optimized_model_results = optimized_model(*inputs).detach().cpu() - torch_model_results = fx_graph(*inputs).detach().cpu() - - max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) - self.assertAlmostEqual( - max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model." - ) - - -class TestEinsum(TestCase): - def test_pre_aot_lowering_einsum(self): - class Einsum(torch.nn.Module): - def forward(self, x, y): - return torch.einsum("ij,ji->ij", x, y) - - # Operations expected to be included in the traced graph after decompositions - expected_ops = {torch.ops.tensorrt.einsum.default} - - inputs = [ - torch.rand( - 16, - 16, - ).cuda(), - torch.rand( - 16, - 16, - ).cuda(), - ] - - fx_graph = torch.fx.symbolic_trace(Einsum()) - _, expected_ops_unseen = lower_graph_testing( - fx_graph, inputs, expected_ops=expected_ops, min_block_size=1 - ) - - self.assertEquals( - len(expected_ops_unseen), - 0, - f"The following expected ops were not encountered: {expected_ops_unseen}", - ) - - torch._dynamo.reset() - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - ) - optimized_model_results = optimized_model(*inputs).detach().cpu() - torch_model_results = fx_graph(*inputs).detach().cpu() - - max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) - self.assertAlmostEqual( - max_diff, 0, f"Einsum TRT outputs don't match with the original model." - ) - - -if __name__ == "__main__": - run_tests() diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 40e5a8f3e8..95a8c96e94 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -313,6 +313,185 @@ def forward(self, x): f"Var TRT outputs don't match with the original model.", ) + def test_lowering_maxpool1d_functional(self): + class MaxPool1d(torch.nn.Module): + def forward(self, x): + y = torch.nn.functional.max_pool1d(x, 3) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.max_pool2d.default} + unexpected_ops = { + torch.ops.aten.max_pool1d_with_indices.default, + torch.ops.aten.max_pool2d_with_indices.default, + } + + inputs = [torch.randn(4, 8, 27).cuda()] + + fx_graph = torch.fx.symbolic_trace(MaxPool1d()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"MaxPool1d TRT outputs don't match with the original model.", + ) + + def test_lowering_maxpool_2d_module(self): + class MaxPool2d(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.maxpool = torch.nn.MaxPool2d((5, 3), stride=(2, 1)) + + def forward(self, x): + y = self.maxpool(x) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.max_pool2d.default} + unexpected_ops = {torch.ops.aten.max_pool2d_with_indices.default} + + inputs = [torch.randn(1, 3, 25, 30).cuda()] + + fx_graph = torch.fx.symbolic_trace(MaxPool2d()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"MaxPool2d TRT outputs don't match with the original model.", + ) + + def test_lowering_maxpool_3d_module(self): + class MaxPool3d(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.maxpool = torch.nn.MaxPool3d(3) + + def forward(self, x): + y = self.maxpool(x) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.max_pool3d.default} + unexpected_ops = {torch.ops.aten.max_pool3d_with_indices.default} + + inputs = [torch.randn(4, 8, 27, 72, 96).cuda()] + + fx_graph = torch.fx.symbolic_trace(MaxPool3d()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"MaxPool3d TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index b55194fa4c..9ec0fcf58e 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -12,7 +12,6 @@ get_decompositions, repair_input_aliasing, ) -from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions DECIMALS_OF_AGREEMENT = 4 @@ -35,8 +34,6 @@ def fx_dynamo_testing_backend( use_fast_partitioner=use_fast_partitioner, ) - gm = pre_aot_substitutions(gm) - fake_mode = detect_fake_mode(sample_inputs) # Place backend tracing within FakeTensor context allowing nonfake Tensors