diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index cf57e7c83c..b1406446f1 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -144,6 +144,7 @@ void LowerGraph(std::shared_ptr& g, std::vector& graph, std::str void UnpackAndCastNumToTensor(std::shared_ptr& graph, std::string target_device_name); void UnpackAndCastFull(std::shared_ptr& graph, std::string target_device_name); void ReplaceScalarImplicit(std::shared_ptr& graph); +void ReplaceAtenPad(std::shared_ptr& graph); // utility functions exposed for testing std::string unmangle_cls_name(const std::string& name); diff --git a/core/lowering/passes/replace_aten_pad.cpp b/core/lowering/passes/replace_aten_pad.cpp new file mode 100644 index 0000000000..f99a0349c1 --- /dev/null +++ b/core/lowering/passes/replace_aten_pad.cpp @@ -0,0 +1,125 @@ +#include + +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + +void ReplaceAtenPad(std::shared_ptr& graph) { + for (auto it = graph->block()->nodes().begin(), end = graph->block()->nodes().end(); it != end; ++it) { + if (it->kind() == c10::Symbol::fromQualString("aten::pad")) { + // aten::pad(Tensor self, int[] pad, str mode='constant', float? value=None) -> (Tensor) + auto mode = it->inputs()[2]; + if (mode->type()->isSubtypeOf(c10::StringType::get())) { + std::string mode_str = torch::jit::toIValue(mode)->to(); + if (mode_str == "reflect") { + auto pad = it->inputs()[1]; + c10::List pad_list = torch::jit::toIValue(pad)->to>(); + if (pad_list.size() == 2) { + // aten::reflection_pad1d(Tensor self, int[2] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::reflection_pad1d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if (pad_list.size() == 4) { + // aten::reflection_pad2d(Tensor self, int[4] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::reflection_pad2d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if (pad_list.size() == 6) { + LOG_ERROR("Torch-TRT doesn't support aten::reflection_pad3d currently."); + } + + } else if (mode_str == "replicate") { + auto pad = it->inputs()[1]; + c10::List pad_list = torch::jit::toIValue(pad)->to>(); + if (pad_list.size() == 2) { + // aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::replication_pad1d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if (pad_list.size() == 4) { + // aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::replication_pad2d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if (pad_list.size() == 6) { + // aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::replication_pad3d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } + + } else if (mode_str == "constant") { + // aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::constant_pad_nd"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1], it->inputs()[3]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if (mode_str == "circular") { + LOG_ERROR("Torch-TRT doesn't support circular padding currently."); + } + } + } + } + LOG_GRAPH("Post map aten::pad -> aten::constant_pad_nd/aten::reflection_padXd/aten::replication_padXd: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 7f4e53d8a6..8cc2c3a1e9 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -95,6 +95,10 @@ lowering_test( name = "test_rewrite_inputs_with_params", ) +lowering_test( + name = "test_replace_aten_pad_pass", +) + test_suite( name = "lowering_tests", tests = [ @@ -111,6 +115,7 @@ test_suite( ":test_remove_detach_pass", ":test_remove_dropout_pass", ":test_remove_unnecessary_casts", + ":test_replace_aten_pad_pass", ":test_rewrite_inputs_with_params", ":test_unpack_hardsigmoid", ":test_unpack_hardswish", diff --git a/tests/core/lowering/test_replace_aten_pad_pass.cpp b/tests/core/lowering/test_replace_aten_pad_pass.cpp new file mode 100644 index 0000000000..ddad5efb16 --- /dev/null +++ b/tests/core/lowering/test_replace_aten_pad_pass.cpp @@ -0,0 +1,224 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, AtenPadConstantCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="constant"]() + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %2 : Scalar = prim::Constant[value=0.0]() + %3 : Tensor = aten::constant_pad_nd(%0, %1, %2) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, AtenPadReflect1dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="reflect"]() + %1 : int[] = prim::Constant[value=[2, 3]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3]]() + %3 : Tensor = aten::reflection_pad1d(%0, %1) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, AtenPadReflect2dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="reflect"]() + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : Tensor = aten::reflection_pad2d(%0, %1) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, AtenPadReplicate1dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="replicate"]() + %1 : int[] = prim::Constant[value=[2, 3]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3]]() + %3 : Tensor = aten::replication_pad1d(%0, %1) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, AtenPadReplicate2dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="replicate"]() + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : Tensor = aten::replication_pad2d(%0, %1) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, AtenPadReplicate3dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="replicate"]() + %1 : int[] = prim::Constant[value=[2, 3, 2, 3, 1, 4]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3, 2, 3, 1, 4]]() + %3 : Tensor = aten::replication_pad3d(%0, %1) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4, 5, 3}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} \ No newline at end of file