diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py
index b321eabcb2..ae93bf344a 100644
--- a/py/torch_tensorrt/dynamo/_compiler.py
+++ b/py/torch_tensorrt/dynamo/_compiler.py
@@ -189,10 +189,10 @@ def compile(
     )
     gm = exported_program.module()
     logger.debug("Input graph: " + str(gm.graph))
-
     # Apply lowering on the graph module
     torch_inputs = get_torch_inputs(inputs, device)
     gm = apply_lowering_passes(gm, torch_inputs)
+
     logger.debug("Lowered Input graph: " + str(gm.graph))
 
     enabled_precisions = set(enabled_precisions)
@@ -308,6 +308,24 @@ def compile_module(
             f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
         )
 
+    def contains_metadata(gm: torch.fx.GraphModule) -> bool:
+        for node in gm.graph.nodes:
+            if node.op != "output" and (not node.meta) and "val" not in node.meta:
+                logger.warning(
+                    f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
+                )
+                return False
+        return True
+
+    # Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
+    if not contains_metadata(gm):
+        from torch._inductor.compile_fx import fake_tensor_prop
+
+        torch_inputs = get_torch_inputs(sample_inputs, settings.device)
+        with torch.no_grad():
+            # This fails if the module has data-dependent shape operators.
+            fake_tensor_prop(gm, torch_inputs)
+
     # Partition module into components that can be TRT-accelerated
     fast_partitioner_failed = False
 
@@ -366,12 +384,7 @@ def compile_module(
         )
 
         # Get the submodule inputs for min, opt, max shapes of the graph inputs
-        submodule_inputs = partitioning.get_submod_inputs(
-            partitioned_module,
-            submodule,
-            sample_inputs,
-            to_torch_device(settings.device),
-        )
+        submodule_inputs = partitioning.construct_submodule_inputs(submodule)
 
         logger.debug(
             "Submodule name: %s\n Input shapes: %s\n %s",
diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py
index 1fa2806181..bade91c553 100644
--- a/py/torch_tensorrt/dynamo/backend/backends.py
+++ b/py/torch_tensorrt/dynamo/backend/backends.py
@@ -74,7 +74,6 @@ def _pretraced_backend(
             fake_mode, "allow_non_fake_inputs", True
         ), fake_mode:
             repair_input_aliasing(gm)
-
             # Invoke AOTAutograd to translate operators to aten
             gm = aot_export_joint_simple(
                 gm,
diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index 45949a1c8d..72998e1917 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -392,6 +392,22 @@ def aten_ops_sigmoid(
     )
 
 
+@enforce_tensor_types(
+    {
+        0: (TRTTensor,),
+    }
+)
+@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
+def aten_ops_symsize_int(
+    ctx: ConversionContext,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])
+
+
 def index_dtype_validator(node: Node) -> bool:
     index = node.args[1]
     for ind in index:
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py
index 672fc97351..63ff93b0c7 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/grid.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py
@@ -1,13 +1,11 @@
-from typing import Optional, Sequence
+from typing import Optional
 
 import tensorrt as trt
-import torch
 from torch.fx.node import Target
 from torch_tensorrt.dynamo._SourceIR import SourceIR
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
 from torch_tensorrt.fx.converters.converter_utils import set_layer_name
-from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
+from torch_tensorrt.fx.types import TRTTensor
 
 # nearest, linear, cubic
 GridSamplerInterpolationMode = {
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py
index db586be65f..dc33129d24 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/select.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py
@@ -90,7 +90,7 @@ def index(
     # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
     # If any is not this flag will be set to False
     _LOGGER.debug(
-        f"Determining whether aten.index constant-index optimization can be invoked"
+        "Determining whether aten.index constant-index optimization can be invoked"
     )
     is_numpy = all(
         isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
@@ -123,7 +123,7 @@ def index(
         return identity_layer.get_output(0)
     elif len(tensor_indices) == 1:
         indices_tensor = get_trt_tensor(
-            ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
+            ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor"
         )
         index = adv_indx_indices[0]
         _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
@@ -204,7 +204,7 @@ def index(
                 cum_adv_index = cum_adv_index + adv_index
                 multiplier = multiplier * input_shape[adv_indx_indices[i]]
             cum_adv_index = get_trt_tensor(
-                ctx, cum_adv_index, name + f"_index_sum_intermediate"
+                ctx, cum_adv_index, name + "_index_sum_intermediate"
             )
         else:
             multiplier = get_trt_tensor(
@@ -263,7 +263,7 @@ def index(
             adv_indx_count
             == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
         ):
-            _LOGGER.debug(f"The indices are continuous in this case")
+            _LOGGER.debug("The indices are continuous in this case")
             concat_tensor_reshape.append(
                 get_trt_tensor(ctx, -1, name + "_dynamic_concat")
             )
@@ -287,7 +287,7 @@ def index(
                 source_ir,
             )
             unfold_tensor = regular_index_shuffle_layer.get_output(0)
-            _LOGGER.debug(f"The tensor is unfolded now")
+            _LOGGER.debug("The tensor is unfolded now")
             _LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")
 
             # Transpose folded advanced indexed axis to its original location.
@@ -342,7 +342,7 @@ def index(
             reshape_output = unfold_advanced_shuffle_layer.get_output(0)
 
         else:
-            _LOGGER.debug(f"The indices are not continuous in this case")
+            _LOGGER.debug("The indices are not continuous in this case")
             concat_final_tensor = []
             concat_final_tensor.append(cum_adv_index_shape_tensor)
             for i in range(0, rank):
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py
index ef30b186c1..2d2481936b 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py
@@ -8,7 +8,11 @@
 from torch.fx.node import Target
 from torch_tensorrt.dynamo._SourceIR import SourceIR
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
+from torch_tensorrt.dynamo.conversion.converter_utils import (
+    get_positive_dim,
+    get_trt_tensor,
+    to_numpy,
+)
 from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
     convert_binary_elementwise,
 )
@@ -16,6 +20,33 @@
 from torch_tensorrt.fx.types import TRTTensor
 
 
+def shape(
+    ctx: ConversionContext,
+    target: Target,
+    source_ir: Optional[SourceIR],
+    name: str,
+    input_val: TRTTensor,
+    dim: int,
+) -> TRTTensor:
+    """
+    This is the general shape layer implementation in TensorRT.
+    sym_size.int ops map to addShape layer in TensorRT and returns
+    the dynamic shape of the tensor optionally taking in a dim argument.
+    """
+    shape_layer = ctx.net.add_shape(input_val)
+    input_shape = shape_layer.get_output(0)
+    set_layer_name(shape_layer, target, name + "_shape", source_ir)
+
+    n_dims = len(input_val.shape)
+    dim = get_positive_dim(dim, n_dims)
+    dim_tensor = get_trt_tensor(ctx, dim, name + "_dim")
+    gather_layer = ctx.net.add_gather(input_shape, dim_tensor, axis=0)
+    set_layer_name(gather_layer, target, name + "_gather", source_ir)
+    input_shape = gather_layer.get_output(0)
+
+    return input_shape
+
+
 def get_shape_with_dynamic_shape(
     ctx: ConversionContext,
     target: Target,
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
index 49ddb76e2c..6d848c4be3 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
@@ -3,7 +3,7 @@
 import torch_tensorrt.dynamo.conversion.impl as impl
 from torch.fx.node import Target
 from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
-from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
+from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
 from torch_tensorrt.fx.converters.converter_utils import set_layer_name
 from torch_tensorrt.fx.types import TRTTensor
 
@@ -17,7 +17,23 @@ def reshape(
     shape: Sequence[int],
 ) -> TRTTensor:
     layer = ctx.net.add_shuffle(input)
-    layer.reshape_dims = tuple(shape)
+    if all(isinstance(s, int) for s in shape):
+        layer.reshape_dims = tuple(shape)
+    else:
+        # Convert all the dimensions to trt Tensors.
+        trt_shape = []
+
+        for i, s in enumerate(shape):
+            if isinstance(s, TRTTensor):
+                trt_shape.append(s)
+            else:
+                a = get_trt_tensor(ctx, s, f"{name}_{i}")
+                trt_shape.append(a)
+        shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
+        shape_layer.axis = 0
+        shape_layer.name = f"{name}_output_shape"
+        layer.set_input(1, shape_layer.get_output(0))
+
     set_layer_name(layer, target, name, source_ir)
     return layer.get_output(0)
 
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
index 5f1db00f33..61d71fe9a0 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
@@ -69,7 +69,6 @@ def expand(
 ) -> TRTTensor:
     shape_rank = len(shape)
     initial_tensor_rank = len(input_t.shape)
-
     # If the rank of the input tensor is less than the shape's rank, pad with ones
     if initial_tensor_rank < shape_rank:
         input_t = prepend_ones(
@@ -99,6 +98,7 @@ def expand(
     stride = tuple(
         [int(i == o) for i, o in zip(input_tensor_shape, shape)]
     )  # stride == 1 if dimensions match, 0 otherwise
+
     layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
     set_layer_name(layer, target, name, source_ir)
     return layer.get_output(0)
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py
index 3313730ec3..594bb4167c 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py
@@ -29,7 +29,7 @@ def upsample(
         resize_layer.scales = [1.0, 1.0] + list(scale_factors)
     else:
         raise RuntimeError(
-            f"At least one of out_shape and scale_factors should be specified."
+            "At least one of out_shape and scale_factors should be specified."
         )
 
     # interpolate mode
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
index 31a55099c2..0ffc6d3c76 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import Any, List
 
 import torch
 
@@ -29,3 +29,24 @@ def get_tensor_placeholders(
     ]
 
     return placeholders
+
+
+def get_metadata(
+    gm: torch.fx.GraphModule, target_op: Any
+) -> List[torch._ops.OpOverload]:
+    """
+    Return the list which has the metadata of all the target_op nodes present in the graph.
+    """
+    return [node.meta for node in gm.graph.nodes if node.target == target_op]
+
+
+def set_metadata(
+    gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload]
+) -> None:
+    """
+    Return the list which has the metadata of all the target_op nodes present in the graph.
+    """
+    target_nodes = [node for node in gm.graph.nodes if node.target == target_op]
+    assert len(target_nodes) == len(metadata)
+    for idx, node in enumerate(target_nodes):
+        node.meta = metadata[idx]
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
index e2ef051f06..b2da354122 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
@@ -1,9 +1,11 @@
 import logging
-from typing import Callable, List, Sequence, Tuple
+from typing import List, Sequence
 
 import torch
 from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
     clean_up_graph_after_modifications,
+    get_metadata,
+    set_metadata,
 )
 
 logger = logging.getLogger(__name__)
@@ -13,27 +15,25 @@ def view_to_reshape(
     gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
 ) -> torch.fx.GraphModule:
     """Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
-    orig, replacement = view_replacement()
-
-    if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
-        gm = clean_up_graph_after_modifications(gm)
-        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
-
-    return gm
-
-
-def view_replacement() -> Tuple[
-    torch.fx.GraphModule,
-    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-]:
-    """Constructs the original and replacement functions for view"""
+    orig_op = torch.ops.aten.view.default
+    replacement_op = torch.ops.aten.reshape.default
 
     # Original graph
     def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
-        return torch.ops.aten.view.default(input, shape)
+        return orig_op(input, shape)
 
     # Replacement graph
     def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
-        return torch.ops.aten.reshape.default(input, shape)
+        return replacement_op(input, shape)
 
-    return orig, replacement
+    # Store metadata of the orig_op
+    metadata = get_metadata(gm, orig_op)
+
+    if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
+        gm = clean_up_graph_after_modifications(gm)
+        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
+
+    # Copy the orig_op's metadata to the replacement op
+    set_metadata(gm, replacement_op, metadata)
+
+    return gm
diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py
index 1a8cc94099..25487da065 100644
--- a/py/torch_tensorrt/dynamo/partitioning/__init__.py
+++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py
@@ -1,3 +1,7 @@
 from ._adjacency_partitioner import partition as fast_partition
 from ._global_partitioner import partition as global_partition
-from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
+from .common import (
+    construct_submodule_inputs,
+    get_graph_converter_support,
+    run_shape_analysis,
+)
diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py
index 8348738afa..270973c8c3 100644
--- a/py/torch_tensorrt/dynamo/partitioning/common.py
+++ b/py/torch_tensorrt/dynamo/partitioning/common.py
@@ -4,11 +4,99 @@
 import torch
 from torch_tensorrt._Input import Input
 from torch_tensorrt.dynamo._defaults import DEBUG
-from torch_tensorrt.dynamo.utils import get_torch_inputs, input_is_dynamic
 
 logger = logging.getLogger(__name__)
 
 
+def contains_sym_int(tensor: torch.Tensor) -> bool:
+    """
+    Returns true if the given tensor has symbolic shape.
+    """
+    for dim in tensor:
+        if isinstance(dim, torch.SymInt):
+            return True
+    return False
+
+
+def construct_dynamic_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input:
+    """
+    Constructs a torch_tensorrt.Input based on a symbolic input
+    Args:
+        input_shape: A symbolic shape / regular shape of a tensor (which can have a  mix of SymInt nodes and static values)
+    Returns:
+        A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input.
+    """
+    min_shape = []
+    opt_shape = []
+    max_shape = []
+    for dim in input_shape:
+        if isinstance(dim, torch.SymInt):
+            node = dim.node
+            expr = node.expr
+            shape_env = node.shape_env
+            var_range = shape_env.var_to_range.get(expr, None)
+            var_val = shape_env.var_to_val.get(expr, None)
+            assert var_range, var_val
+            # Torchdynamo 0/1 specialization outlier
+            if var_range.lower == 2:
+                min_shape.append(1)
+            else:
+                min_shape.append(int(var_range.lower))
+            opt_shape.append(int(var_val))
+            max_shape.append(int(var_range.upper))
+        else:
+            min_shape.append(dim)
+            opt_shape.append(dim)
+            max_shape.append(dim)
+
+    return Input(
+        min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape, dtype=input_dtype
+    )
+
+
+def get_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input:
+    """
+    Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs
+    """
+    if contains_sym_int(input_shape):
+        return construct_dynamic_input(input_shape, input_dtype)
+    else:
+        return Input(shape=input_shape, dtype=input_dtype)
+
+
+def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
+    """
+    Construct torch_tensorrt Inputs based on the module inputs.
+    The module inputs will have meta data which has the shape and dtype info
+    Args:
+        module: Input FX GraphModule
+    Returns:
+        Sequence of torch_tensorrt.Input's representing inputs to given module
+    """
+    torchtrt_inputs = []
+    module_inputs = [node for node in module.graph.nodes if node.op == "placeholder"]
+    for input in module_inputs:
+        if input.meta:
+            if "val" in input.meta:
+                input_meta = input.meta["val"]
+                input_shape = input_meta.size()
+                torchtrt_inputs.append(get_input(input_shape, input_meta.dtype))
+            elif "tensor_meta" in input.meta:
+                input_meta = input.meta["tensor_meta"]
+                input_shape = input_meta.shape
+                torchtrt_inputs.append(get_input(input_shape, input_meta.dtype))
+            else:
+                raise AssertionError(
+                    f"Input {input.name} does not contain val and tensor_meta fields in the metadata. Please ensure you have exported the graph correctly"
+                )
+        else:
+            raise AssertionError(
+                f"Input {input.name} does not contain metadata. Please ensure you have exported the graph correctly"
+            )
+
+    return torchtrt_inputs
+
+
 def run_shape_analysis(
     parent_module: torch.fx.GraphModule, inputs: Sequence[Input]
 ) -> Tuple[Dict[Any, Sequence[Any]], Dict[Any, Sequence[Any]]]:
@@ -46,80 +134,6 @@ def get_submodule_io(
     return submod_inputs_shape_map, submod_outputs_shape_map
 
 
-def get_submod_inputs(
-    mod: torch.fx.GraphModule,
-    submod: torch.fx.GraphModule,
-    inputs: Sequence[Input],
-    device: torch.device,
-) -> Optional[Sequence[torch.Tensor]]:
-    """Helper function to get inputs to a Torch submodule
-
-    Args:
-        mod: Parent FX GraphModule
-        submod: Child FX GraphModule
-        inputs: Sample inputs to parent module
-    Returns:
-        Sequence of Tensors representing inputs to child module
-    """
-    acc_inputs: Any = None
-
-    def get_input(self: Any, inputs: Sequence[torch.Tensor]) -> None:
-        nonlocal acc_inputs
-        acc_inputs = inputs
-        return
-
-    # Register a hook to capture submodule input
-    handle = submod.register_forward_pre_hook(get_input)
-    # Iterate over min, opt, max shapes for dynamic inputs
-    inputs_map = {}
-
-    if input_is_dynamic(inputs):
-        for mode in ["min_shape", "opt_shape", "max_shape"]:
-            torch_inputs = get_torch_inputs(inputs, device, mode)
-            mod(*torch_inputs)
-            inputs_map[mode] = acc_inputs
-        handle.remove()
-    else:
-        torch_inputs = get_torch_inputs(inputs, device)
-        mod(*torch_inputs)
-        handle.remove()
-        assert isinstance(acc_inputs, tuple)
-        return [
-            Input(shape=acc_input.shape, dtype=acc_input.dtype)
-            for acc_input in acc_inputs
-        ]
-
-    num_submodule_inputs = (
-        len(inputs_map["min_shape"]) if inputs_map["min_shape"] else 0
-    )
-    submodule_inputs = []
-    for idx in range(num_submodule_inputs):
-        if not isinstance(inputs_map["min_shape"][idx], torch.Tensor):
-            input_val = torch.tensor(inputs_map["opt_shape"][idx], dtype=torch.int32)
-            logger.warning(
-                "Detected a zero-dimensional input. This might be a shape tensor input which is not currently supported. This might result in undefined behavior"
-            )
-            submodule_inputs.append(
-                Input(
-                    shape=[1],
-                    torch_tensor=input_val,
-                    dtype=input_val.dtype,
-                )
-            )
-        else:
-            submodule_inputs.append(
-                Input(
-                    min_shape=inputs_map["min_shape"][idx].shape,
-                    opt_shape=inputs_map["opt_shape"][idx].shape,
-                    max_shape=inputs_map["max_shape"][idx].shape,
-                    torch_tensor=inputs_map["opt_shape"][idx],
-                    dtype=inputs_map["opt_shape"][idx].dtype,
-                )
-            )
-
-    return submodule_inputs
-
-
 def get_graph_converter_support(
     graph_module: torch.fx.GraphModule,
     verbose: bool = DEBUG,
diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py
index 22590fe73d..549636b3c7 100644
--- a/py/torch_tensorrt/dynamo/utils.py
+++ b/py/torch_tensorrt/dynamo/utils.py
@@ -88,7 +88,8 @@ def get_torch_inputs(
             if isinstance(input, Input)
         ]
     return [
-        input.torch_tensor.to(device) for input in inputs if isinstance(input, Input)
+        input.torch_tensor.to(device) if isinstance(input, Input) else input
+        for input in inputs
     ]
 
 
diff --git a/tests/py/dynamo/conversion/test_sym_size.py b/tests/py/dynamo/conversion/test_sym_size.py
new file mode 100644
index 0000000000..35bf75a509
--- /dev/null
+++ b/tests/py/dynamo/conversion/test_sym_size.py
@@ -0,0 +1,44 @@
+import torch
+import torch.nn as nn
+from parameterized import parameterized
+from torch.testing._internal.common_utils import run_tests
+
+from .harness import DispatchTestCase
+
+
+class TestSymSizeConverter(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ((3, 2, 4),),
+        ]
+    )
+    def test_sym_size_batch(self, input_shape):
+        class BatchDim(nn.Module):
+            def forward(self, x):
+                return torch.ops.aten.sym_size.int(x, 0)
+
+        inputs = [torch.randn(*input_shape)]
+        self.run_test(
+            BatchDim(),
+            inputs,
+        )
+
+    @parameterized.expand(
+        [
+            ((3, 2, 4),),
+        ]
+    )
+    def test_sym_size_non_batch(self, input_shape):
+        class NonBatchDim(nn.Module):
+            def forward(self, x):
+                return torch.ops.aten.sym_size.int(x, 1)
+
+        inputs = [torch.randn(*input_shape)]
+        self.run_test(
+            NonBatchDim(),
+            inputs,
+        )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py
index ceb4a6dd2c..822ee468a9 100644
--- a/tests/py/dynamo/models/test_dyn_models.py
+++ b/tests/py/dynamo/models/test_dyn_models.py
@@ -3,9 +3,8 @@
 import pytest
 import timm
 import torch
-from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
-
 import torch_tensorrt as torchtrt
+from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
 
 assertions = unittest.TestCase()
 
@@ -65,7 +64,7 @@ def forward(self, x):
 @pytest.mark.unit
 def test_base_dynamic_fallback(ir):
     """
-    Tests the model (which is fully convertible) with dynamic shapes
+    Tests the model with dynamic shapes where torch.abs op is forced to run in PyTorch
     """
 
     class MyModule(torch.nn.Module):
@@ -114,3 +113,53 @@ def forward(self, x):
 
     with torch.no_grad():
         torch.cuda.empty_cache()
+
+
+@pytest.mark.unit
+def test_view(ir):
+    """
+    Tests the model (which is fully convertible) with dynamic shapes
+    """
+
+    class MyModule(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            input_shape = x.size()
+            y = x.view(input_shape[0], -1)
+            return y
+
+    model = MyModule().eval().cuda()
+    input = torch.randn((6, 3, 4)).to("cuda")
+
+    compile_spec = {
+        "inputs": [
+            torchtrt.Input(
+                min_shape=(1, 3, 4),
+                opt_shape=(4, 3, 4),
+                max_shape=(8, 3, 4),
+                dtype=torch.float32,
+                name="x",
+            )
+        ],
+        "device": torchtrt.Device("cuda:0"),
+        "enabled_precisions": {torch.float},
+        "ir": ir,
+        "pass_through_build_failures": True,
+        "optimization_level": 1,
+        "min_block_size": 1,
+    }
+
+    trt_mod = torchtrt.compile(model, **compile_spec)
+    cos_sim = cosine_similarity(model(input), trt_mod(input))
+    assertions.assertTrue(
+        cos_sim > COSINE_THRESHOLD,
+        msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
+    )
+
+    # Clean up model env
+    torch._dynamo.reset()
+
+    with torch.no_grad():
+        torch.cuda.empty_cache()
diff --git a/versions.py b/versions.py
index 772737aab7..db418a06d2 100644
--- a/versions.py
+++ b/versions.py
@@ -1,11 +1,10 @@
-import yaml
-import re
 import os
+import re
 import subprocess
-
 from datetime import datetime
 from pathlib import Path
-from typing import List
+
+import yaml
 
 __version__ = "0.0.0"
 __cuda_version__ = "0.0"