diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 02a49f1bdb..b5405e118c 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -67,6 +67,34 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s return true; } +nvinfer1::ITensor* roll( + ConversionCtx* ctx, + nvinfer1::ITensor* in, + int shift, + int dim, + const std::vector& in_shape) { + auto in_dim = in_shape[dim]; + + auto start = (in_dim - shift) % in_dim; + // Behavior of % is different in C++ vs Python for negative numbers. This + // corrects the difference. + if (start < 0) { + start = start + in_dim; + } + at::Tensor index0 = at::arange(start, in_dim, 1, torch::kInt32); + at::Tensor index; + if (start == 0) { + index = index0; + } else { + at::Tensor index1 = at::arange(start, torch::kInt32); + index = at::cat({index0, index1}, 0); + } + auto index_tensor = tensor_to_const(ctx, index); + auto gather_layer = ctx->net->addGather(*in, *index_tensor, dim); + auto out = gather_layer->getOutput(0); + return out; +} + auto select_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() .pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))", @@ -200,6 +228,69 @@ auto select_registrations TORCHTRT_UNUSED = LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + return true; + }}) + .pattern({"aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto shifts = args[1].unwrapToIntList().vec(); + auto dims = args[2].unwrapToIntList().vec(); + + TORCHTRT_CHECK(dims.size() == shifts.size(), "dims.size() should be equal to shifts.size()"); + if (ctx->input_is_dynamic) { + TORCHTRT_THROW_ERROR("aten::roll is currently not support in dynamic input shape compilation"); + } else { + auto in_shape = util::toVec(in->getDimensions()); + for (size_t i = 0; i < dims.size(); i++) { + auto dim = dims[i] < 0 ? (in_shape.size() + dims[i]) : dims[i]; + TORCHTRT_CHECK(dim < in_shape.size(), "Dimension out of range"); + in = roll(ctx, in, shifts[i], dim, in_shape); + } + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + + return true; + } + }}) + .pattern( + {"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto ts = args[1].IValue()->toListRef(); + + std::vector tensors; + for (auto t : ts) { + if (t.isTensor()) { + auto torch_tensor = t.toTensor(); + tensors.push_back(tensor_to_const(ctx, torch_tensor)); + } else { + auto cont = t.toCustomClass(); + tensors.push_back(cont->tensor()); + } + } + + // In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several + // indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only + // indexes the self tensor along dimension 0. + TORCHTRT_CHECK( + tensors.size() == 1, + "In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0."); + auto indicesTensor = tensors[0]; + // Set datatype for indices tensor to INT32 + auto identity = ctx->net->addIdentity(*indicesTensor); + identity->setOutputType(0, nvinfer1::DataType::kINT32); + indicesTensor = identity->getOutput(0); + + // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices + // from + auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0); + TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); + auto gather_out = gather_layer->getOutput(0); + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) .pattern( diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index d27510b628..eb0651890b 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -195,6 +195,84 @@ TEST(Converters, ATenEmbeddingConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenRollConvertsCorrectly) { + const auto graph = R"IR( + graph(%1 : Tensor): + %2 : int[] = prim::Constant[value=[1, 0, 3, 7]]() + %3 : int[] = prim::Constant[value=[0, 1, 2, 3]]() + %4 : Tensor = aten::roll(%1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + // Run Pytorch + auto in = at::randint(1, 10, {2, 3, 4, 5}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRollShiftsNegativeConvertsCorrectly) { + const auto graph = R"IR( + graph(%1 : Tensor): + %2 : int[] = prim::Constant[value=[0, -3, -3]]() + %3 : int[] = prim::Constant[value=[1, 2, 3]]() + %4 : Tensor = aten::roll(%1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + // Run Pytorch + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRollDimsNegativeConvertsCorrectly) { + const auto graph = R"IR( + graph(%1 : Tensor): + %2 : int[] = prim::Constant[value=[0, -3, -3]]() + %3 : int[] = prim::Constant[value=[1, 2, -1]]() + %4 : Tensor = aten::roll(%1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + // Run Pytorch + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + TEST(Converters, ATenSliceConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): @@ -463,3 +541,29 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { ASSERT_TRUE( torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } + +TEST(Converters, ATenIndexTensorConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index : Tensor): + %18 : Tensor?[] = prim::ListConstruct(%index) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10}, {at::kCUDA}); + auto in2 = at::full({2}, 4, {at::kCUDA}); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto in2_trt = at::full({2}, 4, {options}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2_trt}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +}