From d06c74af7f748ee6028eab1fed5ed3346882995e Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 3 Oct 2023 11:55:48 -0700
Subject: [PATCH 01/33] chore: Switch to new export apis

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 py/torch_tensorrt/dynamo/aten_tracer.py | 14 +++++---------
 1 file changed, 5 insertions(+), 9 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py
index da346635a2..def04e7057 100644
--- a/py/torch_tensorrt/dynamo/aten_tracer.py
+++ b/py/torch_tensorrt/dynamo/aten_tracer.py
@@ -1,7 +1,6 @@
 from __future__ import annotations
 
 import logging
-import unittest.mock
 from typing import Any, List, Tuple
 
 import torch
@@ -77,12 +76,9 @@ def trace(
     experimental_decompositions = kwargs.get(
         "enable_experimental_decompositions", False
     )
-    with unittest.mock.patch(
-        "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
-    ):
-        graph_module = export(
-            model, tuple(trace_inputs), constraints=constraints
-        ).module()
 
-    logger.debug("Post export graph: " + str(graph_module.graph))
-    return graph_module
+    exp_program = export(
+        model, tuple(trace_inputs), constraints=constraints
+    ).run_decompositions(get_decompositions(experimental_decompositions))
+
+    return exp_program

From ad3b0311b33508a85ae33dfdd591962561e453ac Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Thu, 19 Oct 2023 15:16:13 -0700
Subject: [PATCH 02/33] feat: Add support for dynamic shapes and remove
 constraints API

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 py/torch_tensorrt/_Input.py               |  7 ++-
 py/torch_tensorrt/dynamo/aten_tracer.py   | 53 +++++------------------
 tests/py/dynamo/models/test_dyn_models.py |  2 +
 3 files changed, 20 insertions(+), 42 deletions(-)

diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py
index 6e43a23903..4dd3cf62c2 100644
--- a/py/torch_tensorrt/_Input.py
+++ b/py/torch_tensorrt/_Input.py
@@ -47,6 +47,7 @@ class _ShapeMode(Enum):
     high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
     torch_dtype: torch.dtype = torch.float32
     torch_tensor: torch.Tensor = None
+    name: str = ""
 
     def __init__(self, *args: Any, **kwargs: Any) -> None:
         """__init__ Method for torch_tensorrt.Input
@@ -68,7 +69,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
             format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
             tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]).
                 Note: Entering "None" (or not specifying) will set the bound to [0, 2)
-
+            torch_tensor (torch.Tensor): Holds a corresponding torch tensor with this Input.
+            name (str, optional): Name of this input in the pytorch graph. Used to specify dynamic shapes in dynamo tracer.
         Examples:
             - Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
             - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
@@ -180,6 +182,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
             else:
                 self.torch_tensor = self.example_tensor()
 
+        if "name" in kwargs:
+            self.name = kwargs["name"]
+
     def __str__(self) -> str:
         if self.shape_mode == Input._ShapeMode.STATIC:
             return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format(
diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py
index f6d0ad4625..c894ca6f3c 100644
--- a/py/torch_tensorrt/dynamo/aten_tracer.py
+++ b/py/torch_tensorrt/dynamo/aten_tracer.py
@@ -1,10 +1,10 @@
 from __future__ import annotations
 
 import logging
-from typing import Any, List, Tuple
+from typing import Any, Tuple
 
 import torch
-from torch._export import dynamic_dim, export
+from torch.export import Dim, export
 from torch_tensorrt._Input import Input
 from torch_tensorrt.dynamo._defaults import (
     ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
@@ -16,20 +16,6 @@
 logger = logging.getLogger(__name__)
 
 
-def get_random_tensor(
-    shape: List[Any], dtype: torch.dtype, device: torch.device
-) -> torch.Tensor:
-    if dtype == torch.int32 or dtype == torch.int64:
-        return torch.randint(2, 10, shape, dtype=dtype, device=device)
-    elif dtype in (torch.float64, torch.float32, torch.float16):
-        return torch.randn(shape, dtype=dtype, device=device)
-    else:
-        logger.critical(
-            "Invalid dtype detected in creating input tensors for tracing the graph."
-        )
-        raise
-
-
 def trace(
     model: torch.nn.Module | torch.fx.GraphModule,
     inputs: Tuple[Any, ...],
@@ -39,49 +25,34 @@ def trace(
     if "debug" in kwargs and kwargs["debug"]:
         set_log_level(logger.parent, logging.DEBUG)
 
-    # Determine the dynamic dimension and setup constraints to input dimensions as dictated by TensorRT
-    # Torch dynamo does not allow 0/1 value for dynamic dimensions
-    # for inputs during tracing. Hence we create new inputs for export
     device = to_torch_device(kwargs.get("device", default_device()))
     torch_inputs = get_torch_inputs(inputs, device)
-    trace_inputs = []
-    constraints = []
+    dynamic_shapes = {}
     for idx, input in enumerate(inputs):
         if input.shape_mode == Input._ShapeMode.DYNAMIC:
             min_shape = input.shape["min_shape"]
             opt_shape = input.shape["opt_shape"]
             max_shape = input.shape["max_shape"]
             assert len(min_shape) == len(opt_shape) == len(max_shape)
-
-            constraint_dims = []
-            new_shape = []
+            dynamic_dims = {}
             for dim in range(len(min_shape)):
                 if min_shape[dim] == opt_shape[dim] == max_shape[dim]:
-                    new_shape.append(torch_inputs[idx].shape[dim])
+                    continue
                 else:
-                    constraint_dims.append(dim)
-                    if torch_inputs[idx].shape[dim] == 1:
-                        new_shape.append(torch_inputs[idx].shape[dim] + 1)
-                    else:
-                        new_shape.append(torch_inputs[idx].shape[dim])
-
-            trace_input = get_random_tensor(new_shape, torch_inputs[idx].dtype, device)
+                    dynamic_dims[dim] = Dim(
+                        input.name + "_" + str(dim),
+                        min=min_shape[dim],
+                        max=max_shape[dim],
+                    )
 
-            for dim in constraint_dims:
-                if min_shape[dim] > 1:
-                    constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim))
-                if max_shape[dim] > 1:
-                    constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim])
-            trace_inputs.append(trace_input)
-        else:
-            trace_inputs.append(torch_inputs[idx])
+            dynamic_shapes[input.name] = dynamic_dims
 
     experimental_decompositions = kwargs.get(
         "enable_experimental_decompositions", ENABLE_EXPERIMENTAL_DECOMPOSITIONS
     )
 
     exp_program = export(
-        model, tuple(trace_inputs), constraints=constraints
+        model, tuple(torch_inputs), dynamic_shapes=dynamic_shapes
     ).run_decompositions(get_decompositions(experimental_decompositions))
 
     return exp_program
diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py
index 057a95879d..d110845145 100644
--- a/tests/py/dynamo/models/test_dyn_models.py
+++ b/tests/py/dynamo/models/test_dyn_models.py
@@ -36,6 +36,7 @@ def forward(self, x):
                 opt_shape=(4, 3, 224, 224),
                 max_shape=(8, 3, 224, 224),
                 dtype=torch.float32,
+                name="x",
             )
         ],
         "device": torchtrt.Device("cuda:0"),
@@ -88,6 +89,7 @@ def forward(self, x):
                 opt_shape=(4, 3, 224, 224),
                 max_shape=(8, 3, 224, 224),
                 dtype=torch.float32,
+                name="x",
             )
         ],
         "device": torchtrt.Device("cuda:0"),

From 1582b72f2e1f094bda8fcb83d6a20f0e78177e39 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Mon, 23 Oct 2023 13:28:39 -0700
Subject: [PATCH 03/33] chore: add dynamic shape support for certain converters

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 .../dynamo/conversion/aten_ops_converters.py  | 18 +++++++++++++
 .../dynamo/conversion/impl/shape.py           | 25 ++++++++++++++++++-
 .../dynamo/conversion/impl/shuffle.py         | 19 ++++++++++++--
 3 files changed, 59 insertions(+), 3 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index 70c4574b94..149e16c939 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -279,6 +279,24 @@ def aten_ops_sigmoid(
     )
 
 
+@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)  # type: ignore[misc]
+def aten_ops_symsize_int(
+    ctx: ConversionContext,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    return impl.shape.shape(
+        ctx,
+        target,
+        SourceIR.ATEN,
+        name,
+        args[0],
+        args_bounds_check(args, 1, None),
+    )
+
+
 @dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)  # type: ignore[misc]
 @enforce_tensor_types(
     {
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py
index ef30b186c1..f4287feaf9 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py
@@ -8,7 +8,7 @@
 from torch.fx.node import Target
 from torch_tensorrt.dynamo._SourceIR import SourceIR
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
+from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy
 from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
     convert_binary_elementwise,
 )
@@ -16,6 +16,29 @@
 from torch_tensorrt.fx.types import TRTTensor
 
 
+def shape(
+    ctx: ConversionContext,
+    target: Target,
+    source_ir: Optional[SourceIR],
+    name: str,
+    input_val: TRTTensor,
+    dim: int,
+) -> TRTTensor:
+    """
+    This is the general shape layer implementation in TensorRT.
+    sym_size.int ops map to addShape layer in TensorRT and returns
+    the dynamic shape of the tensor optionally taking in a dim argument.
+    """
+    input_shape = ctx.net.add_shape(input_val).get_output(0)
+    if not dim:
+        max_dim = len(input_val.shape)
+        dim = dim if dim > 0 else dim + max_dim
+    indices = get_trt_tensor(ctx, dim, name + "_dim")
+    gather_dim = ctx.net.add_gather(input_shape, indices, axis=0).get_output(0)
+
+    return gather_dim
+
+
 def get_shape_with_dynamic_shape(
     ctx: ConversionContext,
     target: Target,
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
index 3a4c160d77..2b7a658338 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
@@ -2,7 +2,7 @@
 
 from torch.fx.node import Target
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
+from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
 from torch_tensorrt.fx.converters.converter_utils import set_layer_name
 from torch_tensorrt.fx.types import TRTTensor
 
@@ -16,6 +16,21 @@ def reshape(
     shape: Sequence[int],
 ) -> TRTTensor:
     layer = ctx.net.add_shuffle(input)
-    layer.reshape_dims = tuple(shape)
+    if all(isinstance(s, int) for s in shape):
+        layer.reshape_dims = tuple(shape)
+    else:
+        # Convert all the dimensions to trt Tensors.
+        trt_shape = []
+
+        for i, s in enumerate(shape):
+            if isinstance(s, TRTTensor):
+                trt_shape.append(s)
+            else:
+                a = get_trt_tensor(ctx, s, f"{name}_{i}")
+                trt_shape.append(a)
+    shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
+    shape_layer.axis = 0
+    shape_layer.name = f"{name}_output_shape"
+    layer.set_input(1, shape_layer.get_output(0))
     set_layer_name(layer, target, name, source_ir)
     return layer.get_output(0)

From 4d01545db8ca89b67b0d44c5279d69a3b9876ac9 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Wed, 25 Oct 2023 12:46:09 -0700
Subject: [PATCH 04/33] chore: minor updates

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 py/torch_tensorrt/dynamo/aten_tracer.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py
index c894ca6f3c..a28671daf0 100644
--- a/py/torch_tensorrt/dynamo/aten_tracer.py
+++ b/py/torch_tensorrt/dynamo/aten_tracer.py
@@ -28,7 +28,7 @@ def trace(
     device = to_torch_device(kwargs.get("device", default_device()))
     torch_inputs = get_torch_inputs(inputs, device)
     dynamic_shapes = {}
-    for idx, input in enumerate(inputs):
+    for input in inputs:
         if input.shape_mode == Input._ShapeMode.DYNAMIC:
             min_shape = input.shape["min_shape"]
             opt_shape = input.shape["opt_shape"]

From 6731a571134d69c869dfcfd38de2c12143ab8e90 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Thu, 26 Oct 2023 14:50:34 -0700
Subject: [PATCH 05/33] chore: updates

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 py/torch_tensorrt/dynamo/compile.py               |  4 ++++
 py/torch_tensorrt/dynamo/partitioning/__init__.py |  7 ++++++-
 py/torch_tensorrt/dynamo/partitioning/common.py   | 10 ++++++++++
 3 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py
index 5394c1382e..f3be9a223d 100644
--- a/py/torch_tensorrt/dynamo/compile.py
+++ b/py/torch_tensorrt/dynamo/compile.py
@@ -203,6 +203,10 @@ def compile_module(
             min_block_size=settings.min_block_size,
             torch_executed_ops=settings.torch_executed_ops,
         )
+    # Run symbolic shape analysis
+    partitioning.fake_tensor_prop(
+        partitioned_module, sample_inputs, to_torch_device(settings.device)
+    )
 
     # Store TRT replicas of Torch subgraphs
     trt_modules = {}
diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py
index 1a8cc94099..8e67abda88 100644
--- a/py/torch_tensorrt/dynamo/partitioning/__init__.py
+++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py
@@ -1,3 +1,8 @@
 from ._adjacency_partitioner import partition as fast_partition
 from ._global_partitioner import partition as global_partition
-from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
+from .common import (
+    fake_tensor_prop,
+    get_graph_converter_support,
+    get_submod_inputs,
+    run_shape_analysis,
+)
diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py
index 8348738afa..b345a4b814 100644
--- a/py/torch_tensorrt/dynamo/partitioning/common.py
+++ b/py/torch_tensorrt/dynamo/partitioning/common.py
@@ -2,6 +2,7 @@
 from typing import Any, Dict, Optional, Sequence, Set, Tuple
 
 import torch
+from torch.fx.passes.fake_tensor_prop import FakeTensorProp
 from torch_tensorrt._Input import Input
 from torch_tensorrt.dynamo._defaults import DEBUG
 from torch_tensorrt.dynamo.utils import get_torch_inputs, input_is_dynamic
@@ -9,6 +10,15 @@
 logger = logging.getLogger(__name__)
 
 
+def fake_tensor_prop(
+    gm: torch.fx.GraphModule, inputs: Sequence[Input], device: torch.device
+) -> None:
+    torch_inputs = get_torch_inputs(inputs, device)
+    # Propagate fake tensors and generates metadata (shape, dtype) for the nodes in the graph
+    fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
+    FakeTensorProp(gm, mode=fake_mode).propagate(*torch_inputs)
+
+
 def run_shape_analysis(
     parent_module: torch.fx.GraphModule, inputs: Sequence[Input]
 ) -> Tuple[Dict[Any, Sequence[Any]], Dict[Any, Sequence[Any]]]:

From 0b60aae4522fe7ab04bcf0acbe29ed0d29a9bb1e Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 14 Nov 2023 23:24:08 -0800
Subject: [PATCH 06/33] chore: add sym int converter

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 .../dynamo/conversion/aten_ops_converters.py     | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index b05713c360..b61f887acd 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -363,6 +363,22 @@ def aten_ops_sigmoid(
         args[0],
     )
 
+@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
+def aten_ops_symsize_int(
+    ctx: ConversionContext,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    return impl.shape.shape(
+        ctx,
+        target,
+        SourceIR.ATEN,
+        name,
+        args[0],
+        args_bounds_check(args, 1, None),
+    )
 
 @dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
 @enforce_tensor_types(

From 634612fe78e13a32e3d33233e3ba177ad732590d Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Thu, 16 Nov 2023 00:38:39 -0800
Subject: [PATCH 07/33] feat: Replace the existing shape propagation with
 symbolic shape propagation

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 py/torch_tensorrt/dynamo/_compiler.py         | 13 +--
 py/torch_tensorrt/dynamo/_tracer.py           |  2 +-
 .../dynamo/conversion/_conversion.py          |  9 ++-
 .../dynamo/conversion/impl/shuffle.py         |  9 ++-
 .../dynamo/partitioning/__init__.py           |  1 +
 .../dynamo/partitioning/common.py             | 80 +++++++++++++++++--
 6 files changed, 88 insertions(+), 26 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py
index 9082e50664..a0a206091d 100644
--- a/py/torch_tensorrt/dynamo/_compiler.py
+++ b/py/torch_tensorrt/dynamo/_compiler.py
@@ -148,10 +148,10 @@ def compile(
 
     gm = exported_program.module()
     logger.debug("Input graph: " + str(gm.graph))
-
     # Apply lowering on the graph module
     torch_inputs = get_torch_inputs(inputs, device)
     gm = apply_lowering_passes(gm, torch_inputs)
+
     logger.debug("Lowered Input graph: " + str(gm.graph))
 
     enabled_precisions = set(enabled_precisions)
@@ -263,10 +263,6 @@ def compile_module(
             min_block_size=settings.min_block_size,
             torch_executed_ops=settings.torch_executed_ops,
         )
-    # Run symbolic shape analysis
-    partitioning.fake_tensor_prop(
-        partitioned_module, sample_inputs, to_torch_device(settings.device)
-    )
 
     # Store TRT replicas of Torch subgraphs
     trt_modules = {}
@@ -279,12 +275,7 @@ def compile_module(
             continue
 
         # Get the submodule inputs for min, opt, max shapes of the graph inputs
-        submodule_inputs = partitioning.get_submod_inputs(
-            partitioned_module,
-            submodule,
-            sample_inputs,
-            to_torch_device(settings.device),
-        )
+        submodule_inputs = partitioning.construct_submodule_inputs(submodule)
 
         logger.debug(
             "Submodule name: %s\n Input shapes: %s\n %s",
diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py
index 43812fd062..bbc68192c0 100644
--- a/py/torch_tensorrt/dynamo/_tracer.py
+++ b/py/torch_tensorrt/dynamo/_tracer.py
@@ -69,7 +69,7 @@ def trace(
     torch_inputs = get_torch_inputs(inputs, device)
     dynamic_shapes = {}
     for input in inputs:
-        if input.shape_mode == Input._ShapeMode.DYNAMIC:
+        if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC:
             min_shape = input.shape["min_shape"]
             opt_shape = input.shape["opt_shape"]
             max_shape = input.shape["max_shape"]
diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py
index 1cdea63680..2aa34952ed 100644
--- a/py/torch_tensorrt/dynamo/conversion/_conversion.py
+++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py
@@ -3,6 +3,7 @@
 import io
 from typing import Sequence
 
+import tensorrt as trt
 import torch
 from torch_tensorrt._Input import Input
 from torch_tensorrt.dynamo._settings import CompilationSettings
@@ -10,8 +11,6 @@
 from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
 from torch_tensorrt.dynamo.utils import get_torch_inputs
 
-import tensorrt as trt
-
 
 def convert_module(
     module: torch.fx.GraphModule,
@@ -40,6 +39,12 @@ def convert_module(
     # such as aten.sum - such outputs can be truncated
     output_dtypes = []
     for output in module_outputs:
+        if not isinstance(output, torch.Tensor):
+            output = torch.tensor(output)
+            if isinstance(output, int):
+                output = output.to(torch.int32)
+            elif isinstance(output, float):
+                output = output.to(torch.float32)
         if settings.truncate_long_and_double and output.dtype == torch.float64:
             output_dtypes.append(torch.float32)
         elif settings.truncate_long_and_double and output.dtype == torch.int64:
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
index 2b7a658338..a52995d4b7 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
@@ -28,9 +28,10 @@ def reshape(
             else:
                 a = get_trt_tensor(ctx, s, f"{name}_{i}")
                 trt_shape.append(a)
-    shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
-    shape_layer.axis = 0
-    shape_layer.name = f"{name}_output_shape"
-    layer.set_input(1, shape_layer.get_output(0))
+        shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
+        shape_layer.axis = 0
+        shape_layer.name = f"{name}_output_shape"
+        layer.set_input(1, shape_layer.get_output(0))
+
     set_layer_name(layer, target, name, source_ir)
     return layer.get_output(0)
diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py
index 8e67abda88..3c2a9ea199 100644
--- a/py/torch_tensorrt/dynamo/partitioning/__init__.py
+++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py
@@ -1,6 +1,7 @@
 from ._adjacency_partitioner import partition as fast_partition
 from ._global_partitioner import partition as global_partition
 from .common import (
+    construct_submodule_inputs,
     fake_tensor_prop,
     get_graph_converter_support,
     get_submod_inputs,
diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py
index b345a4b814..e892834a5b 100644
--- a/py/torch_tensorrt/dynamo/partitioning/common.py
+++ b/py/torch_tensorrt/dynamo/partitioning/common.py
@@ -2,7 +2,6 @@
 from typing import Any, Dict, Optional, Sequence, Set, Tuple
 
 import torch
-from torch.fx.passes.fake_tensor_prop import FakeTensorProp
 from torch_tensorrt._Input import Input
 from torch_tensorrt.dynamo._defaults import DEBUG
 from torch_tensorrt.dynamo.utils import get_torch_inputs, input_is_dynamic
@@ -10,13 +9,78 @@
 logger = logging.getLogger(__name__)
 
 
-def fake_tensor_prop(
-    gm: torch.fx.GraphModule, inputs: Sequence[Input], device: torch.device
-) -> None:
-    torch_inputs = get_torch_inputs(inputs, device)
-    # Propagate fake tensors and generates metadata (shape, dtype) for the nodes in the graph
-    fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
-    FakeTensorProp(gm, mode=fake_mode).propagate(*torch_inputs)
+def contains_sym_int(tensor: torch.Tensor) -> bool:
+    """
+    Returns true if the given tensor has symbolic shape.
+    """
+    for dim in tensor:
+        if isinstance(dim, torch.SymInt):
+            return True
+    return False
+
+
+def construct_dynamic_input(input: Any) -> Input:
+    """
+    Constructs a torch_tensorrt.Input based on a symbolic input
+    Args:
+        input: A symbolic shape tensor (which can have a  mix of SymInt nodes and static values)
+    Returns:
+        A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input.
+    """
+    input_sym_shape = input.size()
+    min_shape = []
+    opt_shape = []
+    max_shape = []
+    for dim in input_sym_shape:
+        if isinstance(dim, torch.SymInt):
+            node = dim.node
+            expr = node.expr
+            shape_env = node.shape_env
+            var_range = shape_env.var_to_range.get(expr, None)
+            var_val = shape_env.var_to_val.get(expr, None)
+            assert var_range, var_val
+            # Torchdynamo 0/1 specialization outlier
+            if var_range.lower == 2:
+                min_shape.append(1)
+            else:
+                min_shape.append(var_range.lower)
+            opt_shape.append(var_val)
+            max_shape.append(var_range.upper)
+        else:
+            min_shape.append(dim)
+            opt_shape.append(dim)
+            max_shape.append(dim)
+
+    return Input(
+        min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape, dtype=input.dtype
+    )
+
+
+def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
+    """
+    Construct torch_tensorrt Inputs based on the module inputs.
+    The module inputs will have meta data which has the shape and dtype info
+    Args:
+        module: Input FX GraphModule
+    Returns:
+        Sequence of torch_tensorrt.Input's representing inputs to given module
+    """
+    torchtrt_inputs = []
+    module_inputs = [node for node in module.graph.nodes if node.op == "placeholder"]
+    for input in module_inputs:
+        if input.meta and "val" in input.meta:
+            input_meta = input.meta["val"]
+            input_shape = input_meta.size()
+            if contains_sym_int(input_shape):
+                torchtrt_inputs.append(construct_dynamic_input(input_meta))
+            else:
+                torchtrt_inputs.append(Input(shape=input_shape, dtype=input_meta.dtype))
+        else:
+            raise AssertionError(
+                f"Input {input.name} does not contain metadata. Please ensure you have exported the graph correctly"
+            )
+
+    return torchtrt_inputs
 
 
 def run_shape_analysis(

From 93edba415b6e2a7109b935157faedc833315ce6e Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Thu, 16 Nov 2023 08:36:27 -0800
Subject: [PATCH 08/33] chore: fix imports

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 py/torch_tensorrt/dynamo/partitioning/__init__.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py
index 3c2a9ea199..5e5406e67c 100644
--- a/py/torch_tensorrt/dynamo/partitioning/__init__.py
+++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py
@@ -2,7 +2,6 @@
 from ._global_partitioner import partition as global_partition
 from .common import (
     construct_submodule_inputs,
-    fake_tensor_prop,
     get_graph_converter_support,
     get_submod_inputs,
     run_shape_analysis,

From 7ad927248bd4fdfea8034f345448e883c01c9d3c Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Thu, 16 Nov 2023 08:42:28 -0800
Subject: [PATCH 09/33] chore: fix imports

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index b61f887acd..09b0092e85 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -363,6 +363,7 @@ def aten_ops_sigmoid(
         args[0],
     )
 
+
 @dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
 def aten_ops_symsize_int(
     ctx: ConversionContext,
@@ -380,6 +381,7 @@ def aten_ops_symsize_int(
         args_bounds_check(args, 1, None),
     )
 
+
 @dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
 @enforce_tensor_types(
     {

From f444d5459632397c401083bf64af9404d040c7f2 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 21 Nov 2023 00:14:51 -0800
Subject: [PATCH 10/33] chore: updates

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 py/torch_tensorrt/dynamo/_compiler.py          | 18 ++++++++++++++++--
 py/torch_tensorrt/dynamo/backend/backends.py   |  1 -
 .../dynamo/partitioning/common.py              |  2 --
 py/torch_tensorrt/dynamo/utils.py              |  3 ++-
 4 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py
index a0a206091d..0651e1ac42 100644
--- a/py/torch_tensorrt/dynamo/_compiler.py
+++ b/py/torch_tensorrt/dynamo/_compiler.py
@@ -40,7 +40,6 @@
     prepare_inputs,
     set_log_level,
     to_torch_device,
-    to_torch_tensorrt_device,
 )
 
 logger = logging.getLogger(__name__)
@@ -144,7 +143,7 @@ def compile(
 
     # Prepare torch_trt inputs
     inputs = prepare_inputs(inputs)
-    device = to_torch_tensorrt_device(device)
+    device = to_torch_device(device)
 
     gm = exported_program.module()
     logger.debug("Input graph: " + str(gm.graph))
@@ -234,6 +233,21 @@ def compile_module(
             f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
         )
 
+    def contains_metadata(gm: torch.fx.GraphModule) -> bool:
+        for node in gm.graph.nodes:
+            if (not node.meta) or "val" not in node.meta and node.op != "output":
+                return False
+        return True
+
+    # Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
+    if not contains_metadata(gm):
+        from torch._inductor.compile_fx import fake_tensor_prop
+
+        torch_inputs = get_torch_inputs(sample_inputs, settings.device)
+        with torch.no_grad():
+            # This fails if the module has data-dependent shape operators.
+            fake_tensor_prop(gm, torch_inputs)
+
     # Partition module into components that can be TRT-accelerated
     fast_partitioner_failed = False
 
diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py
index 1fa2806181..bade91c553 100644
--- a/py/torch_tensorrt/dynamo/backend/backends.py
+++ b/py/torch_tensorrt/dynamo/backend/backends.py
@@ -74,7 +74,6 @@ def _pretraced_backend(
             fake_mode, "allow_non_fake_inputs", True
         ), fake_mode:
             repair_input_aliasing(gm)
-
             # Invoke AOTAutograd to translate operators to aten
             gm = aot_export_joint_simple(
                 gm,
diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py
index e892834a5b..2cd5dfca76 100644
--- a/py/torch_tensorrt/dynamo/partitioning/common.py
+++ b/py/torch_tensorrt/dynamo/partitioning/common.py
@@ -44,8 +44,6 @@ def construct_dynamic_input(input: Any) -> Input:
                 min_shape.append(1)
             else:
                 min_shape.append(var_range.lower)
-            opt_shape.append(var_val)
-            max_shape.append(var_range.upper)
         else:
             min_shape.append(dim)
             opt_shape.append(dim)
diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py
index 26de1fcb27..31bda92ad3 100644
--- a/py/torch_tensorrt/dynamo/utils.py
+++ b/py/torch_tensorrt/dynamo/utils.py
@@ -88,7 +88,8 @@ def get_torch_inputs(
             if isinstance(input, Input)
         ]
     return [
-        input.torch_tensor.to(device) for input in inputs if isinstance(input, Input)
+        input.torch_tensor.to(device) if isinstance(input, Input) else input
+        for input in inputs
     ]
 
 

From 6e5c5828095d299ae46434fe9e790385dc1ed3af Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 28 Nov 2023 11:54:13 -0800
Subject: [PATCH 11/33] chore: change device calls

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 py/torch_tensorrt/dynamo/_compiler.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py
index 0651e1ac42..33ea1a9f13 100644
--- a/py/torch_tensorrt/dynamo/_compiler.py
+++ b/py/torch_tensorrt/dynamo/_compiler.py
@@ -40,6 +40,7 @@
     prepare_inputs,
     set_log_level,
     to_torch_device,
+    to_torch_tensorrt_device,
 )
 
 logger = logging.getLogger(__name__)
@@ -143,7 +144,7 @@ def compile(
 
     # Prepare torch_trt inputs
     inputs = prepare_inputs(inputs)
-    device = to_torch_device(device)
+    device = to_torch_tensorrt_device(device)
 
     gm = exported_program.module()
     logger.debug("Input graph: " + str(gm.graph))

From 83791f8665530f03227e906fedac62b53d2a1b28 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 5 Dec 2023 12:24:25 -0800
Subject: [PATCH 12/33] chore: fix metadata check

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
---
 py/torch_tensorrt/dynamo/_compiler.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py
index 33ea1a9f13..b99de1788f 100644
--- a/py/torch_tensorrt/dynamo/_compiler.py
+++ b/py/torch_tensorrt/dynamo/_compiler.py
@@ -236,7 +236,10 @@ def compile_module(
 
     def contains_metadata(gm: torch.fx.GraphModule) -> bool:
         for node in gm.graph.nodes:
-            if (not node.meta) or "val" not in node.meta and node.op != "output":
+            if node.op != "output" and (not node.meta) and "val" not in node.meta:
+                logger.debug(
+                    f"Node {node.name} of op type {node.op} does not have metadata"
+                )
                 return False
         return True
 

From 16394d91f947817cc05123e8e525a0f3e04aa471 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Sun, 7 Jan 2024 07:04:13 +0000
Subject: [PATCH 13/33] chore: minor fixes

---
 py/torch_tensorrt/dynamo/conversion/impl/shape.py               | 2 +-
 py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py | 2 --
 py/torch_tensorrt/dynamo/partitioning/common.py                 | 2 ++
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py
index f4287feaf9..24554b6f9a 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py
@@ -32,7 +32,7 @@ def shape(
     input_shape = ctx.net.add_shape(input_val).get_output(0)
     if not dim:
         max_dim = len(input_val.shape)
-        dim = dim if dim > 0 else dim + max_dim
+        dim = dim if dim >= 0 else dim + max_dim
     indices = get_trt_tensor(ctx, dim, name + "_dim")
     gather_dim = ctx.net.add_gather(input_shape, indices, axis=0).get_output(0)
 
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
index d6e12f5215..604eda8c96 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
@@ -11,7 +11,6 @@
 from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
 from .repair_input_as_output import repair_input_as_output
 from .replace_max_pool_with_indices import replace_max_pool_with_indices
-from .view_to_reshape import view_to_reshape
 
 ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
     [
@@ -22,7 +21,6 @@
         lower_linear,
         fuse_prims_broadcast,
         replace_max_pool_with_indices,
-        view_to_reshape,
     ]
 )
 
diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py
index 2cd5dfca76..e892834a5b 100644
--- a/py/torch_tensorrt/dynamo/partitioning/common.py
+++ b/py/torch_tensorrt/dynamo/partitioning/common.py
@@ -44,6 +44,8 @@ def construct_dynamic_input(input: Any) -> Input:
                 min_shape.append(1)
             else:
                 min_shape.append(var_range.lower)
+            opt_shape.append(var_val)
+            max_shape.append(var_range.upper)
         else:
             min_shape.append(dim)
             opt_shape.append(dim)

From b9a7ccd81c923c6f098bc5bf2f8f527241836c46 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Mon, 8 Jan 2024 21:48:22 +0000
Subject: [PATCH 14/33] chore: Add sym_size converter tests

---
 .../dynamo/conversion/aten_ops_converters.py  | 11 +---
 .../dynamo/conversion/impl/shape.py           |  8 +--
 tests/py/dynamo/models/test_dyn_models.py     | 50 +++++++++++++++++++
 3 files changed, 56 insertions(+), 13 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index 74a8427fa7..f132c62ec6 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -400,15 +400,7 @@ def aten_ops_symsize_int(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return impl.shape.shape(
-        ctx,
-        target,
-        SourceIR.ATEN,
-        name,
-        args[0],
-        args_bounds_check(args, 1, None),
-    )
-
+    return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], kwargs["dim"])
 
 
 def index_dtype_validator(node: Node) -> bool:
@@ -420,6 +412,7 @@ def index_dtype_validator(node: Node) -> bool:
                 return False
     return True
 
+
 @dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
 @dynamo_tensorrt_converter(
     torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py
index 24554b6f9a..dc17764f80 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py
@@ -30,13 +30,13 @@ def shape(
     the dynamic shape of the tensor optionally taking in a dim argument.
     """
     input_shape = ctx.net.add_shape(input_val).get_output(0)
-    if not dim:
+    if dim is not None:
         max_dim = len(input_val.shape)
         dim = dim if dim >= 0 else dim + max_dim
-    indices = get_trt_tensor(ctx, dim, name + "_dim")
-    gather_dim = ctx.net.add_gather(input_shape, indices, axis=0).get_output(0)
+        dim_tensor = get_trt_tensor(ctx, dim, name + "_dim")
+        input_shape = ctx.net.add_gather(input_shape, dim_tensor, axis=0).get_output(0)
 
-    return gather_dim
+    return input_shape
 
 
 def get_shape_with_dynamic_shape(
diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py
index d110845145..75c6c51dab 100644
--- a/tests/py/dynamo/models/test_dyn_models.py
+++ b/tests/py/dynamo/models/test_dyn_models.py
@@ -113,3 +113,53 @@ def forward(self, x):
 
     with torch.no_grad():
         torch.cuda.empty_cache()
+
+
+@pytest.mark.unit
+def test_view(ir):
+    """
+    Tests the model (which is fully convertible) with dynamic shapes
+    """
+
+    class MyModule(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            input_shape = x.size()
+            y = x.view(input_shape[0], -1)
+            return y
+
+    model = MyModule().eval().cuda()
+    input = torch.randn((6, 3, 4)).to("cuda")
+
+    compile_spec = {
+        "inputs": [
+            torchtrt.Input(
+                min_shape=(1, 3, 4),
+                opt_shape=(4, 3, 4),
+                max_shape=(8, 3, 4),
+                dtype=torch.float32,
+                name="x",
+            )
+        ],
+        "device": torchtrt.Device("cuda:0"),
+        "enabled_precisions": {torch.float},
+        "ir": ir,
+        "pass_through_build_failures": True,
+        "optimization_level": 1,
+        "min_block_size": 1,
+    }
+
+    trt_mod = torchtrt.compile(model, **compile_spec)
+    cos_sim = cosine_similarity(model(input), trt_mod(input)[0])
+    assertions.assertTrue(
+        cos_sim > COSINE_THRESHOLD,
+        msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
+    )
+
+    # Clean up model env
+    torch._dynamo.reset()
+
+    with torch.no_grad():
+        torch.cuda.empty_cache()

From 15cc6435e960441e38408bf879848b77eab38326 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Mon, 8 Jan 2024 23:04:03 +0000
Subject: [PATCH 15/33] chore: Update test utilities

---
 .../dynamo/conversion/aten_ops_converters.py  |  2 +-
 .../dynamo/conversion/impl/shape.py           | 22 +++++++---
 tests/py/dynamo/conversion/test_sym_size.py   | 43 +++++++++++++++++++
 3 files changed, 59 insertions(+), 8 deletions(-)
 create mode 100644 tests/py/dynamo/conversion/test_sym_size.py

diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index f132c62ec6..fcfa48ebec 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -400,7 +400,7 @@ def aten_ops_symsize_int(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], kwargs["dim"])
+    return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])
 
 
 def index_dtype_validator(node: Node) -> bool:
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py
index dc17764f80..2d2481936b 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py
@@ -8,7 +8,11 @@
 from torch.fx.node import Target
 from torch_tensorrt.dynamo._SourceIR import SourceIR
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy
+from torch_tensorrt.dynamo.conversion.converter_utils import (
+    get_positive_dim,
+    get_trt_tensor,
+    to_numpy,
+)
 from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
     convert_binary_elementwise,
 )
@@ -29,12 +33,16 @@ def shape(
     sym_size.int ops map to addShape layer in TensorRT and returns
     the dynamic shape of the tensor optionally taking in a dim argument.
     """
-    input_shape = ctx.net.add_shape(input_val).get_output(0)
-    if dim is not None:
-        max_dim = len(input_val.shape)
-        dim = dim if dim >= 0 else dim + max_dim
-        dim_tensor = get_trt_tensor(ctx, dim, name + "_dim")
-        input_shape = ctx.net.add_gather(input_shape, dim_tensor, axis=0).get_output(0)
+    shape_layer = ctx.net.add_shape(input_val)
+    input_shape = shape_layer.get_output(0)
+    set_layer_name(shape_layer, target, name + "_shape", source_ir)
+
+    n_dims = len(input_val.shape)
+    dim = get_positive_dim(dim, n_dims)
+    dim_tensor = get_trt_tensor(ctx, dim, name + "_dim")
+    gather_layer = ctx.net.add_gather(input_shape, dim_tensor, axis=0)
+    set_layer_name(gather_layer, target, name + "_gather", source_ir)
+    input_shape = gather_layer.get_output(0)
 
     return input_shape
 
diff --git a/tests/py/dynamo/conversion/test_sym_size.py b/tests/py/dynamo/conversion/test_sym_size.py
new file mode 100644
index 0000000000..5952122247
--- /dev/null
+++ b/tests/py/dynamo/conversion/test_sym_size.py
@@ -0,0 +1,43 @@
+import torch
+import torch.nn as nn
+from parameterized import parameterized
+from torch.testing._internal.common_utils import run_tests
+
+from .harness import DispatchTestCase
+
+
+class TestSymSizeConverter(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ((3, 2, 4),),
+        ]
+    )
+    def test_sym_size_batch(self, input_shape):
+        class BatchDim(nn.Module):
+            def forward(self, x):
+                return torch.ops.aten.sym_size.int(x, 0)
+
+        inputs = [torch.randn(*input_shape)]
+        self.run_test(
+            BatchDim(),
+            inputs,
+        )
+
+    @parameterized.expand(
+        [
+            ((3, 2, 4),),
+        ]
+    )
+    def test_sym_size_non_batch(self, input_shape):
+        class NonBatchDim(nn.Module):
+            def forward(self, x):
+                return torch.ops.aten.sym_size.int(x, 1)
+
+        inputs = [torch.randn(*input_shape)]
+        self.run_test(
+            NonBatchDim(),
+            inputs,
+        )
+
+if __name__ == "__main__":
+    run_tests()

From 5234d74af917f3f286631163ce7dbd7b3ddd819d Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Mon, 8 Jan 2024 23:09:02 +0000
Subject: [PATCH 16/33] chore: add testcase for sym_size.int

---
 tests/py/dynamo/conversion/test_sym_size.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tests/py/dynamo/conversion/test_sym_size.py b/tests/py/dynamo/conversion/test_sym_size.py
index 5952122247..35bf75a509 100644
--- a/tests/py/dynamo/conversion/test_sym_size.py
+++ b/tests/py/dynamo/conversion/test_sym_size.py
@@ -39,5 +39,6 @@ def forward(self, x):
             inputs,
         )
 
+
 if __name__ == "__main__":
     run_tests()

From 51e8bb7d0f80f0cee1b248ab0ea3950611bf979e Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Fri, 26 Jan 2024 01:47:29 -0800
Subject: [PATCH 17/33] chore: revert output type change

---
 py/torch_tensorrt/dynamo/conversion/_conversion.py | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py
index 5cc07913bb..844cb6789a 100644
--- a/py/torch_tensorrt/dynamo/conversion/_conversion.py
+++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py
@@ -39,12 +39,6 @@ def interpret_module_to_result(
     # such as aten.sum - such outputs can be truncated
     output_dtypes = []
     for output in module_outputs:
-        if not isinstance(output, torch.Tensor):
-            output = torch.tensor(output)
-            if isinstance(output, int):
-                output = output.to(torch.int32)
-            elif isinstance(output, float):
-                output = output.to(torch.float32)
         if settings.truncate_long_and_double and output.dtype == torch.float64:
             output_dtypes.append(torch.float32)
         elif settings.truncate_long_and_double and output.dtype == torch.int64:

From 19c3fad9c87ef02b175fa5385be70a7d5d50652e Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Fri, 26 Jan 2024 19:57:17 -0800
Subject: [PATCH 18/33] chore: add update_metadata utility

---
 .../lowering/passes/_aten_lowering_pass.py    |  2 +
 .../dynamo/lowering/passes/pass_utils.py      | 19 +++++++++-
 .../dynamo/lowering/passes/view_to_reshape.py | 37 +++++++++----------
 .../dynamo/partitioning/common.py             |  6 +--
 tests/py/dynamo/models/test_dyn_models.py     |  3 +-
 5 files changed, 41 insertions(+), 26 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
index 24fca9b2f3..489805cb43 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
@@ -11,6 +11,7 @@
 from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
 from .repair_input_as_output import repair_input_as_output
 from .replace_max_pool_with_indices import replace_max_pool_with_indices
+from .view_to_reshape import view_to_reshape
 
 ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
     [
@@ -21,6 +22,7 @@
         lower_linear,
         fuse_prims_broadcast,
         replace_max_pool_with_indices,
+        view_to_reshape,
     ]
 )
 
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
index 31a55099c2..ecb614e355 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import Any, Dict, List
 
 import torch
 
@@ -29,3 +29,20 @@ def get_tensor_placeholders(
     ]
 
     return placeholders
+
+
+def update_metadata(
+    gm: torch.fx.GraphModule, target_op: Any, metadata: Dict[int, torch._ops.OpOverload]
+) -> None:
+    """
+    Given a graph and a node which has target_op in the graph,
+    a) If the node has metadata, store it in the map
+    b) If the node does not have metadata, retrieve it from the map
+       and assign to the node.
+    """
+    for idx, node in enumerate(gm.graph.nodes):
+        if node.target == target_op:
+            if idx not in metadata and node.meta:
+                metadata[idx] = node.meta
+            elif idx in metadata and not node.meta:
+                node.meta = metadata[idx]
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
index efc836814f..3308f84c58 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
@@ -1,9 +1,10 @@
 import logging
-from typing import Callable, List, Sequence, Tuple
+from typing import Dict, List, Sequence
 
 import torch
 from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
     clean_up_graph_after_modifications,
+    update_metadata,
 )
 
 logger = logging.getLogger(__name__)
@@ -13,29 +14,25 @@ def view_to_reshape(
     gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
 ) -> torch.fx.GraphModule:
     """Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
-    orig, replacement = view_replacement()
-
-    if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
-        gm = clean_up_graph_after_modifications(gm)
-        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
-
-    return gm
-
-
-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
-    """Constructs the original and replacement functions for view"""
+    orig_op = torch.ops.aten.view.default
+    replacement_op = torch.ops.aten.reshape.default
 
     # Original graph
     def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
-        return torch.ops.aten.view.default(input, shape)
+        return orig_op(input, shape)
 
     # Replacement graph
     def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
-        return torch.ops.aten.reshape.default(input, shape)
+        return replacement_op(input, shape)
 
-    return orig, replacement
+    # Store metadata of the orig_op and copy it to the replacement op
+    meta_map: Dict[int, torch._ops.OpOverload] = {}
+    update_metadata(gm, orig_op, meta_map)
+
+    if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
+        gm = clean_up_graph_after_modifications(gm)
+        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
+
+    update_metadata(gm, replacement_op, meta_map)
+
+    return gm
diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py
index e892834a5b..26d2e22b7a 100644
--- a/py/torch_tensorrt/dynamo/partitioning/common.py
+++ b/py/torch_tensorrt/dynamo/partitioning/common.py
@@ -43,9 +43,9 @@ def construct_dynamic_input(input: Any) -> Input:
             if var_range.lower == 2:
                 min_shape.append(1)
             else:
-                min_shape.append(var_range.lower)
-            opt_shape.append(var_val)
-            max_shape.append(var_range.upper)
+                min_shape.append(int(var_range.lower))
+            opt_shape.append(int(var_val))
+            max_shape.append(int(var_range.upper))
         else:
             min_shape.append(dim)
             opt_shape.append(dim)
diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py
index 51f84e3684..f9f1d02c02 100644
--- a/tests/py/dynamo/models/test_dyn_models.py
+++ b/tests/py/dynamo/models/test_dyn_models.py
@@ -3,9 +3,8 @@
 import pytest
 import timm
 import torch
-from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
-
 import torch_tensorrt as torchtrt
+from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
 
 assertions = unittest.TestCase()
 

From ed48551e11b4b9df9bbd329f6dd7f9be396ecbe9 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Fri, 26 Jan 2024 21:06:39 -0800
Subject: [PATCH 19/33] chore: change debug to warning if the graph does not
 have metadata

---
 py/torch_tensorrt/dynamo/_compiler.py | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py
index 13c997337e..9bf1002e20 100644
--- a/py/torch_tensorrt/dynamo/_compiler.py
+++ b/py/torch_tensorrt/dynamo/_compiler.py
@@ -5,6 +5,7 @@
 from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
 
 import torch
+import torch_tensorrt
 from torch.export import ExportedProgram
 from torch.fx.node import Target
 from torch_tensorrt import _enums
@@ -66,8 +67,6 @@
     to_torch_tensorrt_device,
 )
 
-import torch_tensorrt
-
 logger = logging.getLogger(__name__)
 
 
@@ -305,8 +304,8 @@ def compile_module(
     def contains_metadata(gm: torch.fx.GraphModule) -> bool:
         for node in gm.graph.nodes:
             if node.op != "output" and (not node.meta) and "val" not in node.meta:
-                logger.debug(
-                    f"Node {node.name} of op type {node.op} does not have metadata"
+                logger.warning(
+                    f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
                 )
                 return False
         return True

From 9aff04bad406c36ca3d2a5cc9f1f058903fe41f2 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 6 Feb 2024 16:00:41 -0800
Subject: [PATCH 20/33] chore: gpt2 changes + linting

---
 .../dynamo/conversion/converter_utils.py      |  2 +-
 .../dynamo/conversion/impl/grid.py            |  6 ++--
 .../dynamo/conversion/impl/select.py          | 12 ++++----
 .../dynamo/conversion/impl/upsample.py        |  2 +-
 .../dynamo/conversion/ops_evaluators.py       | 14 ++++++++--
 .../dynamo/lowering/passes/pass_utils.py      | 28 ++++++++++---------
 .../dynamo/lowering/passes/view_to_reshape.py | 14 ++++++----
 versions.py                                   |  7 ++---
 8 files changed, 48 insertions(+), 37 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py
index f90c869c15..1378f5da17 100644
--- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py
+++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py
@@ -270,7 +270,7 @@ def create_constant(
     """
     numpy_value = to_numpy(value, dtype)
     constant = ctx.net.add_constant(
-        (1,) if isinstance(value, (int, float, bool)) else value.shape,
+        trt.Dims() if isinstance(value, (int, float, bool)) else value.shape,
         numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value,
     )
     constant.name = name
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py
index 672fc97351..63ff93b0c7 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/grid.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py
@@ -1,13 +1,11 @@
-from typing import Optional, Sequence
+from typing import Optional
 
 import tensorrt as trt
-import torch
 from torch.fx.node import Target
 from torch_tensorrt.dynamo._SourceIR import SourceIR
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
 from torch_tensorrt.fx.converters.converter_utils import set_layer_name
-from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
+from torch_tensorrt.fx.types import TRTTensor
 
 # nearest, linear, cubic
 GridSamplerInterpolationMode = {
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py
index db586be65f..dc33129d24 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/select.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py
@@ -90,7 +90,7 @@ def index(
     # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
     # If any is not this flag will be set to False
     _LOGGER.debug(
-        f"Determining whether aten.index constant-index optimization can be invoked"
+        "Determining whether aten.index constant-index optimization can be invoked"
     )
     is_numpy = all(
         isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
@@ -123,7 +123,7 @@ def index(
         return identity_layer.get_output(0)
     elif len(tensor_indices) == 1:
         indices_tensor = get_trt_tensor(
-            ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
+            ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor"
         )
         index = adv_indx_indices[0]
         _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
@@ -204,7 +204,7 @@ def index(
                 cum_adv_index = cum_adv_index + adv_index
                 multiplier = multiplier * input_shape[adv_indx_indices[i]]
             cum_adv_index = get_trt_tensor(
-                ctx, cum_adv_index, name + f"_index_sum_intermediate"
+                ctx, cum_adv_index, name + "_index_sum_intermediate"
             )
         else:
             multiplier = get_trt_tensor(
@@ -263,7 +263,7 @@ def index(
             adv_indx_count
             == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
         ):
-            _LOGGER.debug(f"The indices are continuous in this case")
+            _LOGGER.debug("The indices are continuous in this case")
             concat_tensor_reshape.append(
                 get_trt_tensor(ctx, -1, name + "_dynamic_concat")
             )
@@ -287,7 +287,7 @@ def index(
                 source_ir,
             )
             unfold_tensor = regular_index_shuffle_layer.get_output(0)
-            _LOGGER.debug(f"The tensor is unfolded now")
+            _LOGGER.debug("The tensor is unfolded now")
             _LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")
 
             # Transpose folded advanced indexed axis to its original location.
@@ -342,7 +342,7 @@ def index(
             reshape_output = unfold_advanced_shuffle_layer.get_output(0)
 
         else:
-            _LOGGER.debug(f"The indices are not continuous in this case")
+            _LOGGER.debug("The indices are not continuous in this case")
             concat_final_tensor = []
             concat_final_tensor.append(cum_adv_index_shape_tensor)
             for i in range(0, rank):
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py
index 3313730ec3..594bb4167c 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py
@@ -29,7 +29,7 @@ def upsample(
         resize_layer.scales = [1.0, 1.0] + list(scale_factors)
     else:
         raise RuntimeError(
-            f"At least one of out_shape and scale_factors should be specified."
+            "At least one of out_shape and scale_factors should be specified."
         )
 
     # interpolate mode
diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
index f83e0e5008..b35f198028 100644
--- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
+++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
@@ -2,7 +2,7 @@
 import operator
 from typing import Dict, Sequence, Tuple, Union
 
-import numpy as np
+import tensorrt as trt
 import torch
 from torch.fx.node import Argument, Node, Target
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -46,4 +46,14 @@ def aten_ops_arange_start_step(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return np.arange(*args)
+    # breakpoint()
+    fill_layer = ctx.net.add_fill(trt.Dims(), trt.FillOperation.LINSPACE)
+    fill_layer.set_input(0, args[1])
+    fill_layer.set_output_type(0, trt.DataType.INT32)
+    # fill_layer.set_input(1, 0)
+    # fill_layer.set_input(2, 1)
+    # start_tensor = get_trt_tensor(ctx, 0, "_start_tensor")
+    # fill_layer.set_input(1, start_tensor)
+    # delta_tensor = get_trt_tensor(ctx, torch.tensor([0], dtype=torch.int32), "_delta_tensor")
+    # fill_layer.set_input(2, delta_tensor)
+    return fill_layer.get_output(0)
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
index ecb614e355..e3c0f46e9f 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List
+from typing import Any, List
 
 import torch
 
@@ -31,18 +31,20 @@ def get_tensor_placeholders(
     return placeholders
 
 
-def update_metadata(
-    gm: torch.fx.GraphModule, target_op: Any, metadata: Dict[int, torch._ops.OpOverload]
+def get_metadata(gm: torch.fx.GraphModule, target_op: Any) -> List[torch._ops.OpOverload]:
+    """
+    Return the list which has the metadata of all the target_op nodes present in the graph.
+    """
+    return [node.meta for node in gm.graph.nodes if node.target == target_op]
+
+
+def set_metadata(
+    gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload]
 ) -> None:
     """
-    Given a graph and a node which has target_op in the graph,
-    a) If the node has metadata, store it in the map
-    b) If the node does not have metadata, retrieve it from the map
-       and assign to the node.
+    Return the list which has the metadata of all the target_op nodes present in the graph.
     """
-    for idx, node in enumerate(gm.graph.nodes):
-        if node.target == target_op:
-            if idx not in metadata and node.meta:
-                metadata[idx] = node.meta
-            elif idx in metadata and not node.meta:
-                node.meta = metadata[idx]
+    target_nodes = [node for node in gm.graph.nodes if node.target == target_op]
+    assert len(target_nodes) == len(metadata)
+    for idx, node in enumerate(target_nodes):
+        node.meta = metadata[idx]
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
index 3308f84c58..db0346348b 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
@@ -1,10 +1,11 @@
 import logging
-from typing import Dict, List, Sequence
+from typing import List, Sequence
 
 import torch
 from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
     clean_up_graph_after_modifications,
-    update_metadata,
+    get_metadata,
+    set_metadata,
 )
 
 logger = logging.getLogger(__name__)
@@ -25,14 +26,15 @@ def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
     def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
         return replacement_op(input, shape)
 
-    # Store metadata of the orig_op and copy it to the replacement op
-    meta_map: Dict[int, torch._ops.OpOverload] = {}
-    update_metadata(gm, orig_op, meta_map)
+    # Store metadata of the orig_op
+    metadata = get_metadata(gm, orig_op)
+    # breakpoint()
 
     if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
         gm = clean_up_graph_after_modifications(gm)
         logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
 
-    update_metadata(gm, replacement_op, meta_map)
+    # Copy the orig_op's metadata to the replacement op
+    set_metadata(gm, replacement_op, metadata)
 
     return gm
diff --git a/versions.py b/versions.py
index 772737aab7..db418a06d2 100644
--- a/versions.py
+++ b/versions.py
@@ -1,11 +1,10 @@
-import yaml
-import re
 import os
+import re
 import subprocess
-
 from datetime import datetime
 from pathlib import Path
-from typing import List
+
+import yaml
 
 __version__ = "0.0.0"
 __cuda_version__ = "0.0"

From 440fcd5deb24101bb3b1c5856329dcaac7576c88 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 6 Feb 2024 16:01:01 -0800
Subject: [PATCH 21/33] chore: gpt2 changes + linting

---
 py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
index e3c0f46e9f..0ffc6d3c76 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
@@ -31,7 +31,9 @@ def get_tensor_placeholders(
     return placeholders
 
 
-def get_metadata(gm: torch.fx.GraphModule, target_op: Any) -> List[torch._ops.OpOverload]:
+def get_metadata(
+    gm: torch.fx.GraphModule, target_op: Any
+) -> List[torch._ops.OpOverload]:
     """
     Return the list which has the metadata of all the target_op nodes present in the graph.
     """

From 002db3c36e14c32c1d6e6308238e4de7646bf671 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 6 Feb 2024 19:20:12 -0800
Subject: [PATCH 22/33] chore: add fallback option if val is missing in
 metadata

---
 .../dynamo/conversion/aten_ops_converters.py  |  1 -
 .../dynamo/partitioning/common.py             | 37 +++++++++++++------
 2 files changed, 26 insertions(+), 12 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index 7e833e8b81..7f187e7134 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -413,7 +413,6 @@ def index_dtype_validator(node: Node) -> bool:
     return True
 
 
-@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
 @dynamo_tensorrt_converter(
     torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator
 )
diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py
index 26d2e22b7a..109bda275f 100644
--- a/py/torch_tensorrt/dynamo/partitioning/common.py
+++ b/py/torch_tensorrt/dynamo/partitioning/common.py
@@ -19,19 +19,18 @@ def contains_sym_int(tensor: torch.Tensor) -> bool:
     return False
 
 
-def construct_dynamic_input(input: Any) -> Input:
+def construct_dynamic_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input:
     """
     Constructs a torch_tensorrt.Input based on a symbolic input
     Args:
-        input: A symbolic shape tensor (which can have a  mix of SymInt nodes and static values)
+        input_shape: A symbolic shape / regular shape of a tensor (which can have a  mix of SymInt nodes and static values)
     Returns:
         A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input.
     """
-    input_sym_shape = input.size()
     min_shape = []
     opt_shape = []
     max_shape = []
-    for dim in input_sym_shape:
+    for dim in input_shape:
         if isinstance(dim, torch.SymInt):
             node = dim.node
             expr = node.expr
@@ -52,10 +51,20 @@ def construct_dynamic_input(input: Any) -> Input:
             max_shape.append(dim)
 
     return Input(
-        min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape, dtype=input.dtype
+        min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape, dtype=input_dtype
     )
 
 
+def get_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input:
+    """
+    Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs
+    """
+    if contains_sym_int(input_shape):
+        return construct_dynamic_input(input_shape, input_dtype)
+    else:
+        return Input(shape=input_shape, dtype=input_dtype)
+
+
 def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
     """
     Construct torch_tensorrt Inputs based on the module inputs.
@@ -68,13 +77,19 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
     torchtrt_inputs = []
     module_inputs = [node for node in module.graph.nodes if node.op == "placeholder"]
     for input in module_inputs:
-        if input.meta and "val" in input.meta:
-            input_meta = input.meta["val"]
-            input_shape = input_meta.size()
-            if contains_sym_int(input_shape):
-                torchtrt_inputs.append(construct_dynamic_input(input_meta))
+        if input.meta:
+            if "val" in input.meta:
+                input_meta = input.meta["val"]
+                input_shape = input_meta.size()
+                torchtrt_inputs.append(get_input(input_shape, input_meta.dtype))
+            elif "tensor_meta" in input.meta:
+                input_meta = input.meta["tensor_meta"]
+                input_shape = input_meta.shape
+                torchtrt_inputs.append(get_input(input_shape, input_meta.dtype))
             else:
-                torchtrt_inputs.append(Input(shape=input_shape, dtype=input_meta.dtype))
+                raise AssertionError(
+                    f"Input {input.name} does not contain val and tensor_meta fields in the metadata. Please ensure you have exported the graph correctly"
+                )
         else:
             raise AssertionError(
                 f"Input {input.name} does not contain metadata. Please ensure you have exported the graph correctly"

From 00cd17b973c2ebf6d32622f126f6b6158ce9a8d1 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Mon, 12 Feb 2024 17:27:30 -0800
Subject: [PATCH 23/33] chore: tmp changes

---
 .../dynamo/conversion/impl/slice/base.py        | 17 ++++++++++++++++-
 .../dynamo/conversion/impl/slice/ops.py         | 14 ++++++++++----
 2 files changed, 26 insertions(+), 5 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py
index 018ac63b8c..21a38a290d 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py
@@ -8,8 +8,23 @@
     has_dynamic_shape,
     set_layer_name,
 )
+from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
 from torch_tensorrt.fx.types import Shape, TRTTensor
 
+def get_dynamic_shape(ctx, target, source_ir, name, shape, input):
+    trt_shape = []
+    shape = input.shape
+    for i, s in enumerate(shape):
+        if isinstance(s, TRTTensor):
+            trt_shape.append(s)
+        else:
+            a = get_trt_tensor(ctx, s, f"{name}_{i}")
+            trt_shape.append(a)
+    shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
+    shape_layer.axis = 0
+    shape_layer.name = f"{name}_output_shape"
+    
+    return shape_layer.get_output(0)
 
 def slice(
     ctx: ConversionContext,
@@ -23,7 +38,7 @@ def slice(
 ) -> TRTTensor:
     dynamic_shape = has_dynamic_shape(input.shape)
     if dynamic_shape:
-        shape = get_shape_with_dynamic_shape(ctx, target, source_ir, name, shape, input)
+        shape = get_dynamic_shape(ctx, target, source_ir, name, shape, input)
     layer = ctx.net.add_slice(
         input,
         start=start,
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
index 5f1db00f33..dba4ad52a5 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
@@ -69,7 +69,6 @@ def expand(
 ) -> TRTTensor:
     shape_rank = len(shape)
     initial_tensor_rank = len(input_t.shape)
-
     # If the rank of the input tensor is less than the shape's rank, pad with ones
     if initial_tensor_rank < shape_rank:
         input_t = prepend_ones(
@@ -99,9 +98,16 @@ def expand(
     stride = tuple(
         [int(i == o) for i, o in zip(input_tensor_shape, shape)]
     )  # stride == 1 if dimensions match, 0 otherwise
-    layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
-    set_layer_name(layer, target, name, source_ir)
-    return layer.get_output(0)
+    # layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
+
+    # set_layer_name(layer, target, name, source_ir)
+    # return layer.get_output(0)
+    breakpoint()
+    expand_output = slice(
+        ctx, target, source_ir, name, input_t, start, shape, stride
+    )
+    return expand_output
+    
 
 
 def chunk(

From 6ac70cd2cb2c219dd5f99ffa02faa8ae89f466cb Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Mon, 12 Feb 2024 17:27:37 -0800
Subject: [PATCH 24/33] chore: tmp changes

---
 py/torch_tensorrt/dynamo/conversion/impl/slice/base.py | 6 ++++--
 py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py  | 5 +----
 2 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py
index 21a38a290d..64225227aa 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py
@@ -3,14 +3,15 @@
 from torch.fx.node import Target
 from torch_tensorrt.dynamo._SourceIR import SourceIR
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
+from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
 from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
 from torch_tensorrt.fx.converters.converter_utils import (
     has_dynamic_shape,
     set_layer_name,
 )
-from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
 from torch_tensorrt.fx.types import Shape, TRTTensor
 
+
 def get_dynamic_shape(ctx, target, source_ir, name, shape, input):
     trt_shape = []
     shape = input.shape
@@ -23,9 +24,10 @@ def get_dynamic_shape(ctx, target, source_ir, name, shape, input):
     shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
     shape_layer.axis = 0
     shape_layer.name = f"{name}_output_shape"
-    
+
     return shape_layer.get_output(0)
 
+
 def slice(
     ctx: ConversionContext,
     target: Target,
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
index dba4ad52a5..70badd796c 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
@@ -103,11 +103,8 @@ def expand(
     # set_layer_name(layer, target, name, source_ir)
     # return layer.get_output(0)
     breakpoint()
-    expand_output = slice(
-        ctx, target, source_ir, name, input_t, start, shape, stride
-    )
+    expand_output = slice(ctx, target, source_ir, name, input_t, start, shape, stride)
     return expand_output
-    
 
 
 def chunk(

From cd866609a18fb0dffe1e798a1831bcf698095de5 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Thu, 14 Mar 2024 02:01:07 -0700
Subject: [PATCH 25/33] feat: Add save API for torch-trt compiled models

---
 .github/scripts/install-torch-tensorrt.sh   |  3 +-
 py/torch_tensorrt/_compile.py               | 67 +++++++++++++++++++++
 py/torch_tensorrt/dynamo/_compiler.py       |  9 +--
 py/torch_tensorrt/dynamo/_defaults.py       |  1 -
 py/torch_tensorrt/dynamo/_exporter.py       | 17 +-----
 py/torch_tensorrt/dynamo/_settings.py       |  3 -
 tests/py/dynamo/models/test_export_serde.py | 58 +++++++++---------
 7 files changed, 104 insertions(+), 54 deletions(-)

diff --git a/.github/scripts/install-torch-tensorrt.sh b/.github/scripts/install-torch-tensorrt.sh
index 2930421d5b..9757fadeb4 100644
--- a/.github/scripts/install-torch-tensorrt.sh
+++ b/.github/scripts/install-torch-tensorrt.sh
@@ -2,7 +2,8 @@
 set -eou pipefail
 # Source conda so it's available to the script environment
 source ${BUILD_ENV_FILE}
-${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision pyyaml
+${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision --extra-index-url https://pypi.python.org/simple
+${CONDA_RUN} python -m pip install pyyaml mpmath==1.3.0
 export TRT_VERSION=$(${CONDA_RUN} python -c "import versions; versions.tensorrt_version()")
 ${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.ngc.nvidia.com
 
diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py
index 9dd816e633..aa1bc53a0a 100644
--- a/py/torch_tensorrt/_compile.py
+++ b/py/torch_tensorrt/_compile.py
@@ -6,6 +6,7 @@
 
 import torch
 import torch.fx
+import torch_tensorrt.dynamo
 import torch_tensorrt.ts
 from torch_tensorrt._enums import dtype
 from torch_tensorrt._Input import Input
@@ -29,6 +30,7 @@
 __all__ = [
     "compile",
     "convert_method_to_trt_engine",
+    "save",
 ]
 
 
@@ -332,3 +334,68 @@ def convert_method_to_trt_engine(
         )
     else:
         raise RuntimeError("Module is an unknown format or the ir requested is unknown")
+
+
+def save(
+    module: Any,
+    file_path: str = "",
+    *,
+    output_format: str = "exported_program",
+    inputs: Optional[Sequence[torch.Tensor]] = None,
+    retrace: bool = False,
+) -> None:
+    """
+    Save the model to disk in the specified output format.
+    Arguments:
+        module : Compiled Torch-TensorRT module (Options include torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)
+        inputs (torch.Tensor): Torch input tensors
+    """
+    module_type = _parse_module_type(module)
+    accepted_formats = {"exported_program", "torchscript"}
+    if inputs and not all(isinstance(input, torch.Tensor) for input in inputs):
+        raise ValueError(
+            "Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
+        )
+    if output_format not in accepted_formats:
+        raise ValueError(
+            f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
+        )
+    if not file_path:
+        raise ValueError("File path cannot be empty. Please provide a valid file path")
+
+    if module_type == _ModuleType.nn:
+        raise ValueError(
+            "Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
+        )
+    elif module_type == _ModuleType.ts:
+        if output_format == "exported_program":
+            raise ValueError(
+                "Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
+            )
+        else:
+            torch.jit.save(module, file_path)
+    elif module_type == _ModuleType.ep:
+        if output_format == "torchscript":
+            raise ValueError(
+                "Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
+            )
+        else:
+            torch.export.save(module, file_path)
+    elif module_type == _ModuleType.fx:
+        if not inputs:
+            raise ValueError(
+                "Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
+            )
+        # The module type is torch.fx.GraphModule
+        if output_format == "torchscript":
+            module_ts = torch.jit.trace(module, inputs)
+            torch.jit.save(module_ts, file_path)
+        else:
+            if not retrace:
+                from torch_tensorrt.dynamo._exporter import export
+
+                exp_program = export(module, inputs)
+                torch.export.save(exp_program, file_path)
+            else:
+                exp_program = torch.export.export(module, tuple(inputs), strict=False)
+                torch.export.save(exp_program, file_path)
diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py
index 6312532f1c..b321eabcb2 100644
--- a/py/torch_tensorrt/dynamo/_compiler.py
+++ b/py/torch_tensorrt/dynamo/_compiler.py
@@ -30,7 +30,6 @@
     MIN_BLOCK_SIZE,
     NUM_AVG_TIMING_ITERS,
     OPTIMIZATION_LEVEL,
-    OUTPUT_FORMAT,
     PASS_THROUGH_BUILD_FAILURES,
     PRECISION,
     REFIT,
@@ -48,7 +47,6 @@
     dryrun_stats_display,
     parse_non_trt_nodes,
 )
-from torch_tensorrt.dynamo._exporter import export
 from torch_tensorrt.dynamo.conversion import (
     CompilationSettings,
     UnsupportedOperatorException,
@@ -102,9 +100,8 @@ def compile(
     enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
     dryrun: bool = DRYRUN,
     hardware_compatible: bool = HARDWARE_COMPATIBLE,
-    output_format: str = OUTPUT_FORMAT,
     **kwargs: Any,
-) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]:
+) -> torch.fx.GraphModule:
     """Compile a TorchScript module for NVIDIA GPUs using TensorRT
 
     Takes a existing TorchScript module and a set of settings to configure the compiler
@@ -246,14 +243,12 @@ def compile(
         "dla_global_dram_size": dla_global_dram_size,
         "dryrun": dryrun,
         "hardware_compatible": hardware_compatible,
-        "output_format": output_format,
     }
 
     settings = CompilationSettings(**compilation_options)
     logger.info("Compilation Settings: %s\n", settings)
     trt_gm = compile_module(gm, inputs, settings)
-    trt_result = export(trt_gm, torch_inputs, output_format)
-    return trt_result
+    return trt_gm
 
 
 def compile_module(
diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py
index ec038c0dba..3d48ab3def 100644
--- a/py/torch_tensorrt/dynamo/_defaults.py
+++ b/py/torch_tensorrt/dynamo/_defaults.py
@@ -26,7 +26,6 @@
 REQUIRE_FULL_COMPILATION = False
 DRYRUN = False
 HARDWARE_COMPATIBLE = False
-OUTPUT_FORMAT = "exported_program"
 
 
 def default_device() -> Device:
diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py
index c7e2f37795..bae20ac235 100644
--- a/py/torch_tensorrt/dynamo/_exporter.py
+++ b/py/torch_tensorrt/dynamo/_exporter.py
@@ -18,27 +18,16 @@
 def export(
     gm: torch.fx.GraphModule,
     inputs: Sequence[torch.Tensor],
-    output_format: str,
 ) -> ExportedProgram:
     """Export the result of TensorRT compilation into the desired output format.
 
     Arguments:
         gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
         inputs (torch.Tensor): Torch input tensors
-        output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
     """
-    if output_format == "torchscript" or output_format == "ts":
-        return torch.jit.trace(gm, inputs)
-    elif output_format == "exported_program" or output_format == "ep":
-        patched_module = transform(gm, inputs)
-        exp_program = create_trt_exp_program(patched_module)
-        return exp_program
-    elif output_format == "graph_module" or output_format == "fx":
-        return gm
-    else:
-        raise ValueError(
-            f"Invalid output format {output_format} specified. Supported options include exported_program (or) ep | torchscript (or) ts | graph_module (or) fx"
-        )
+    patched_module = transform(gm, inputs)
+    exp_program = create_trt_exp_program(patched_module)
+    return exp_program
 
 
 def transform(
diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py
index c00b049f45..2420a227d8 100644
--- a/py/torch_tensorrt/dynamo/_settings.py
+++ b/py/torch_tensorrt/dynamo/_settings.py
@@ -19,7 +19,6 @@
     MIN_BLOCK_SIZE,
     NUM_AVG_TIMING_ITERS,
     OPTIMIZATION_LEVEL,
-    OUTPUT_FORMAT,
     PASS_THROUGH_BUILD_FAILURES,
     PRECISION,
     REFIT,
@@ -71,7 +70,6 @@ class CompilationSettings:
             TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the
             ouptut to a file if a string path is specified
         hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
-        output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
     """
 
     precision: torch.dtype = PRECISION
@@ -99,4 +97,3 @@ class CompilationSettings:
     dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
     dryrun: Union[bool, str] = DRYRUN
     hardware_compatible: bool = HARDWARE_COMPATIBLE
-    output_format: str = OUTPUT_FORMAT
diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py
index efa593890e..0905c5e859 100644
--- a/tests/py/dynamo/models/test_export_serde.py
+++ b/tests/py/dynamo/models/test_export_serde.py
@@ -42,18 +42,18 @@ def forward(self, x):
     }
 
     exp_program = torchtrt.dynamo.trace(model, **compile_spec)
-    trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
-    torch.export.save(trt_exp_program, "/tmp/trt.ep")
+    trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
+    torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
     deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
-
+    deser_trt_module = deser_trt_exp_program.module()
     # Check Pyt and TRT exported program outputs
-    cos_sim = cosine_similarity(model(input), trt_exp_program(input)[0])
+    cos_sim = cosine_similarity(model(input), trt_module(input)[0])
     assertions.assertTrue(
         cos_sim > COSINE_THRESHOLD,
         msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
     )
     # Check Pyt and deserialized TRT exported program outputs
-    cos_sim = cosine_similarity(model(input), deser_trt_exp_program(input)[0])
+    cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0])
     assertions.assertTrue(
         cos_sim > COSINE_THRESHOLD,
         msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
@@ -93,12 +93,13 @@ def forward(self, x):
     }
 
     exp_program = torchtrt.dynamo.trace(model, **compile_spec)
-    trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
-    torch.export.save(trt_exp_program, "/tmp/trt.ep")
+    trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
+    torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
     deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
+    deser_trt_module = deser_trt_exp_program.module()
     # Check Pyt and TRT exported program outputs
     outputs_pyt = model(input)
-    outputs_trt = trt_exp_program(input)
+    outputs_trt = trt_module(input)
     for idx in range(len(outputs_pyt)):
         cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
         assertions.assertTrue(
@@ -107,7 +108,7 @@ def forward(self, x):
         )
 
     # Check Pyt and deserialized TRT exported program outputs
-    outputs_trt_deser = deser_trt_exp_program(input)
+    outputs_trt_deser = deser_trt_module(input)
     for idx in range(len(outputs_pyt)):
         cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
         assertions.assertTrue(
@@ -149,12 +150,13 @@ def forward(self, x):
     }
 
     exp_program = torchtrt.dynamo.trace(model, **compile_spec)
-    trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
-    torch.export.save(trt_exp_program, "/tmp/trt.ep")
+    trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
+    torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
     deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
+    deser_trt_module = deser_trt_exp_program.module()
     # Check Pyt and TRT exported program outputs
     outputs_pyt = model(input)
-    outputs_trt = trt_exp_program(input)
+    outputs_trt = trt_module(input)
     for idx in range(len(outputs_pyt)):
         cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
         assertions.assertTrue(
@@ -163,7 +165,7 @@ def forward(self, x):
         )
 
     # Check Pyt and deserialized TRT exported program outputs
-    outputs_trt_deser = deser_trt_exp_program(input)
+    outputs_trt_deser = deser_trt_module(input)
     for idx in range(len(outputs_pyt)):
         cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
         assertions.assertTrue(
@@ -207,12 +209,12 @@ def forward(self, x):
     }
 
     exp_program = torchtrt.dynamo.trace(model, **compile_spec)
-    trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
-    torch.export.save(trt_exp_program, "/tmp/trt.ep")
+    trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
+    torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
     deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
-
+    deser_trt_module = deser_trt_exp_program.module()
     outputs_pyt = model(input)
-    outputs_trt = trt_exp_program(input)
+    outputs_trt = trt_module(input)
     for idx in range(len(outputs_pyt)):
         cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
         assertions.assertTrue(
@@ -220,7 +222,7 @@ def forward(self, x):
             msg=f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
         )
 
-    outputs_trt_deser = deser_trt_exp_program(input)
+    outputs_trt_deser = deser_trt_module(input)
     for idx in range(len(outputs_pyt)):
         cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
         assertions.assertTrue(
@@ -248,19 +250,19 @@ def test_resnet18(ir):
     }
 
     exp_program = torchtrt.dynamo.trace(model, **compile_spec)
-    trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
-    torch.export.save(trt_exp_program, "/tmp/trt.ep")
+    trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
+    torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
     deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
-
+    deser_trt_module = deser_trt_exp_program.module()
     outputs_pyt = model(input)
-    outputs_trt = trt_exp_program(input)
+    outputs_trt = trt_module(input)
     cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
     assertions.assertTrue(
         cos_sim > COSINE_THRESHOLD,
         msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
     )
 
-    outputs_trt_deser = deser_trt_exp_program(input)
+    outputs_trt_deser = deser_trt_module(input)
 
     cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
     assertions.assertTrue(
@@ -303,12 +305,12 @@ def forward(self, x):
     }
 
     exp_program = torchtrt.dynamo.trace(model, **compile_spec)
-    trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
-    torch.export.save(trt_exp_program, "/tmp/trt.ep")
+    trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
+    torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
     deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
-
+    deser_trt_module = deser_trt_exp_program.module()
     outputs_pyt = model(input)
-    outputs_trt = trt_exp_program(input)
+    outputs_trt = trt_module(input)
 
     for idx in range(len(outputs_pyt)):
         cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
@@ -317,7 +319,7 @@ def forward(self, x):
             msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
         )
 
-    outputs_trt_deser = deser_trt_exp_program(input)
+    outputs_trt_deser = deser_trt_module(input)
     for idx in range(len(outputs_pyt)):
         cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
         assertions.assertTrue(

From eab0dba2955a87550fe12e7b67ae092597b8c453 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Mon, 18 Mar 2024 12:43:45 -0700
Subject: [PATCH 26/33] chore: Fix save failures

---
 core/runtime/TRTEngine.cpp                  |   2 +-
 py/torch_tensorrt/_compile.py               |   9 +-
 py/torch_tensorrt/dynamo/_exporter.py       | 135 +++++++++++++-------
 tests/py/dynamo/models/test_export_serde.py |   2 +-
 4 files changed, 100 insertions(+), 48 deletions(-)

diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp
index 92e5d7a8ff..7a046f6d94 100644
--- a/core/runtime/TRTEngine.cpp
+++ b/core/runtime/TRTEngine.cpp
@@ -241,7 +241,7 @@ std::string TRTEngine::to_str() const {
               exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str()))
        << std::endl;
   }
-  ss << "  }" << std::endl;
+  ss << "  ]" << std::endl;
   ss << "  Device: " << device_info << std::endl;
   ss << "  Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
   // clang-format on
diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py
index aa1bc53a0a..443dec8869 100644
--- a/py/torch_tensorrt/_compile.py
+++ b/py/torch_tensorrt/_compile.py
@@ -397,5 +397,10 @@ def save(
                 exp_program = export(module, inputs)
                 torch.export.save(exp_program, file_path)
             else:
-                exp_program = torch.export.export(module, tuple(inputs), strict=False)
-                torch.export.save(exp_program, file_path)
+                from torch._higher_order_ops.torchbind import enable_torchbind_tracing
+
+                with enable_torchbind_tracing():
+                    exp_program = torch.export.export(
+                        module, tuple(inputs), strict=False
+                    )
+                    torch.export.save(exp_program, file_path)
diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py
index bae20ac235..cf06bc4531 100644
--- a/py/torch_tensorrt/dynamo/_exporter.py
+++ b/py/torch_tensorrt/dynamo/_exporter.py
@@ -1,3 +1,4 @@
+import copy
 import operator
 from typing import Any, Dict, Sequence, Tuple, cast
 
@@ -6,8 +7,11 @@
 from torch._subclasses.fake_tensor import FakeTensor
 from torch.export import ExportedProgram, ExportGraphSignature
 from torch.export.exported_program import (
+    CustomObjArgument,
     InputKind,
     InputSpec,
+    ModuleCallEntry,
+    ModuleCallSignature,
     OutputKind,
     OutputSpec,
     TensorArgument,
@@ -44,24 +48,27 @@ def transform(
 
     Returns an inlined torch.fx.GraphModule
     """
+    gm_export = copy.deepcopy(gm)
     # Run shape analysis
-    _, outputs_map = partitioning.run_shape_analysis(gm, inputs)
+    _, outputs_map = partitioning.run_shape_analysis(gm_export, inputs)
 
     # Inline TensorRT submodules
-    inline_trt_modules(gm, outputs_map)
+    inline_trt_modules(gm_export, outputs_map)
 
     # Inline pytorch submodules
-    inline_torch_modules(gm)
+    inline_torch_modules(gm_export)
 
     # Clean the graph
-    gm.delete_all_unused_submodules()
-    gm.graph.eliminate_dead_code()
-    gm.graph.lint()
+    gm_export.delete_all_unused_submodules()
+    gm_export.graph.eliminate_dead_code()
+    gm_export.graph.lint()
 
-    return gm
+    return gm_export
 
 
-def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule:
+def lift(
+    gm: torch.fx.GraphModule, graph_signature: Any
+) -> Tuple[torch.fx.GraphModule, ExportGraphSignature, Dict[str, Any], Dict[str, Any]]:
     """
     Given an unlifted fx.GraphModule, lift all parameters, buffers into placeholders.
     Arguments:
@@ -75,6 +82,7 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
     # exp_program.state_dict contains parameters and buffers whereas a graph_module's state_dict
     # has all parameters registered as torch.tensors.
     state_dict = gm.state_dict()
+    constants = {}
 
     fake_mode = detect_fake_mode(
         tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
@@ -89,52 +97,68 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
             break
 
     # At first the user_inputs are only present in the graph_signature.input_specs and hence non_user_input_idx=0
-    # The input_specs should be of the form [params, buffers, constant_tensors, user_inputs]
+    # The input_specs should be of the form [params, buffers, constant_tensors, custom_obj, user_inputs]
     non_user_input_idx = 0
     for node in gm.graph.nodes:
         if node.op == "get_attr":
-            if node.target not in state_dict:
-                raise ValueError(
-                    f"The get_attr node : {node.name} with target: {node.target} value could not be found in state_dict. Please check the input exported_program's graphmodule parameters."
-                )
 
-            constant_tensor = state_dict[node.target]
-            input_kind = InputKind.CONSTANT_TENSOR
+            lift_val = None
+            input_kind = None
 
-            # state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively.
-            for name, _ in gm.named_parameters():
-                if node.target == name:
-                    input_kind = InputKind.PARAMETER
-                    state_dict[name] = torch.nn.Parameter(state_dict[name])
-                    break
-            for name, _ in gm.named_buffers():
-                if node.target == name:
-                    input_kind = InputKind.BUFFER
-                    break
+            if node.target not in state_dict:
+                constants[node.target] = getattr(gm, node.target)
+                input_kind = InputKind.CUSTOM_OBJ
+                lift_val = constants[node.target]
+            else:
+                lift_val = state_dict[node.target]
+
+                input_kind = InputKind.CONSTANT_TENSOR
+
+                # state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively.
+                for name, _ in gm.named_parameters():
+                    if node.target == name:
+                        input_kind = InputKind.PARAMETER
+                        state_dict[name] = torch.nn.Parameter(state_dict[name])
+                        break
+                for name, _ in gm.named_buffers():
+                    if node.target == name:
+                        input_kind = InputKind.BUFFER
+                        break
+
+            assert lift_val is not None and input_kind is not None
 
             # Replace get_attr nodes with placeholder nodes and copy metadata.
             with gm.graph.inserting_before(first_user_input):
-                const_placeholder_node = gm.graph.placeholder(node.target)
+                const_placeholder_node = gm.graph.placeholder(
+                    node.target.replace(".", "_")
+                )
                 # Copy the node meta into this new placeholder node
                 const_placeholder_node.meta = node.meta
-                const_placeholder_node.meta["val"] = cast(
-                    FakeTensor,
-                    torch.empty_strided(
-                        tuple(constant_tensor.shape),
-                        tuple([1] * len(constant_tensor.shape)),
-                    ),
-                )
+
+                if isinstance(lift_val, torch.Tensor):
+                    const_placeholder_node.meta["val"] = cast(
+                        FakeTensor,
+                        torch.empty_strided(
+                            tuple(lift_val.shape),
+                            tuple([1] * len(lift_val.shape)),
+                        ),
+                    )
 
                 node.replace_all_uses_with(const_placeholder_node)
                 gm.graph.erase_node(node)
 
                 # Add these parameters/buffers/constants to the existing graph signature
                 # before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
+                input_spec_arg = TensorArgument(name=const_placeholder_node.name)
+                if input_kind == InputKind.CUSTOM_OBJ:
+                    input_spec_arg = CustomObjArgument(
+                        name=const_placeholder_node.name, class_fqn=""
+                    )
                 graph_signature.input_specs.insert(
                     non_user_input_idx,
                     InputSpec(
                         kind=input_kind,
-                        arg=TensorArgument(name=const_placeholder_node.name),
+                        arg=input_spec_arg,
                         target=node.target,
                     ),
                 )
@@ -143,7 +167,7 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
     gm.graph.eliminate_dead_code()
     gm.graph.lint()
 
-    return gm, graph_signature, state_dict
+    return gm, graph_signature, state_dict, constants
 
 
 def get_duplicate_nodes(
@@ -281,18 +305,30 @@ def create_trt_exp_program(
         input_specs=input_specs, output_specs=output_specs
     )
 
+    module_call_graph = [
+        ModuleCallEntry(
+            "",
+            ModuleCallSignature(
+                inputs=[],
+                outputs=[],
+                in_spec=gm.graph._codegen.pytree_info.in_spec,
+                out_spec=gm.graph._codegen.pytree_info.out_spec,
+            ),
+        )
+    ]
+
     # Lift parameters/buffers/constants in the graph
     # torch.export serialization expects them to be lifted
-    gm, trt_graph_signature, state_dict = lift(gm, trt_graph_signature)
+    gm, trt_graph_signature, state_dict, constants = lift(gm, trt_graph_signature)
 
     trt_exp_program = ExportedProgram(
-        gm,
-        gm.graph,
-        trt_graph_signature,
-        state_dict,
-        {},
-        [],
-        [],
+        root=gm,
+        graph=gm.graph,
+        graph_signature=trt_graph_signature,
+        state_dict=state_dict,
+        range_constraints={},
+        module_call_graph=module_call_graph,
+        constants=constants,
     )
 
     return trt_exp_program
@@ -319,9 +355,13 @@ def inline_trt_modules(
         num_outputs = len(outputs_map[trt_module_node.name])
         # Insert a call_function node to perform inference on TRT engine
         with gm.graph.inserting_before(trt_module_node):
+            engine_name = f"{name}_engine"
+            setattr(gm, engine_name, trt_module.engine)
+            engine_node = gm.graph.get_attr(engine_name)
+
             trt_node = gm.graph.call_function(
                 torch.ops.tensorrt.execute_engine.default,
-                (trt_module_node.args, trt_module.engine),
+                (trt_module_node.args, engine_node),
             )
             trt_node.meta["val"] = []
             assert num_outputs > 0
@@ -337,6 +377,13 @@ def inline_trt_modules(
                     )
                 )
 
+            # meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties)
+            # Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but
+            # for custom object nodes, it should be CustomObjArgument
+            engine_node.meta["val"] = CustomObjArgument(
+                name=engine_node.name, class_fqn=""
+            )
+
         if num_outputs == 1:
             # Insert getitem nodes as outputs (for export serialization to work)
             with gm.graph.inserting_after(trt_node):
diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py
index 0905c5e859..40fa01c2c9 100644
--- a/tests/py/dynamo/models/test_export_serde.py
+++ b/tests/py/dynamo/models/test_export_serde.py
@@ -146,7 +146,6 @@ def forward(self, x):
             )
         ],
         "ir": ir,
-        "debug": True,
     }
 
     exp_program = torchtrt.dynamo.trace(model, **compile_spec)
@@ -306,6 +305,7 @@ def forward(self, x):
 
     exp_program = torchtrt.dynamo.trace(model, **compile_spec)
     trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
+
     torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
     deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
     deser_trt_module = deser_trt_exp_program.module()

From b191d62bafac9740657cb9dc67ccafb213d4914c Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Mon, 18 Mar 2024 16:29:04 -0700
Subject: [PATCH 27/33] chore: update to 2.3 rc build

---
 py/requirements.txt | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/py/requirements.txt b/py/requirements.txt
index cd52d32436..419c325653 100644
--- a/py/requirements.txt
+++ b/py/requirements.txt
@@ -1,9 +1,9 @@
 numpy
 packaging
 pybind11==2.6.2
---extra-index-url https://download.pytorch.org/whl/nightly/cu121
-torch>=2.3.0.dev,<2.4.0
-torchvision>=0.18.0.dev,<0.19.0
+--index-url https://download.pytorch.org/whl/test/cu121
+torch>=2.3.0,<2.4.0
+torchvision>=0.18.0,<0.19.0
 --extra-index-url https://pypi.ngc.nvidia.com
 tensorrt==8.6.1
 pyyaml

From 5f34d4fe7231167e91fae32581fabc420c79904b Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 19 Mar 2024 13:36:37 -0700
Subject: [PATCH 28/33] chore: minor fixes

---
 py/torch_tensorrt/dynamo/conversion/converter_utils.py | 2 +-
 tests/py/dynamo/models/test_dyn_models.py              | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py
index 7d5e59367b..f9d14917f1 100644
--- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py
+++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py
@@ -270,7 +270,7 @@ def create_constant(
     """
     numpy_value = to_numpy(value, dtype)
     constant = ctx.net.add_constant(
-        trt.Dims() if isinstance(value, (int, float, bool)) else value.shape,
+        (1,) if isinstance(value, (int, float, bool)) else value.shape,
         numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value,
     )
     constant.name = name
diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py
index f9f1d02c02..e4675b41be 100644
--- a/tests/py/dynamo/models/test_dyn_models.py
+++ b/tests/py/dynamo/models/test_dyn_models.py
@@ -152,7 +152,7 @@ def forward(self, x):
     }
 
     trt_mod = torchtrt.compile(model, **compile_spec)
-    cos_sim = cosine_similarity(model(input), trt_mod(input)[0])
+    cos_sim = cosine_similarity(model(input), trt_mod(input))
     assertions.assertTrue(
         cos_sim > COSINE_THRESHOLD,
         msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",

From 8674a3c437d767c6e2b09db3a574af9a99c318d3 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 19 Mar 2024 16:18:17 -0700
Subject: [PATCH 29/33] chore: minor fixes

---
 py/torch_tensorrt/dynamo/_exporter.py         |  1 +
 .../lowering/test_aten_lowering_passes.py     | 12 ++++++++----
 tests/py/dynamo/models/test_models_export.py  | 19 ++++++++++---------
 tests/py/dynamo/testing_utilities.py          |  1 +
 4 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py
index cf06bc4531..d4a9fd3584 100644
--- a/py/torch_tensorrt/dynamo/_exporter.py
+++ b/py/torch_tensorrt/dynamo/_exporter.py
@@ -129,6 +129,7 @@ def lift(
 
             # Replace get_attr nodes with placeholder nodes and copy metadata.
             with gm.graph.inserting_before(first_user_input):
+                # Ensure name doesn't contain period as it is used for submodules
                 const_placeholder_node = gm.graph.placeholder(
                     node.target.replace(".", "_")
                 )
diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py
index bc75a8aa3d..3afc5e5923 100644
--- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py
+++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py
@@ -1,9 +1,12 @@
 import torch
-from torch.testing._internal.common_utils import TestCase, run_tests
-
 import torch_tensorrt
+from torch.testing._internal.common_utils import TestCase, run_tests
 
-from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
+from ..testing_utilities import (
+    DECIMALS_OF_AGREEMENT,
+    DECIMALS_OF_AGREEMENT_3,
+    lower_graph_testing,
+)
 
 
 class TestInputAsOutput(TestCase):
@@ -444,10 +447,11 @@ def forward(self, input, weight, bias):
         max_diff = float(
             torch.max(torch.abs(optimized_model_results - torch_model_results))
         )
+
         self.assertAlmostEqual(
             max_diff,
             0,
-            DECIMALS_OF_AGREEMENT,
+            DECIMALS_OF_AGREEMENT_3,
             msg=f"Linear TRT outputs don't match with the original model.",
         )
         torch._dynamo.reset()
diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py
index fd7b40592a..bc8bf12c95 100644
--- a/tests/py/dynamo/models/test_models_export.py
+++ b/tests/py/dynamo/models/test_models_export.py
@@ -159,11 +159,11 @@ def test_bert_base_uncased(ir):
     model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
     input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
     input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
-    model = (
-        transformers_trace(model, input_names=["input_ids", "attention_mask"])
-        .eval()
-        .cuda()
-    )
+    # model = (
+    #     transformers_trace(model, input_names=["input_ids", "attention_mask"])
+    #     .eval()
+    #     .cuda()
+    # )
 
     compile_spec = {
         "inputs": [
@@ -182,8 +182,8 @@ def test_bert_base_uncased(ir):
         "enabled_precisions": {torch.float},
         "truncate_long_and_double": True,
         "ir": ir,
-        "min_block_size": 10,
-        "torch_executed_ops": {"torch.ops.aten.gelu.default"},
+        "min_block_size": 15,
+        "debug": True,
     }
     trt_mod = torchtrt.compile(model, **compile_spec)
     model_outputs = model(input, input2)
@@ -192,8 +192,9 @@ def test_bert_base_uncased(ir):
         len(model_outputs) == len(trt_model_outputs),
         msg=f"Number of outputs for BERT model compilation is different with Pytorch {len(model_outputs)} and TensorRT {len(trt_model_outputs)}. Please check the compilation.",
     )
-    for index, key in enumerate(model_outputs):
-        out, trt_out = model_outputs[key], trt_model_outputs[index]
+
+    for key, _ in model_outputs.items():
+        out, trt_out = model_outputs[key], trt_model_outputs[key]
         cos_sim = cosine_similarity(out, trt_out)
         assertions.assertTrue(
             cos_sim > COSINE_THRESHOLD,
diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py
index 742b9fc1a3..c815d2fde4 100644
--- a/tests/py/dynamo/testing_utilities.py
+++ b/tests/py/dynamo/testing_utilities.py
@@ -14,6 +14,7 @@
 )
 
 DECIMALS_OF_AGREEMENT = 4
+DECIMALS_OF_AGREEMENT_3 = 3
 
 
 def fx_dynamo_testing_backend(

From f4e8fe9bc3f114dd8da3760dca510f78a8f58a0d Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 19 Mar 2024 17:04:38 -0700
Subject: [PATCH 30/33] chore: remove duplicate bert test case

---
 tests/py/dynamo/models/test_models_export.py | 53 --------------------
 1 file changed, 53 deletions(-)

diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py
index bc8bf12c95..4d0f4e2e7f 100644
--- a/tests/py/dynamo/models/test_models_export.py
+++ b/tests/py/dynamo/models/test_models_export.py
@@ -105,55 +105,6 @@ def test_efficientnet_b0(ir):
     torch._dynamo.reset()
 
 
-@pytest.mark.unit
-def test_bert_base_uncased(ir):
-    model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
-    input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
-    input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
-    model = (
-        transformers_trace(model, input_names=["input_ids", "attention_mask"])
-        .eval()
-        .cuda()
-    )
-
-    compile_spec = {
-        "inputs": [
-            torchtrt.Input(
-                input.shape,
-                dtype=input.dtype,
-                format=torch.contiguous_format,
-            ),
-            torchtrt.Input(
-                input.shape,
-                dtype=input.dtype,
-                format=torch.contiguous_format,
-            ),
-        ],
-        "device": torchtrt.Device("cuda:0"),
-        "enabled_precisions": {torch.float},
-        "truncate_long_and_double": True,
-        "ir": ir,
-        "min_block_size": 10,
-    }
-    trt_mod = torchtrt.compile(model, **compile_spec)
-    model_outputs = model(input, input2)
-    trt_model_outputs = trt_mod(input, input2)
-    assertions.assertTrue(
-        len(model_outputs) == len(trt_model_outputs),
-        msg=f"Number of outputs for BERT model compilation is different with Pytorch {len(model_outputs)} and TensorRT {len(trt_model_outputs)}. Please check the compilation.",
-    )
-    for index, key in enumerate(model_outputs):
-        out, trt_out = model_outputs[key], trt_model_outputs[index]
-        cos_sim = cosine_similarity(out, trt_out)
-        assertions.assertTrue(
-            cos_sim > COSINE_THRESHOLD,
-            msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
-        )
-
-    # Clean up model env
-    torch._dynamo.reset()
-
-
 @pytest.mark.unit
 def test_bert_base_uncased(ir):
     model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
@@ -183,7 +134,6 @@ def test_bert_base_uncased(ir):
         "truncate_long_and_double": True,
         "ir": ir,
         "min_block_size": 15,
-        "debug": True,
     }
     trt_mod = torchtrt.compile(model, **compile_spec)
     model_outputs = model(input, input2)
@@ -204,9 +154,6 @@ def test_bert_base_uncased(ir):
     # Clean up model env
     torch._dynamo.reset()
 
-    with torch.no_grad():
-        torch.cuda.empty_cache()
-
 
 @pytest.mark.unit
 def test_resnet18_half(ir):

From 4ae6ab95a98f5b1f644c3bfd3824ac76df442dc7 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 19 Mar 2024 17:05:51 -0700
Subject: [PATCH 31/33] chore: remove comments

---
 tests/py/dynamo/models/test_models_export.py | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py
index 4d0f4e2e7f..84f6bf7a36 100644
--- a/tests/py/dynamo/models/test_models_export.py
+++ b/tests/py/dynamo/models/test_models_export.py
@@ -110,11 +110,6 @@ def test_bert_base_uncased(ir):
     model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
     input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
     input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
-    # model = (
-    #     transformers_trace(model, input_names=["input_ids", "attention_mask"])
-    #     .eval()
-    #     .cuda()
-    # )
 
     compile_spec = {
         "inputs": [

From 78f7eb550603134666b4e8f646d5dee43358e699 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Tue, 2 Apr 2024 05:45:27 -0700
Subject: [PATCH 32/33] chore: updates

---
 .../dynamo/conversion/aten_ops_converters.py  |  5 ++
 .../dynamo/conversion/impl/slice/ops.py       |  4 -
 .../dynamo/conversion/ops_evaluators.py       |  7 --
 .../dynamo/lowering/passes/view_to_reshape.py |  1 -
 .../dynamo/partitioning/__init__.py           |  1 -
 .../dynamo/partitioning/common.py             | 75 -------------------
 tests/py/dynamo/models/test_dyn_models.py     |  2 +-
 7 files changed, 6 insertions(+), 89 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index 3f547f9d40..72998e1917 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -392,6 +392,11 @@ def aten_ops_sigmoid(
     )
 
 
+@enforce_tensor_types(
+    {
+        0: (TRTTensor,),
+    }
+)
 @dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
 def aten_ops_symsize_int(
     ctx: ConversionContext,
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
index 70badd796c..e578ebee54 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
@@ -98,11 +98,7 @@ def expand(
     stride = tuple(
         [int(i == o) for i, o in zip(input_tensor_shape, shape)]
     )  # stride == 1 if dimensions match, 0 otherwise
-    # layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
 
-    # set_layer_name(layer, target, name, source_ir)
-    # return layer.get_output(0)
-    breakpoint()
     expand_output = slice(ctx, target, source_ir, name, input_t, start, shape, stride)
     return expand_output
 
diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
index b35f198028..5ddd8c5e3a 100644
--- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
+++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
@@ -46,14 +46,7 @@ def aten_ops_arange_start_step(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    # breakpoint()
     fill_layer = ctx.net.add_fill(trt.Dims(), trt.FillOperation.LINSPACE)
     fill_layer.set_input(0, args[1])
     fill_layer.set_output_type(0, trt.DataType.INT32)
-    # fill_layer.set_input(1, 0)
-    # fill_layer.set_input(2, 1)
-    # start_tensor = get_trt_tensor(ctx, 0, "_start_tensor")
-    # fill_layer.set_input(1, start_tensor)
-    # delta_tensor = get_trt_tensor(ctx, torch.tensor([0], dtype=torch.int32), "_delta_tensor")
-    # fill_layer.set_input(2, delta_tensor)
     return fill_layer.get_output(0)
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
index db0346348b..b2da354122 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
@@ -28,7 +28,6 @@ def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
 
     # Store metadata of the orig_op
     metadata = get_metadata(gm, orig_op)
-    # breakpoint()
 
     if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
         gm = clean_up_graph_after_modifications(gm)
diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py
index 5e5406e67c..25487da065 100644
--- a/py/torch_tensorrt/dynamo/partitioning/__init__.py
+++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py
@@ -3,6 +3,5 @@
 from .common import (
     construct_submodule_inputs,
     get_graph_converter_support,
-    get_submod_inputs,
     run_shape_analysis,
 )
diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py
index 109bda275f..270973c8c3 100644
--- a/py/torch_tensorrt/dynamo/partitioning/common.py
+++ b/py/torch_tensorrt/dynamo/partitioning/common.py
@@ -4,7 +4,6 @@
 import torch
 from torch_tensorrt._Input import Input
 from torch_tensorrt.dynamo._defaults import DEBUG
-from torch_tensorrt.dynamo.utils import get_torch_inputs, input_is_dynamic
 
 logger = logging.getLogger(__name__)
 
@@ -135,80 +134,6 @@ def get_submodule_io(
     return submod_inputs_shape_map, submod_outputs_shape_map
 
 
-def get_submod_inputs(
-    mod: torch.fx.GraphModule,
-    submod: torch.fx.GraphModule,
-    inputs: Sequence[Input],
-    device: torch.device,
-) -> Optional[Sequence[torch.Tensor]]:
-    """Helper function to get inputs to a Torch submodule
-
-    Args:
-        mod: Parent FX GraphModule
-        submod: Child FX GraphModule
-        inputs: Sample inputs to parent module
-    Returns:
-        Sequence of Tensors representing inputs to child module
-    """
-    acc_inputs: Any = None
-
-    def get_input(self: Any, inputs: Sequence[torch.Tensor]) -> None:
-        nonlocal acc_inputs
-        acc_inputs = inputs
-        return
-
-    # Register a hook to capture submodule input
-    handle = submod.register_forward_pre_hook(get_input)
-    # Iterate over min, opt, max shapes for dynamic inputs
-    inputs_map = {}
-
-    if input_is_dynamic(inputs):
-        for mode in ["min_shape", "opt_shape", "max_shape"]:
-            torch_inputs = get_torch_inputs(inputs, device, mode)
-            mod(*torch_inputs)
-            inputs_map[mode] = acc_inputs
-        handle.remove()
-    else:
-        torch_inputs = get_torch_inputs(inputs, device)
-        mod(*torch_inputs)
-        handle.remove()
-        assert isinstance(acc_inputs, tuple)
-        return [
-            Input(shape=acc_input.shape, dtype=acc_input.dtype)
-            for acc_input in acc_inputs
-        ]
-
-    num_submodule_inputs = (
-        len(inputs_map["min_shape"]) if inputs_map["min_shape"] else 0
-    )
-    submodule_inputs = []
-    for idx in range(num_submodule_inputs):
-        if not isinstance(inputs_map["min_shape"][idx], torch.Tensor):
-            input_val = torch.tensor(inputs_map["opt_shape"][idx], dtype=torch.int32)
-            logger.warning(
-                "Detected a zero-dimensional input. This might be a shape tensor input which is not currently supported. This might result in undefined behavior"
-            )
-            submodule_inputs.append(
-                Input(
-                    shape=[1],
-                    torch_tensor=input_val,
-                    dtype=input_val.dtype,
-                )
-            )
-        else:
-            submodule_inputs.append(
-                Input(
-                    min_shape=inputs_map["min_shape"][idx].shape,
-                    opt_shape=inputs_map["opt_shape"][idx].shape,
-                    max_shape=inputs_map["max_shape"][idx].shape,
-                    torch_tensor=inputs_map["opt_shape"][idx],
-                    dtype=inputs_map["opt_shape"][idx].dtype,
-                )
-            )
-
-    return submodule_inputs
-
-
 def get_graph_converter_support(
     graph_module: torch.fx.GraphModule,
     verbose: bool = DEBUG,
diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py
index e4675b41be..822ee468a9 100644
--- a/tests/py/dynamo/models/test_dyn_models.py
+++ b/tests/py/dynamo/models/test_dyn_models.py
@@ -64,7 +64,7 @@ def forward(self, x):
 @pytest.mark.unit
 def test_base_dynamic_fallback(ir):
     """
-    Tests the model (which is fully convertible) with dynamic shapes
+    Tests the model with dynamic shapes where torch.abs op is forced to run in PyTorch
     """
 
     class MyModule(torch.nn.Module):

From e9b649d2466b87bd3200fe2bbd26b2de61d57f66 Mon Sep 17 00:00:00 2001
From: Dheeraj Peri <peri.dheeraj@gmail.com>
Date: Thu, 4 Apr 2024 21:46:19 -0700
Subject: [PATCH 33/33] chore: revert changes

---
 .../dynamo/conversion/impl/slice/base.py      | 19 +------------------
 .../dynamo/conversion/impl/slice/ops.py       |  5 +++--
 .../dynamo/conversion/ops_evaluators.py       |  7 ++-----
 3 files changed, 6 insertions(+), 25 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py
index 64225227aa..018ac63b8c 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py
@@ -3,7 +3,6 @@
 from torch.fx.node import Target
 from torch_tensorrt.dynamo._SourceIR import SourceIR
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
 from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
 from torch_tensorrt.fx.converters.converter_utils import (
     has_dynamic_shape,
@@ -12,22 +11,6 @@
 from torch_tensorrt.fx.types import Shape, TRTTensor
 
 
-def get_dynamic_shape(ctx, target, source_ir, name, shape, input):
-    trt_shape = []
-    shape = input.shape
-    for i, s in enumerate(shape):
-        if isinstance(s, TRTTensor):
-            trt_shape.append(s)
-        else:
-            a = get_trt_tensor(ctx, s, f"{name}_{i}")
-            trt_shape.append(a)
-    shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
-    shape_layer.axis = 0
-    shape_layer.name = f"{name}_output_shape"
-
-    return shape_layer.get_output(0)
-
-
 def slice(
     ctx: ConversionContext,
     target: Target,
@@ -40,7 +23,7 @@ def slice(
 ) -> TRTTensor:
     dynamic_shape = has_dynamic_shape(input.shape)
     if dynamic_shape:
-        shape = get_dynamic_shape(ctx, target, source_ir, name, shape, input)
+        shape = get_shape_with_dynamic_shape(ctx, target, source_ir, name, shape, input)
     layer = ctx.net.add_slice(
         input,
         start=start,
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
index e578ebee54..61d71fe9a0 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
@@ -99,8 +99,9 @@ def expand(
         [int(i == o) for i, o in zip(input_tensor_shape, shape)]
     )  # stride == 1 if dimensions match, 0 otherwise
 
-    expand_output = slice(ctx, target, source_ir, name, input_t, start, shape, stride)
-    return expand_output
+    layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
+    set_layer_name(layer, target, name, source_ir)
+    return layer.get_output(0)
 
 
 def chunk(
diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
index 5ddd8c5e3a..f83e0e5008 100644
--- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
+++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
@@ -2,7 +2,7 @@
 import operator
 from typing import Dict, Sequence, Tuple, Union
 
-import tensorrt as trt
+import numpy as np
 import torch
 from torch.fx.node import Argument, Node, Target
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -46,7 +46,4 @@ def aten_ops_arange_start_step(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    fill_layer = ctx.net.add_fill(trt.Dims(), trt.FillOperation.LINSPACE)
-    fill_layer.set_input(0, args[1])
-    fill_layer.set_output_type(0, trt.DataType.INT32)
-    return fill_layer.get_output(0)
+    return np.arange(*args)