From df01f966321362a3390f8aff08ce415031f08feb Mon Sep 17 00:00:00 2001
From: Anurag Dixit <a.dixit91@gmail.com>
Date: Wed, 5 Apr 2023 10:34:03 -0700
Subject: [PATCH 1/2] feat: Added support for aten::unflatten converter

Signed-off-by: Anurag Dixit <a.dixit91@gmail.com>
---
 core/conversion/converters/impl/shuffle.cpp   |  98 ++++++++++++
 .../conversion/converters/test_shuffle.cpp    |  52 +++++++
 tests/cpp/BUILD                               |  14 ++
 tests/cpp/test_dynamic_size.cpp               | 139 ++++++++++++++++++
 4 files changed, 303 insertions(+)
 create mode 100644 tests/cpp/test_dynamic_size.cpp

diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp
index 2df7e653ef..314fe74582 100644
--- a/core/conversion/converters/impl/shuffle.cpp
+++ b/core/conversion/converters/impl/shuffle.cpp
@@ -64,6 +64,104 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
                LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
                return true;
              }})
+        .pattern(
+            {"aten::unflatten.int(Tensor self, int dim, int[] sizes) -> (Tensor)",
+              [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
+                auto in = args[0].ITensorOrFreeze(ctx);
+                auto dim = args[1].unwrapToInt();
+                auto in_shape = util::toVec(in->getDimensions());
+                std::vector<int64_t> new_shape;
+                nvinfer1::ITensor* shape_tensor;
+                if (ctx->input_is_dynamic) {
+                  /*
+                   * In case the dim is negative
+                   * If the dim in negative range is larger than in_shape,
+                   * then it should run into index out of bound error as expected
+                   */
+                  if (dim < 0) {
+                    dim = in_shape.size() + dim;
+                  }
+                  std::cout << "Dynamic shape case" << std::endl;
+                  LOG_DEBUG("Using dynamic version of reshape layer");
+                  if (args[2].isITensorList()) {
+                    std::cout << "isTensorList case" << std::endl;
+                    LOG_DEBUG("Shape tensor is an ITensorList");
+                    auto expand_shape = args[2].unwrapToITensorList();
+                    auto shape_layer = ctx->net->addShape(*in);
+                    TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
+                    auto shape_1d_tensor = shape_layer->getOutput(0);
+
+                    std::vector<int> before_dim_indices_vector(dim);
+                    std::iota(before_dim_indices_vector.begin(), before_dim_indices_vector.end(), 0);
+
+                    nvinfer1::ITensor* before_dim_gather_out = nullptr;
+                    if(before_dim_indices_vector.size()){
+                      at::Tensor before_dim_indices = torch::tensor(before_dim_indices_vector).to(torch::kI32);
+                      auto before_dim_indices_out = converters::tensor_to_const(ctx, before_dim_indices);
+                      auto before_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *before_dim_indices_out, 0);
+                      TORCHTRT_CHECK(before_dim_gather_layer, "Unable to create gather layer from node: " << *n);
+                      before_dim_gather_out = before_dim_gather_layer->getOutput(0);
+                    }
+
+                    std::vector<int> after_dim_indices_vector(in_shape.size() - (dim + 1));
+                    std::iota(after_dim_indices_vector.begin(), after_dim_indices_vector.end(), dim + 1);
+
+                    nvinfer1::ITensor* after_dim_gather_out = nullptr;
+                    if(after_dim_indices_vector.size()){ 
+                      at::Tensor after_dim_indices = torch::tensor(after_dim_indices_vector).to(torch::kI32);
+                      auto after_dim_indices_out = converters::tensor_to_const(ctx, after_dim_indices);
+                      auto after_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *after_dim_indices_out, 0);
+                      TORCHTRT_CHECK(after_dim_gather_layer, "Unable to create gather layer from node: " << *n);
+                      after_dim_gather_out = after_dim_gather_layer->getOutput(0);
+                    }
+
+                    std::vector<nvinfer1::ITensor*> shape_tensors;
+                    if(before_dim_gather_out){
+                      shape_tensors.push_back(before_dim_gather_out);
+                    }
+                    for(auto new_shape_tensor : expand_shape){
+                      shape_tensors.push_back(new_shape_tensor);
+                    }
+                    if(after_dim_gather_out){
+                      shape_tensors.push_back(after_dim_gather_out);
+                    }
+
+                    auto shape_cat_layer = ctx->net->addConcatenation(shape_tensors.data(), shape_tensors.size());
+                    TORCHTRT_CHECK(shape_cat_layer, "Unable to create cat layer from node: " << *n);
+                    shape_tensor = shape_cat_layer->getOutput(0);
+                    LOG_DEBUG("Shape tensor shape: " << shape_tensor->getDimensions());
+                  } else if (args[2].isIntList()) {
+                    auto shape_vec = args[2].unwrapToIntList().vec();                    
+                    // New shape 
+                    new_shape.insert(new_shape.end(), in_shape.begin(), in_shape.begin() + dim);
+                    new_shape.insert(new_shape.end(), shape_vec.begin(), shape_vec.end());
+                    new_shape.insert(new_shape.end(), in_shape.begin() + dim + 1, in_shape.end());
+
+                    shape_tensor = tensor_to_const(ctx, torch::tensor(new_shape).to(torch::kI32));
+                  } else {
+                    LOG_ERROR(
+                      "Invalid IValue type of " <<  args[2].ivalue_type()
+                                                << " detected for shape tensor from node: " << *n);
+                  }
+                }
+                else {
+                  new_shape = torch::unflatten(torch::rand(in_shape), dim, args[2].unwrapToIntList().vec()).sizes().vec();
+                }
+                auto shuffle = ctx->net->addShuffle(*in);
+                shuffle->setName(util::node_info(n).c_str());
+                TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
+
+                if (ctx->input_is_dynamic) {
+                  shuffle->setInput(1, *shape_tensor);
+                } else {
+                  shuffle->setReshapeDimensions(util::toDims(new_shape));
+                }
+
+                auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
+                LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
+
+                return true;
+              }})
         .pattern(
             {"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
              [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
diff --git a/tests/core/conversion/converters/test_shuffle.cpp b/tests/core/conversion/converters/test_shuffle.cpp
index fad50c9340..9c972ba988 100644
--- a/tests/core/conversion/converters/test_shuffle.cpp
+++ b/tests/core/conversion/converters/test_shuffle.cpp
@@ -364,3 +364,55 @@ TEST(Converters, ATenPixelShuffle5DConvertsCorrectly) {
 
   ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
 }
+
+TEST(Converters, ATenUnflattenConvertsCorrectly) {
+  const auto graph = R"IR(
+    graph(%x.1 : Tensor):
+      %2 : int = prim::Constant[value=1]()
+      %3 : int = prim::Constant[value=512]()
+      %4 : int = prim::Constant[value=1]()
+      %5 : int = prim::Constant[value=1]()
+      %6 : int[] = prim::ListConstruct(%3, %4, %5)
+      %7 : Tensor = aten::unflatten(%x.1, %2, %6)
+      return (%7))IR";
+
+  auto g = std::make_shared<torch::jit::Graph>();
+  torch::jit::parseIR(graph, g.get());
+
+  auto in = at::randint(0, 5, {1, 512}, {at::kCUDA});
+  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+
+  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
+
+  in = at::clone(in);
+  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
+
+  ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
+}
+
+TEST(Converters, ATenUnflattenNegativeDimConvertsCorrectly) {
+  const auto graph = R"IR(
+    graph(%x.1 : Tensor):
+      %2 : int = prim::Constant[value=-1]()
+      %3 : int = prim::Constant[value=512]()
+      %4 : int = prim::Constant[value=1]()
+      %5 : int = prim::Constant[value=1]()
+      %6 : int[] = prim::ListConstruct(%3, %4, %5)
+      %7 : Tensor = aten::unflatten(%x.1, %2, %6)
+      return (%7))IR";
+
+  auto g = std::make_shared<torch::jit::Graph>();
+  torch::jit::parseIR(graph, g.get());
+
+  auto in = at::randint(0, 5, {1, 512}, {at::kCUDA});
+  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+
+  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
+
+  in = at::clone(in);
+  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
+
+  ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
+}
\ No newline at end of file
diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD
index c34aa09372..709187e1b2 100644
--- a/tests/cpp/BUILD
+++ b/tests/cpp/BUILD
@@ -16,6 +16,7 @@ test_suite(
         ":test_compiled_modules",
         ":test_default_input_types",
         ":test_dynamic_fallback",
+        ":test_dynamic_size",
         ":test_example_tensors",
         ":test_module_fallback",
         ":test_modules_as_engines",
@@ -32,6 +33,7 @@ test_suite(
         ":test_compiled_modules",
         ":test_default_input_types",
         ":test_dynamic_fallback",
+        ":test_dynamic_size",
         ":test_example_tensors",
         ":test_module_fallback",
         ":test_modules_as_engines",
@@ -142,6 +144,18 @@ cc_test(
     }),
 )
 
+cc_test(
+    name = "test_dynamic_size",
+    srcs = ["test_dynamic_size.cpp"],
+    deps = [
+        "//tests/util",
+        "@googletest//:gtest_main",
+    ] + select({
+        ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
+        "//conditions:default": ["@libtorch//:libtorch"],
+    }),
+)
+
 cc_test(
     name = "test_collections",
     srcs = ["test_collections.cpp"],
diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp
new file mode 100644
index 0000000000..bd2bf90d0d
--- /dev/null
+++ b/tests/cpp/test_dynamic_size.cpp
@@ -0,0 +1,139 @@
+#include <torch/torch.h>
+#include <string>
+#include "core/compiler.h"
+#include "gtest/gtest.h"
+#include "tests/util/util.h"
+#include "torch/csrc/jit/ir/irparser.h"
+
+TEST(Converters, ATenUnflattenDynShapeShapeCorrectly) {
+  const auto graph = R"IR(
+    graph(%x.1 : Tensor):
+            %2 : int = prim::Constant[value=1]()
+            %3 : int = prim::Constant[value=512]()
+            %4 : int = prim::Constant[value=1]()
+            %5 : int = prim::Constant[value=1]()
+            %6 : int[] = prim::ListConstruct(%3, %4, %5)
+            %7 : Tensor = aten::unflatten(%x.1, %2, %6)
+            return (%7))IR";
+
+  auto g = std::make_shared<torch::jit::Graph>();
+
+  torch::jit::parseIR(graph, g.get());
+
+  auto in = at::randint(0, 10, {1, 512}, {at::kCUDA});
+
+  auto jit_in = at::clone(in);
+  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
+
+  auto trt_in = at::clone(in);
+  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
+
+  ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
+}
+
+TEST(Converters, ATenUnflattenDynShapeNegativeDimsShapeCorrectly) {
+  const auto graph = R"IR(
+    graph(%x.1 : Tensor):
+            %2 : int = prim::Constant[value=-2]()
+            %3 : int = prim::Constant[value=512]()
+            %4 : int = prim::Constant[value=1]()
+            %5 : int = prim::Constant[value=1]()
+            %6 : int[] = prim::ListConstruct(%3, %4, %5)
+            %7 : Tensor = aten::unflatten(%x.1, %2, %6)
+            return (%7))IR";
+
+  auto g = std::make_shared<torch::jit::Graph>();
+
+  torch::jit::parseIR(graph, g.get());
+
+  auto in = at::randint(0, 10, {1, 512, 2}, {at::kCUDA});
+
+  auto jit_in = at::clone(in);
+  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
+
+  auto trt_in = at::clone(in);
+  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
+
+  ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
+}
+
+TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectly) {
+  const auto graph = R"IR(
+    graph(%x.1 : Tensor):
+            %2 : int = prim::Constant[value=1]()
+            %3 : int = aten::size(%x.1, %2)
+            %4 : int = prim::Constant[value=256]()
+            %5 : int = prim::Constant[value=2]()
+            %6 : int[] = prim::ListConstruct(%4, %5)
+            %7 : Tensor = aten::unflatten(%x.1, %2, %6)
+            return (%7))IR";
+  auto g = std::make_shared<torch::jit::Graph>();
+  torch::jit::parseIR(graph, g.get());
+
+  auto in = at::randint(0, 10, {1, 512, 1}, {at::kCUDA});
+
+  auto jit_in = at::clone(in);
+  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
+
+  auto trt_in = at::clone(in);
+  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
+
+  ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
+}
+
+TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectlyFirstDim) {
+  const auto graph = R"IR(
+    graph(%x.1 : Tensor):
+            %1 : int = prim::Constant[value=0]()
+            %2 : int = prim::Constant[value=1]()
+            %3 : int = aten::size(%x.1, %1)
+            %6 : int[] = prim::ListConstruct(%2, %2, %3, %2, %2)
+            %7 : Tensor = aten::unflatten(%x.1, %1, %6)
+            return (%7))IR";
+  auto g = std::make_shared<torch::jit::Graph>();
+  torch::jit::parseIR(graph, g.get());
+
+  auto in = at::randint(0, 10, {64, 512, 1}, {at::kCUDA});
+
+  auto jit_in = at::clone(in);
+  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
+
+  auto trt_in = at::clone(in);
+  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
+
+  ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
+}
+
+TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectlyLastDim) {
+  const auto graph = R"IR(
+    graph(%x.1 : Tensor):
+            %1 : int = prim::Constant[value=2]()
+            %2 : int = prim::Constant[value=1]()
+            %3 : int = aten::size(%x.1, %1)
+            %5 : int = prim::Constant[value=2]()
+            %6 : int[] = prim::ListConstruct(%3, %2, %2)
+            %7 : Tensor = aten::unflatten(%x.1, %5, %6)
+            return (%7))IR";
+  auto g = std::make_shared<torch::jit::Graph>();
+  torch::jit::parseIR(graph, g.get());
+
+  auto in = at::randint(0, 10, {1, 512, 9}, {at::kCUDA});
+
+  auto jit_in = at::clone(in);
+  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
+
+  auto trt_in = at::clone(in);
+  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
+  auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
+
+  ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
+}
\ No newline at end of file

From 0c23a2dd444942262f582bca5b46bdbdda0bd6ba Mon Sep 17 00:00:00 2001
From: Anurag Dixit <a.dixit91@gmail.com>
Date: Wed, 5 Apr 2023 10:44:21 -0700
Subject: [PATCH 2/2] chore: Trigger lint

Signed-off-by: Anurag Dixit <a.dixit91@gmail.com>
---
 tests/cpp/test_dynamic_size.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp
index bd2bf90d0d..5870796547 100644
--- a/tests/cpp/test_dynamic_size.cpp
+++ b/tests/cpp/test_dynamic_size.cpp
@@ -5,6 +5,7 @@
 #include "tests/util/util.h"
 #include "torch/csrc/jit/ir/irparser.h"
 
+
 TEST(Converters, ATenUnflattenDynShapeShapeCorrectly) {
   const auto graph = R"IR(
     graph(%x.1 : Tensor):