Skip to content

Commit 2fe72e5

Browse files
committed
Add converter support for aten::grid_sampler
1 parent 7f14221 commit 2fe72e5

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

core/conversion/converters/impl/interpolate.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,37 @@ auto interpolate_registrations TORCHTRT_UNUSED =
520520
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
521521
}
522522

523+
return true;
524+
}})
525+
.pattern(
526+
{"aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor",
527+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
528+
auto in = args[0].ITensorOrFreeze(ctx);
529+
auto grid = args[1].ITensorOrFreeze(ctx);
530+
auto interpolation_mode = args[2].unwrapToInt();
531+
auto padding_mode = args[3].unwrapToInt();
532+
auto align_corners = args[4].unwrapToBool();
533+
534+
static const auto sample_map = std::map<int, nvinfer1::SampleMode>{
535+
{0, nvinfer1::SampleMode::kFILL},
536+
{1, nvinfer1::SampleMode::kCLAMP},
537+
{2, nvinfer1::SampleMode::kREFLECT}};
538+
539+
static const auto interpolation_map = std::map<int, nvinfer1::ResizeMode>{
540+
{0, nvinfer1::InterpolationMode::kLINEAR},
541+
{1, nvinfer1::InterpolationMode::kNEAREST},
542+
{2, nvinfer1::InterpolationMode::kCUBIC}};
543+
544+
auto grid_sample_layer = ctx->net->addGridSample(*in, *grid);
545+
TORCHTRT_CHECK(
546+
grid_sample_layer, "Unable to create grid_sample layer from node: " << util::node_info(n));
547+
548+
grid_sample_layer->setAlignCorners(align_corners);
549+
grid_sample_layer->setSampleMode(sample_map.at(padding_mode));
550+
grid_sample_layer->setInterpolationMode(interpolation_map.at(interpolation_mode));
551+
552+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], grid_sample_layer->getOutput(0));
553+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
523554
return true;
524555
}});
525556

tests/core/conversion/converters/test_interpolate.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,99 @@ ATEN_INTERPOLATE_STATIC_ONLY_TEST(
377377
%7 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %6)
378378
return (%7))IR",
379379
std::vector<int64_t>({10, 2, 2, 2, 2}));
380+
381+
TEST(Converters, GridSampleConvertsCorrectly) {
382+
const auto graph = R"IR(
383+
graph(%input : Tensor, %grid : Tensor):
384+
%5 : int = prim::Constant[value=2]()
385+
%6 : int = prim::Constant[value=2]()
386+
%7 : bool = prim::Constant[value=1]()
387+
%8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
388+
return (%8))IR";
389+
auto g = std::make_shared<torch::jit::Graph>();
390+
391+
torch::jit::parseIR(graph, g.get());
392+
393+
auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA);
394+
auto d = at::linspace(-1, 1, 8);
395+
auto mesh = at::meshgrid({d, d});
396+
auto mesh_x = mesh[0];
397+
auto mesh_y = mesh[1];
398+
auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA);
399+
400+
auto trt_input = input.clone();
401+
auto trt_grid = grid.clone();
402+
403+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
404+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid});
405+
406+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid});
407+
408+
for (size_t i = 0; i < jit_results.size(); i++) {
409+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6));
410+
}
411+
}
412+
413+
TEST(Converters, GridSampleOptions1ConvertsCorrectly) {
414+
const auto graph = R"IR(
415+
graph(%input : Tensor, %grid : Tensor):
416+
%5 : int = prim::Constant[value=1]()
417+
%6 : int = prim::Constant[value=1]()
418+
%7 : bool = prim::Constant[value=0]()
419+
%8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
420+
return (%8))IR";
421+
auto g = std::make_shared<torch::jit::Graph>();
422+
423+
torch::jit::parseIR(graph, g.get());
424+
425+
auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA);
426+
auto d = at::linspace(-1, 1, 8);
427+
auto mesh = at::meshgrid({d, d});
428+
auto mesh_x = mesh[0];
429+
auto mesh_y = mesh[1];
430+
auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA);
431+
432+
auto trt_input = input.clone();
433+
auto trt_grid = grid.clone();
434+
435+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
436+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid});
437+
438+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid});
439+
440+
for (size_t i = 0; i < jit_results.size(); i++) {
441+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6));
442+
}
443+
}
444+
445+
TEST(Converters, GridSampleOptions2ConvertsCorrectly) {
446+
const auto graph = R"IR(
447+
graph(%input : Tensor, %grid : Tensor):
448+
%5 : int = prim::Constant[value=0]()
449+
%6 : int = prim::Constant[value=0]()
450+
%7 : bool = prim::Constant[value=0]()
451+
%8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
452+
return (%8))IR";
453+
auto g = std::make_shared<torch::jit::Graph>();
454+
455+
torch::jit::parseIR(graph, g.get());
456+
457+
auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA);
458+
auto d = at::linspace(-1, 1, 8);
459+
auto mesh = at::meshgrid({d, d});
460+
auto mesh_x = mesh[0];
461+
auto mesh_y = mesh[1];
462+
auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA);
463+
464+
auto trt_input = input.clone();
465+
auto trt_grid = grid.clone();
466+
467+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
468+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid});
469+
470+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid});
471+
472+
for (size_t i = 0; i < jit_results.size(); i++) {
473+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6));
474+
}
475+
}

0 commit comments

Comments
 (0)