diff --git a/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst b/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst
new file mode 100644
index 0000000000..d64f81d4aa
--- /dev/null
+++ b/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst
@@ -0,0 +1,109 @@
+.. _writing_dynamo_aten_lowering_passes:
+
+Writing Dynamo ATen Lowering Passes
+===================
+
+Basics of a Lowering Pass
+------------
+
+ATen lowering passes are Python functions which take as input a graph of ATen operators, apply some desired modification such as operator coalescing/fusion, operator replacement, subgraph rewriting, custom operator insertion, or other operation on a `torch.fx.GraphModule`, then return the modified graph to the caller. These lowering passes generally modify the graph in-place and return the same input object.
+
+Lowering Pass Requirements
+------------
+
+An ATen lowering pass function in Torch-TRT must satisfy two requirements:
+- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule`
+- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation
+
+See this link for information on `Graph Manipulations <https://pytorch.org/docs/stable/fx.html#graph-manipulation>`_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines.
+
+Example Lowering Pass
+------------
+
+.. code-block:: python
+
+    def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        """Repair scenarios where inputs are also outputs of the graph
+
+        TRT does not allow such cases, so we insert a clone (identity) layer
+        """
+        modified_graph = False
+
+        # Extract graph placeholder Tensors
+        placeholders = [
+            node
+            for node in gm.graph.nodes
+            if (
+                node.op == "placeholder"
+                and isinstance(node.type, type)
+                and issubclass(node.type, torch.Tensor)
+            )
+        ]
+
+        for placeholder in placeholders:
+            # If any placeholder has any users which are direct graph outputs
+            if len(placeholder.users) >= 1 and any(
+                user.op == "output" for user in placeholder.users
+            ):
+                modified_graph = True
+
+                # Get direct graph outputs which are direct uses of placeholders
+                direct_outputs = [user for user in placeholder.users if user.op == "output"]
+
+                # Insert clone node for placeholder to ensure
+                # placeholder is not a direct output
+                with gm.graph.inserting_after(placeholder):
+                    cloned_placeholder = gm.graph.call_function(
+                        torch.ops.aten.clone.default,
+                        args=(placeholder,),
+                    )
+
+                # Replace placeholder as output with cloned version
+                for output in direct_outputs:
+                    output.replace_input_with(placeholder, cloned_placeholder)
+
+        # If the graph was modified, clean up the graph and ensure it is up-to-date
+        if modified_graph:
+            gm.graph.eliminate_dead_code()
+            gm.graph.lint()
+            gm.recompile()
+            logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
+
+        return gm
+
+
+Registering Lowering Passes
+----------------------
+
+Lowering passes are currently registered in `py/torch_tensorrt/dynamo/lowering/passes/__init__.py`, using the `torch.fx.passes.pass_manager.PassManager` utility to assemble the list of passes in a desired order. New passes added directly to that list will be applied to graphs in the Torch-TensorRT `torch.compile` backend. Currently, we offer an ATen lowering pass registration decorator for convenience, which can be invoked either directly, or with the optional `index` keyword argument which controls where in the pass list the lowering pass will be inserted.
+
+For instance, to insert the pass at the default location (end of the list), the following code can be used:
+
+.. code-block:: python
+
+    @_aten_lowering_pass
+    def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        ...
+
+Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used:
+
+.. code-block:: python
+
+    @_aten_lowering_pass(index=0)
+    def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        ...
+
+There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index.
+
+.. code-block:: python
+
+    # Print all lowering passes in the list
+    print(dump_lowering_passes())
+
+    # Apply lowering passes to a GraphModule
+    apply_lowering_passes(graph_module)
+
+    # Remove the lowering pass at index 1
+    _remove_lowering_pass(index=1)
+
+**Note:** The above APIs are subject to change, as the lowering pass system evolves.
diff --git a/docsrc/index.rst b/docsrc/index.rst
index eee62bc2f7..ded3b99c9d 100644
--- a/docsrc/index.rst
+++ b/docsrc/index.rst
@@ -128,6 +128,7 @@ Contributor Documentation
 --------------------------------
 * :ref:`system_overview`
 * :ref:`writing_converters`
+* :ref:`writing_dynamo_aten_lowering_passes`
 * :ref:`useful_links`
 
 .. toctree::
@@ -137,6 +138,7 @@ Contributor Documentation
 
    contributors/system_overview
    contributors/writing_converters
+   contributors/writing_dynamo_aten_lowering_passes
    contributors/useful_links
 
 Indices
diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py
index 32225e79fc..b271d0d6fb 100644
--- a/py/torch_tensorrt/dynamo/aten_tracer.py
+++ b/py/torch_tensorrt/dynamo/aten_tracer.py
@@ -6,8 +6,7 @@
 
 import torch
 from torch._export import export
-from torch_tensorrt.dynamo.backend.backends import constant_fold
-from torch_tensorrt.dynamo.lowering import get_decompositions
+from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
 from torch_tensorrt.dynamo.utils import set_log_level
 
 logger = logging.getLogger(__name__)
@@ -29,6 +28,6 @@ def trace(
         "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
     ):
         graph_module = export(model, tuple(inputs)).module()
-        constant_fold(graph_module)
+        graph_module = apply_lowering_passes(graph_module)
     logger.debug("Post export graph: " + str(graph_module.graph))
     return graph_module
diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py
index 7fde8bbb41..022f3b193d 100644
--- a/py/torch_tensorrt/dynamo/backend/backends.py
+++ b/py/torch_tensorrt/dynamo/backend/backends.py
@@ -10,24 +10,12 @@
 from torch._dynamo.utils import detect_fake_mode
 from torch._functorch.aot_autograd import _aot_export_function
 from torch._ops import OpOverload
-from torch_tensorrt._utils import sanitized_torch_version
 from torch_tensorrt.dynamo import CompilationSettings
 from torch_tensorrt.dynamo.compile import compile_module
-from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
+from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
 from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
 from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level
 
-from packaging import version
-
-# Modify import location of utilities based on Torch version
-if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
-    from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
-else:
-    from torch._inductor.constant_folding import (
-        ConstantFolder,
-        replace_node_with_constant,
-    )
-
 logger = logging.getLogger(__name__)
 
 
@@ -84,7 +72,7 @@ def _pretraced_backend(
             fake_mode, "allow_non_fake_inputs", True
         ), fake_mode:
             # Invoke AOTAutograd to translate operators to aten
-            graph_module = aot_export_for_compile(
+            gm = aot_export_for_compile(
                 gm,
                 sample_inputs,
                 decompositions=get_decompositions(
@@ -94,10 +82,10 @@ def _pretraced_backend(
 
             logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
 
-            constant_fold(graph_module)
+            gm = apply_lowering_passes(gm)
 
             trt_compiled = compile_module(
-                graph_module,
+                gm,
                 sample_inputs,
                 settings=settings,
             )
@@ -121,35 +109,6 @@ def _pretraced_backend(
             raise
 
 
-@torch.utils._python_dispatch._disable_current_modes()  # type: ignore
-def constant_fold(gm: torch.fx.GraphModule) -> Any:
-    """Adapted from:
-    https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
-
-    Folds constants in the graph module, not skipping constructors
-
-    Modifies the graph in-place and replaces node with constants
-    """
-    cf = ConstantFolder(gm, skip_constructors=False)
-    cf.run()
-
-    for node, constant in cf.node_replacements.items():
-        replace_node_with_constant(gm, node, constant)
-
-    erased_params = []
-    for node in gm.graph.nodes:
-        if node.op == "get_attr" and len(node.users) == 0:
-            delattr(gm, node.target)
-            erased_params.append(node)
-
-    for node in erased_params:
-        gm.graph.erase_node(node)
-
-    gm.graph.eliminate_dead_code()
-    gm.graph.lint()
-    gm.recompile()
-
-
 def aot_export_for_compile(
     func: torch.fx.GraphModule,
     args: Sequence[torch.Tensor],
diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py
index 6eda61a6fd..34faa1d11b 100644
--- a/py/torch_tensorrt/dynamo/lowering/__init__.py
+++ b/py/torch_tensorrt/dynamo/lowering/__init__.py
@@ -2,4 +2,5 @@
 from ._fusers import *  # noqa: F401
 from ._pre_aot_lowering import SUBSTITUTION_REGISTRY  # noqa: F401
 from ._pre_aot_lowering import register_substitution  # noqa: F401
+from .passes import apply_lowering_passes
 from .substitutions import *  # noqa: F401
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py
new file mode 100644
index 0000000000..ea393fab14
--- /dev/null
+++ b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py
@@ -0,0 +1 @@
+from ._aten_lowering_pass import *
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
new file mode 100644
index 0000000000..a4c7fad607
--- /dev/null
+++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
@@ -0,0 +1,76 @@
+import logging
+from typing import Callable, Optional
+
+import torch
+
+from .constant_folding import constant_fold
+from .pass_manager import DynamoPassManager
+from .repair_input_as_output import repair_input_as_output
+
+ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
+    [
+        constant_fold,
+        repair_input_as_output,
+    ]
+)
+
+logger = logging.getLogger(__name__)
+
+
+LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
+
+
+def _aten_lowering_pass(
+    *args: LoweringPassSignature,
+    index: Optional[int] = None,
+) -> LoweringPassSignature:
+    """Adds a lowering pass to the registry, at a specified index if desired
+
+    If no index is specified, the lowering pass is inserted at the end of the list
+    """
+
+    def add_lowering_pass(
+        lowering_pass: LoweringPassSignature,
+    ) -> LoweringPassSignature:
+        ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
+        logger.debug(
+            f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
+        )
+        return lowering_pass
+
+    # If there are arguments specified, the decorator may have been called as-is
+    if args:
+        # The decorator may only be called with the lowering pass
+        # The index must be specified as a keyword argument
+        if len(args) == 1 and callable(args[0]):
+            return add_lowering_pass(args[0])
+        else:
+            raise AssertionError(
+                f"aten_lowering_pass decorator called with invalid arguments {args} "
+                "To specify an index to insert the pass, use the keyword 'index='"
+            )
+    # If no arguments are specified, the decorator was called with an index keyword
+    else:
+        return add_lowering_pass
+
+
+def _remove_lowering_pass(*, index: int) -> None:
+    """Removes a lowering pass at a specific index from the registry"""
+    ATEN_LOWERING_PASSES.remove_pass_with_index(index)
+    logger.debug(
+        f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
+    )
+    return
+
+
+def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """Applies the lowering passes to a graph module, returns the modified GraphModule"""
+    logging.debug(
+        f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}"
+    )
+    return ATEN_LOWERING_PASSES(gm)
+
+
+def dump_lowering_passes() -> str:
+    """Returns a string containing the lowering passes"""
+    return str(ATEN_LOWERING_PASSES)
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
new file mode 100644
index 0000000000..d17d0a2528
--- /dev/null
+++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
@@ -0,0 +1,56 @@
+import logging
+
+import torch
+from torch_tensorrt._utils import sanitized_torch_version
+
+from packaging import version
+
+# Modify import location of utilities based on Torch version
+if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
+    from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
+else:
+    from torch._inductor.constant_folding import (
+        ConstantFolder,
+        replace_node_with_constant,
+    )
+
+logger = logging.getLogger(__name__)
+
+
+@torch.utils._python_dispatch._disable_current_modes()  # type: ignore
+def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """Adapted from:
+    https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
+
+    Folds constants in the graph module, not skipping constructors
+
+    Modifies the graph in-place and replaces node with constants
+    """
+    cf = ConstantFolder(gm, skip_constructors=False)
+    cf.run()
+
+    for node, constant in cf.node_replacements.items():
+        replace_node_with_constant(gm, node, constant)
+
+    erased_params = []
+    for node in gm.graph.nodes:
+        # If get_attr node has no users, mark it for deletion
+        if node.op == "get_attr" and len(node.users) == 0:
+            # If the node's parameter is not a parameter of any other node, remove it
+            if not any(
+                other.target == node.target for other in gm.graph.nodes if other != node
+            ):
+                delattr(gm, node.target)
+            erased_params.append(node)
+
+    # Remove unused nodes from the graph
+    for node in erased_params:
+        gm.graph.erase_node(node)
+
+    gm.graph.eliminate_dead_code()
+    gm.graph.lint()
+    gm.recompile()
+
+    logger.debug(f"Graph after constant folding:\n{gm.graph}")
+
+    return gm
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py
new file mode 100644
index 0000000000..51e2584364
--- /dev/null
+++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py
@@ -0,0 +1,42 @@
+from typing import Any, Callable, List, Optional
+
+import torch
+from torch.fx.passes.pass_manager import PassManager
+
+
+class DynamoPassManager(PassManager):  # type: ignore[misc]
+    def __init__(
+        self,
+        passes: Optional[
+            List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]
+        ] = None,
+    ):
+        super().__init__(passes)
+
+    @classmethod
+    def build_from_passlist(
+        cls,
+        passes: Optional[List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]],
+    ) -> Any:
+        pm = DynamoPassManager(passes)
+        return pm
+
+    def add_pass_with_index(
+        self,
+        lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule],
+        index: Optional[int] = None,
+    ) -> None:
+        if index is None:
+            self.passes.append(lowering_pass)
+            index = -1
+        else:
+            self.passes.insert(index, lowering_pass)
+
+    def remove_pass_with_index(self, index: int) -> None:
+        del self.passes[index]
+
+    def __call__(self, source: Any) -> Any:
+        return super().__call__(source)
+
+    def __str__(self) -> str:
+        return str(self.passes)
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py b/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py
new file mode 100644
index 0000000000..6ce846637d
--- /dev/null
+++ b/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py
@@ -0,0 +1,53 @@
+import logging
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """Repair scenarios where inputs are also outputs of the graph
+
+    TRT does not allow such cases, so we insert a clone (identity) layer
+    """
+    modified_graph = False
+
+    # Extract graph placeholder Tensors
+    placeholders = [
+        node
+        for node in gm.graph.nodes
+        if (
+            node.op == "placeholder"
+            and isinstance(node.type, type)
+            and issubclass(node.type, torch.Tensor)
+        )
+    ]
+
+    for placeholder in placeholders:
+        # If any placeholder has any users which are direct graph outputs
+        if len(placeholder.users) >= 1 and any(
+            user.op == "output" for user in placeholder.users
+        ):
+            modified_graph = True
+
+            # Get direct graph outputs which are direct uses of placeholders
+            direct_outputs = [user for user in placeholder.users if user.op == "output"]
+
+            # Insert clone node for placeholder to ensure placeholder is not a direct output
+            with gm.graph.inserting_after(placeholder):
+                cloned_placeholder = gm.graph.call_function(
+                    torch.ops.aten.clone.default,
+                    args=(placeholder,),
+                )
+
+            # Replace placeholder as output with cloned version
+            for output in direct_outputs:
+                output.replace_input_with(placeholder, cloned_placeholder)
+
+    if modified_graph:
+        gm.graph.eliminate_dead_code()
+        gm.graph.lint()
+        gm.recompile()
+        logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
+
+    return gm
diff --git a/setup.py b/setup.py
index 6b013daf9e..d02adfb678 100644
--- a/setup.py
+++ b/setup.py
@@ -392,6 +392,7 @@ def run(self):
     "torch_tensorrt.dynamo.conversion.impl.unary",
     "torch_tensorrt.dynamo.lowering",
     "torch_tensorrt.dynamo.lowering.substitutions",
+    "torch_tensorrt.dynamo.lowering.passes",
     "torch_tensorrt.dynamo.partitioning",
     "torch_tensorrt.dynamo.runtime",
     "torch_tensorrt.dynamo.tools",
@@ -419,6 +420,7 @@ def run(self):
     "torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary",
     "torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering",
     "torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions",
+    "torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes",
     "torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning",
     "torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime",
     "torch_tensorrt.dynamo.tools": "py/torch_tensorrt/dynamo/tools",
diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py
new file mode 100644
index 0000000000..a63c5e3439
--- /dev/null
+++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py
@@ -0,0 +1,95 @@
+import torch
+import torch_tensorrt
+from torch.testing._internal.common_utils import TestCase, run_tests
+
+from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
+
+
+class TestInputAsOutput(TestCase):
+    def test_input_as_output(self):
+        class InputAsOutput(torch.nn.Module):
+            def forward(self, x, y):
+                y_new = y + x + 1
+                y_new = y_new * 7
+                return (y_new, x, y)
+
+        inputs = [
+            torch.rand(
+                5,
+                7,
+            ).cuda(),
+            torch.rand(
+                5,
+                7,
+            ).cuda(),
+        ]
+
+        fx_graph = torch.fx.symbolic_trace(InputAsOutput())
+        lower_graph_testing(fx_graph, inputs, min_block_size=1)
+        torch._dynamo.reset()
+
+        # Validate that the results between Torch and Torch-TRT are similar
+        optimized_model = torch_tensorrt.compile(
+            fx_graph,
+            "torch_compile",
+            inputs,
+            min_block_size=1,
+            pass_through_build_failures=True,
+        )
+        optimized_model_results = torch.cat(
+            [tensor.detach().cpu() for tensor in optimized_model(*inputs)]
+        )
+        torch_model_results = torch.cat(
+            [tensor.detach().cpu() for tensor in fx_graph(*inputs)]
+        )
+
+        max_diff = float(
+            torch.max(torch.abs(optimized_model_results - torch_model_results))
+        )
+        self.assertAlmostEqual(
+            max_diff,
+            0,
+            DECIMALS_OF_AGREEMENT,
+            msg=f"InputAsOutput TRT outputs don't match with the original model.",
+        )
+        torch._dynamo.reset()
+
+
+class TestLoweringPassMembership(TestCase):
+    def insert_at_end(self):
+        from torch_tensorrt.dynamo.lowering.passes import (
+            ATEN_LOWERING_PASSES,
+            _aten_lowering_pass,
+            _remove_lowering_pass,
+        )
+
+        @_aten_lowering_pass
+        def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+            return gm
+
+        self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[-1])
+
+        _remove_lowering_pass(-1)
+
+        self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes)
+
+    def insert_at_index(self):
+        from torch_tensorrt.dynamo.lowering.passes import (
+            ATEN_LOWERING_PASSES,
+            _aten_lowering_pass,
+            _remove_lowering_pass,
+        )
+
+        @_aten_lowering_pass(index=0)
+        def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+            return gm
+
+        self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[0])
+
+        _remove_lowering_pass(0)
+
+        self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes)
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py
index f311f2db2b..af5336813f 100644
--- a/tests/py/dynamo/testing_utilities.py
+++ b/tests/py/dynamo/testing_utilities.py
@@ -6,8 +6,8 @@
 import torch
 from torch._dynamo.utils import detect_fake_mode
 from torch_tensorrt.dynamo import partitioning
-from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile, constant_fold
-from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
+from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile
+from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
 from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
 
 DECIMALS_OF_AGREEMENT = 4
@@ -40,16 +40,16 @@ def fx_dynamo_testing_backend(
         fake_mode, "allow_non_fake_inputs", True
     ), fake_mode:
         # Invoke AOTAutograd to translate operators to aten
-        graph_module = aot_export_for_compile(
+        gm = aot_export_for_compile(
             gm,
             sample_inputs,
             decompositions=get_decompositions(),
         )
 
-        constant_fold(graph_module)
+        gm = apply_lowering_passes(gm)
 
         trt_compiled = custom_backend(
-            graph_module,
+            gm,
             sample_inputs,
         )
         return trt_compiled