diff --git a/core/conversion/converters/impl/expand.cpp b/core/conversion/converters/impl/expand.cpp index e090261a2a..e379614ad3 100644 --- a/core/conversion/converters/impl/expand.cpp +++ b/core/conversion/converters/impl/expand.cpp @@ -282,6 +282,116 @@ auto expand_registrations TORCHTRT_UNUSED = auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in); LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions()); + return true; + }}) + .pattern( + {"aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + auto repeats = args[1].unwrapToScalar().to(); + + auto input_shape = self->getDimensions(); + + int dim; + if (args[2].IValue()->isNone()) { + dim = 0; + + // Flatten self tensor + int size; + if (ctx->input_is_dynamic) { + // Set size to -1 if input is dynamic + size = -1; + } else { + size = 1; + for (int i = 0; i < input_shape.nbDims; i++) { + size *= input_shape.d[i]; + } + } + auto flatten = ctx->net->addShuffle(*self); + TORCHTRT_CHECK(flatten, "Unable to create shuffle layer from node: " << *n); + flatten->setReshapeDimensions(util::toDims(std::vector({size}))); + self = flatten->getOutput(0); + input_shape = self->getDimensions(); + } else { + dim = args[2].unwrapToScalar().to(); + } + + if (ctx->input_is_dynamic) { + int dynamic_dims = 0; + for (int idx = 0; idx < input_shape.nbDims; idx++) { + if (input_shape.d[idx] == -1) { + dynamic_dims++; + } + } + + if (dynamic_dims > 1) { + TORCHTRT_THROW_ERROR( + "Repeat_interleave is currently not supported when target shape contains more than one dynamic dimension"); + } + } + + // Insert singleton dimension after desired repeat dimension + std::vector repeat_shape_vec; + for (int j = 0; j < input_shape.nbDims; j++) { + repeat_shape_vec.push_back(input_shape.d[j]); + if (j == dim) { + repeat_shape_vec.push_back(1); + } + } + auto expand = ctx->net->addShuffle(*self); + TORCHTRT_CHECK(expand, "Unable to create shuffle layer from node: " << *n); + auto repeat_shape_dims = util::toDims(repeat_shape_vec); + expand->setReshapeDimensions(repeat_shape_dims); + + // Expand on newly created singleton dimension + repeat_shape_dims.d[dim + 1] = repeats; + std::vector start_vec(repeat_shape_dims.nbDims, 0); + auto start_dims = util::toDims(start_vec); + + std::vector strides_vec(repeat_shape_dims.nbDims, 1); + strides_vec[dim + 1] = 0; + auto strides_dims = util::toDims(strides_vec); + + auto slice = ctx->net->addSlice(*expand->getOutput(0), start_dims, repeat_shape_dims, strides_dims); + + if (ctx->input_is_dynamic) { + auto start_tensor = tensor_to_const(ctx, torch::tensor(start_vec, torch::kInt32)); + + auto expand_output_shape = ctx->net->addShape(*expand->getOutput(0))->getOutput(0); + std::vector repeat_const_vec(repeat_shape_dims.nbDims, 1); + repeat_const_vec[dim + 1] = repeats; + auto repeat_const = tensor_to_const(ctx, torch::tensor(repeat_const_vec, torch::kInt32)); + auto repeat_shape_tensor = + ctx->net + ->addElementWise(*expand_output_shape, *repeat_const, nvinfer1::ElementWiseOperation::kPROD) + ->getOutput(0); + + auto strides_tensor = tensor_to_const(ctx, torch::tensor(strides_vec, torch::kInt32)); + slice->setInput(1, *start_tensor); + slice->setInput(2, *repeat_shape_tensor); + slice->setInput(3, *strides_tensor); + } + + // Collapse repeated dimension back into desired dimension + std::vector collapse_shape_vec; + for (int k = 0; k < repeat_shape_dims.nbDims; k++) { + if (k == dim) { + int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[++k]; + // Set dim size to -1 if repeat is being done on dynamic dim + collapse_dim = std::max(collapse_dim, (int64_t)-1); + collapse_shape_vec.push_back(collapse_dim); + } else { + collapse_shape_vec.push_back(repeat_shape_dims.d[k]); + } + } + auto collapse = ctx->net->addShuffle(*slice->getOutput(0)); + TORCHTRT_CHECK(collapse, "Unable to create shuffle layer from node: " << *n); + collapse->setReshapeDimensions(util::toDims(collapse_shape_vec)); + + collapse->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], collapse->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; }}); diff --git a/tests/core/conversion/converters/test_expand.cpp b/tests/core/conversion/converters/test_expand.cpp index bf62266f32..53630b661a 100644 --- a/tests/core/conversion/converters/test_expand.cpp +++ b/tests/core/conversion/converters/test_expand.cpp @@ -445,3 +445,227 @@ TEST(Converters, ATenRepeatExtraDimsConvertsCorrectlyWithDynamicInput) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } + +TEST(Converters, ATenRepeatInterleaveScalarDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=3]() + %3 : int = prim::Constant[value=1]() + %4 : None = prim::Constant() + %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4) + return (%5))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRepeatInterleaveScalarDimConvertsCorrectlyWithDynamicInput) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=3]() + %3 : int = prim::Constant[value=1]() + %4 : None = prim::Constant() + %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4) + return (%5))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRepeatInterleaveScalarNoDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=3]() + %3 : None = prim::Constant() + %4 : None = prim::Constant() + %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4) + return (%5))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRepeatInterleaveScalarNoDimConvertsCorrectlyWithDynamicInput) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=3]() + %3 : None = prim::Constant() + %4 : None = prim::Constant() + %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4) + return (%5))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRepeatInterleave3dScalarDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=3]() + %3 : int = prim::Constant[value=1]() + %4 : None = prim::Constant() + %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4) + return (%5))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRepeatInterleave3dScalarDimConvertsCorrectlyWithDynamicInput) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=3]() + %3 : int = prim::Constant[value=1]() + %4 : None = prim::Constant() + %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4) + return (%5))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=3]() + %3 : None = prim::Constant() + %4 : None = prim::Constant() + %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4) + return (%5))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicInput) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=3]() + %3 : None = prim::Constant() + %4 : None = prim::Constant() + %5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4) + return (%5))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +}