diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
index 604eda8c96..d6e12f5215 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
@@ -11,6 +11,7 @@
 from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
 from .repair_input_as_output import repair_input_as_output
 from .replace_max_pool_with_indices import replace_max_pool_with_indices
+from .view_to_reshape import view_to_reshape
 
 ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
     [
@@ -21,6 +22,7 @@
         lower_linear,
         fuse_prims_broadcast,
         replace_max_pool_with_indices,
+        view_to_reshape,
     ]
 )
 
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
new file mode 100644
index 0000000000..efc836814f
--- /dev/null
+++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
@@ -0,0 +1,41 @@
+import logging
+from typing import Callable, List, Sequence, Tuple
+
+import torch
+from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
+    clean_up_graph_after_modifications,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def view_to_reshape(
+    gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
+) -> torch.fx.GraphModule:
+    """Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
+    orig, replacement = view_replacement()
+
+    if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
+        gm = clean_up_graph_after_modifications(gm)
+        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
+
+    return gm
+
+
+def view_replacement() -> (
+    Tuple[
+        torch.fx.GraphModule,
+        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+    ]
+):
+    """Constructs the original and replacement functions for view"""
+
+    # Original graph
+    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
+        return torch.ops.aten.view.default(input, shape)
+
+    # Replacement graph
+    def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
+        return torch.ops.aten.reshape.default(input, shape)
+
+    return orig, replacement
diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py
index 184e7c9c54..11b989bd90 100644
--- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py
+++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py
@@ -1,7 +1,8 @@
 import torch
-import torch_tensorrt
 from torch.testing._internal.common_utils import TestCase, run_tests
 
+import torch_tensorrt
+
 from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
 
 
@@ -375,5 +376,70 @@ def forward(self, input, weight, bias):
         torch._dynamo.reset()
 
 
+class TestLowerViewToReshape(TestCase):
+    def test_view_to_reshape(self):
+        class ViewToReshape(torch.nn.Module):
+            def forward(self, input):
+                out = torch.ops.aten.view.default(input, (1, 1, -1))
+                return out
+
+        inputs = [
+            torch.rand((3, 4, 5, 32)).cuda(),
+        ]
+
+        fx_graph = torch.fx.symbolic_trace(ViewToReshape())
+        expected_ops = {torch.ops.aten.reshape.default}
+        unexpected_ops = {
+            torch.ops.aten.view.default,
+        }
+
+        unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
+            fx_graph,
+            inputs,
+            expected_ops=expected_ops,
+            unexpected_ops=unexpected_ops,
+            min_block_size=1,
+        )
+
+        self.assertEquals(
+            len(unexpected_ops_seen),
+            0,
+            f"The following unexpected ops were encountered: {unexpected_ops_seen}",
+        )
+
+        self.assertEquals(
+            len(expected_ops_unseen),
+            0,
+            f"The following expected ops were not encountered: {expected_ops_unseen}",
+        )
+        torch._dynamo.reset()
+
+        # Validate that the results between Torch and Torch-TRT are similar
+        optimized_model = torch_tensorrt.compile(
+            fx_graph,
+            "torch_compile",
+            inputs,
+            min_block_size=1,
+            pass_through_build_failures=True,
+        )
+        optimized_model_results = torch.cat(
+            [tensor.detach().cpu() for tensor in optimized_model(*inputs)]
+        )
+        torch_model_results = torch.cat(
+            [tensor.detach().cpu() for tensor in fx_graph(*inputs)]
+        )
+
+        max_diff = float(
+            torch.max(torch.abs(optimized_model_results - torch_model_results))
+        )
+        self.assertAlmostEqual(
+            max_diff,
+            0,
+            DECIMALS_OF_AGREEMENT,
+            msg=f"ViewToReshape TRT outputs don't match with the original model.",
+        )
+        torch._dynamo.reset()
+
+
 if __name__ == "__main__":
     run_tests()