diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index cb1fd97327..ce38a1f292 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -153,6 +153,7 @@ void LowerGraph(std::shared_ptr& g, std::vector& graph, std::st void UnpackAndCastFull(std::shared_ptr& graph, std::string target_device_name); void ReplaceScalarImplicit(std::shared_ptr& graph); void ReplaceAtenPad(std::shared_ptr& graph); +void ReplaceTileWithRepeat(std::shared_ptr& graph); // utility functions exposed for testing std::string unmangle_cls_name(const std::string& name); diff --git a/core/lowering/passes/tile_to_repeat.cpp b/core/lowering/passes/tile_to_repeat.cpp new file mode 100644 index 0000000000..7ecb2bc13d --- /dev/null +++ b/core/lowering/passes/tile_to_repeat.cpp @@ -0,0 +1,25 @@ +#include "core/util/prelude.h" +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { +void ReplaceTileWithRepeat(std::shared_ptr& graph) { + std::string tile_pattern = R"IR( + graph(%input, %1): + %2 = aten::tile(%input, %1) + return (%2))IR"; + std::string repeat_pattern = R"IR( + graph(%input, %1): + %2 = aten::repeat(%input, %1) + return (%2))IR"; + torch::jit::SubgraphRewriter tile_to_repeat; + tile_to_repeat.RegisterRewritePattern(tile_pattern, repeat_pattern); + tile_to_repeat.runOnGraph(graph); + LOG_GRAPH("Mapping tile -> repeat: " << *graph); +} +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/docsrc/contributors/lowering.rst b/docsrc/contributors/lowering.rst index 956c2004e1..a82f497ed2 100644 --- a/docsrc/contributors/lowering.rst +++ b/docsrc/contributors/lowering.rst @@ -205,3 +205,10 @@ Unroll Loops `torch/csrc/jit/passes/loop_unrolling.h `_ Unrolls the operations of compatable loops (e.g. sufficently short) so that you only have to go through the loop once. + +Replace Tile with Repeat +*************************************** + + `Torch-TensorRT/core/lowering/passes/tile_to_repeat.cpp `_ + +Removes dropout operators since we are doing inference. diff --git a/tests/core/conversion/converters/test_expand.cpp b/tests/core/conversion/converters/test_expand.cpp index 77b42fb1d9..341fe29aa4 100644 --- a/tests/core/conversion/converters/test_expand.cpp +++ b/tests/core/conversion/converters/test_expand.cpp @@ -1,6 +1,7 @@ #include #include #include "core/compiler.h" +#include "core/lowering/passes/passes.h" #include "gtest/gtest.h" #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" @@ -670,6 +671,131 @@ TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicIn ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenTileConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[4, 1]]() + %3 : Tensor = aten::tile(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); + + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(jit_in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenTileRepeatRankConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[4, 1, 2]]() + %3 : Tensor = aten::tile(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); + + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(jit_in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenTileConvertsCorrectlyWithDynamicInput) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[4, 1]]() + %3 : Tensor = aten::tile(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); + + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(jit_in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenTile3dConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[2, 2, 2]]() + %3 : Tensor = aten::tile(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); + + auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(jit_in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenTile3dConvertsCorrectlyWithDynamicInput) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[2, 2, 2]]() + %3 : Tensor = aten::tile(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); + + auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(jit_in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenMeshGridConvertsCorrectly) { const auto graph = R"IR( graph(%x : Tensor, %y : Tensor, %z : Tensor): diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 081443ecb3..30f1fd8e5a 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -103,6 +103,10 @@ lowering_test( name = "test_replace_aten_pad_pass", ) +lowering_test( + name = "test_tile_to_repeat_pass", +) + test_suite( name = "lowering_tests", tests = [ @@ -122,6 +126,7 @@ test_suite( ":test_remove_unnecessary_casts", ":test_replace_aten_pad_pass", ":test_rewrite_inputs_with_params", + ":test_tile_to_repeat_pass", ":test_unpack_hardsigmoid", ":test_unpack_hardswish", ":test_unpack_reduce_ops", diff --git a/tests/core/lowering/test_tile_to_repeat_pass.cpp b/tests/core/lowering/test_tile_to_repeat_pass.cpp new file mode 100644 index 0000000000..8357007091 --- /dev/null +++ b/tests/core/lowering/test_tile_to_repeat_pass.cpp @@ -0,0 +1,26 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, TileToRepeatCorrectly) { + std::string source_graph = R"IR( + graph(%input, %dim): + %o : Tensor = aten::tile(%input, %dim) + return (%o))IR"; + std::string target_graph = R"IR( + graph(%input, %dim): + %o : Tensor = aten::repeat(%input, %dim) + return (%o))IR"; + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +}