From 96bdead71c60fe08b51812c959fa8566f5d52c24 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Wed, 22 Mar 2023 15:18:21 -0700
Subject: [PATCH 01/25] Implementation of slice and select operations

---
 .../fx/converters/aten_ops_converters.py      | 16 +++++
 .../fx/converters/converter_utils.py          | 24 +++++++
 py/torch_tensorrt/fx/converters/operator.py   | 66 ++++++++++++-------
 .../converters/aten_op/test_select_aten.py    | 56 ++++++++++++++++
 4 files changed, 138 insertions(+), 24 deletions(-)
 create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py

diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py
index 3f55ab3827..275887ae57 100644
--- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py
+++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py
@@ -572,3 +572,19 @@ def aten_ops_sigmoid(
         "input": args[0],
     }
     return add_sigmoid(network, target, kwargs_new, name)
+
+
+@tensorrt_converter(torch.ops.aten.select)
+def aten_ops_select(
+    network: TRTNetwork,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    kwargs_new = {
+        "input": args[0],
+        "dim": args[1],
+        "index": args[2],
+    }
+    return add_select(network, target.kwargs_new, name)
diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py
index ba6421701e..0b9561b6db 100644
--- a/py/torch_tensorrt/fx/converters/converter_utils.py
+++ b/py/torch_tensorrt/fx/converters/converter_utils.py
@@ -543,3 +543,27 @@ def type_cast(
     layer_i.set_output_type(0, cast_type)
     set_layer_name(layer_i, target, f"{name}_dtype_change")
     return layer_i.get_output(0)
+
+
+def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]:
+    """
+    Convert a PyTorch Tensor to a Numpy Array. If the tensor is
+    quantized it will be dequantized first.
+
+    Args:
+        tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
+
+    Returns:
+        A Numpy array.
+    """
+
+    if tensor is None:
+        return tensor
+
+    assert isinstance(
+        tensor, torch.Tensor
+    ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}"
+    if tensor.is_quantized:
+        tensor = tensor.dequantize()
+
+    return tensor.cpu().detach().contiguous().numpy()
diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index 539b766336..de9ac6bc4e 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -22,6 +22,7 @@
 from .converter_utils import prepend_ones
 from .converter_utils import has_dynamic_shape
 from .converter_utils import get_shape_with_dynamic_shape
+from .converter_utils import to_numpy
 
 from ..types import (
     Shape,
@@ -278,30 +279,6 @@ def trunc_div(
     return output
 
 
-def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]:
-    """
-    Convert a PyTorch Tensor to a Numpy Array. If the tensor is
-    quantized it will be dequantized first.
-
-    Args:
-        tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
-
-    Returns:
-        A Numpy array.
-    """
-
-    if tensor is None:
-        return tensor
-
-    assert isinstance(
-        tensor, torch.Tensor
-    ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}"
-    if tensor.is_quantized:
-        tensor = tensor.dequantize()
-
-    return tensor.cpu().detach().contiguous().numpy()
-
-
 def trt_dtype_to_torch_dtype(trt_dtype):
     table = {
         trt.bool: torch.bool,
@@ -1050,3 +1027,44 @@ def add_expand(network, target, kwargs, name):
     layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
     set_layer_name(layer, target, name)
     return layer.get_output(0)
+
+
+def add_select(network, target, kwargs, name):
+    input_val = kwargs["input"]
+    if not isinstance(input_val, TRTTensor):
+        raise RuntimeError(
+            f"slice_tensor received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
+    dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
+    dynamic_shape = has_dynamic_shape(input_val.shape)
+    if network.has_implicit_batch_dimension:
+        if dim == 0:
+            raise RuntimeError(
+                f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
+            )
+        dim = dim - 1
+    else:
+        if dynamic_shape:
+            # Check whether slice target dim is dynamic shape dim
+            assert (
+                input_val.shape[dim] != -1
+            ), "Can't select on negative shape dimension!"
+    index = kwargs[2]
+    if index >= input_val.shape[dim]:
+        raise RuntimeError(
+            f"cannot have index greater than the dimension length! {input_val.shape[dim]}"
+        )
+    output_shape = list(input_val.shape)
+    output_shape[dim] = 1
+    if dynamic_shape > 0:
+        output_shape = get_shape_with_dynamic_shape(
+            network, output_shape, input_val, target, name
+        )
+    layer = network.add_gather(input_val, dim, index)
+    out = layer.getOutput(0)
+    if len(out.shape) != 1:
+        layer = network.add_shuffle(out)
+    return layer.getOutput(0)
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py
new file mode 100644
index 0000000000..8868db2668
--- /dev/null
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py
@@ -0,0 +1,56 @@
+import unittest
+
+import torch
+import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
+from parameterized import param, parameterized
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+
+
+class TestSelectConverter(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("select_dim_index", 2, 1),
+        ]
+    )
+    def test_select(self, _, dim_test, index_test):
+        class TestModule(torch.nn.Module):
+            def __init__(self, dim, index):
+                super().__init__()
+                self.dim = dim
+                self.index = index
+
+            def forward(self, input):
+                return torch.select(input, self.dim, self.index)
+
+        input = [torch.randn(1, 3, 32)]
+        self.run_test(
+            TestModule(dim_test, index_test),
+            input,
+            expected_ops={torch.ops.aten.select},
+            test_explicit_precision=True,
+        )
+
+    # def test_select_with_dynamic_shape(self, _, dim_test, index_test):
+    #     class TestModule(torch.nn.Module):
+    #         def __init__(self, dim, index):
+    #             super().__init__()
+    #             self.dim = dim
+    #             self.index = index
+    #         def forward(self, input):
+    #             return torch.select(input, self.dim, self.index)
+
+    #     input_spec = [
+    #         InputTensorSpec(
+    #             shape=(-1, 3, 32),
+    #             dtype=torch.float32,
+    #             shape_ranges=[((1, 3, 3), (3, 3, 3), (32, 32, 32))],
+    #         ),
+    #     ]
+    #     self.run_test_with_dynamic_shape(
+    #         TestModule(dim_test, index_test), input_spec, expected_ops={torch.ops.aten.select}
+    #     )
+
+
+if __name__ == "__main__":
+    run_tests()

From c8811cd1788cf220fc4d809e34b56cc28c71a19f Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Wed, 22 Mar 2023 17:52:02 -0700
Subject: [PATCH 02/25] select test implementation

---
 .../fx/converters/aten_ops_converters.py             |  4 ++--
 py/torch_tensorrt/fx/converters/operator.py          | 12 ++++++++----
 2 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py
index 042298a471..228194fbe9 100644
--- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py
+++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py
@@ -572,7 +572,7 @@ def aten_ops_sigmoid(
     return add_sigmoid(network, target, kwargs_new, name)
 
 
-@tensorrt_converter(torch.ops.aten.select)
+@tensorrt_converter(torch.ops.aten.select.int)
 def aten_ops_select(
     network: TRTNetwork,
     target: Target,
@@ -585,4 +585,4 @@ def aten_ops_select(
         "dim": args[1],
         "index": args[2],
     }
-    return add_select(network, target.kwargs_new, name)
+    return add_select(network, target, kwargs_new, name)
diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index bdda8338d1..cf5ffe0349 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -1153,7 +1153,8 @@ def add_select(network, target, kwargs, name):
             assert (
                 input_val.shape[dim] != -1
             ), "Can't select on negative shape dimension!"
-    index = kwargs[2]
+    index = kwargs["index"]
+
     if index >= input_val.shape[dim]:
         raise RuntimeError(
             f"cannot have index greater than the dimension length! {input_val.shape[dim]}"
@@ -1164,8 +1165,11 @@ def add_select(network, target, kwargs, name):
         output_shape = get_shape_with_dynamic_shape(
             network, output_shape, input_val, target, name
         )
-    layer = network.add_gather(input_val, dim, index)
-    out = layer.getOutput(0)
+    input_shape = network.add_shape(input_val).get_output(0)
+    dim_value = torch.tensor(dim, dtype=torch.int32)
+    axis = network.add_constant(dim_value.shape, to_numpy(dim_value)).get_output(0)
+    layer = network.add_gather(input_shape, axis, index)
+    out = layer.get_output(0)
     if len(out.shape) != 1:
         layer = network.add_shuffle(out)
-    return layer.getOutput(0)
+    return layer.get_output(0)

From 4f18c0f800e2580c97a988c78d9c747b16fe850f Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Wed, 22 Mar 2023 17:52:40 -0700
Subject: [PATCH 03/25] select aten test

---
 .../fx/test/converters/aten_op/test_select_aten.py   | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py
index 8868db2668..e21ab0dd61 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py
@@ -13,21 +13,19 @@ class TestSelectConverter(DispatchTestCase):
             ("select_dim_index", 2, 1),
         ]
     )
-    def test_select(self, _, dim_test, index_test):
+    def test_select(self, _, dim, index):
         class TestModule(torch.nn.Module):
-            def __init__(self, dim, index):
+            def __init__(self):
                 super().__init__()
-                self.dim = dim
-                self.index = index
 
             def forward(self, input):
-                return torch.select(input, self.dim, self.index)
+                return torch.select(input, dim, index)
 
         input = [torch.randn(1, 3, 32)]
         self.run_test(
-            TestModule(dim_test, index_test),
+            TestModule(),
             input,
-            expected_ops={torch.ops.aten.select},
+            expected_ops={torch.ops.aten.select.int},
             test_explicit_precision=True,
         )
 

From 8303cd55669177b189f1000e769e8d55356836ef Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Fri, 24 Mar 2023 16:50:22 -0700
Subject: [PATCH 04/25] aten::matmul, aten::slice, aten::select converters

---
 .../fx/converters/aten_ops_converters.py      | 34 +++++++
 py/torch_tensorrt/fx/converters/operator.py   | 89 ++++++++++++++++++-
 .../converters/aten_op/test_matmul_aten.py    | 27 ++++++
 .../converters/aten_op/test_select_aten.py    | 73 ++++++++++-----
 .../converters/aten_op/test_slice_aten.py     | 58 ++++++++++++
 5 files changed, 255 insertions(+), 26 deletions(-)
 create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py
 create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py

diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py
index 228194fbe9..1dbfa14076 100644
--- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py
+++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py
@@ -586,3 +586,37 @@ def aten_ops_select(
         "index": args[2],
     }
     return add_select(network, target, kwargs_new, name)
+
+
+@tensorrt_converter(torch.ops.aten.slice.Tensor)
+def aten_ops_slice(
+    network: TRTNetwork,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    kwargs_new = {
+        "input": args[0],
+        "dim": args[1],
+        "start": args[2],
+        "stop": args[3],
+        "step": args[4],
+    }
+    return add_slice(network, target, kwargs_new, name)
+
+
+@tensorrt_converter(torch.ops.aten.matmul)
+@tensorrt_converter(torch.ops.aten.mm.default)
+def aten_ops_matmul(
+    network: TRTNetwork,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    kwargs_new = {
+        "input": args[0],
+        "other": args[1],
+    }
+    return add_matmul(network, target, kwargs_new, name)
diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index cf5ffe0349..5955e598f5 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -1165,11 +1165,92 @@ def add_select(network, target, kwargs, name):
         output_shape = get_shape_with_dynamic_shape(
             network, output_shape, input_val, target, name
         )
-    input_shape = network.add_shape(input_val).get_output(0)
-    dim_value = torch.tensor(dim, dtype=torch.int32)
-    axis = network.add_constant(dim_value.shape, to_numpy(dim_value)).get_output(0)
-    layer = network.add_gather(input_shape, axis, index)
+    index_value = torch.tensor(index, dtype=torch.int32)
+    indices_tensor = network.add_constant(
+        index_value.shape, to_numpy(index_value)
+    ).get_output(0)
+    layer = network.add_gather(input_val, indices_tensor, dim)
     out = layer.get_output(0)
     if len(out.shape) != 1:
         layer = network.add_shuffle(out)
     return layer.get_output(0)
+
+
+def add_slice(network, target, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, TRTTensor):
+        raise RuntimeError(
+            f"slice_tensor received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
+    dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
+    dynamic_shape = has_dynamic_shape(input_val.shape)
+    if network.has_implicit_batch_dimension:
+        if dim == 0:
+            raise RuntimeError(
+                f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
+            )
+        dim = dim - 1
+    else:
+        if dynamic_shape:
+            # Check whether slice target dim is dynamic shape dim
+            assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
+
+    start_int = cast(int, kwargs["start"])
+    stop_int = cast(int, kwargs["stop"])
+    step_int = cast(int, kwargs["step"])
+    start = [0] * len(input_val.shape)
+    start[dim] = start_int
+    stride = [1] * len(start)
+    stride[dim] = step_int
+    output_shape = list(input_val.shape)
+    output_shape[dim] = (stop_int - start_int) // step_int + 1
+
+    if dynamic_shape > 0:
+        output_shape = get_shape_with_dynamic_shape(
+            network, output_shape, input_val, target, name
+        )
+    layer = network.add_slice(
+        input_val,
+        start=start,
+        shape=[] if dynamic_shape else output_shape,
+        stride=stride,
+    )
+    if dynamic_shape:
+        layer.set_input(2, output_shape)
+    set_layer_name(layer, target, name)
+    return layer.get_output(0)
+
+
+def add_matmul(network, target, kwargs, name):
+    input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
+    other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other")
+
+    for i in [input_val, other_val]:
+        if not isinstance(i, TRTTensor):
+            raise RuntimeError(
+                f"matmul received input {i} that is not part of the TensorRT region!"
+            )
+
+    input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
+    preset_diff = 0
+
+    if len(input_val.shape) == 1:
+        preset_diff -= 1
+        input_matrix_op = trt.MatrixOperation.VECTOR
+
+    if len(other_val.shape) == 1:
+        preset_diff += 1
+        other_matrix_op = trt.MatrixOperation.VECTOR
+
+    input_val, other_val = broadcast(
+        network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff
+    )
+    layer = network.add_matrix_multiply(
+        input_val, input_matrix_op, other_val, other_matrix_op
+    )
+    set_layer_name(layer, target, name)
+    return layer.get_output(0)
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py
new file mode 100644
index 0000000000..0b0cd8d0b5
--- /dev/null
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py
@@ -0,0 +1,27 @@
+import unittest
+
+import torch
+import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
+from parameterized import param, parameterized
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+
+
+class TestMatMulConverter(DispatchTestCase):
+    def test_matmul(self):
+        class TestModule(torch.nn.Module):
+            def forward(self, x, y):
+                return torch.matmul(x, y)
+
+        inputOne = torch.randn(2, 32)
+        inputTwo = torch.randn(32, 2)
+        inputs = [inputOne, inputTwo]
+        self.run_test(
+            TestModule(),
+            inputs,
+            expected_ops={torch.ops.aten.mm.default},
+        )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py
index e21ab0dd61..1d5cb84f31 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py
@@ -7,10 +7,10 @@
 from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
 
 
-class TestSelectConverter(DispatchTestCase):
+class TestSelectConverterOne(DispatchTestCase):
     @parameterized.expand(
         [
-            ("select_dim_index", 2, 1),
+            ("select_dim_index", 1, 0),
         ]
     )
     def test_select(self, _, dim, index):
@@ -21,7 +21,7 @@ def __init__(self):
             def forward(self, input):
                 return torch.select(input, dim, index)
 
-        input = [torch.randn(1, 3, 32)]
+        input = [torch.randn(1, 2)]
         self.run_test(
             TestModule(),
             input,
@@ -29,25 +29,54 @@ def forward(self, input):
             test_explicit_precision=True,
         )
 
-    # def test_select_with_dynamic_shape(self, _, dim_test, index_test):
-    #     class TestModule(torch.nn.Module):
-    #         def __init__(self, dim, index):
-    #             super().__init__()
-    #             self.dim = dim
-    #             self.index = index
-    #         def forward(self, input):
-    #             return torch.select(input, self.dim, self.index)
-
-    #     input_spec = [
-    #         InputTensorSpec(
-    #             shape=(-1, 3, 32),
-    #             dtype=torch.float32,
-    #             shape_ranges=[((1, 3, 3), (3, 3, 3), (32, 32, 32))],
-    #         ),
-    #     ]
-    #     self.run_test_with_dynamic_shape(
-    #         TestModule(dim_test, index_test), input_spec, expected_ops={torch.ops.aten.select}
-    #     )
+
+class TestSelectConverterTwo(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("select_dim_index", 1, 0),
+        ]
+    )
+    def test_select(self, _, dim, index):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, input):
+                return torch.select(input, dim, index)
+
+        input = [torch.randn(4, 4, 4, 4)]
+        self.run_test(
+            TestModule(),
+            input,
+            expected_ops={torch.ops.aten.select.int},
+            test_explicit_precision=True,
+        )
+
+
+class TestSelectConverterWithDynamicShape(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("select_dim_index", 1, 0),
+        ]
+    )
+    def test_select_with_dynamic_shape(self, _, dim, index):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, input):
+                return torch.select(input, dim, index)
+
+        input_spec = [
+            InputTensorSpec(
+                shape=(-1, 3, 3),
+                dtype=torch.float32,
+                shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
+            ),
+        ]
+        self.run_test_with_dynamic_shape(
+            TestModule(), input_spec, expected_ops={torch.ops.aten.select.int}
+        )
 
 
 if __name__ == "__main__":
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py
new file mode 100644
index 0000000000..b018aff73e
--- /dev/null
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py
@@ -0,0 +1,58 @@
+import unittest
+
+import torch
+import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
+from parameterized import param, parameterized
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+
+
+class TestSelectConverterImplicitBatch(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("select_dim_start_stop_step", 0, 0, 7, 2),
+        ]
+    )
+    def test_slice(self, _, dim, start, stop, step):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, input):
+                out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step)
+                return out
+
+        input = [torch.randn(10, 2, 3, 1)]
+        self.run_test(
+            TestModule(),
+            input,
+            expected_ops={torch.ops.aten.slice.Tensor},
+        )
+
+
+class TestSelectConverterExplicitBatch(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("select_dim_start_stop_step", 1, 0, 7, 2),
+        ]
+    )
+    def test_slice(self, _, dim, start, stop, step):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, input):
+                out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step)
+                return out
+
+        input = [torch.randn(10, 10, 3, 1)]
+        self.run_test(
+            TestModule(),
+            input,
+            expected_ops={torch.ops.aten.slice.Tensor},
+            test_explicit_precision=True,
+        )
+
+
+if __name__ == "__main__":
+    run_tests()

From f1098f2520f30a6082597c93895c350905c8245d Mon Sep 17 00:00:00 2001
From: gs-olive <113141689+gs-olive@users.noreply.github.com>
Date: Mon, 20 Mar 2023 14:45:35 -0700
Subject: [PATCH 05/25] feat: Add sample torch.compile backend for tensorrt
 aten path

- Add backend adapted from previous `fx2trt_compiler` provided by
Dynamo
- Currently, the TRTSplitter needs work to fully support the `aten` path
- Additionally, the existing `aten` pass was reworked to exclude the
`torch._dynamo.export` call, which may be necessary here
---
 .../fx/tracer/dispatch_tracer/aten_tracer.py  |   8 +-
 .../tensorrt_dynamo_backend.py                | 107 ++++++++++++++++++
 2 files changed, 113 insertions(+), 2 deletions(-)
 create mode 100644 py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py

diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py
index e60c8f8d13..356ddc978e 100644
--- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py
+++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py
@@ -130,7 +130,7 @@ def trace(f, args, *rest):
 
 
 @req_torch_version("2.dev")
-def opt_trace(f, args, *rest):
+def opt_trace(f, args, perform_trace=True, *rest):
     """
     Optimized trace with necessary passes which re-compose some ops or replace some ops
     These passes should be general and functional purpose
@@ -148,7 +148,11 @@ def opt_trace(f, args, *rest):
         replace_inplace_ops,  # remove it once functionalization is enabled
     ]
 
-    fx_module, _ = trace(f, args)
+    if perform_trace:
+        fx_module, _ = trace(f, args)
+    else:
+        fx_module = f
+
     print(fx_module.graph)
     for passes in passes_list:
         pr: PassResult = passes(fx_module)
diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
new file mode 100644
index 0000000000..bb6e68b0b5
--- /dev/null
+++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
@@ -0,0 +1,107 @@
+import torch
+import traceback
+import torch._dynamo as td
+
+from torch_tensorrt.fx.fx2trt import (
+    InputTensorSpec,
+    TRTInterpreter,
+)
+import tensorrt as trt
+from torch_tensorrt.fx.tools.trt_splitter import (
+    TRTSplitter,
+    TRTSplitterSetting,
+)
+from torch_tensorrt.fx.tracer.dispatch_tracer import aten_tracer
+from torch_tensorrt.fx.trt_module import TRTModule
+from torch_tensorrt.fx.utils import LowerPrecision
+
+from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
+
+MAX_SPLITS_THRESHOLD = 10
+
+
+def tensorrt_backend(gm, sample_inputs):
+    # Invoke AOTAutograd to compile model
+    return aot_module_simplified(
+        gm,
+        sample_inputs,
+        fw_compiler=make_boxed_compiler(fx2trt_compiler),
+    )
+
+
+def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs):
+    model = gm
+    inputs = example_inputs
+
+    # Perform lowering pass on model
+    model = aten_tracer.opt_trace(model, inputs, perform_trace=False)
+
+    # Split out unsupported ops --> Needs rewrite/revision for ATEN
+    splitter_setting = TRTSplitterSetting()
+    splitter_setting.use_implicit_batch_dim = False
+    splitter = TRTSplitter(model, inputs, settings=splitter_setting)
+
+    splitter.node_support_preview()
+    split_mod = splitter()
+    num_piece = 0
+
+    for name, _ in split_mod.named_children():
+        print(f"Graph is split into {name}")
+        num_pieces += 1
+
+    # Select threshold above which segmentation is not beneficial and run graph in Torch
+    if num_pieces > MAX_SPLITS_THRESHOLD:
+        raise AssertionError(
+            f"The graph module is split into {num_piece} which is large than the \
+            threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module."
+        )
+
+    precision = LowerPrecision.FP32
+
+    def get_submod_inputs(mod, submod, inputs):
+        acc_inputs = None
+
+        def get_input(self, inputs):
+            nonlocal acc_inputs
+            acc_inputs = inputs
+
+        handle = submod.register_forward_pre_hook(get_input)
+        mod(*inputs)
+        handle.remove()
+        return acc_inputs
+
+    for name, _ in split_mod.named_children():
+        if "_run_on_acc" in name:
+            submod = getattr(split_mod, name)
+            acc_inputs = get_submod_inputs(split_mod, submod, inputs)
+
+            interp = TRTInterpreter(
+                submod,
+                InputTensorSpec.from_tensors(acc_inputs),
+                explicit_batch_dimension=True,
+                logger_level=trt.Logger.VERBOSE,
+            )
+            r = interp.run(
+                max_workspace_size=20 << 30,
+                lower_precision=precision,
+                profiling_verbosity=trt.ProfilingVerbosity.VERBOSE,
+            )
+
+            trt_mod = TRTModule(*r)
+
+            setattr(split_mod, name, trt_mod)
+
+    return split_mod
+
+
+@td.register_backend
+def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs):
+    try:
+        trt_compiled = fx2trt(gm, example_inputs)
+        return trt_compiled
+    except Exception:
+        traceback.print_exc()
+        print(
+            "FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead"
+        )
+        return gm.forward

From 243bf9bc340e27837a33c3d6fc3c0998381aff0a Mon Sep 17 00:00:00 2001
From: gs-olive <113141689+gs-olive@users.noreply.github.com>
Date: Tue, 21 Mar 2023 16:17:51 -0700
Subject: [PATCH 06/25] Add decompositions to aot call

---
 .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py      | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
index bb6e68b0b5..a76162b93b 100644
--- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
+++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
@@ -17,6 +17,9 @@
 
 from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
 
+from torch._inductor.decomposition import decompositions
+
+DECOMPOSITIONS = decompositions.copy()
 MAX_SPLITS_THRESHOLD = 10
 
 
@@ -26,6 +29,7 @@ def tensorrt_backend(gm, sample_inputs):
         gm,
         sample_inputs,
         fw_compiler=make_boxed_compiler(fx2trt_compiler),
+        decompositions=DECOMPOSITIONS,
     )
 
 

From 76fd3c8207bdf017af294f1883863a755045b1a8 Mon Sep 17 00:00:00 2001
From: gs-olive <113141689+gs-olive@users.noreply.github.com>
Date: Mon, 27 Mar 2023 15:31:22 -0700
Subject: [PATCH 07/25] Mark FX2TRT converter as fake tensor unsupported

---
 .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py       | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
index a76162b93b..20cea4ffd5 100644
--- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
+++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
@@ -15,6 +15,8 @@
 from torch_tensorrt.fx.trt_module import TRTModule
 from torch_tensorrt.fx.utils import LowerPrecision
 
+from torch._dynamo.backends.common import fake_tensor_unsupported
+
 from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
 
 from torch._inductor.decomposition import decompositions
@@ -99,6 +101,7 @@ def get_input(self, inputs):
 
 
 @td.register_backend
+@fake_tensor_unsupported
 def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs):
     try:
         trt_compiled = fx2trt(gm, example_inputs)

From 6a8102c14f3c0fa7a200222979888e9d213d0d84 Mon Sep 17 00:00:00 2001
From: gs-olive <113141689+gs-olive@users.noreply.github.com>
Date: Tue, 28 Mar 2023 18:52:12 -0700
Subject: [PATCH 08/25] Minor naming bugfix

---
 .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py      | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
index 20cea4ffd5..55c5e2df33 100644
--- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
+++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
@@ -49,7 +49,7 @@ def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs):
 
     splitter.node_support_preview()
     split_mod = splitter()
-    num_piece = 0
+    num_pieces = 0
 
     for name, _ in split_mod.named_children():
         print(f"Graph is split into {name}")
@@ -58,7 +58,7 @@ def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs):
     # Select threshold above which segmentation is not beneficial and run graph in Torch
     if num_pieces > MAX_SPLITS_THRESHOLD:
         raise AssertionError(
-            f"The graph module is split into {num_piece} which is large than the \
+            f"The graph module is split into {num_pieces} which is large than the \
             threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module."
         )
 

From e97ed50eeb17b661cb7da060b5dd24bc32d9bb43 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Fri, 7 Apr 2023 11:12:12 -0700
Subject: [PATCH 09/25] Implementing aten::chunk, aten::layer_norm,
 aten::softmax, aten::where, aten::rsub, aten::rsqrt

---
 .../fx/converters/acc_ops_converters.py       | 220 +-------------
 .../fx/converters/aten_ops_converters.py      | 113 ++++++++
 py/torch_tensorrt/fx/converters/operator.py   | 269 +++++++++++++++++-
 .../converters/aten_op/test_chunk_aten.py     |  58 ++++
 .../aten_op/test_layer_norm_aten.py           |  45 +++
 .../converters/aten_op/test_rsqrt_aten.py     |   0
 .../test/converters/aten_op/test_rsub_aten.py |   0
 .../converters/aten_op/test_softmax_aten.py   |  44 +++
 .../converters/aten_op/test_squeeze_aten.py   |  67 +++++
 .../converters/aten_op/test_where_aten.py     |  56 ++++
 10 files changed, 662 insertions(+), 210 deletions(-)
 create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py
 create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
 create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py
 create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
 create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py
 create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py
 create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py

diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py
index e556e81bb5..a321bb8dfe 100644
--- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py
+++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py
@@ -678,7 +678,13 @@ def acc_ops_batch_norm(
 
 
 @tensorrt_converter(acc_ops.layer_norm)
-def acc_ops_layer_norm(network, target, args, kwargs, name):
+def acc_ops_layer_norm(
+    network: TRTNetwork,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
     return add_layer_norm(network, target, kwargs, name)
 
 
@@ -690,37 +696,7 @@ def acc_ops_softmax(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    input_val = kwargs["input"]
-    input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)  # type: ignore[union-attr]
-
-    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"softmax received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
-
-    # Used to get dim when dim is None. Copied from PyTorch softmax implementation.
-    def get_softmax_dim(ndim: int) -> int:
-        if ndim == 0 or ndim == 1 or ndim == 3:
-            ret = 0
-        else:
-            ret = 1
-        return ret
-
-    if kwargs["dim"] is None:
-        dim = get_softmax_dim(input_ranks)
-    else:
-        dim = cast(int, kwargs["dim"])
-
-    dim = get_positive_dim(dim, input_ranks)
-    if network.has_implicit_batch_dimension:
-        assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
-        dim -= 1
-
-    layer = network.add_softmax(input_val)
-    layer.axes = 1 << dim
-    set_layer_name(layer, target, name)
-    return layer.get_output(0)
+    return add_softmax(network, target, kwargs, name)
 
 
 @tensorrt_converter(acc_ops.tile)
@@ -956,9 +932,7 @@ def acc_ops_sqrt(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    input_val = kwargs["input"]
-    operation_type = trt.UnaryOperation.SQRT
-    return add_unary_layer(network, input_val, operation_type, target, name)
+    return add_sqrt(network, target, kwargs, name)
 
 
 @tensorrt_converter(acc_ops.reciprocal)
@@ -1619,40 +1593,7 @@ def acc_ops_squeeze(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    input_val = kwargs["input"]
-
-    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"squeeze received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
-
-    dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
-    # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
-    # dim, which is a very rare case. For now we just claim not supporting dim=None.
-    assert dim is not None, "We don't support dim=None right now for squeeze."
-
-    dim = get_positive_dim(
-        dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
-    )
-    if network.has_implicit_batch_dimension:
-        assert dim != 0, "We don't support squeeze batch dim when it's implicit."
-        dim -= 1
-
-    assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
-    assert (
-        len(get_dynamic_dims(input_val.shape)) <= 1
-    ), "Currently more than one dynamic dim for input to squeeze is not supported."
-
-    output_shape = []
-    for i, s in enumerate(input_val.shape):
-        if i == dim and s == 1:
-            continue
-        output_shape.append(s)
-    layer = network.add_shuffle(input_val)
-    layer.reshape_dims = tuple(output_shape)
-    set_layer_name(layer, target, name)
-    return layer.get_output(0)
+    return add_squeeze(network, target, kwargs, name)
 
 
 @tensorrt_converter(acc_ops.add)
@@ -2022,89 +1963,7 @@ def acc_ops_where(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-
-    condition_t = kwargs["condition"]
-    x_t = kwargs["x"]
-    y_t = kwargs["y"]
-
-    if type(x_t) != TRTTensor:
-        assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!"
-
-    if type(y_t) != TRTTensor:
-        assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!"
-
-    # get output shape
-
-    x_shape = list(x_t.shape)
-    y_shape = list(y_t.shape)
-    condition_shape = list(condition_t.shape)
-    output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
-
-    # expand shape
-    if type(condition_t) != TRTTensor:
-        assert condition_t.dtype == torch.bool, "condition dtype is not bool"
-        if condition_shape != output_shape:
-            condition_t.expand(output_shape)
-        condition_t = condition_t.to(torch.int32)
-        condition_const = get_trt_tensor(network, condition_t, f"{name}_condition")
-        condition_layer = network.add_identity(condition_const)
-        condition_layer.set_output_type(0, trt.bool)
-        set_layer_name(condition_layer, target, f"{name}_condition")
-        condition_val = condition_layer.get_output(0)
-    else:
-        assert condition_t.dtype == trt.bool, "mask dtype is not bool!"
-        if condition_shape != output_shape:
-            condition_val = acc_ops_expand_tensor(
-                network,
-                target,
-                None,
-                {"input": condition_t, "sizes": output_shape},
-                name=f"{name}_expand",
-            )
-        else:
-            condition_val = condition_t
-
-    if type(x_t) != TRTTensor:
-        if x_shape != output_shape:
-            # special case where 1 element in x_t
-            if len(x_t.shape) == 0:
-                x_t = x_t.unsqueeze(0)
-            x_t = x_t.expand(output_shape)
-        x_val = get_trt_tensor(network, x_t, f"{name}_x")
-    else:
-        x_val = x_t
-        if x_shape != output_shape:
-            x_val = acc_ops_expand_tensor(
-                network,
-                target,
-                None,
-                {"input": x_val, "sizes": output_shape},
-                name=f"{name}_x_expand",
-            )
-
-    if type(y_t) != TRTTensor:
-        if y_shape != output_shape:
-            # special case where 1 element in y_t
-            if len(y_t.shape) == 0:
-                y_t = y_t.unsqueeze(0)
-            y_t = y_t.expand(output_shape)
-        y_val = get_trt_tensor(network, y_t, f"{name}_y")
-    else:
-        y_val = y_t
-        if y_shape != output_shape:
-            y_val = acc_ops_expand_tensor(
-                network,
-                target,
-                None,
-                {"input": y_val, "sizes": output_shape},
-                name=f"{name}_y_expand",
-            )
-
-    select_layer = network.add_select(condition_val, x_val, y_val)
-
-    set_layer_name(select_layer, target, f"{name}_select")
-
-    return select_layer.get_output(0)
+    return add_where(network, target, kwargs, name)
 
 
 @tensorrt_converter(acc_ops.masked_fill, no_implicit_batch_dim=True)
@@ -2721,62 +2580,7 @@ def acc_ops_chunk(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    input_val = kwargs["input"]
-    chunks = cast(int, kwargs["chunks"])
-    dim = cast(int, kwargs["dim"])
-    input_dim_size = len(input_val.shape)  # type: ignore[union-attr]
-
-    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"chunk received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
-
-    dynamic_shape = has_dynamic_shape(input_val.shape)
-    if network.has_implicit_batch_dimension:
-        input_dim_size += 1
-        dim = get_positive_dim(dim, input_dim_size)
-        assert dim != 0, "Can't chunk on batch dim when it's implicit!"
-        dim -= 1
-    else:
-        if dynamic_shape:
-            assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
-        dim = get_positive_dim(dim, input_dim_size)
-
-    if chunks > input_val.shape[dim]:
-        warnings.warn(
-            f"Asked for {chunks} chunks along dimention "
-            f"{dim} on tensor with size {input_val.shape}, chunks "
-            f"will default to {input_val.shape[dim]}",
-            RuntimeWarning,
-        )
-        chunks = input_val.shape[dim]
-
-    start = [0] * len(input_val.shape)
-    stride = [1] * len(start)
-    offset = 0
-    split_size = (input_val.shape[dim] + chunks - 1) // chunks
-
-    max_offset = input_val.shape[dim]
-    # add slice layers
-    output = []
-    for i in range(chunks):
-        shape = list(input_val.shape)
-        shape[dim] = min(split_size, max_offset - offset)
-        if dynamic_shape:
-            shape = get_shape_with_dynamic_shape(
-                network, shape, input_val, target, f"{name}_{i}"
-            )
-        start[dim] = offset
-        layer = network.add_slice(
-            input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
-        )
-        if dynamic_shape:
-            layer.set_input(2, shape)
-        offset += split_size
-        set_layer_name(layer, target, f"{name}_{i}")
-        output.append(layer.get_output(0))
-    return output
+    return add_chunk(network, target, kwargs, name)
 
 
 @tensorrt_converter(acc_ops.cumsum, no_implicit_batch_dim=True)
diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py
index 1dbfa14076..d47f30a790 100644
--- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py
+++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py
@@ -620,3 +620,116 @@ def aten_ops_matmul(
         "other": args[1],
     }
     return add_matmul(network, target, kwargs_new, name)
+
+
+@tensorrt_converter(torch.ops.aten.layer_norm.default)
+def aten_ops_layernorm(
+    network: TRTNetwork,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    kwargs_new = {
+        "input": args[0],
+        "normalized_shape": args[1],
+        "weight": args[2],
+        "bias": args[3],
+        "eps": args[4],
+    }
+    return add_layer_norm(network, target, kwargs_new, name)
+
+
+@tensorrt_converter(torch.ops.aten._softmax.default)
+def aten_ops_softmax(
+    network: TRTNetwork,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    kwargs_new = {
+        "input": args[0],
+        "dim": args[1],
+    }
+    return add_softmax(network, target, kwargs_new, name)
+
+
+# FIXME: need to look at case where dim is tuple
+@tensorrt_converter(torch.ops.aten.squeeze.dim)
+@tensorrt_converter(torch.ops.aten.squeeze.dims)
+def aten_ops_squeeze(
+    network: TRTNetwork,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    kwargs_new = {
+        "input": args[0],
+        "dim": args[1],
+    }
+    return add_squeeze(network, target, kwargs_new, name)
+
+
+# FIXME: need to confirm lower basic passes
+# @tensorrt_converter(torch.ops.aten.chunk)
+# def aten_ops_chunk(
+#     network: TRTNetwork,
+#     target: Target,
+#     args: Tuple[Argument, ...],
+#     kwargs: Dict[str, Argument],
+#     name: str,
+# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
+#     kwargs_new = {
+#         "input": args[0],
+#         "chunks": args[1],
+#         "dim": args[2],
+#     }
+#     return add_chunk(network, target, kwargs_new, name)
+
+
+@tensorrt_converter(torch.ops.aten.where.self)
+def aten_ops_where(
+    network: TRTNetwork,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    kwargs_new = {
+        "condition": args[0],
+        "x": args[1],
+        "y": args[2],
+    }
+    return add_where(network, target, kwargs_new, name)
+
+
+@tensorrt_converter(torch.ops.aten.rsub)
+def aten_ops_rsub(
+    network: TRTNetwork,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    kwargs_new = {
+        "input": args[0],
+        "other": args[1],
+        "alpha": args[2],
+    }
+    return add_rsub(network, target, kwargs_new, name)
+
+
+@tensorrt_converter(torch.ops.aten.rsqrt)
+def aten_ops_rsqrt(
+    network: TRTNetwork,
+    target: Target,
+    args: Tuple[Argument, ...],
+    kwargs: Dict[str, Argument],
+    name: str,
+) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    kwargs_new = {
+        "input": args[0],
+    }
+    return add_rsqrt(network, target, kwargs_new, name)
diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index 5955e598f5..8d45278548 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -580,7 +580,7 @@ def layer_norm(
     set_layer_name(mean_expected_layer, target, f"{name}_mean_expected")
 
     # X-E[x]
-    sub_trt = operator.add_binary_elementwise_layer(
+    sub_trt = add_binary_elementwise_layer(
         network,
         input_val,
         mean_expected_layer.get_output(0),
@@ -594,7 +594,7 @@ def layer_norm(
         trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)),
     )
     pow_tensor.name = f"{name}_power"
-    pow_var = operator.add_binary_elementwise_layer(
+    pow_var = add_binary_elementwise_layer(
         network,
         sub_trt,
         pow_tensor.get_output(0),
@@ -739,6 +739,7 @@ def add_layer_norm(network, target, kwargs, name):
         _LOGGER.error(
             "Unable to find layer norm plugin, fall back to TensorRT implementation."
         )
+        args = []
         return layer_norm(network, target, args, kwargs, name)
     layer = network.add_plugin_v2([input_val], plugin)
     layer.name = name
@@ -1254,3 +1255,267 @@ def add_matmul(network, target, kwargs, name):
     )
     set_layer_name(layer, target, name)
     return layer.get_output(0)
+
+
+def add_softmax(network, target, kwargs, name):
+    input_val = kwargs["input"]
+    input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)  # type: ignore[union-attr]
+
+    if not isinstance(input_val, TRTTensor):
+        raise RuntimeError(
+            f"softmax received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    # Used to get dim when dim is None. Copied from PyTorch softmax implementation.
+    def get_softmax_dim(ndim: int) -> int:
+        if ndim == 0 or ndim == 1 or ndim == 3:
+            ret = 0
+        else:
+            ret = 1
+        return ret
+
+    if kwargs["dim"] is None:
+        dim = get_softmax_dim(input_ranks)
+    else:
+        dim = cast(int, kwargs["dim"])
+
+    dim = get_positive_dim(dim, input_ranks)
+    if network.has_implicit_batch_dimension:
+        assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
+        dim -= 1
+
+    layer = network.add_softmax(input_val)
+    layer.axes = 1 << dim
+    set_layer_name(layer, target, name)
+    return layer.get_output(0)
+
+
+def add_squeeze(network, target, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, TRTTensor):
+        raise RuntimeError(
+            f"squeeze received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
+    # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
+    # dim, which is a very rare case. For now we just claim not supporting dim=None.
+    assert dim is not None, "We don't support dim=None right now for squeeze."
+
+    dim = get_positive_dim(
+        dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
+    )
+    if network.has_implicit_batch_dimension:
+        assert dim != 0, "We don't support squeeze batch dim when it's implicit."
+        dim -= 1
+
+    assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
+    assert (
+        len(get_dynamic_dims(input_val.shape)) <= 1
+    ), "Currently more than one dynamic dim for input to squeeze is not supported."
+
+    output_shape = []
+    for i, s in enumerate(input_val.shape):
+        if i == dim and s == 1:
+            continue
+        output_shape.append(s)
+    layer = network.add_shuffle(input_val)
+    layer.reshape_dims = tuple(output_shape)
+    set_layer_name(layer, target, name)
+    return layer.get_output(0)
+
+
+def add_chunk(network, target, kwargs, name):
+    input_val = kwargs["input"]
+    chunks = cast(int, kwargs["chunks"])
+    dim = cast(int, kwargs["dim"])
+    input_dim_size = len(input_val.shape)  # type: ignore[union-attr]
+
+    if not isinstance(input_val, TRTTensor):
+        raise RuntimeError(
+            f"chunk received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    dynamic_shape = has_dynamic_shape(input_val.shape)
+    if network.has_implicit_batch_dimension:
+        input_dim_size += 1
+        dim = get_positive_dim(dim, input_dim_size)
+        assert dim != 0, "Can't chunk on batch dim when it's implicit!"
+        dim -= 1
+    else:
+        if dynamic_shape:
+            assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
+        dim = get_positive_dim(dim, input_dim_size)
+
+    if chunks > input_val.shape[dim]:
+        warnings.warn(
+            f"Asked for {chunks} chunks along dimention "
+            f"{dim} on tensor with size {input_val.shape}, chunks "
+            f"will default to {input_val.shape[dim]}",
+            RuntimeWarning,
+        )
+        chunks = input_val.shape[dim]
+
+    start = [0] * len(input_val.shape)
+    stride = [1] * len(start)
+    offset = 0
+    split_size = (input_val.shape[dim] + chunks - 1) // chunks
+
+    max_offset = input_val.shape[dim]
+    # add slice layers
+    output = []
+    for i in range(chunks):
+        shape = list(input_val.shape)
+        shape[dim] = min(split_size, max_offset - offset)
+        if dynamic_shape:
+            shape = get_shape_with_dynamic_shape(
+                network, shape, input_val, target, f"{name}_{i}"
+            )
+        start[dim] = offset
+        layer = network.add_slice(
+            input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
+        )
+        if dynamic_shape:
+            layer.set_input(2, shape)
+        offset += split_size
+        set_layer_name(layer, target, f"{name}_{i}")
+        output.append(layer.get_output(0))
+    return output
+
+
+def add_where(network, target, kwargs, name):
+    condition_t = kwargs["condition"]
+    x_t = kwargs["x"]
+    y_t = kwargs["y"]
+
+    if type(x_t) != TRTTensor:
+        assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!"
+
+    if type(y_t) != TRTTensor:
+        assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!"
+
+    # get output shape
+
+    x_shape = list(x_t.shape)
+    y_shape = list(y_t.shape)
+    condition_shape = list(condition_t.shape)
+    output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
+
+    # expand shape
+    if type(condition_t) != TRTTensor:
+        assert condition_t.dtype == torch.bool, "condition dtype is not bool"
+        if condition_shape != output_shape:
+            condition_t.expand(output_shape)
+        condition_t = condition_t.to(torch.int32)
+        condition_const = get_trt_tensor(network, condition_t, f"{name}_condition")
+        condition_layer = network.add_identity(condition_const)
+        condition_layer.set_output_type(0, trt.bool)
+        set_layer_name(condition_layer, target, f"{name}_condition")
+        condition_val = condition_layer.get_output(0)
+    else:
+        assert condition_t.dtype == trt.bool, "mask dtype is not bool!"
+        if condition_shape != output_shape:
+            condition_val = add_expand(
+                network,
+                target,
+                None,
+                {"input": condition_t, "sizes": output_shape},
+                name=f"{name}_expand",
+            )
+        else:
+            condition_val = condition_t
+
+    if type(x_t) != TRTTensor:
+        if x_shape != output_shape:
+            # special case where 1 element in x_t
+            if len(x_t.shape) == 0:
+                x_t = x_t.unsqueeze(0)
+            x_t = x_t.expand(output_shape)
+        x_val = get_trt_tensor(network, x_t, f"{name}_x")
+    else:
+        x_val = x_t
+        if x_shape != output_shape:
+            x_val = add_expand(
+                network,
+                target,
+                None,
+                {"input": x_val, "sizes": output_shape},
+                name=f"{name}_x_expand",
+            )
+
+    if type(y_t) != TRTTensor:
+        if y_shape != output_shape:
+            # special case where 1 element in y_t
+            if len(y_t.shape) == 0:
+                y_t = y_t.unsqueeze(0)
+            y_t = y_t.expand(output_shape)
+        y_val = get_trt_tensor(network, y_t, f"{name}_y")
+    else:
+        y_val = y_t
+        if y_shape != output_shape:
+            y_val = add_expand(
+                network,
+                target,
+                None,
+                {"input": y_val, "sizes": output_shape},
+                name=f"{name}_y_expand",
+            )
+
+    select_layer = network.add_select(condition_val, x_val, y_val)
+
+    set_layer_name(select_layer, target, f"{name}_select")
+
+    return select_layer.get_output(0)
+
+
+def add_scale(network, target, kwargs, name):
+    other = kwargs["other"]
+    scale = kwargs["scale"]
+    if isinstance(other, TRTTensor):
+        other_dtype = torch_dtype_from_trt(other.dtype)
+        is_other_trt_tensor = True
+
+    if not is_other_trt_tensor:
+        warnings.warn(
+            f"The value to be scaled is constant"
+            "In this case, please consider constant fold the model first."
+        )
+        return other * scale
+    layer = network.add_scale(other, trt.ScaleMode.UNIFORM, 0, scale, 1)
+    set_layer_name(layer, target, name)
+    return layer.get_output(0)
+
+
+def add_rsub(network, target, kwargs, name):
+    scaled_tensor = add_scale(network, target, kwargs, name)
+    input = kwargs["input"]
+    return add_binary_elementwise_layer(
+        network,
+        kwargs["input"],
+        scaled_tensor,
+        trt.ElementWiseOperation.SUB,
+        target,
+        name,
+    )
+
+
+def add_sqrt(network, target, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.SQRT
+    return add_unary_layer(network, input_val, operation_type, target, name)
+
+
+def add_rsqrt(network, target, kwargs, name):
+    sqrt_trt = add_sqrt(network, target, kwargs, name)
+    div_trt = add_binary_elementwise_layer(
+        network,
+        1,
+        sqrt_trt,
+        trt.ElementWiseOperation.DIV,
+        target,
+        f"{name}_div_trt",
+    )
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py
new file mode 100644
index 0000000000..8fae6da293
--- /dev/null
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py
@@ -0,0 +1,58 @@
+import unittest
+
+import torch
+import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
+from parameterized import param, parameterized
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+
+
+class TestSelectConverterImplicitBatch(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("select_chunk_dim", 6, 0),
+        ]
+    )
+    def test_chunk(self, _, chunk, dim):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, input):
+                out = torch.ops.aten.chunk(input, chunk, dim)
+                return out
+
+        input = [torch.randn(11)]
+        self.run_test(
+            TestModule(),
+            input,
+            expected_ops={torch.ops.aten.chunk},
+        )
+
+
+class TestSelectConverterExplicitBatch(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("select_chunk_dim", 6, 0),
+        ]
+    )
+    def test_chunk(self, _, chunk, dim):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, input):
+                out = torch.ops.aten.chunk(input, chunk, dim)
+                return out
+
+        input = [torch.randn(12)]
+        self.run_test(
+            TestModule(),
+            input,
+            expected_ops={torch.ops.aten.chunk},
+            test_explicit_precision=True,
+        )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
new file mode 100644
index 0000000000..cf97e828d0
--- /dev/null
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
@@ -0,0 +1,45 @@
+import torch
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+
+
+class TestLayerNormConverter(DispatchTestCase):
+    def test_layer_norm(self):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.ln = torch.nn.LayerNorm([3, 224, 224])
+
+            def forward(self, x):
+                return self.ln(x)
+
+        inputs = [torch.randn(1, 3, 224, 224)]
+        self.run_test(
+            TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default}
+        )
+
+
+def test_layernorm_with_dynamic_shape(self):
+    class TestModule(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.ln = torch.nn.LayerNorm([3, 224, 224])
+
+        def forward(self, x):
+            return self.ln(x)
+
+    input_specs = [
+        InputTensorSpec(
+            shape=(-1, 3, 224, 224),
+            dtype=torch.float32,
+            shape_ranges=[(1, 3, 1, 1)],
+        ),
+    ]
+
+    self.run_test_with_dynamic_shape(
+        TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm}
+    )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py
new file mode 100644
index 0000000000..31e293fc91
--- /dev/null
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py
@@ -0,0 +1,44 @@
+import torch
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+
+
+class TestSoftMaxConverter(DispatchTestCase):
+    def test_softmax(self):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.softmax = torch.nn.Softmax(1)
+
+            def forward(self, x):
+                return self.softmax(x)
+
+        inputs = [torch.randn(1, 3, 224, 224)]
+        self.run_test(
+            TestModule(), inputs, expected_ops={torch.ops.aten._softmax.default}
+        )
+
+    def test_softmax_with_dynamic_shape(self):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.softmax = torch.nn.Softmax(2)
+
+            def forward(self, x):
+                return self.softmax(x)
+
+        input_specs = [
+            InputTensorSpec(
+                shape=(-1, 3, -1, -1),
+                dtype=torch.float32,
+                shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
+            ),
+        ]
+
+        self.run_test_with_dynamic_shape(
+            TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default}
+        )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py
new file mode 100644
index 0000000000..5dd15a89e7
--- /dev/null
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py
@@ -0,0 +1,67 @@
+import torch
+import torch.nn as nn
+from parameterized import parameterized
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+
+
+class TestSqueezeConverter(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("2d_dim", (0), (2, 1)),
+            ("3d_one_dim", (0), (2, 2, 1)),
+            # ("3d_two_dim", (0, 1), (2, 2, 1)),
+            # ("4d_dim", (0, 1, 2), (2, 2, 2, 1)),
+        ]
+    )
+    def test_squeeze(self, _, dim, init_size):
+        class Squeeze(nn.Module):
+            def forward(self, x):
+                return torch.squeeze(x, dim)
+
+        inputs = [torch.randn(*init_size)]
+        expected_op = {}
+        if isinstance(dim, int) == 1:
+            expected_op = {torch.ops.aten.squeeze.dim}
+        else:
+            expected_op = {torch.ops.aten.squeeze.dims}
+        self.run_test(
+            Squeeze(),
+            inputs,
+            expected_ops=expected_op,
+        )
+
+
+class TestSqueezeConverter(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]),
+            ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]),
+            # ("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]),
+        ]
+    )
+    def test_squeeze(self, _, dim, init_size, shape_range):
+        class Squeeze(nn.Module):
+            def forward(self, x):
+                return torch.squeeze(x, dim)
+
+        if isinstance(dim, int) == 1:
+            expected_op = {torch.ops.aten.squeeze.dim}
+        else:
+            expected_op = {torch.ops.aten.squeeze.dims}
+        input_specs = [
+            InputTensorSpec(
+                shape=init_size,
+                dtype=torch.float32,
+                shape_ranges=shape_range,
+            ),
+        ]
+        self.run_test_with_dynamic_shape(
+            Squeeze(),
+            input_specs,
+            expected_ops=expected_op,
+        )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py
new file mode 100644
index 0000000000..6c050eee2f
--- /dev/null
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn as nn
+from parameterized import parameterized
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+
+
+class TestWhereConverter(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("2d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2)),
+            ("2d_broadcast_condition_xshape_yshape", (x < 0), (2, 2), (2, 1)),
+            ("3d_condition_xshape_yshape", (x > 0), (2, 2, 1), (2, 2, 1)),
+            ("2d_3d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2, 1)),
+        ]
+    )
+    def test_(self, _, condition, x_size, y_size):
+        class Where(nn.Module):
+            def forward(self, x):
+                return torch.where(x, dim)
+
+        inputX = [torch.randn(*x_size)]
+        inputOther = [torch.randn(*y_size)]
+        expected_op = {}
+        self.run_test(
+            Where(),
+            inputs,
+            expected_ops=torch.ops.aten.where.self,
+        )
+
+
+# class TestWhereConverter(DispatchTestCase):
+#     @parameterized.expand(
+#         [
+#             ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]),
+#             ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]),
+#             #("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]),
+#         ]
+#     )
+#     def test_where(self, _, dim, init_size, shape_range):
+#         class Squeeze(nn.Module):
+#             def forward(self, x):
+#                 return torch.squeeze(x, dim)
+
+#         input_specs = [
+#             InputTensorSpec(
+#                 shape=init_size,
+#                 dtype=torch.float32,
+#                 shape_ranges=shape_range,
+#             ),
+#         ]
+#         self.run_test_with_dynamic_shape(
+#             Squeeze(),
+#             input_specs,
+#             expected_ops=torch.ops.aten.where.self,
+#         )

From c5a4744867e8637a58042972bfee133372dcfbb1 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Mon, 10 Apr 2023 09:13:14 -0700
Subject: [PATCH 10/25] Transformer operator changes

---
 .../fx/converters/converter_utils.py          | 33 ++++++++++
 py/torch_tensorrt/fx/converters/operator.py   | 64 +++++++++++++------
 .../fx/passes/lower_basic_pass_aten.py        |  1 +
 .../converters/aten_op/test_rsqrt_aten.py     | 29 +++++++++
 .../test/converters/aten_op/test_rsub_aten.py | 29 +++++++++
 .../converters/aten_op/test_squeeze_aten.py   |  4 +-
 .../converters/aten_op/test_where_aten.py     | 57 +++++++++--------
 .../tensorrt_dynamo_backend.py                |  2 +-
 8 files changed, 171 insertions(+), 48 deletions(-)

diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py
index 551c18652d..9d405767ea 100644
--- a/py/torch_tensorrt/fx/converters/converter_utils.py
+++ b/py/torch_tensorrt/fx/converters/converter_utils.py
@@ -288,6 +288,39 @@ def prepend_ones(
     return layer.get_output(0)
 
 
+def broadcastable(
+    a: TRTTensor,
+    b: TRTTensor,
+) -> bool:
+    "Check if two tensors are broadcastable according to torch rules"
+    a_shape = tuple(a.shape)
+    b_shape = tuple(b.shape)
+    print("a shape is", a_shape)
+    print("b shape is", b_shape)
+    # check from the trailing
+    diff = len(a_shape) - len(b_shape)
+    if diff == 0:
+        return True
+    if diff > 0:
+        max = len(a_shape)
+        min = len(b_shape)
+        greater_tensor = a_shape
+        lesser_tensor = b_shape
+    elif diff < 0:
+        max = len(b_shape)
+        min = len(a_shape)
+        greater_tensor = b_shape
+        lesser_tensor = a_shape
+    j = min - 1
+    for i in range(max - 1, diff - 1, -1):
+        if not (
+            greater_tensor[i] != lesser_tensor[j]
+            and (greater_tensor[i] == 1 or lesser_tensor[i] == 1)
+        ):
+            return False
+    return True
+
+
 def broadcast(
     network: TRTNetwork,
     a: TRTTensor,
diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index 8d45278548..4449b7146e 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -15,6 +15,7 @@
 from .converter_utils import set_layer_name
 from .converter_utils import get_trt_tensor
 from .converter_utils import broadcast
+from .converter_utils import broadcastable
 from .converter_utils import squeeze_left
 from .converter_utils import dtype_uniform
 from .converter_utils import get_trt_plugin
@@ -1119,7 +1120,6 @@ def add_expand(network, target, kwargs, name):
     # TRT does not support different dimension size
     assert len(shape) == ranks
     shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]
-
     inshape = tuple(input_val.shape)
     shape = tuple(shape)
     start = tuple([0] * ranks)
@@ -1299,27 +1299,36 @@ def add_squeeze(network, target, kwargs, name):
             f"squeeze received input {input_val} that is not part "
             "of the TensorRT region!"
         )
+    dims = []
+    if "dim" in kwargs:
+        if isinstance(kwargs["dim"], int):
+            dims.append(cast(Optional[int], kwargs["dim"]))
+        else:
+            for dim in kwargs["dim"]:
+                dims.append(cast(Optional[int], dim))
 
-    dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
+    # dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
     # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
     # dim, which is a very rare case. For now we just claim not supporting dim=None.
-    assert dim is not None, "We don't support dim=None right now for squeeze."
+    assert not (len(dims) == 0), "We don't support dim=None right now for squeeze."
 
-    dim = get_positive_dim(
-        dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
-    )
-    if network.has_implicit_batch_dimension:
-        assert dim != 0, "We don't support squeeze batch dim when it's implicit."
-        dim -= 1
+    for dim in dims:
+        dim = get_positive_dim(
+            dim,
+            len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0),
+        )
+        if network.has_implicit_batch_dimension:
+            assert dim != 0, "We don't support squeeze batch dim when it's implicit."
+            dim -= 1
 
-    assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
-    assert (
-        len(get_dynamic_dims(input_val.shape)) <= 1
-    ), "Currently more than one dynamic dim for input to squeeze is not supported."
+        assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
+        assert (
+            len(get_dynamic_dims(input_val.shape)) <= 1
+        ), "Currently more than one dynamic dim for input to squeeze is not supported."
 
     output_shape = []
     for i, s in enumerate(input_val.shape):
-        if i == dim and s == 1:
+        if (i in dims) and s == 1:
             continue
         output_shape.append(s)
     layer = network.add_shuffle(input_val)
@@ -1392,14 +1401,32 @@ def add_where(network, target, kwargs, name):
     x_t = kwargs["x"]
     y_t = kwargs["y"]
 
+    x_t_dim = len(tuple(x_t.shape))
+    y_t_dim = len(tuple(y_t.shape))
+    condition_t_dim = len(tuple(condition_t.shape))
+
     if type(x_t) != TRTTensor:
         assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!"
 
     if type(y_t) != TRTTensor:
         assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!"
 
+    if not (broadcastable(x_t, y_t)):
+        assert f"The two torch tensors should be broadcastable"
+
     # get output shape
+    # purpose of this is to bring x_t and y_t rank same as
+    # output_shape to input it to the add_expand operation
+    # condition_t will have dimension of either x_t or y_t
+    x_t, y_t = broadcast(network, x_t, y_t, f"{name}_x", f"{name}_y")
+    if len(tuple(condition_t.shape)) != len(tuple(x_t.shape)):
+        condition_t, x_t = broadcast(
+            network, condition_t, x_t, f"{name}_condition", f"{name}_x"
+        )
 
+    print("x_t shape", x_t.shape)
+    print("y_t shape", y_t.shape)
+    print("condition_t shape", condition_t.shape)
     x_shape = list(x_t.shape)
     y_shape = list(y_t.shape)
     condition_shape = list(condition_t.shape)
@@ -1418,11 +1445,10 @@ def add_where(network, target, kwargs, name):
         condition_val = condition_layer.get_output(0)
     else:
         assert condition_t.dtype == trt.bool, "mask dtype is not bool!"
-        if condition_shape != output_shape:
+        if condition_shape != condition_t_dim:
             condition_val = add_expand(
                 network,
                 target,
-                None,
                 {"input": condition_t, "sizes": output_shape},
                 name=f"{name}_expand",
             )
@@ -1430,7 +1456,7 @@ def add_where(network, target, kwargs, name):
             condition_val = condition_t
 
     if type(x_t) != TRTTensor:
-        if x_shape != output_shape:
+        if x_shape != x_t_dim:
             # special case where 1 element in x_t
             if len(x_t.shape) == 0:
                 x_t = x_t.unsqueeze(0)
@@ -1442,7 +1468,6 @@ def add_where(network, target, kwargs, name):
             x_val = add_expand(
                 network,
                 target,
-                None,
                 {"input": x_val, "sizes": output_shape},
                 name=f"{name}_x_expand",
             )
@@ -1456,11 +1481,10 @@ def add_where(network, target, kwargs, name):
         y_val = get_trt_tensor(network, y_t, f"{name}_y")
     else:
         y_val = y_t
-        if y_shape != output_shape:
+        if y_shape != y_t_dim:
             y_val = add_expand(
                 network,
                 target,
-                None,
                 {"input": y_val, "sizes": output_shape},
                 name=f"{name}_y_expand",
             )
diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
index 00063c3e21..30aeee6944 100644
--- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
+++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
@@ -258,6 +258,7 @@ def remove_ops(
     for n in module.graph.nodes:
         if n.op == "call_function" and n.target in (
             torch.ops.aten._unsafe_view.default,
+            torch.ops.aten.view.default,
         ):
             modified = True
             node = n
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py
index e69de29bb2..da3aa30cb7 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn as nn
+from parameterized import parameterized
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+
+
+class TestRSubConverter(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("2d_dim_alpha", (2, 1), 2),
+            ("3d_dim_alpha", (2, 1, 2), 2),
+        ]
+    )
+    def test_rsqrt(self, _, x, alpha):
+        class rsqrt(nn.Module):
+            def forward(self, input):
+                return torch.rsqrt(input, input, alpha)
+
+        inputs = [torch.randn(x) + 1]
+        self.run_test(
+            rsqrt(),
+            inputs,
+            expected_ops=torch.ops.aten.rsqrt,
+        )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
index e69de29bb2..9be23fc419 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn as nn
+from parameterized import parameterized
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+
+
+class TestRSubConverter(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("2d_dim_alpha", (2, 1), 2),
+            ("3d_dim_alpha", (2, 1, 2), 2),
+        ]
+    )
+    def test_rsub(self, _, x, alpha):
+        class rsub(nn.Module):
+            def forward(self, input):
+                return torch.rsub(input, input, alpha)
+
+        inputs = [torch.randn(x)]
+        self.run_test(
+            rsub(),
+            inputs,
+            expected_ops=torch.ops.aten.rsub,
+        )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py
index 5dd15a89e7..5c655422de 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py
@@ -10,8 +10,8 @@ class TestSqueezeConverter(DispatchTestCase):
         [
             ("2d_dim", (0), (2, 1)),
             ("3d_one_dim", (0), (2, 2, 1)),
-            # ("3d_two_dim", (0, 1), (2, 2, 1)),
-            # ("4d_dim", (0, 1, 2), (2, 2, 2, 1)),
+            ("3d_two_dim", (0, 1), (2, 1, 1)),
+            ("4d_dim", (0, 1, 2), (2, 2, 1, 1)),
         ]
     )
     def test_squeeze(self, _, dim, init_size):
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py
index 6c050eee2f..0d4849c21f 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py
@@ -8,49 +8,56 @@
 class TestWhereConverter(DispatchTestCase):
     @parameterized.expand(
         [
-            ("2d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2)),
-            ("2d_broadcast_condition_xshape_yshape", (x < 0), (2, 2), (2, 1)),
-            ("3d_condition_xshape_yshape", (x > 0), (2, 2, 1), (2, 2, 1)),
-            ("2d_3d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2, 1)),
+            ("2d_condition_xshape_yshape", (2, 2), (2, 2)),
+            ("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)),
+            ("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)),
+            ("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)),
         ]
     )
-    def test_(self, _, condition, x_size, y_size):
+    def test_(self, _, x_size, y_size):
         class Where(nn.Module):
-            def forward(self, x):
-                return torch.where(x, dim)
+            def forward(self, condition, x, y):
+                return torch.where(condition, x, y)
 
-        inputX = [torch.randn(*x_size)]
-        inputOther = [torch.randn(*y_size)]
-        expected_op = {}
+        inputX = torch.randn(*x_size)
+        inputOther = torch.randn(*y_size)
+        condition = inputX < 0
         self.run_test(
             Where(),
-            inputs,
-            expected_ops=torch.ops.aten.where.self,
+            (condition, inputX, inputOther),
+            expected_ops={torch.ops.aten.where.self},
         )
 
 
+# FIXME: How to specify condition for dynamic shape
+# InputTensorSpec like case below where one input is dynamic another is not
 # class TestWhereConverter(DispatchTestCase):
 #     @parameterized.expand(
 #         [
-#             ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]),
-#             ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]),
+#             ("2d_dim", (-1, 2), [((1, 2), (2, 2), (2, 2))], (2,2))
+#             #("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]),
 #             #("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]),
 #         ]
 #     )
-#     def test_where(self, _, dim, init_size, shape_range):
-#         class Squeeze(nn.Module):
-#             def forward(self, x):
-#                 return torch.squeeze(x, dim)
-
-#         input_specs = [
-#             InputTensorSpec(
-#                 shape=init_size,
+#     def test_where(self, _, x_size, x_size_range, y_size):
+#         class Where(nn.Module):
+#             def forward(self, condition, x, y):
+#                 return torch.where(condition, x, y)
+#         inputX = InputTensorSpec(
+#                 shape=x_size,
 #                 dtype=torch.float32,
-#                 shape_ranges=shape_range,
-#             ),
+#                 shape_ranges=x_size_range,
+#                 )
+#         inputOther = torch.randn(*y_size)
+#         condition = (inputOther < 0)
+#         input_specs = [
+#             inputX, inputOther, condition
 #         ]
 #         self.run_test_with_dynamic_shape(
-#             Squeeze(),
+#             Where(),
 #             input_specs,
 #             expected_ops=torch.ops.aten.where.self,
 #         )
+
+# if __name__ == "__main__":
+#     run_tests()
diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
index 55c5e2df33..e53f0bc64e 100644
--- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
+++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py
@@ -22,7 +22,7 @@
 from torch._inductor.decomposition import decompositions
 
 DECOMPOSITIONS = decompositions.copy()
-MAX_SPLITS_THRESHOLD = 10
+MAX_SPLITS_THRESHOLD = 100
 
 
 def tensorrt_backend(gm, sample_inputs):

From 8d4e4b4f89c8e413a4420e4ba468e2e3d3e284ce Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Mon, 10 Apr 2023 23:39:59 -0700
Subject: [PATCH 11/25] Fixing acc split test

---
 py/torch_tensorrt/fx/converters/operator.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index 5955e598f5..53e0d88557 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -1207,7 +1207,7 @@ def add_slice(network, target, kwargs, name):
     stride = [1] * len(start)
     stride[dim] = step_int
     output_shape = list(input_val.shape)
-    output_shape[dim] = (stop_int - start_int) // step_int + 1
+    output_shape[dim] = (stop_int - start_int) // step_int
 
     if dynamic_shape > 0:
         output_shape = get_shape_with_dynamic_shape(

From 1ab9af5ae5f6842fcd00f7e12d4fe2b308c1fbfe Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Tue, 11 Apr 2023 12:42:05 -0700
Subject: [PATCH 12/25] Bug fix for add_slice

---
 py/torch_tensorrt/fx/converters/operator.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index 53e0d88557..37dd84d84e 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -2,6 +2,7 @@
 import operator
 import warnings
 import logging
+import math
 from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
 
 import tensorrt as trt
@@ -1207,7 +1208,7 @@ def add_slice(network, target, kwargs, name):
     stride = [1] * len(start)
     stride[dim] = step_int
     output_shape = list(input_val.shape)
-    output_shape[dim] = (stop_int - start_int) // step_int
+    output_shape[dim] = math.ceil((stop_int - start_int) // step_int)
 
     if dynamic_shape > 0:
         output_shape = get_shape_with_dynamic_shape(

From 8de6c9d449518512c7a4f3c12cda06495866f284 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Tue, 11 Apr 2023 16:05:49 -0700
Subject: [PATCH 13/25] dynamic test for slice

---
 .../converters/aten_op/test_slice_aten.py     | 31 +++++++++++++++++++
 1 file changed, 31 insertions(+)

diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py
index b018aff73e..6ddc082657 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py
@@ -34,6 +34,7 @@ class TestSelectConverterExplicitBatch(DispatchTestCase):
     @parameterized.expand(
         [
             ("select_dim_start_stop_step", 1, 0, 7, 2),
+            ("select_dim_start_stop_step_exact", 1, 0, 10, 2),
         ]
     )
     def test_slice(self, _, dim, start, stop, step):
@@ -54,5 +55,35 @@ def forward(self, input):
         )
 
 
+class TestSelectConverterDynamicShape(DispatchTestCase):
+    @parameterized.expand(
+        [
+            ("select_dim_start_stop_step", 1, 0, 7, 2),
+            ("select_dim_start_stop_step", 1, 0, 10, 2),
+        ]
+    )
+    def test_slice(self, _, dim, start, stop, step):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, input):
+                out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step)
+                return out
+
+        input_specs = [
+            InputTensorSpec(
+                shape=(1, 10, -1),
+                dtype=torch.float32,
+                shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
+            ),
+        ]
+        self.run_test_with_dynamic_shape(
+            TestModule(),
+            input_specs,
+            expected_ops={torch.ops.aten.slice.Tensor},
+        )
+
+
 if __name__ == "__main__":
     run_tests()

From ab89d2b045eb93f549c5c6de65b22830a47f9386 Mon Sep 17 00:00:00 2001
From: Apurba Bose <44209735+apbose@users.noreply.github.com>
Date: Fri, 14 Apr 2023 12:09:12 -0700
Subject: [PATCH 14/25] Correct the output_shape dimension for add_slice

---
 py/torch_tensorrt/fx/converters/operator.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index 37dd84d84e..374db3f611 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -1208,7 +1208,7 @@ def add_slice(network, target, kwargs, name):
     stride = [1] * len(start)
     stride[dim] = step_int
     output_shape = list(input_val.shape)
-    output_shape[dim] = math.ceil((stop_int - start_int) // step_int)
+    output_shape[dim] = math.ceil((stop_int - start_int) / step_int)
 
     if dynamic_shape > 0:
         output_shape = get_shape_with_dynamic_shape(

From 09a52b99082604279c59b66aa5cb5b9d417db71f Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Wed, 19 Apr 2023 08:04:31 -0700
Subject: [PATCH 15/25] matmul changes, bmm changes and adding broadcastable

---
 .../fx/converters/converter_utils.py          |  37 +++++-
 py/torch_tensorrt/fx/converters/operator.py   |  13 ++-
 .../fx/passes/lower_basic_pass_aten.py        |  28 ++++-
 .../converters/aten_op/test_matmul_aten.py    | 107 ++++++++++++++++--
 4 files changed, 167 insertions(+), 18 deletions(-)

diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py
index 551c18652d..8205d32ecb 100644
--- a/py/torch_tensorrt/fx/converters/converter_utils.py
+++ b/py/torch_tensorrt/fx/converters/converter_utils.py
@@ -77,7 +77,7 @@ def get_positive_dim(dim: int, dim_size: int) -> int:
     return dim
 
 
-def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
+def set_layer_name(layer: TRTLayer, target: Target, name: str, is_acc=True) -> None:
     """
     Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
 
@@ -87,7 +87,7 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
             the node represents.
         name (str): Consists of fx node.name with optional suffix.
     """
-    target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}"
+    target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}" if is_acc else f"aten_ops.{target.__name__}"
     layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
 
 
@@ -288,6 +288,39 @@ def prepend_ones(
     return layer.get_output(0)
 
 
+def broadcastable(
+    a: TRTTensor,
+    b: TRTTensor,
+) -> bool:
+    "Check if two tensors are broadcastable according to torch rules"
+    a_shape = tuple(a.shape)
+    b_shape = tuple(b.shape)
+    print("a shape is", a_shape)
+    print("b shape is", b_shape)
+    # check from the trailing
+    diff = len(a_shape) - len(b_shape)
+    if diff == 0:
+        return True
+    if diff > 0:
+        max = len(a_shape)
+        min = len(b_shape)
+        greater_tensor = a_shape
+        lesser_tensor = b_shape
+    elif diff < 0:
+        max = len(b_shape)
+        min = len(a_shape)
+        greater_tensor = b_shape
+        lesser_tensor = a_shape
+    j = min - 1
+    for i in range(max - 1, diff - 1, -1):
+        if not (
+            greater_tensor[i] != lesser_tensor[j]
+            and (greater_tensor[i] == 1 or lesser_tensor[i] == 1)
+        ):
+            return False
+    return True
+
+
 def broadcast(
     network: TRTNetwork,
     a: TRTTensor,
diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index 374db3f611..674942b1ce 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -16,6 +16,7 @@
 from .converter_utils import set_layer_name
 from .converter_utils import get_trt_tensor
 from .converter_utils import broadcast
+from .converter_utils import broadcastable
 from .converter_utils import squeeze_left
 from .converter_utils import dtype_uniform
 from .converter_utils import get_trt_plugin
@@ -1117,7 +1118,17 @@ def add_expand(network, target, kwargs, name):
 
     ranks = len(input_val.shape)
     # TRT does not support different dimension size
-    assert len(shape) == ranks
+    #though this condition is not seen in the case of bmm 
+    # where input_t and shape dimensions are not equal
+    assert len(shape) >= ranks
+    if(len(shape) != ranks):
+            shape_tuple = tuple([0] * len(shape))
+            shape_tensor = get_trt_tensor(network, input_t, f"{name}_shape")
+            input_val, shape_tensor = broadcast(network, input_val, shape_tensor, 
+                                  f"{name}_input_val",
+                                  f"{name}_shape_val")
+            ranks = len(shape)
+            
     shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]
 
     inshape = tuple(input_val.shape)
diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
index 00063c3e21..f9b5b20fbf 100644
--- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
+++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
@@ -416,29 +416,46 @@ def compose_bmm(
             node = n
             input_n = node.all_input_nodes[0]
             other_n = node.all_input_nodes[1]
+
+             # If no input nodes are available, the bmm argument itself could be an input
+            # Alternatively, if the node has no users, it can be eliminated
+            if len(input_n.all_input_nodes) == 0 or len(node.users) == 0:
+                return PassResult(module, modified)
+            
             output = next(iter(node.users))
             input_input_n = input_n.all_input_nodes[0]
             if (
                 input_input_n.target != torch.ops.aten.expand.default
                 and input_n.target != torch.ops.aten.view.default
             ):
-                raise RuntimeError(
-                    "Bmm is addressed in fixed pattern. A new pattern is met!"
+                _LOGGER.warn(
+                    "Bmm is addressed in fixed pattern. "
+                    + f"A new pattern {input_input_n.target}, {input_n.target} is met! "
+                    + "Skipping bmm lowering on this operation"
                 )
+                return PassResult(module, modified)
+            
             real_input = input_input_n.all_input_nodes[0]
             input_other_n = other_n.all_input_nodes[0]
             if (
                 input_other_n.target != torch.ops.aten.expand.default
                 and other_n.target != torch.ops.aten.view.default
             ):
-                raise RuntimeError(
-                    "Bmm is addressed in fixed pattern. A new pattern is met!"
+                _LOGGER.warn(
+                    "Bmm is addressed in fixed pattern. "
+                    + f"A new pattern {input_other_n.target}, {other_n.target} is met! "
+                    + "Skipping bmm lowering on this operation"
                 )
+                return PassResult(module, modified)
+            
             real_other = input_other_n.all_input_nodes[0]
             if len(real_other.meta["val"].size()) == 2:
                 new_func = aten_compose_bmm_2d
-            if len(real_other.meta["val"].size()) == 3:
+            elif len(real_other.meta["val"].size()) == 3:
                 new_func = aten_compose_bmm_3d
+            else:
+                # No valid bmm replacement exists for the specified dimensions
+                return PassResult(module, modified)
 
             with module.graph.inserting_after(node):
                 new_args = (real_input, real_other)
@@ -449,6 +466,7 @@ def compose_bmm(
                     kwargs=None,
                 )
             output.replace_all_uses_with(new_node)
+            modified = True
 
     module.graph.eliminate_dead_code()
     module.recompile()
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py
index 0b0cd8d0b5..de4911bb08 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py
@@ -6,22 +6,109 @@
 from torch.testing._internal.common_utils import run_tests
 from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
 
+import torch
+import torch.nn as nn
+import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
+from parameterized import parameterized
+from torch.testing._internal.common_utils import run_tests
+from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
+
 
 class TestMatMulConverter(DispatchTestCase):
-    def test_matmul(self):
-        class TestModule(torch.nn.Module):
-            def forward(self, x, y):
-                return torch.matmul(x, y)
-
-        inputOne = torch.randn(2, 32)
-        inputTwo = torch.randn(32, 2)
-        inputs = [inputOne, inputTwo]
+    @parameterized.expand(
+        [
+            ("2_2", (2, 3), (3, 2)),
+            ("2_2", (2, 3), (3, 1)),
+            #FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
+            # (2,3), (3,) torch.ops.aten.mv.default
+            # Following cases use torch.ops.aten.bmm.defauly 
+            #("4_3", (3,1,3,2), (2,2,3)),
+            #("3_4", (3,1,3,2), (2,2,3)),
+            #("3_4", (2, 2, 3), (3, 1, 3, 3)),
+            #("4_2", (1, 2, 2, 3), (3, 2)),
+        ]
+    )
+    def test_matmul_other_constant(self, _, input_shape, other_shape):
+        class MatMul(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.other = nn.Parameter(torch.randn(*other_shape))
+
+            def forward(self, input):
+                return torch.matmul(input, self.other)
+
+        inputs = [torch.randn(*input_shape)]
+        
+        self.run_test(
+            MatMul(),
+            inputs,
+            expected_ops={torch.ops.aten.mm.default},
+            test_explicit_batch_dim=(len(input_shape) >= 1),
+        )
+
+    @parameterized.expand(
+        [
+            ("2_2", (2, 3), (3, 2)),
+            ("1_2", (1, 3), (3, 2)),
+            #FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
+            # (2,3), (3,) torch.ops.aten.mv.default
+            # Following cases use torch.ops.aten.bmm.defauly 
+            #("4_3", (3,1,3,2), (2,2,3)),
+            #("3_4", (3,1,3,2), (2,2,3)),
+            #("3_4", (2, 2, 3), (3, 1, 3, 3)),
+            #("4_2", (1, 2, 2, 3), (3, 2)),
+            
+        ]
+    )
+    def test_matmul_input_constant(self, _, input_shape, other_shape):
+        class MatMul(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.input = nn.Parameter(torch.randn(*input_shape))
+
+            def forward(self, other):
+                return torch.matmul(self.input, other)
+
+        inputs = [torch.randn(*other_shape)]
+
+        self.run_test(
+            MatMul(),
+            inputs,
+            expected_ops={torch.ops.aten.mm.default},
+            test_explicit_batch_dim=True 
+            #test_explicit_batch_dim=(len(other_shape) <= 2),
+        )
+
+    @parameterized.expand(
+        [
+            ("2_2", (2, 3), (3, 2)),
+            # ("2_3", (2, 3), (2, 3, 4)),
+            # ("4_4", (2, 2, 2, 3), (2, 1, 3, 2)),
+            # ("4_2", (2, 1, 2, 3), (3, 2)),
+            # ("2_1", (2, 3), (3,)),
+            # ("1_2", (3,), (3, 2)),
+            # ("1_1", (3,), (3,)),
+        ]
+    )
+    def test_matmul(self, _, input_shape, other_shape):
+        class MatMul(nn.Module):
+            def forward(self, input, other):
+                return torch.matmul(input, other)
+
+        inputs = [torch.randn(*input_shape), torch.randn(*other_shape)]
+        test_explicit_batch_dim = not(
+            input_shape[0] == other_shape[0]
+            and len(input_shape) > 2
+            and len(other_shape) > 2
+        )
         self.run_test(
-            TestModule(),
+            MatMul(),
             inputs,
             expected_ops={torch.ops.aten.mm.default},
+            test_explicit_batch_dim=test_explicit_batch_dim,
         )
 
+    #FIXME: dynamic shape is giving bmm
 
 if __name__ == "__main__":
-    run_tests()
+    run_tests()
\ No newline at end of file

From d1fd1d7b30548281f8155abfda872e7e76a8f362 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Wed, 19 Apr 2023 08:43:49 -0700
Subject: [PATCH 16/25] Correcting pre-commit hooks

---
 .../fx/converters/converter_utils.py          |  8 +++-
 py/torch_tensorrt/fx/converters/operator.py   | 18 ++++-----
 .../fx/passes/lower_basic_pass_aten.py        |  8 ++--
 .../converters/aten_op/test_matmul_aten.py    | 38 +++++++++----------
 4 files changed, 39 insertions(+), 33 deletions(-)

diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py
index 8205d32ecb..2d79014ebd 100644
--- a/py/torch_tensorrt/fx/converters/converter_utils.py
+++ b/py/torch_tensorrt/fx/converters/converter_utils.py
@@ -87,7 +87,13 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str, is_acc=True) -> N
             the node represents.
         name (str): Consists of fx node.name with optional suffix.
     """
-    target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}" if is_acc else f"aten_ops.{target.__name__}"
+    target_name = (
+        target
+        if isinstance(target, str)
+        else f"acc_ops.{target.__name__}"
+        if is_acc
+        else f"aten_ops.{target.__name__}"
+    )
     layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
 
 
diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index 674942b1ce..9487894506 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -1118,17 +1118,17 @@ def add_expand(network, target, kwargs, name):
 
     ranks = len(input_val.shape)
     # TRT does not support different dimension size
-    #though this condition is not seen in the case of bmm 
+    # though this condition is not seen in the case of bmm
     # where input_t and shape dimensions are not equal
     assert len(shape) >= ranks
-    if(len(shape) != ranks):
-            shape_tuple = tuple([0] * len(shape))
-            shape_tensor = get_trt_tensor(network, input_t, f"{name}_shape")
-            input_val, shape_tensor = broadcast(network, input_val, shape_tensor, 
-                                  f"{name}_input_val",
-                                  f"{name}_shape_val")
-            ranks = len(shape)
-            
+    if len(shape) != ranks:
+        shape_tuple = tuple([0] * len(shape))
+        shape_tensor = get_trt_tensor(network, input_t, f"{name}_shape")
+        input_val, shape_tensor = broadcast(
+            network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val"
+        )
+        ranks = len(shape)
+
     shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]
 
     inshape = tuple(input_val.shape)
diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
index f9b5b20fbf..6790962621 100644
--- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
+++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
@@ -417,11 +417,11 @@ def compose_bmm(
             input_n = node.all_input_nodes[0]
             other_n = node.all_input_nodes[1]
 
-             # If no input nodes are available, the bmm argument itself could be an input
+            # If no input nodes are available, the bmm argument itself could be an input
             # Alternatively, if the node has no users, it can be eliminated
             if len(input_n.all_input_nodes) == 0 or len(node.users) == 0:
                 return PassResult(module, modified)
-            
+
             output = next(iter(node.users))
             input_input_n = input_n.all_input_nodes[0]
             if (
@@ -434,7 +434,7 @@ def compose_bmm(
                     + "Skipping bmm lowering on this operation"
                 )
                 return PassResult(module, modified)
-            
+
             real_input = input_input_n.all_input_nodes[0]
             input_other_n = other_n.all_input_nodes[0]
             if (
@@ -447,7 +447,7 @@ def compose_bmm(
                     + "Skipping bmm lowering on this operation"
                 )
                 return PassResult(module, modified)
-            
+
             real_other = input_other_n.all_input_nodes[0]
             if len(real_other.meta["val"].size()) == 2:
                 new_func = aten_compose_bmm_2d
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py
index de4911bb08..e0dc05fded 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py
@@ -19,13 +19,13 @@ class TestMatMulConverter(DispatchTestCase):
         [
             ("2_2", (2, 3), (3, 2)),
             ("2_2", (2, 3), (3, 1)),
-            #FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
+            # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
             # (2,3), (3,) torch.ops.aten.mv.default
-            # Following cases use torch.ops.aten.bmm.defauly 
-            #("4_3", (3,1,3,2), (2,2,3)),
-            #("3_4", (3,1,3,2), (2,2,3)),
-            #("3_4", (2, 2, 3), (3, 1, 3, 3)),
-            #("4_2", (1, 2, 2, 3), (3, 2)),
+            # Following cases use torch.ops.aten.bmm.defauly
+            # ("4_3", (3,1,3,2), (2,2,3)),
+            # ("3_4", (3,1,3,2), (2,2,3)),
+            # ("3_4", (2, 2, 3), (3, 1, 3, 3)),
+            # ("4_2", (1, 2, 2, 3), (3, 2)),
         ]
     )
     def test_matmul_other_constant(self, _, input_shape, other_shape):
@@ -38,7 +38,7 @@ def forward(self, input):
                 return torch.matmul(input, self.other)
 
         inputs = [torch.randn(*input_shape)]
-        
+
         self.run_test(
             MatMul(),
             inputs,
@@ -50,14 +50,13 @@ def forward(self, input):
         [
             ("2_2", (2, 3), (3, 2)),
             ("1_2", (1, 3), (3, 2)),
-            #FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
+            # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
             # (2,3), (3,) torch.ops.aten.mv.default
-            # Following cases use torch.ops.aten.bmm.defauly 
-            #("4_3", (3,1,3,2), (2,2,3)),
-            #("3_4", (3,1,3,2), (2,2,3)),
-            #("3_4", (2, 2, 3), (3, 1, 3, 3)),
-            #("4_2", (1, 2, 2, 3), (3, 2)),
-            
+            # Following cases use torch.ops.aten.bmm.defauly
+            # ("4_3", (3,1,3,2), (2,2,3)),
+            # ("3_4", (3,1,3,2), (2,2,3)),
+            # ("3_4", (2, 2, 3), (3, 1, 3, 3)),
+            # ("4_2", (1, 2, 2, 3), (3, 2)),
         ]
     )
     def test_matmul_input_constant(self, _, input_shape, other_shape):
@@ -75,8 +74,8 @@ def forward(self, other):
             MatMul(),
             inputs,
             expected_ops={torch.ops.aten.mm.default},
-            test_explicit_batch_dim=True 
-            #test_explicit_batch_dim=(len(other_shape) <= 2),
+            test_explicit_batch_dim=True
+            # test_explicit_batch_dim=(len(other_shape) <= 2),
         )
 
     @parameterized.expand(
@@ -96,7 +95,7 @@ def forward(self, input, other):
                 return torch.matmul(input, other)
 
         inputs = [torch.randn(*input_shape), torch.randn(*other_shape)]
-        test_explicit_batch_dim = not(
+        test_explicit_batch_dim = not (
             input_shape[0] == other_shape[0]
             and len(input_shape) > 2
             and len(other_shape) > 2
@@ -108,7 +107,8 @@ def forward(self, input, other):
             test_explicit_batch_dim=test_explicit_batch_dim,
         )
 
-    #FIXME: dynamic shape is giving bmm
+    # FIXME: dynamic shape is giving bmm
+
 
 if __name__ == "__main__":
-    run_tests()
\ No newline at end of file
+    run_tests()

From 1d78f436a9a423a8486338ecefd972fd08777f63 Mon Sep 17 00:00:00 2001
From: Michael Feliz <104801882+mfeliz-cruise@users.noreply.github.com>
Date: Wed, 19 Apr 2023 15:30:32 -0700
Subject: [PATCH 17/25] feat: Add ts converter support for aten::all.dim
 (#1840)

---
 core/conversion/converters/impl/reduce.cpp    | 76 +++++++++++++------
 .../conversion/converters/test_reduce.cpp     | 53 ++++++++++++-
 2 files changed, 105 insertions(+), 24 deletions(-)

diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp
index 249ae916ef..e3c7498c47 100644
--- a/core/conversion/converters/impl/reduce.cpp
+++ b/core/conversion/converters/impl/reduce.cpp
@@ -9,6 +9,36 @@ namespace converters {
 namespace impl {
 namespace {
 
+nvinfer1::ITensor* anyDimImplementation(
+    ConversionCtx* ctx,
+    const torch::jit::Node* n,
+    nvinfer1::ITensor* in_tensor,
+    int dim,
+    bool keepdim) {
+  auto in_dims = in_tensor->getDimensions();
+  LOG_DEBUG("Dim to reduce (original): " << dim);
+  dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
+  LOG_DEBUG("Dim to reduce (converted): " << dim);
+
+  uint32_t axis_mask = 1 << dim;
+  LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
+  LOG_DEBUG("Keep dims: " << keepdim);
+
+  // Reduce does not work on bool inputs
+  if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
+    in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
+  }
+  auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
+
+  TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
+
+  sum_layer->setName(util::node_info(n).c_str());
+  auto out_tensor =
+      castITensor(ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
+  out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
+  return out_tensor;
+}
+
 auto reduce_registrations TORCHTRT_UNUSED =
     RegisterNodeConversionPatterns()
         .pattern(
@@ -224,33 +254,35 @@ auto reduce_registrations TORCHTRT_UNUSED =
             {"aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
              [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
                auto in_tensor = args[0].ITensorOrFreeze(ctx);
-               auto in_dims = in_tensor->getDimensions();
                auto dim = args[1].unwrapToInt();
-               LOG_DEBUG("Dim to reduce (original): " << dim);
-               dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
-               LOG_DEBUG("Dim to reduce (converted): " << dim);
-
-               uint32_t axis_mask = 1 << dim;
-               LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
-
                auto keepdim = args[2].unwrapToBool();
-               LOG_DEBUG("Keep dims: " << keepdim);
-
-               // Reduce does not work on bool inputs
-               if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
-                 in_tensor =
-                     castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
-               }
-               auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
-
-               TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
-
-               sum_layer->setName(util::node_info(n).c_str());
-               auto out_tensor = castITensor(
-                   ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
+               auto out_tensor = anyDimImplementation(ctx, n, in_tensor, dim, keepdim);
                out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
                LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
                return true;
+             }})
+        .pattern(
+            {"aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
+             [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
+               // use Not(Any(Not(input))) to calculate all without a direct all reduction
+               auto in_tensor = args[0].ITensorOrFreeze(ctx);
+               auto dim = args[1].unwrapToInt();
+               auto keepdim = args[2].unwrapToBool();
+               if (in_tensor->getType() != nvinfer1::DataType::kBOOL) {
+                 // unary not layer only supports bool inputs
+                 in_tensor = castITensor(
+                     ctx, in_tensor, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_in_to_bool").c_str());
+               }
+               auto not_input_layer = ctx->net->addUnary(*in_tensor, nvinfer1::UnaryOperation::kNOT);
+               TORCHTRT_CHECK(not_input_layer, "Unable to create logical_not layer from node: " << *n);
+               not_input_layer->setName((util::node_info(n) + "_not_in").c_str());
+               auto not_in = not_input_layer->getOutput(0);
+               auto any_out = anyDimImplementation(ctx, n, not_in, dim, keepdim);
+               auto not_output_layer = ctx->net->addUnary(*any_out, nvinfer1::UnaryOperation::kNOT);
+               TORCHTRT_CHECK(not_output_layer, "Unable to create logical_not layer from node: " << *n);
+               auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], not_output_layer->getOutput(0));
+               LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
+               return true;
              }});
 } // namespace
 } // namespace impl
diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp
index 40835a8dea..47e8b8d154 100644
--- a/tests/core/conversion/converters/test_reduce.cpp
+++ b/tests/core/conversion/converters/test_reduce.cpp
@@ -62,7 +62,7 @@ std::string gen_keepdim_graph(const std::string& op) {
         return (%5))IR";
 }
 
-void test_body(const std::string& graph, at::Tensor& in) {
+void test_body(const std::string& graph, at::Tensor& in, bool dynamic = false) {
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
 
@@ -71,7 +71,12 @@ void test_body(const std::string& graph, at::Tensor& in) {
 
   in = at::clone(in);
   params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
-  auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
+  std::vector<at::Tensor> trt_results;
+  if (dynamic) {
+    trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in});
+  } else {
+    trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
+  }
   ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
 }
 } // namespace
@@ -344,6 +349,50 @@ TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) {
   test_body(graph, in);
 }
 
+TEST(Converters, ATenAllDimConvertsCorrectly) {
+  const auto graph = R"IR(
+    graph(%0 : Tensor):
+      %1 : int = prim::Constant[value=-1]()
+      %3 : bool = prim::Constant[value=0]()
+      %5 : Tensor = aten::all(%0, %1, %3)
+      return (%5))IR";
+  auto in = at::randint(0, 2, {64, 2}, at::kCUDA);
+  test_body(graph, in);
+}
+
+TEST(Converters, ATenAllDimKeepDimConvertsCorrectly) {
+  const auto graph = R"IR(
+    graph(%0 : Tensor):
+      %1 : int = prim::Constant[value=0]()
+      %3 : bool = prim::Constant[value=1]()
+      %5 : Tensor = aten::all(%0, %1, %3)
+      return (%5))IR";
+  auto in = at::randint(-2, 2, {2, 32}, at::kCUDA).to(torch::kBool);
+  test_body(graph, in);
+}
+
+TEST(Converters, ATenAllDimAllTrueConvertsCorrectly) {
+  const auto graph = R"IR(
+    graph(%0 : Tensor):
+      %1 : int = prim::Constant[value=1]()
+      %3 : bool = prim::Constant[value=0]()
+      %5 : Tensor = aten::all(%0, %1, %3)
+      return (%5))IR";
+  auto in = at::ones({2, 32}, at::kCUDA);
+  test_body(graph, in);
+}
+
+TEST(Converters, ATenAllDimDynamicConvertsCorrectly) {
+  const auto graph = R"IR(
+    graph(%0 : Tensor):
+      %1 : int = prim::Constant[value=-1]()
+      %3 : bool = prim::Constant[value=0]()
+      %5 : Tensor = aten::all(%0, %1, %3)
+      return (%5))IR";
+  auto in = at::randint(0, 2, {64, 2}, at::kCUDA).to(torch::kHalf);
+  test_body(graph, in, true);
+}
+
 TEST(Converters, UnpackVarLowersCorrectly) {
   const auto graph = R"IR(
       graph(%x.1 : Tensor):

From ce7f122aa47d632120ec7c36e91f359afbe612d8 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Thu, 20 Apr 2023 12:33:02 -0700
Subject: [PATCH 18/25] Correcting rsqrt and rsub operator

---
 .../fx/test/converters/aten_op/test_rsqrt_aten.py             | 4 ++--
 .../fx/test/converters/aten_op/test_rsub_aten.py              | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py
index da3aa30cb7..c80216654c 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py
@@ -15,13 +15,13 @@ class TestRSubConverter(DispatchTestCase):
     def test_rsqrt(self, _, x, alpha):
         class rsqrt(nn.Module):
             def forward(self, input):
-                return torch.rsqrt(input, input, alpha)
+                return torch.rsqrt(input)
 
         inputs = [torch.randn(x) + 1]
         self.run_test(
             rsqrt(),
             inputs,
-            expected_ops=torch.ops.aten.rsqrt,
+            expected_ops={torch.ops.aten.rsqrt.default},
         )
 
 
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
index 9be23fc419..dddd72f732 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
@@ -15,13 +15,13 @@ class TestRSubConverter(DispatchTestCase):
     def test_rsub(self, _, x, alpha):
         class rsub(nn.Module):
             def forward(self, input):
-                return torch.rsub(input, input, alpha)
+                return torch.rsub(input, input, alpha = alpha)
 
         inputs = [torch.randn(x)]
         self.run_test(
             rsub(),
             inputs,
-            expected_ops=torch.ops.aten.rsub,
+            expected_ops={torch.ops.aten.rsub.Tensor},
         )
 
 

From 30c5fd6e654f0ac3a7025c49d60b24cd8f96df40 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Thu, 20 Apr 2023 13:08:34 -0700
Subject: [PATCH 19/25] python linting issues and removing chunk test

---
 .../fx/converters/aten_ops_converters.py      | 25 ++------
 py/torch_tensorrt/fx/converters/operator.py   | 12 +++-
 .../converters/aten_op/test_chunk_aten.py     | 58 -------------------
 .../test/converters/aten_op/test_rsub_aten.py |  2 +-
 4 files changed, 15 insertions(+), 82 deletions(-)
 delete mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py

diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py
index d47f30a790..defa88d18b 100644
--- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py
+++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py
@@ -672,23 +672,6 @@ def aten_ops_squeeze(
     return add_squeeze(network, target, kwargs_new, name)
 
 
-# FIXME: need to confirm lower basic passes
-# @tensorrt_converter(torch.ops.aten.chunk)
-# def aten_ops_chunk(
-#     network: TRTNetwork,
-#     target: Target,
-#     args: Tuple[Argument, ...],
-#     kwargs: Dict[str, Argument],
-#     name: str,
-# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
-#     kwargs_new = {
-#         "input": args[0],
-#         "chunks": args[1],
-#         "dim": args[2],
-#     }
-#     return add_chunk(network, target, kwargs_new, name)
-
-
 @tensorrt_converter(torch.ops.aten.where.self)
 def aten_ops_where(
     network: TRTNetwork,
@@ -705,7 +688,7 @@ def aten_ops_where(
     return add_where(network, target, kwargs_new, name)
 
 
-@tensorrt_converter(torch.ops.aten.rsub)
+@tensorrt_converter(torch.ops.aten.rsub.Tensor)
 def aten_ops_rsub(
     network: TRTNetwork,
     target: Target,
@@ -713,15 +696,17 @@ def aten_ops_rsub(
     kwargs: Dict[str, Argument],
     name: str,
 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
+    if "alpha" in kwargs:
+        alpha = kwargs["alpha"]
     kwargs_new = {
         "input": args[0],
         "other": args[1],
-        "alpha": args[2],
+        "alpha": alpha,
     }
     return add_rsub(network, target, kwargs_new, name)
 
 
-@tensorrt_converter(torch.ops.aten.rsqrt)
+@tensorrt_converter(torch.ops.aten.rsqrt.default)
 def aten_ops_rsqrt(
     network: TRTNetwork,
     target: Target,
diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py
index 1e53b1ccc5..ffd6a1bab5 100644
--- a/py/torch_tensorrt/fx/converters/operator.py
+++ b/py/torch_tensorrt/fx/converters/operator.py
@@ -1526,7 +1526,13 @@ def add_scale(network, target, kwargs, name):
 
 
 def add_rsub(network, target, kwargs, name):
-    scaled_tensor = add_scale(network, target, kwargs, name)
+    kwargs_new = {}
+    if "alpha" in kwargs:
+        kwargs_new["input"] = kwargs["other"]
+        kwargs_new["other"] = kwargs["alpha"]
+        scaled_tensor = add_mul(network, target, kwargs_new, name + "_mul")
+    else:
+        scaled_tensor = kwargs["other"]
     input = kwargs["input"]
     return add_binary_elementwise_layer(
         network,
@@ -1534,7 +1540,7 @@ def add_rsub(network, target, kwargs, name):
         scaled_tensor,
         trt.ElementWiseOperation.SUB,
         target,
-        name,
+        name + "_sub",
     )
 
 
@@ -1546,7 +1552,7 @@ def add_sqrt(network, target, kwargs, name):
 
 def add_rsqrt(network, target, kwargs, name):
     sqrt_trt = add_sqrt(network, target, kwargs, name)
-    div_trt = add_binary_elementwise_layer(
+    return add_binary_elementwise_layer(
         network,
         1,
         sqrt_trt,
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py
deleted file mode 100644
index 8fae6da293..0000000000
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import unittest
-
-import torch
-import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
-from parameterized import param, parameterized
-from torch.testing._internal.common_utils import run_tests
-from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
-
-
-class TestSelectConverterImplicitBatch(DispatchTestCase):
-    @parameterized.expand(
-        [
-            ("select_chunk_dim", 6, 0),
-        ]
-    )
-    def test_chunk(self, _, chunk, dim):
-        class TestModule(torch.nn.Module):
-            def __init__(self):
-                super().__init__()
-
-            def forward(self, input):
-                out = torch.ops.aten.chunk(input, chunk, dim)
-                return out
-
-        input = [torch.randn(11)]
-        self.run_test(
-            TestModule(),
-            input,
-            expected_ops={torch.ops.aten.chunk},
-        )
-
-
-class TestSelectConverterExplicitBatch(DispatchTestCase):
-    @parameterized.expand(
-        [
-            ("select_chunk_dim", 6, 0),
-        ]
-    )
-    def test_chunk(self, _, chunk, dim):
-        class TestModule(torch.nn.Module):
-            def __init__(self):
-                super().__init__()
-
-            def forward(self, input):
-                out = torch.ops.aten.chunk(input, chunk, dim)
-                return out
-
-        input = [torch.randn(12)]
-        self.run_test(
-            TestModule(),
-            input,
-            expected_ops={torch.ops.aten.chunk},
-            test_explicit_precision=True,
-        )
-
-
-if __name__ == "__main__":
-    run_tests()
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
index dddd72f732..268df8ccfd 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py
@@ -15,7 +15,7 @@ class TestRSubConverter(DispatchTestCase):
     def test_rsub(self, _, x, alpha):
         class rsub(nn.Module):
             def forward(self, input):
-                return torch.rsub(input, input, alpha = alpha)
+                return torch.rsub(input, input, alpha=alpha)
 
         inputs = [torch.randn(x)]
         self.run_test(

From 7ab071d91cc09281d8d518ce4f0dd406c6537955 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Thu, 20 Apr 2023 16:00:48 -0700
Subject: [PATCH 20/25] Correcting acc squeeze test

---
 .../fx/test/converters/acc_op/test_squeeze.py              | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py
index d265def896..c9b4776dd3 100644
--- a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py
+++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py
@@ -12,7 +12,12 @@ def forward(self, x):
                 return x.squeeze(2)
 
         inputs = [torch.randn(1, 2, 1)]
-        self.run_test(Squeeze(), inputs, expected_ops={acc_ops.squeeze})
+        self.run_test(
+            Squeeze(),
+            inputs,
+            expected_ops={acc_ops.squeeze},
+            test_implicit_batch_dim=False,
+        )
 
     # Testing with shape=(-1, -1, -1, -1) results in error:
     # AssertionError: We don't support squeeze dynamic dim.

From 36ac0cf341286865cdea67c87fac2f3f9cf8b8b9 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Thu, 20 Apr 2023 17:23:15 -0700
Subject: [PATCH 21/25] test_reshape expected ops aten.reshape since aten.view
 has been removed in lowering

---
 .../fx/test/converters/aten_op/test_reshape_aten.py         | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py
index 538e575d6e..385ec05b8b 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py
@@ -31,7 +31,7 @@ def forward(self, x):
         self.run_test(
             TestModule(target_shape),
             inputs,
-            expected_ops={torch.ops.aten.view.default},
+            expected_ops={torch.ops.aten.reshape},
         )
 
     @parameterized.expand(
@@ -64,7 +64,7 @@ def forward(self, x):
         self.run_test_with_dynamic_shape(
             TestModule(target_shape),
             input_specs,
-            expected_ops={torch.ops.aten.view.default},
+            expected_ops={torch.ops.aten.reshape},
         )
 
     @unittest.skipIf(
@@ -94,7 +94,7 @@ def forward(self, x, y):
         self.run_test_with_dynamic_shape(
             TestModule(),
             input_specs,
-            expected_ops={torch.ops.aten.view.default},
+            expected_ops={torch.ops.aten.reshape},
         )
 
 

From eb851b19880dbc00acf6c69e78dd509e87bd1e81 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Thu, 20 Apr 2023 21:43:07 -0700
Subject: [PATCH 22/25] removing aten.view in lowering pass

---
 py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py        | 1 -
 .../fx/test/converters/aten_op/test_reshape_aten.py         | 6 +++---
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
index 0d6b1c28de..6790962621 100644
--- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
+++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
@@ -258,7 +258,6 @@ def remove_ops(
     for n in module.graph.nodes:
         if n.op == "call_function" and n.target in (
             torch.ops.aten._unsafe_view.default,
-            torch.ops.aten.view.default,
         ):
             modified = True
             node = n
diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py
index 385ec05b8b..538e575d6e 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py
@@ -31,7 +31,7 @@ def forward(self, x):
         self.run_test(
             TestModule(target_shape),
             inputs,
-            expected_ops={torch.ops.aten.reshape},
+            expected_ops={torch.ops.aten.view.default},
         )
 
     @parameterized.expand(
@@ -64,7 +64,7 @@ def forward(self, x):
         self.run_test_with_dynamic_shape(
             TestModule(target_shape),
             input_specs,
-            expected_ops={torch.ops.aten.reshape},
+            expected_ops={torch.ops.aten.view.default},
         )
 
     @unittest.skipIf(
@@ -94,7 +94,7 @@ def forward(self, x, y):
         self.run_test_with_dynamic_shape(
             TestModule(),
             input_specs,
-            expected_ops={torch.ops.aten.reshape},
+            expected_ops={torch.ops.aten.view.default},
         )
 
 

From 6b234e0f34a9a27851eb438d70327c316976368e Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Thu, 20 Apr 2023 22:47:12 -0700
Subject: [PATCH 23/25] layer_norm test

---
 .../aten_op/test_layer_norm_aten.py           | 40 +++++++++----------
 1 file changed, 20 insertions(+), 20 deletions(-)

diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
index cf97e828d0..e204f4ec8b 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
@@ -19,26 +19,26 @@ def forward(self, x):
         )
 
 
-def test_layernorm_with_dynamic_shape(self):
-    class TestModule(torch.nn.Module):
-        def __init__(self):
-            super().__init__()
-            self.ln = torch.nn.LayerNorm([3, 224, 224])
-
-        def forward(self, x):
-            return self.ln(x)
-
-    input_specs = [
-        InputTensorSpec(
-            shape=(-1, 3, 224, 224),
-            dtype=torch.float32,
-            shape_ranges=[(1, 3, 1, 1)],
-        ),
-    ]
-
-    self.run_test_with_dynamic_shape(
-        TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm}
-    )
+    def test_layernorm_with_dynamic_shape(self):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.ln = torch.nn.LayerNorm([3, 224, 224])
+
+            def forward(self, x):
+                return self.ln(x)
+
+        input_specs = [
+            InputTensorSpec(
+                shape=(-1, 3, 224, 224),
+                dtype=torch.float32,
+                shape_ranges=[(1, 3, 1, 1)],
+            ),
+        ]
+
+        self.run_test_with_dynamic_shape(
+            TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm}
+        )
 
 
 if __name__ == "__main__":

From 95c1adab0143bf7d2f1aeb99988bde00f54a6be4 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Thu, 20 Apr 2023 22:49:21 -0700
Subject: [PATCH 24/25] correcting linting error

---
 .../fx/test/converters/aten_op/test_layer_norm_aten.py           | 1 -
 1 file changed, 1 deletion(-)

diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
index e204f4ec8b..6662d91b9a 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
@@ -18,7 +18,6 @@ def forward(self, x):
             TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default}
         )
 
-
     def test_layernorm_with_dynamic_shape(self):
         class TestModule(torch.nn.Module):
             def __init__(self):

From 1a1b809b7b2f90043cbcab0318141f6302057021 Mon Sep 17 00:00:00 2001
From: apbose <apbose694@gmail.com>
Date: Fri, 21 Apr 2023 05:06:07 -0700
Subject: [PATCH 25/25] correcting dynamic shape layer norm

---
 .../fx/test/converters/aten_op/test_layer_norm_aten.py        | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
index 6662d91b9a..fab398ac0f 100644
--- a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
+++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py
@@ -31,12 +31,12 @@ def forward(self, x):
             InputTensorSpec(
                 shape=(-1, 3, 224, 224),
                 dtype=torch.float32,
-                shape_ranges=[(1, 3, 1, 1)],
+                shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))],
             ),
         ]
 
         self.run_test_with_dynamic_shape(
-            TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm}
+            TestModule(), input_specs, expected_ops={torch.ops.aten.layer_norm.default}
         )