diff --git a/core/conversion/converters/impl/interpolate.cpp b/core/conversion/converters/impl/interpolate.cpp index fad2ca5121..b9a5f631b0 100644 --- a/core/conversion/converters/impl/interpolate.cpp +++ b/core/conversion/converters/impl/interpolate.cpp @@ -520,6 +520,37 @@ auto interpolate_registrations TORCHTRT_UNUSED = resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); } + return true; + }}) + .pattern( + {"aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto grid = args[1].ITensorOrFreeze(ctx); + auto interpolation_mode = args[2].unwrapToInt(); + auto padding_mode = args[3].unwrapToInt(); + auto align_corners = args[4].unwrapToBool(); + + static const auto sample_map = std::map{ + {0, nvinfer1::SampleMode::kFILL}, + {1, nvinfer1::SampleMode::kCLAMP}, + {2, nvinfer1::SampleMode::kREFLECT}}; + + static const auto interpolation_map = std::map{ + {0, nvinfer1::InterpolationMode::kLINEAR}, + {1, nvinfer1::InterpolationMode::kNEAREST}, + {2, nvinfer1::InterpolationMode::kCUBIC}}; + + auto grid_sample_layer = ctx->net->addGridSample(*in, *grid); + TORCHTRT_CHECK( + grid_sample_layer, "Unable to create grid_sample layer from node: " << util::node_info(n)); + + grid_sample_layer->setAlignCorners(align_corners); + grid_sample_layer->setSampleMode(sample_map.at(padding_mode)); + grid_sample_layer->setInterpolationMode(interpolation_map.at(interpolation_mode)); + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], grid_sample_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; }}); diff --git a/tests/core/conversion/converters/test_interpolate.cpp b/tests/core/conversion/converters/test_interpolate.cpp index 22931bf9ec..c3b92c3be1 100644 --- a/tests/core/conversion/converters/test_interpolate.cpp +++ b/tests/core/conversion/converters/test_interpolate.cpp @@ -377,3 +377,99 @@ ATEN_INTERPOLATE_STATIC_ONLY_TEST( %7 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %6) return (%7))IR", std::vector({10, 2, 2, 2, 2})); + +TEST(Converters, GridSampleConvertsCorrectly) { + const auto graph = R"IR( + graph(%input : Tensor, %grid : Tensor): + %5 : int = prim::Constant[value=2]() + %6 : int = prim::Constant[value=2]() + %7 : bool = prim::Constant[value=1]() + %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7) + return (%8))IR"; + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA); + auto d = at::linspace(-1, 1, 8); + auto mesh = at::meshgrid({d, d}); + auto mesh_x = mesh[0]; + auto mesh_y = mesh[1]; + auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA); + + auto trt_input = input.clone(); + auto trt_grid = grid.clone(); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid}); + + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid}); + + for (size_t i = 0; i < jit_results.size(); i++) { + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6)); + } +} + +TEST(Converters, GridSampleOptions1ConvertsCorrectly) { + const auto graph = R"IR( + graph(%input : Tensor, %grid : Tensor): + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %7 : bool = prim::Constant[value=0]() + %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7) + return (%8))IR"; + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA); + auto d = at::linspace(-1, 1, 8); + auto mesh = at::meshgrid({d, d}); + auto mesh_x = mesh[0]; + auto mesh_y = mesh[1]; + auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA); + + auto trt_input = input.clone(); + auto trt_grid = grid.clone(); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid}); + + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid}); + + for (size_t i = 0; i < jit_results.size(); i++) { + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6)); + } +} + +TEST(Converters, GridSampleOptions2ConvertsCorrectly) { + const auto graph = R"IR( + graph(%input : Tensor, %grid : Tensor): + %5 : int = prim::Constant[value=0]() + %6 : int = prim::Constant[value=0]() + %7 : bool = prim::Constant[value=0]() + %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7) + return (%8))IR"; + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA); + auto d = at::linspace(-1, 1, 8); + auto mesh = at::meshgrid({d, d}); + auto mesh_x = mesh[0]; + auto mesh_y = mesh[1]; + auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA); + + auto trt_input = input.clone(); + auto trt_grid = grid.clone(); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid}); + + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid}); + + for (size_t i = 0; i < jit_results.size(); i++) { + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6)); + } +} \ No newline at end of file