From 54e022f4719e9765b7747d62e6ee20ef93d69c3d Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 27 Feb 2023 15:13:11 -0800 Subject: [PATCH] fix: Add schemas to conv replace - Add support for transposed conv2d and conv3d, as well as for conv3d - Add testing for all convolution functions, rename test accordingly --- core/lowering/lowering.cpp | 2 + .../lowering/passes/convNd_to_convolution.cpp | 28 ++ core/lowering/passes/passes.h | 2 + tests/core/lowering/BUILD | 4 +- tests/core/lowering/test_conv1d_pass.cpp | 202 --------- tests/core/lowering/test_conv_pass.cpp | 422 ++++++++++++++++++ 6 files changed, 456 insertions(+), 204 deletions(-) delete mode 100644 tests/core/lowering/test_conv1d_pass.cpp create mode 100644 tests/core/lowering/test_conv_pass.cpp diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index b1406446f1..8b3c0c8119 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -124,7 +124,9 @@ void LowerGraph(std::shared_ptr& g, std::vector& graph) { LOG_GRAPH("Post map conv2d -> _convolution: " << *graph); } +void ConvTransposed2DToConvolution(std::shared_ptr& graph) { + const std::string conv_transpose2d_node_kind = "aten::conv_transpose2d"; + const std::string convolution_pattern = R"IR( + graph(%x, %w, %b, %s, %p, %o, %g, %d): + %1 : bool = prim::Constant[value=1]() + %2 : bool = prim::Constant[value=1]() + %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2) + return (%4))IR"; + + // Schema is aten::conv_transpose2d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs + replaceConv(graph->block(), conv_transpose2d_node_kind, convolution_pattern, 8); + LOG_GRAPH("Post map conv_transpose2d -> _convolution: " << *graph); +} + void Conv3DToConvolution(std::shared_ptr& graph) { const std::string conv3d_node_kind = "aten::conv3d"; const std::string convolution_pattern = R"IR( @@ -96,6 +110,20 @@ void Conv3DToConvolution(std::shared_ptr& graph) { LOG_GRAPH("Post map conv3d -> _convolution: " << *graph); } +void ConvTransposed3DToConvolution(std::shared_ptr& graph) { + const std::string conv_transpose3d_node_kind = "aten::conv_transpose3d"; + const std::string convolution_pattern = R"IR( + graph(%x, %w, %b, %s, %p, %o, %g, %d): + %1 : bool = prim::Constant[value=1]() + %2 : bool = prim::Constant[value=1]() + %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2) + return (%4))IR"; + + // Schema is aten::conv_transpose3d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs + replaceConv(graph->block(), conv_transpose3d_node_kind, convolution_pattern, 8); + LOG_GRAPH("Post map conv_transpose3d -> _convolution: " << *graph); +} + } // namespace passes } // namespace lowering } // namespace core diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 94d4353851..a524296e4c 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -15,7 +15,9 @@ void NotateModuleForFallback( void Conv1DToConvolution(std::shared_ptr& graph); void ConvTransposed1DToConvolution(std::shared_ptr& graph); void Conv2DToConvolution(std::shared_ptr& graph); +void ConvTransposed2DToConvolution(std::shared_ptr& graph); void Conv3DToConvolution(std::shared_ptr& graph); +void ConvTransposed3DToConvolution(std::shared_ptr& graph); void FuseAddMMBranches(std::shared_ptr graph); void LinearToAddMM(std::shared_ptr& graph); void EliminateExceptionOrPassPattern(std::shared_ptr graph); diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 8cc2c3a1e9..801c6009c9 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -28,7 +28,7 @@ cc_test( ) lowering_test( - name = "test_conv1d_pass", + name = "test_conv_pass", ) lowering_test( @@ -102,7 +102,7 @@ lowering_test( test_suite( name = "lowering_tests", tests = [ - ":test_conv1d_pass", + ":test_conv_pass", ":test_device_casting", ":test_exception_elimination_pass", ":test_linear_to_addmm", diff --git a/tests/core/lowering/test_conv1d_pass.cpp b/tests/core/lowering/test_conv1d_pass.cpp deleted file mode 100644 index 3694559108..0000000000 --- a/tests/core/lowering/test_conv1d_pass.cpp +++ /dev/null @@ -1,202 +0,0 @@ -#include -#include "core/compiler.h" -#include "core/lowering/passes/passes.h" -#include "gtest/gtest.h" -#include "tests/util/util.h" -#include "torch/csrc/jit/ir/irparser.h" -#include "torch/csrc/jit/ir/subgraph_matcher.h" -#include "torch/csrc/jit/passes/canonicalize.h" -#include "torch/csrc/jit/passes/constant_pooling.h" - -TEST(LoweringPasses, Conv1dCorrectly) { - const auto source_graph = R"IR( - graph(%0 : Tensor, - %1 : Float(4, 3, 3, strides=[9, 3, 1]), - %2 : Float(3)): - %4 : int = prim::Constant[value=0]() - %5 : int = prim::Constant[value=1]() - %6 : int = prim::Constant[value=1]() - %stride : int[] = prim::ListConstruct(%6) - %padding : int[] = prim::ListConstruct(%4) - %dilation : int[] = prim::ListConstruct(%5) - %12 : Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6) - return (%12))IR"; - - const auto target_graph = R"IR( - graph(%0 : Tensor, - %1 : Float(4, 3, 3, strides=[9, 3, 1]), - %2 : Float(3)): - %3 : bool = prim::Constant[value=0]() - %4 : int = prim::Constant[value=0]() - %5 : int = prim::Constant[value=1]() - %6 : int = prim::Constant[value=1]() - %stride : int[] = prim::ListConstruct(%6) - %padding : int[] = prim::ListConstruct(%4) - %dilation : int[] = prim::ListConstruct(%5) - %output_padding : int[] = prim::Constant[value=[0]]() - %12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3) - return (%12))IR"; - - torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( - torch_tensorrt::core::util::logging::LogLevel::kGRAPH); - auto sg = std::make_shared(); - torch::jit::parseIR(source_graph, &*sg); - torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg); - - auto tg = std::make_shared(); - torch::jit::parseIR(target_graph, &*tg); - - auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA}); - auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA}); - auto b = at::randint(1, 10, {4}, {at::kCUDA}); - - auto trt_in = at::clone(in); - auto trt_w = at::clone(w); - auto trt_b = at::clone(b); - auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); - auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); - - params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); - auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); -} - -TEST(LoweringPasses, ConvTransposed1dCorrectly) { - const auto source_graph = R"IR( - graph(%0 : Tensor, - %1 : Float(8, 3, 3, strides=[9, 3, 1]), - %2 : Float(3)): - %3 : int = prim::Constant[value=1]() - %4 : int = prim::Constant[value=0]() - %5 : int = prim::Constant[value=1]() - %6 : int = prim::Constant[value=0]() - %stride : int[] = prim::ListConstruct(%3) - %padding : int[] = prim::ListConstruct(%4) - %dilation : int[] = prim::ListConstruct(%5) - %output_padding : int[] = prim::ListConstruct(%6) - %12 : Tensor = aten::conv_transpose1d(%0, %1, %2, %stride, %padding, %output_padding, %3, %dilation) - return (%12))IR"; - - const auto target_graph = R"IR( - graph(%0 : Tensor, - %1 : Float(8, 3, 3, strides=[9, 3, 1]), - %2 : Float(3)): - %3 : int = prim::Constant[value=1]() - %4 : int = prim::Constant[value=0]() - %5 : int = prim::Constant[value=1]() - %6 : int = prim::Constant[value=0]() - %7 : bool = prim::Constant[value=0]() - %8 : bool = prim::Constant[value=1]() - %stride : int[] = prim::ListConstruct(%3) - %padding : int[] = prim::ListConstruct(%4) - %dilation : int[] = prim::ListConstruct(%5) - %output_padding : int[] = prim::ListConstruct(%6) - %12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %8, %output_padding, %5, %7, %7, %7, %7) - return (%12))IR"; - - torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( - torch_tensorrt::core::util::logging::LogLevel::kGRAPH); - auto sg = std::make_shared(); - torch::jit::parseIR(source_graph, &*sg); - torch_tensorrt::core::lowering::passes::ConvTransposed1DToConvolution(sg); - - auto tg = std::make_shared(); - torch::jit::parseIR(target_graph, &*tg); - - auto in = at::randint(1, 2, {1, 8, 3}, {at::kCUDA}); - auto w = at::randint(1, 2, {8, 3, 3}, {at::kCUDA}); - auto b = at::randint(1, 10, {3}, {at::kCUDA}); - - auto trt_in = at::clone(in); - auto trt_w = at::clone(w); - auto trt_b = at::clone(b); - auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); - auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); - - params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); - auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); -} - -TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) { - std::string source_graph = R"IR( - graph(%0 : Tensor, - %1 : Float(4, 3, 3, strides=[9, 3, 1]), - %2 : Float(3)): - %4 : int = prim::Constant[value=0]() - %5 : int = prim::Constant[value=1]() - %6 : int = prim::Constant[value=1]() - %stride : int[] = prim::ListConstruct(%6) - %padding : int[] = prim::ListConstruct(%4) - %dilation : int[] = prim::ListConstruct(%5) - - # Add intentionally-invalid weight tensor to ensure prim::If blocks are respected - %true : bool = prim::Constant[value=1]() - %invalid_weight : Tensor = aten::transpose(%0, %4, %5) - %12 : Tensor = prim::If(%true) - block0(): - %res: Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6) - -> (%res) - block1(): - %res: Tensor = aten::conv1d(%invalid_weight, %1, %2, %stride, %padding, %dilation, %6) - -> (%res) - return (%12))IR"; - - std::string target_graph = R"IR( - graph(%0 : Tensor, - %1 : Float(4, 3, 3, strides=[9, 3, 1]), - %2 : Float(3)): - %4 : int = prim::Constant[value=0]() - %5 : int = prim::Constant[value=1]() - %true : bool = prim::Constant[value=1]() - %3 : bool = prim::Constant[value=0]() - %output_padding : int[] = prim::Constant[value=[0]]() - %6 : int = prim::Constant[value=1]() - %stride : int[] = prim::ListConstruct(%6) - %padding : int[] = prim::ListConstruct(%4) - %dilation : int[] = prim::ListConstruct(%5) - - # Add intentionally-invalid weight tensor to ensure prim::If blocks are respected - %invalid_weight : Tensor = aten::transpose(%0, %4, %5) - %12 : Tensor = prim::If(%true) - block0(): - %res: Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3) - -> (%res) - block1(): - %res: Tensor = aten::_convolution(%invalid_weight, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3) - -> (%res) - return (%12))IR"; - - torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( - torch_tensorrt::core::util::logging::LogLevel::kGRAPH); - auto sg = std::make_shared(); - torch::jit::parseIR(source_graph, &*sg); - torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg); - torch::jit::ConstantPooling(sg); - sg = torch::jit::Canonicalize(sg, false); - - auto tg = std::make_shared(); - torch::jit::parseIR(target_graph, &*tg); - torch::jit::ConstantPooling(tg); - tg = torch::jit::Canonicalize(tg, false); - - // Validate identical graphs after pooling constants and canonicalizing - ASSERT_TRUE((tg->toString() == sg->toString())); - - auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA}); - auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA}); - auto b = at::randint(1, 10, {4}, {at::kCUDA}); - - auto trt_in = at::clone(in); - auto trt_w = at::clone(w); - auto trt_b = at::clone(b); - auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); - auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); - - params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); - auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); -} diff --git a/tests/core/lowering/test_conv_pass.cpp b/tests/core/lowering/test_conv_pass.cpp new file mode 100644 index 0000000000..d3bc9a385d --- /dev/null +++ b/tests/core/lowering/test_conv_pass.cpp @@ -0,0 +1,422 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" +#include "torch/csrc/jit/passes/canonicalize.h" +#include "torch/csrc/jit/passes/constant_pooling.h" + +TEST(LoweringPasses, Conv1dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(4, 3, 3, strides=[9, 3, 1]), + %2 : Float(4)): + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %stride : int[] = prim::ListConstruct(%6) + %padding : int[] = prim::ListConstruct(%4) + %dilation : int[] = prim::ListConstruct(%5) + %12 : Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6) + return (%12))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(4, 3, 3, strides=[9, 3, 1]), + %2 : Float(4)): + %3 : bool = prim::Constant[value=0]() + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %stride : int[] = prim::ListConstruct(%6) + %padding : int[] = prim::ListConstruct(%4) + %dilation : int[] = prim::ListConstruct(%5) + %output_padding : int[] = prim::Constant[value=[0]]() + %12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3) + return (%12))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA}); + auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA}); + auto b = at::randint(1, 10, {4}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, ConvTransposed1dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(8, 3, 3, strides=[9, 3, 1]), + %2 : Float(3)): + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=0]() + %stride : int[] = prim::ListConstruct(%3) + %padding : int[] = prim::ListConstruct(%4) + %dilation : int[] = prim::ListConstruct(%5) + %output_padding : int[] = prim::ListConstruct(%6) + %12 : Tensor = aten::conv_transpose1d(%0, %1, %2, %stride, %padding, %output_padding, %3, %dilation) + return (%12))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(8, 3, 3, strides=[9, 3, 1]), + %2 : Float(3)): + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=0]() + %7 : bool = prim::Constant[value=0]() + %8 : bool = prim::Constant[value=1]() + %stride : int[] = prim::ListConstruct(%3) + %padding : int[] = prim::ListConstruct(%4) + %dilation : int[] = prim::ListConstruct(%5) + %output_padding : int[] = prim::ListConstruct(%6) + %12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %8, %output_padding, %5, %7, %7, %7, %7) + return (%12))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ConvTransposed1DToConvolution(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 2, {1, 8, 3}, {at::kCUDA}); + auto w = at::randint(1, 2, {8, 3, 3}, {at::kCUDA}); + auto b = at::randint(1, 10, {3}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, Conv2dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(3, 4, 3, 3, strides=[36, 9, 3, 1]), + %2 : Float(3)): + %16 : int[] = prim::Constant[value=[0, 0]]() + %15 : int[] = prim::Constant[value=[1, 1]]() + %5 : int = prim::Constant[value=1]() + %11 : Tensor = aten::conv2d(%0, %1, %2, %15, %16, %15, %5) + return (%11))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(3, 4, 3, 3, strides=[36, 9, 3, 1]), + %2 : Float(3)): + %3 : int[] = prim::Constant[value=[0, 0]]() + %4 : int[] = prim::Constant[value=[1, 1]]() + %5 : int = prim::Constant[value=1]() + %7 : bool = prim::Constant[value=0]() + %8 : int[] = prim::Constant[value=[0, 0]]() + %9 : Tensor = aten::_convolution(%0, %1, %2, %4, %3, %4, %7, %8, %5, %7, %7, %7, %7) + return (%9))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::Conv2DToConvolution(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); + + auto in = at::randint(1, 2, {3, 4, 4, 3}, {at::kCUDA}); + auto w = at::randint(1, 2, {3, 4, 3, 3}, {at::kCUDA}); + auto b = at::randint(1, 10, {3}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, ConvTransposed2dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(3, 4, 3, 3, strides=[36, 9, 3, 1]), + %2 : Float(4)): + %93 : int[] = prim::Constant[value=[0, 0]]() + %92 : int[] = prim::Constant[value=[1, 1]]() + %12 : int = prim::Constant[value=1]() + %88 : Tensor = aten::conv_transpose2d(%0, %1, %2, %92, %93, %93, %12, %92) + return (%88))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(3, 4, 3, 3, strides=[36, 9, 3, 1]), + %2 : Float(4)): + %3 : int[] = prim::Constant[value=[0, 0]]() + %4 : int[] = prim::Constant[value=[1, 1]]() + %5 : int = prim::Constant[value=1]() + %7 : bool = prim::Constant[value=1]() + %8 : bool = prim::Constant[value=1]() + %9 : Tensor = aten::_convolution(%0, %1, %2, %4, %3, %4, %7, %3, %5, %8, %8, %8, %8) + return (%9))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ConvTransposed2DToConvolution(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); + + auto in = at::randint(1, 2, {3, 3, 3, 3}, {at::kCUDA}); + auto w = at::randint(1, 2, {3, 4, 3, 3}, {at::kCUDA}); + auto b = at::randint(1, 10, {4}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, Conv3dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(3, 4, 3, 3, 3, strides=[108, 27, 9, 3, 1]), + %2 : Float(3)): + %16 : int[] = prim::Constant[value=[0, 0, 0]]() + %15 : int[] = prim::Constant[value=[1, 1, 1]]() + %5 : int = prim::Constant[value=1]() + %11 : Tensor = aten::conv3d(%0, %1, %2, %15, %16, %15, %5) + return (%11))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(3, 4, 3, 3, 3, strides=[108, 27, 9, 3, 1]), + %2 : Float(3)): + %3 : int[] = prim::Constant[value=[0, 0, 0]]() + %4 : int[] = prim::Constant[value=[1, 1, 1]]() + %5 : int = prim::Constant[value=1]() + %7 : bool = prim::Constant[value=0]() + %8 : int[] = prim::Constant[value=[0, 0, 0]]() + %9 : Tensor = aten::_convolution(%0, %1, %2, %4, %3, %4, %7, %8, %5, %7, %7, %7, %7) + return (%9))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::Conv3DToConvolution(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); + + auto in = at::randint(1, 2, {4, 4, 4, 4, 4}, {at::kCUDA}); + auto w = at::randint(1, 2, {3, 4, 3, 3, 3}, {at::kCUDA}); + auto b = at::randint(1, 10, {3}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, ConvTransposed3dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(3, 4, 3, 3, 3, strides=[108, 27, 9, 3, 1]), + %2 : Float(4)): + %93 : int[] = prim::Constant[value=[0, 0, 0]]() + %92 : int[] = prim::Constant[value=[1, 1, 1]]() + %13 : int = prim::Constant[value=1]() + %88 : Tensor = aten::conv_transpose3d(%0, %1, %2, %92, %93, %93, %13, %92) + return (%88))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(3, 4, 3, 3, 3, strides=[108, 27, 9, 3, 1]), + %2 : Float(4)): + %3 : int[] = prim::Constant[value=[0, 0, 0]]() + %4 : int[] = prim::Constant[value=[1, 1, 1]]() + %5 : int = prim::Constant[value=1]() + %7 : bool = prim::Constant[value=1]() + %8 : bool = prim::Constant[value=1]() + %9 : Tensor = aten::_convolution(%0, %1, %2, %4, %3, %4, %7, %3, %5, %8, %8, %8, %8) + return (%9))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ConvTransposed3DToConvolution(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); + + auto in = at::randint(1, 2, {3, 3, 3, 3, 3}, {at::kCUDA}); + auto w = at::randint(1, 2, {3, 4, 3, 3, 3}, {at::kCUDA}); + auto b = at::randint(1, 10, {4}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) { + std::string source_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(4, 3, 3, strides=[9, 3, 1]), + %2 : Float(3)): + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %stride : int[] = prim::ListConstruct(%6) + %padding : int[] = prim::ListConstruct(%4) + %dilation : int[] = prim::ListConstruct(%5) + + # Add intentionally-invalid weight tensor to ensure prim::If blocks are respected + %true : bool = prim::Constant[value=1]() + %invalid_weight : Tensor = aten::transpose(%0, %4, %5) + %12 : Tensor = prim::If(%true) + block0(): + %res: Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6) + -> (%res) + block1(): + %res: Tensor = aten::conv1d(%invalid_weight, %1, %2, %stride, %padding, %dilation, %6) + -> (%res) + return (%12))IR"; + + std::string target_graph = R"IR( + graph(%0 : Tensor, + %1 : Float(4, 3, 3, strides=[9, 3, 1]), + %2 : Float(3)): + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=1]() + %true : bool = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=0]() + %output_padding : int[] = prim::Constant[value=[0]]() + %6 : int = prim::Constant[value=1]() + %stride : int[] = prim::ListConstruct(%6) + %padding : int[] = prim::ListConstruct(%4) + %dilation : int[] = prim::ListConstruct(%5) + + # Add intentionally-invalid weight tensor to ensure prim::If blocks are respected + %invalid_weight : Tensor = aten::transpose(%0, %4, %5) + %12 : Tensor = prim::If(%true) + block0(): + %res: Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3) + -> (%res) + block1(): + %res: Tensor = aten::_convolution(%invalid_weight, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3) + -> (%res) + return (%12))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); + + auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA}); + auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA}); + auto b = at::randint(1, 10, {4}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +}