From 11e4830b6d12d806e6acdf7b626e96d9b43e0baa Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Mon, 27 Feb 2023 15:51:02 -0800 Subject: [PATCH] aten::index fix --- core/conversion/converters/impl/select.cpp | 2 +- .../conversion/converters/test_select.cpp | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index c569a6088e..dab9670cc7 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -337,7 +337,7 @@ auto select_registrations TORCHTRT_UNUSED = // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices // from - auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0); + auto gather_layer = ctx->net->addGather(*in, *indicesTensor, adv_idx_indices[0]); TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); auto gather_out = gather_layer->getOutput(0); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 991a1b792c..bb2402bcaa 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -1093,6 +1093,32 @@ TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } +TEST(Converters, ATenIndexTensorNoneIdx1ConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor): + %5 : NoneType = prim::Constant() + %18 : Tensor?[] = prim::ListConstruct(%5, %index0) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {1, 3, 480, 928}, {at::kCUDA}); + auto index0 = at::tensor({2, 1, 0}, {at::kCUDA}).to(torch::kLong); + + auto index0_trt = index0.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenUnbindConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):