diff --git a/tests/core/conversion/converters/BUILD b/tests/core/conversion/converters/BUILD index 901ca94998..a8c57b1b41 100644 --- a/tests/core/conversion/converters/BUILD +++ b/tests/core/conversion/converters/BUILD @@ -95,6 +95,10 @@ converter_test( name = "test_matrix_multiply", ) +converter_test( + name = "test_masked_fill", +) + converter_test( name = "test_max", ) @@ -115,6 +119,10 @@ converter_test( name = "test_reduce", ) +converter_test( + name = "test_roll", +) + converter_test( name = "test_reflection_pad", ) @@ -123,6 +131,10 @@ converter_test( name = "test_replication_pad", ) +converter_test( + name = "test_scatter", +) + converter_test( name = "test_shuffle", ) @@ -139,6 +151,10 @@ converter_test( name = "test_interpolate", ) +converter_test( + name = "test_index", +) + converter_test( name = "test_select", ) @@ -147,6 +163,14 @@ converter_test( name = "test_stack", ) +converter_test( + name = "test_slice", +) + +converter_test( + name = "test_split", +) + converter_test( name = "test_topk", ) @@ -159,10 +183,22 @@ converter_test( name = "test_unsqueeze", ) +converter_test( + name = "test_unbind", +) + +converter_test( + name = "test_unpack", +) + converter_test( name = "test_squeeze", ) +converter_test( + name = "test_where", +) + test_suite( name = "converter_tests", tests = [ @@ -185,22 +221,31 @@ test_suite( ":test_expand", ":test_instance_norm", ":test_interpolate", + ":test_index", ":test_layer_norm", ":test_linear", ":test_lstm_cell", ":test_matrix_multiply", + ":test_masked_fill", ":test_max", ":test_normalize", ":test_pooling", ":test_reduce", + ":test_roll", ":test_replication_pad", + ":test_scatter", ":test_select", ":test_shuffle", ":test_softmax", ":test_squeeze", ":test_stack", + ":test_split", + ":test_slice", ":test_topk", ":test_unary", ":test_unsqueeze", + ":test_unbind", + ":test_unpack", + ":test_where", ], ) diff --git a/tests/core/conversion/converters/test_index.cpp b/tests/core/conversion/converters/test_index.cpp new file mode 100644 index 0000000000..b405d7a436 --- /dev/null +++ b/tests/core/conversion/converters/test_index.cpp @@ -0,0 +1,295 @@ +#include +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenIndexSelectConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %index : Int (2)): + %2 : int = prim::Constant[value=0]() + %3 : Tensor = aten::index_select(%0, %2, %index) + return (%3))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); + auto index = at::randint(0, 4, {2}, {at::kCUDA}).to(torch::kI32); + + auto jit_in = at::clone(in); + auto jit_index = at::clone(index); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_index = at::clone(index); + auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenIndexSelectNegativeDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %index : Int (5)): + %2 : int = prim::Constant[value=-1]() + %3 : Tensor = aten::index_select(%0, %2, %index) + return (%3))IR"; + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {5, 3, 9}, {at::kCUDA}); + auto index = at::randint(0, 9, {5}, {at::kCUDA}).to(torch::kI32); + + auto jit_in = at::clone(in); + auto jit_index = at::clone(index); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_index = at::clone(index); + auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenIndexTensorOneIndiceConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index : Tensor): + %18 : Tensor?[] = prim::ListConstruct(%index) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10}, {at::kCUDA}); + auto in2 = at::full({2}, 4, {at::kCUDA}); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto in2_trt = at::full({2}, 4, {options}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2_trt}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenIndexTensorFullIndicesConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor, + %index2 : Tensor): + %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); + auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); + auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); + auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + auto index2_trt = index2.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenIndexTensorRepeatedFullIndicesConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor, + %index2 : Tensor): + %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2) + %19 : Tensor = aten::index(%x.1, %18) + %20 : Tensor = aten::index(%x.1, %18) + return (%19, %20))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); + auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); + auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); + auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + auto index2_trt = index2.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6)); +} + +TEST(Converters, ATenIndexTensorIdx0Idx1NoneConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor): + %5 : NoneType = prim::Constant() + %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %5) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); + auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); + auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); + LOG_DEBUG(trt_results); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenIndexTensorIdx0NoneIdx1ConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor): + %5 : NoneType = prim::Constant() + %18 : Tensor?[] = prim::ListConstruct(%index0, %5, %index1) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); + auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); + auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenIndexTensorNoneIdx0Idx1ConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor): + %5 : NoneType = prim::Constant() + %18 : Tensor?[] = prim::ListConstruct(%5, %index0, %index1) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); + auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); + auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor, + %index2 : Tensor): + %5 : NoneType = prim::Constant() + %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2, %5) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {4, 8, 8, 4}, {at::kCUDA}); + auto index0 = at::full({4, 13, 1}, 1, {at::kCUDA}).to(torch::kLong); + auto index1 = at::full({4, 13, 1}, 2, {at::kCUDA}).to(torch::kLong); + auto index2 = at::full({4, 13, 1}, 3, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + auto index2_trt = index2.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenIndexTensorNoneIdx1ConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor): + %5 : NoneType = prim::Constant() + %18 : Tensor?[] = prim::ListConstruct(%5, %index0) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {1, 3, 480, 928}, {at::kCUDA}); + auto index0 = at::tensor({2, 1, 0}, {at::kCUDA}).to(torch::kLong); + + auto index0_trt = index0.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} \ No newline at end of file diff --git a/tests/core/conversion/converters/test_masked_fill.cpp b/tests/core/conversion/converters/test_masked_fill.cpp new file mode 100644 index 0000000000..2c375463e5 --- /dev/null +++ b/tests/core/conversion/converters/test_masked_fill.cpp @@ -0,0 +1,100 @@ +#include +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %44 : Device = prim::Constant[value="cuda"]() + %8 : bool = prim::Constant[value=0]() + %7 : None = prim::Constant() + %f32_dtype: int = prim::Constant[value=11]() + %1 : int = prim::Constant[value=0]() # bert.py:5:26 + %2 : int = prim::Constant[value=1]() # bert.py:5:32 + %33 : int = prim::Constant[value=2]() # bert.py:6:31 + %3 : int[] = prim::ListConstruct(%1, %1, %2) + %4 : int[] = prim::ListConstruct(%2, %2, %1) + %5 : int[][] = prim::ListConstruct(%3, %4) + %9 : Tensor = aten::tensor(%5, %f32_dtype, %7, %8) # bert.py:5:11 + %mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11 + %mask.2 : Tensor = trt::const(%mask.1) + %34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11 + return (%34, %mask.2))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + auto in = at::zeros({1, 2, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + torch_tensorrt::core::lowering::passes::RemoveNOPs(g); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenMaskedFillMixedTypesFloatIntConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, %x.2 : Tensor): + %val : float = prim::Constant[value=4.0]() + %out : Tensor = aten::masked_fill(%x.1, %x.2, %val) + return (%out))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + // Input is a float tensor, filled with an int --> expecting float tensor out + auto in1 = at::rand({2, 3, 5, 7}, {at::kCUDA}).to(torch::kFloat32); + auto in2 = (2 * at::rand({2, 3, 5, 7}, {at::kCUDA})).to(torch::kBool); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + + // Ensure data types match in outputs + ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype()); +} + +TEST(Converters, ATenMaskedFillMixedTypesIntFloatConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, %x.2 : Tensor): + %val : int = prim::Constant[value=4]() + %out : Tensor = aten::masked_fill(%x.1, %x.2, %val) + return (%out))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + // Input is an integer tensor, filled with a float --> expecting integer tensor out + auto in1 = at::rand({1, 3, 5, 7}, {at::kCUDA}).to(torch::kInt32); + auto in2 = (2 * at::rand({1, 3, 5, 7}, {at::kCUDA})).to(torch::kBool); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + + // Ensure data types match in outputs + ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype()); +} \ No newline at end of file diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index 47e8b8d154..3cdb2d3b84 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -392,240 +392,3 @@ TEST(Converters, ATenAllDimDynamicConvertsCorrectly) { auto in = at::randint(0, 2, {64, 2}, at::kCUDA).to(torch::kHalf); test_body(graph, in, true); } - -TEST(Converters, UnpackVarLowersCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 - %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 - %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %6 : int[] = prim::ListConstruct(%3) - %7 : Tensor = aten::var(%x.1, %6, %5, %4) # test_zeros.py:10:26 - return (%7))IR"; - - auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); - - in = at::clone(in); - torch_tensorrt::core::lowering::passes::UnpackVar(g); - torch::jit::EliminateCommonSubexpression(g); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, UnpackVarKeepDimsLowersCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 - %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 - %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %6 : int[] = prim::ListConstruct(%3) - %7 : Tensor = aten::var(%x.1, %6, %5, %5) # test_zeros.py:10:26 - return (%7))IR"; - - auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); - - in = at::clone(in); - torch_tensorrt::core::lowering::passes::UnpackVar(g); - torch::jit::EliminateCommonSubexpression(g); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, UnpackVarUnbiasedLowersCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 - %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 - %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %6 : int[] = prim::ListConstruct(%3) - %7 : Tensor = aten::var(%x.1, %6, %4, %4) # test_zeros.py:10:26 - return (%7))IR"; - - auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); - - in = at::clone(in); - torch_tensorrt::core::lowering::passes::UnpackVar(g); - torch::jit::EliminateCommonSubexpression(g); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, UnpackVarUnbiasedKeepDimsLowersCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 - %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 - %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %6 : int[] = prim::ListConstruct(%3) - %7 : Tensor = aten::var(%x.1, %6, %4, %5) # test_zeros.py:10:26 - return (%7))IR"; - - auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); - - in = at::clone(in); - torch_tensorrt::core::lowering::passes::UnpackVar(g); - torch::jit::EliminateCommonSubexpression(g); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, UnpackStdLowersCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 - %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 - %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %6 : int[] = prim::ListConstruct(%3) - %7 : Tensor = aten::std(%x.1, %6, %5, %4) # test_zeros.py:10:26 - return (%7))IR"; - - auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); - - in = at::clone(in); - torch_tensorrt::core::lowering::passes::UnpackStd(g); - torch_tensorrt::core::lowering::passes::UnpackVar(g); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, UnpackStdKeepDimsLowersCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 - %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 - %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %6 : int[] = prim::ListConstruct(%3) - %7 : Tensor = aten::std(%x.1, %6, %5, %5) # test_zeros.py:10:26 - return (%7))IR"; - - auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); - - in = at::clone(in); - torch_tensorrt::core::lowering::passes::UnpackStd(g); - torch_tensorrt::core::lowering::passes::UnpackVar(g); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, UnpackStdUnbiasedLowersCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 - %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 - %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %6 : int[] = prim::ListConstruct(%3) - %7 : Tensor = aten::std(%x.1, %6, %4, %4) # test_zeros.py:10:26 - return (%7))IR"; - - auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); - - in = at::clone(in); - torch_tensorrt::core::lowering::passes::UnpackStd(g); - torch_tensorrt::core::lowering::passes::UnpackVar(g); - torch::jit::EliminateCommonSubexpression(g); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, UnpackStdUnbiasedKeepDimsLowersCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 - %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 - %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %one : int = prim::Constant[value=1]() - %6 : int[] = prim::ListConstruct(%3, %one) - %7 : Tensor = aten::std(%x.1, %6, %4, %5) # test_zeros.py:10:26 - return (%7))IR"; - - auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); - - in = at::clone(in); - torch_tensorrt::core::lowering::passes::UnpackStd(g); - torch_tensorrt::core::lowering::passes::UnpackVar(g); - torch::jit::EliminateCommonSubexpression(g); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, UnpackVarUnbiasedNegAxisLowersCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %37 : bool = prim::Constant[value=1]() - %53 : int[] = prim::Constant[value=[-1]]() - %69 : Tensor = aten::var(%x.1, %53, %37, %37) - return (%69))IR"; - - auto in = at::randint(-5, 5, {2, 20, 768}, at::kCUDA).to(at::kFloat); - - auto jit_in = at::clone(in); - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - in = at::clone(in); - torch_tensorrt::core::lowering::passes::UnpackVar(g); - torch::jit::EliminateCommonSubexpression(g); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {jit_in}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} diff --git a/tests/core/conversion/converters/test_roll.cpp b/tests/core/conversion/converters/test_roll.cpp new file mode 100644 index 0000000000..693fd47aef --- /dev/null +++ b/tests/core/conversion/converters/test_roll.cpp @@ -0,0 +1,84 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenRollConvertsCorrectly) { + const auto graph = R"IR( + graph(%1 : Tensor): + %2 : int[] = prim::Constant[value=[1, 0, 3, 7]]() + %3 : int[] = prim::Constant[value=[0, 1, 2, 3]]() + %4 : Tensor = aten::roll(%1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + // Run Pytorch + auto in = at::randint(1, 10, {2, 3, 4, 5}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRollShiftsNegativeConvertsCorrectly) { + const auto graph = R"IR( + graph(%1 : Tensor): + %2 : int[] = prim::Constant[value=[0, -3, -3]]() + %3 : int[] = prim::Constant[value=[1, 2, 3]]() + %4 : Tensor = aten::roll(%1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + // Run Pytorch + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenRollDimsNegativeConvertsCorrectly) { + const auto graph = R"IR( + graph(%1 : Tensor): + %2 : int[] = prim::Constant[value=[0, -3, -3]]() + %3 : int[] = prim::Constant[value=[1, 2, -1]]() + %4 : Tensor = aten::roll(%1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + // Run Pytorch + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} \ No newline at end of file diff --git a/tests/core/conversion/converters/test_scatter.cpp b/tests/core/conversion/converters/test_scatter.cpp new file mode 100644 index 0000000000..b7d0883249 --- /dev/null +++ b/tests/core/conversion/converters/test_scatter.cpp @@ -0,0 +1,79 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ScatterValueConvertsCorrectly) { + const auto graph = R"IR( + graph(%data : Tensor, + %index.1 : Tensor): + %value : int = prim::Constant[value=100]() + %dim : int = prim::Constant[value=1]() + %5 : NoneType = prim::Constant() + %6 : bool = prim::Constant[value=0]() + %7 : int = prim::Constant[value=4]() + %index : Tensor = aten::to(%index.1, %7, %6, %6, %5) + %10 : Tensor = aten::scatter(%data, %dim, %index, %value) + return (%10))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto index = at::randint(0, 5, {2, 2}, {at::kCUDA}); + auto data = at::randn({5, 5}, {at::kCUDA}); + + auto jit_index = at::clone(index); + auto jit_data = at::clone(data); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_index}); + + auto trt_index = at::clone(index); + auto trt_data = at::clone(data); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_index}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ScatterSrcConvertsCorrectly) { + const auto graph = R"IR( + graph(%data : Tensor, + %src : Tensor, + %index.1 : Tensor): + %dim : int = prim::Constant[value=1]() + %5 : NoneType = prim::Constant() + %6 : bool = prim::Constant[value=0]() + %7 : int = prim::Constant[value=4]() + %index : Tensor = aten::to(%index.1, %7, %6, %6, %5) + %10 : Tensor = aten::scatter(%data, %dim, %index, %src) + return (%10))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto index = at::randint(0, 4, {2, 2}, {at::kCUDA}); + auto data = at::randn({5, 5}, {at::kCUDA}); + auto src = at::randn({2, 2}, {at::kCUDA}); + + auto jit_index = at::clone(index); + auto jit_data = at::clone(data); + auto jit_src = at::clone(src); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_src, jit_index}); + + auto trt_index = at::clone(index); + auto trt_data = at::clone(data); + auto trt_src = at::clone(src); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_src, trt_index}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} \ No newline at end of file diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index d93dd5b2c5..d2af33f099 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -165,60 +165,6 @@ TEST(Converters, ATenSelectEmptyTensorConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::sameShape(jit_results[0], trt_results[0])); } -TEST(Converters, ATenIndexSelectConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, %index : Int (2)): - %2 : int = prim::Constant[value=0]() - %3 : Tensor = aten::index_select(%0, %2, %index) - return (%3))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); - auto index = at::randint(0, 4, {2}, {at::kCUDA}).to(torch::kI32); - - auto jit_in = at::clone(in); - auto jit_index = at::clone(index); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_index = at::clone(index); - auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in}); - - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenIndexSelectNegativeDimConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, %index : Int (5)): - %2 : int = prim::Constant[value=-1]() - %3 : Tensor = aten::index_select(%0, %2, %index) - return (%3))IR"; - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {5, 3, 9}, {at::kCUDA}); - auto index = at::randint(0, 9, {5}, {at::kCUDA}).to(torch::kI32); - - auto jit_in = at::clone(in); - auto jit_index = at::clone(index); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_index = at::clone(index); - auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in}); - - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): @@ -273,1119 +219,3 @@ TEST(Converters, ATenEmbeddingConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } - -TEST(Converters, ATenRollConvertsCorrectly) { - const auto graph = R"IR( - graph(%1 : Tensor): - %2 : int[] = prim::Constant[value=[1, 0, 3, 7]]() - %3 : int[] = prim::Constant[value=[0, 1, 2, 3]]() - %4 : Tensor = aten::roll(%1, %2, %3) - return (%4))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - // Run Pytorch - auto in = at::randint(1, 10, {2, 3, 4, 5}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenRollShiftsNegativeConvertsCorrectly) { - const auto graph = R"IR( - graph(%1 : Tensor): - %2 : int[] = prim::Constant[value=[0, -3, -3]]() - %3 : int[] = prim::Constant[value=[1, 2, 3]]() - %4 : Tensor = aten::roll(%1, %2, %3) - return (%4))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - // Run Pytorch - auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenRollDimsNegativeConvertsCorrectly) { - const auto graph = R"IR( - graph(%1 : Tensor): - %2 : int[] = prim::Constant[value=[0, -3, -3]]() - %3 : int[] = prim::Constant[value=[1, 2, -1]]() - %4 : Tensor = aten::roll(%1, %2, %3) - return (%4))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - // Run Pytorch - auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %3 : int = prim::Constant[value=2]() - %4 : int = prim::Constant[value=4]() - %5 : int = prim::Constant[value=1]() - %6 : int = prim::Constant[value=0]() - %7 : Tensor = aten::select(%x.1, %6, %6) - %8 : Tensor = aten::select(%7, %6, %5) - %9 : Tensor = aten::slice(%8, %6, %5, %4, %3) - %10 : Tensor = aten::slice(%9, %5, %2, %2, %5) - return (%10))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {1, 3, 5, 5}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceNegStartIndexConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=1]() - %3 : int = prim::Constant[value=9223372036854775807]() - %4 : int = prim::Constant[value=-2]() - %5 : int = prim::Constant[value=0]() - %6 : Tensor = aten::slice(%x.1, %5, %4, %3, %2) - %7 : Tensor = aten::slice(%6, %2, %5, %3, %2) - return (%7))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {6, 3}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=3]() - %3 : int = prim::Constant[value=9223372036854775807]() - %4 : int = prim::Constant[value=2]() - %5 : int = prim::Constant[value=-3]() - %6 : int = prim::Constant[value=1]() - %7 : int = prim::Constant[value=-2]() - %8 : int = prim::Constant[value=0]() - %9 : Tensor = aten::slice(%x.1, %8, %8, %7, %6) - %10 : Tensor = aten::slice(%9, %6, %8, %5, %6) - %11 : Tensor = aten::slice(%10, %4, %8, %3, %6) - %12 : Tensor = aten::slice(%11, %2, %8, %3, %6) - return (%12))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceListConvertsCorrectly) { - const auto graph = R"IR( - graph(%x : Tensor): - %1 : NoneType = prim::Constant() - %2 : int = prim::Constant[value=2]() - %3 : int = prim::Constant[value=1]() - %4 : int = prim::Constant[value=3]() - %list : Tensor[] = aten::unbind(%x, %4) - %slice : Tensor[] = aten::slice(%list, %1, %2, %3) - %out.1 : Tensor, %out.2 : Tensor = prim::ListUnpack(%slice) - return (%out.1, %out.2))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in_x = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA}); - - auto jit_in_x = at::clone(in_x); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in_x}); - - auto trt_in_x = at::clone(in_x); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in_x}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=0]() - %start : int = prim::Constant[value=1]() - %end : int = prim::Constant[value=15]() - %step : int = prim::Constant[value=2]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamicBatchLargeEndConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=0]() - %start : int = prim::Constant[value=1]() - %end : int = prim::Constant[value=9223372036854775807]() - %step : int = prim::Constant[value=2]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamicNegStartBatchConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=0]() - %start : int = prim::Constant[value=-15]() - %end : int = prim::Constant[value=15]() - %step : int = prim::Constant[value=2]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamicNegEndBatchConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=0]() - %start : int = prim::Constant[value=1]() - %end : int = prim::Constant[value=-2]() - %step : int = prim::Constant[value=3]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamicNoneBatchConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %dim : int = prim::Constant[value=0]() - %start : None = prim::Constant() - %end : None = prim::Constant() - %step : int = prim::Constant[value=3]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamicConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=1]() - %start : int = prim::Constant[value=3]() - %end : int = prim::Constant[value=32]() - %step : int = prim::Constant[value=3]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in dim 1, slice in dim 1 - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamic2ConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=1]() - %start : int = prim::Constant[value=3]() - %end : int = prim::Constant[value=17]() - %step : int = prim::Constant[value=3]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch, slice in dim 1 - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSplitSizesInScriptingConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int[] = prim::Constant[value=[1, 2]]() - %3 : int = prim::Constant[value=1]() - %4 : Tensor[] = aten::split(%x.1, %2, %3) - %x1.1 : Tensor, %x2.1 : Tensor = prim::ListUnpack(%4) - return (%x1.1, %x2.1))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSplitSizesinTracingConvertsCorrectly) { - const auto graph = R"IR( - graph(%argument_1.1 : Tensor): - %2 : int[] = prim::Constant[value=[1, 2]]() - %3 : int = prim::Constant[value=1]() - %4 : Tensor[] = aten::split_with_sizes(%argument_1.1, %2, %3) - %5 : Tensor, %6 : Tensor = prim::ListUnpack(%4) - return (%5, %6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSplitFixedConvertsCorrectly) { - const auto graph = R"IR( - graph(%argument_1.1 : Tensor): - %2 : int = prim::Constant[value=1]() - %3 : Tensor[] = aten::split(%argument_1.1, %2, %2) - %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3) - return (%4, %5, %6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSplitFixedHasRemainderConvertsCorrectly) { - const auto graph = R"IR( - graph(%argument_1.1 : Tensor): - %2 : int = prim::Constant[value=2]() - %2.1 : int = prim::Constant[value=1]() - %3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1) - %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3) - return (%4, %5, %6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - auto in = at::randint(1, 10, {1, 5, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSplitAndAddConvertsCorrectly) { - const auto graph = R"IR( - graph(%argument_1.1 : Tensor): - %2 : int = prim::Constant[value=2]() - %2.1 : int = prim::Constant[value=1]() - %3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1) - %4 : Tensor, %5 : Tensor = prim::ListUnpack(%3) - %6 : Tensor = aten::add(%4, %5, %2.1) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSplitNegativeDimsConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=1]() - %n1 : int = prim::Constant[value=-1]() - %3 : Tensor[] = aten::split(%x.1, %2, %n1) - %4 : Tensor, %5 : Tensor, %6 : Tensor, %7 : Tensor = prim::ListUnpack(%3) - return (%4, %5, %6, %7))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %44 : Device = prim::Constant[value="cuda"]() - %8 : bool = prim::Constant[value=0]() - %7 : None = prim::Constant() - %f32_dtype: int = prim::Constant[value=11]() - %1 : int = prim::Constant[value=0]() # bert.py:5:26 - %2 : int = prim::Constant[value=1]() # bert.py:5:32 - %33 : int = prim::Constant[value=2]() # bert.py:6:31 - %3 : int[] = prim::ListConstruct(%1, %1, %2) - %4 : int[] = prim::ListConstruct(%2, %2, %1) - %5 : int[][] = prim::ListConstruct(%3, %4) - %9 : Tensor = aten::tensor(%5, %f32_dtype, %7, %8) # bert.py:5:11 - %mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11 - %mask.2 : Tensor = trt::const(%mask.1) - %34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11 - return (%34, %mask.2))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - auto in = at::zeros({1, 2, 3}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - torch_tensorrt::core::lowering::passes::RemoveNOPs(g); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenMaskedFillMixedTypesFloatIntConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, %x.2 : Tensor): - %val : float = prim::Constant[value=4.0]() - %out : Tensor = aten::masked_fill(%x.1, %x.2, %val) - return (%out))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input is a float tensor, filled with an int --> expecting float tensor out - auto in1 = at::rand({2, 3, 5, 7}, {at::kCUDA}).to(torch::kFloat32); - auto in2 = (2 * at::rand({2, 3, 5, 7}, {at::kCUDA})).to(torch::kBool); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); - - // Ensure data types match in outputs - ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype()); -} - -TEST(Converters, ATenMaskedFillMixedTypesIntFloatConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, %x.2 : Tensor): - %val : int = prim::Constant[value=4]() - %out : Tensor = aten::masked_fill(%x.1, %x.2, %val) - return (%out))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input is an integer tensor, filled with a float --> expecting integer tensor out - auto in1 = at::rand({1, 3, 5, 7}, {at::kCUDA}).to(torch::kInt32); - auto in2 = (2 * at::rand({1, 3, 5, 7}, {at::kCUDA})).to(torch::kBool); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); - - // Ensure data types match in outputs - ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype()); -} - -TEST(Converters, ATenIndexTensorOneIndiceConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index : Tensor): - %18 : Tensor?[] = prim::ListConstruct(%index) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10}, {at::kCUDA}); - auto in2 = at::full({2}, 4, {at::kCUDA}); - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto in2_trt = at::full({2}, 4, {options}); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2_trt}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorFullIndicesConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor, - %index2 : Tensor): - %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); - auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); - auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); - auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - auto index2_trt = index2.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorRepeatedFullIndicesConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor, - %index2 : Tensor): - %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2) - %19 : Tensor = aten::index(%x.1, %18) - %20 : Tensor = aten::index(%x.1, %18) - return (%19, %20))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); - auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); - auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); - auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - auto index2_trt = index2.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6)); -} - -TEST(Converters, ATenIndexTensorIdx0Idx1NoneConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor): - %5 : NoneType = prim::Constant() - %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %5) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); - auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); - auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); - LOG_DEBUG(trt_results); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorIdx0NoneIdx1ConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor): - %5 : NoneType = prim::Constant() - %18 : Tensor?[] = prim::ListConstruct(%index0, %5, %index1) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); - auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); - auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorNoneIdx0Idx1ConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor): - %5 : NoneType = prim::Constant() - %18 : Tensor?[] = prim::ListConstruct(%5, %index0, %index1) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); - auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); - auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor, - %index2 : Tensor): - %5 : NoneType = prim::Constant() - %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2, %5) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {4, 8, 8, 4}, {at::kCUDA}); - auto index0 = at::full({4, 13, 1}, 1, {at::kCUDA}).to(torch::kLong); - auto index1 = at::full({4, 13, 1}, 2, {at::kCUDA}).to(torch::kLong); - auto index2 = at::full({4, 13, 1}, 3, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - auto index2_trt = index2.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, ATenIndexTensorNoneIdx1ConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor): - %5 : NoneType = prim::Constant() - %18 : Tensor?[] = prim::ListConstruct(%5, %index0) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {1, 3, 480, 928}, {at::kCUDA}); - auto index0 = at::tensor({2, 1, 0}, {at::kCUDA}).to(torch::kLong); - - auto index0_trt = index0.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, ATenUnbindConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=0]() - %3 : Tensor[] = aten::unbind(%x.1, %2) - %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3) - return (%o1.1, %o2.1))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {2, 3, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i]; - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=-1]() - %3 : Tensor[] = aten::unbind(%x.1, %2) - %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3) - return (%o1.1, %o2.1))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {5, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i]; - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenUnbindEvaluatedTensor) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %3 : int[] = aten::size(%x.1) - %z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2) - %5 : int = prim::Constant[value=-1]() - %6 : Tensor[] = aten::unbind(%z.1, %5) - %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%6) - return (%o1.1, %o2.1))IR"; - - auto in = at::randint(1, 10, {2}, {at::kCUDA}); - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); - - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i]; - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i].cuda(), trt, 2e-6)); - } -} - -TEST(Converters, ScatterValueConvertsCorrectly) { - const auto graph = R"IR( - graph(%data : Tensor, - %index.1 : Tensor): - %value : int = prim::Constant[value=100]() - %dim : int = prim::Constant[value=1]() - %5 : NoneType = prim::Constant() - %6 : bool = prim::Constant[value=0]() - %7 : int = prim::Constant[value=4]() - %index : Tensor = aten::to(%index.1, %7, %6, %6, %5) - %10 : Tensor = aten::scatter(%data, %dim, %index, %value) - return (%10))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto index = at::randint(0, 5, {2, 2}, {at::kCUDA}); - auto data = at::randn({5, 5}, {at::kCUDA}); - - auto jit_index = at::clone(index); - auto jit_data = at::clone(data); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_index}); - - auto trt_index = at::clone(index); - auto trt_data = at::clone(data); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_index}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ScatterSrcConvertsCorrectly) { - const auto graph = R"IR( - graph(%data : Tensor, - %src : Tensor, - %index.1 : Tensor): - %dim : int = prim::Constant[value=1]() - %5 : NoneType = prim::Constant() - %6 : bool = prim::Constant[value=0]() - %7 : int = prim::Constant[value=4]() - %index : Tensor = aten::to(%index.1, %7, %6, %6, %5) - %10 : Tensor = aten::scatter(%data, %dim, %index, %src) - return (%10))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto index = at::randint(0, 4, {2, 2}, {at::kCUDA}); - auto data = at::randn({5, 5}, {at::kCUDA}); - auto src = at::randn({2, 2}, {at::kCUDA}); - - auto jit_index = at::clone(index); - auto jit_data = at::clone(data); - auto jit_src = at::clone(src); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_src, jit_index}); - - auto trt_index = at::clone(index); - auto trt_data = at::clone(data); - auto trt_src = at::clone(src); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_src, trt_index}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, WhereConvertsCorrectly) { - const auto graph = R"IR( - graph(%condition : Tensor, - %x : Tensor, - %y : Tensor): - %out : Tensor = aten::where(%condition, %x, %y) - return (%out))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto condition = at::randint(0, 2, {5, 5}, {at::kCUDA}).to(torch::kBool); - auto x = at::randn({5, 5}, {at::kCUDA}); - auto y = at::randn({5, 5}, {at::kCUDA}); - - auto jit_condition = at::clone(condition); - auto jit_x = at::clone(x); - auto jit_y = at::clone(y); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y}); - - auto trt_condition = at::clone(condition); - auto trt_x = at::clone(x); - auto trt_y = at::clone(y); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, WhereConvertsMismatchedShapesCorrectly) { - const auto graph = R"IR( - graph(%condition : Tensor, - %x : Tensor, - %y : Tensor): - %out : Tensor = aten::where(%condition, %x, %y) - return (%out))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - // As per Torch behavior, the input Tensors are expected to be broadcasted - // along their respective dimension in the largest-rank Tensor provided - auto condition = at::randint(0, 2, {7, 5}, {at::kCUDA}).to(torch::kBool); - auto x = at::randn({2, 7, 5}, {at::kCUDA}); - auto y = at::randn({5}, {at::kCUDA}); - - auto jit_condition = at::clone(condition); - auto jit_x = at::clone(x); - auto jit_y = at::clone(y); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y}); - - auto trt_condition = at::clone(condition); - auto trt_x = at::clone(x); - auto trt_y = at::clone(y); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} diff --git a/tests/core/conversion/converters/test_slice.cpp b/tests/core/conversion/converters/test_slice.cpp new file mode 100644 index 0000000000..83ba879291 --- /dev/null +++ b/tests/core/conversion/converters/test_slice.cpp @@ -0,0 +1,332 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenSliceConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %3 : int = prim::Constant[value=2]() + %4 : int = prim::Constant[value=4]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=0]() + %7 : Tensor = aten::select(%x.1, %6, %6) + %8 : Tensor = aten::select(%7, %6, %5) + %9 : Tensor = aten::slice(%8, %6, %5, %4, %3) + %10 : Tensor = aten::slice(%9, %5, %2, %2, %5) + return (%10))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3, 5, 5}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceNegStartIndexConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=9223372036854775807]() + %4 : int = prim::Constant[value=-2]() + %5 : int = prim::Constant[value=0]() + %6 : Tensor = aten::slice(%x.1, %5, %4, %3, %2) + %7 : Tensor = aten::slice(%6, %2, %5, %3, %2) + return (%7))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {6, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=3]() + %3 : int = prim::Constant[value=9223372036854775807]() + %4 : int = prim::Constant[value=2]() + %5 : int = prim::Constant[value=-3]() + %6 : int = prim::Constant[value=1]() + %7 : int = prim::Constant[value=-2]() + %8 : int = prim::Constant[value=0]() + %9 : Tensor = aten::slice(%x.1, %8, %8, %7, %6) + %10 : Tensor = aten::slice(%9, %6, %8, %5, %6) + %11 : Tensor = aten::slice(%10, %4, %8, %3, %6) + %12 : Tensor = aten::slice(%11, %2, %8, %3, %6) + return (%12))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceListConvertsCorrectly) { + const auto graph = R"IR( + graph(%x : Tensor): + %1 : NoneType = prim::Constant() + %2 : int = prim::Constant[value=2]() + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=3]() + %list : Tensor[] = aten::unbind(%x, %4) + %slice : Tensor[] = aten::slice(%list, %1, %2, %3) + %out.1 : Tensor, %out.2 : Tensor = prim::ListUnpack(%slice) + return (%out.1, %out.2))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in_x = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA}); + + auto jit_in_x = at::clone(in_x); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in_x}); + + auto trt_in_x = at::clone(in_x); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in_x}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=0]() + %start : int = prim::Constant[value=1]() + %end : int = prim::Constant[value=15]() + %step : int = prim::Constant[value=2]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamicBatchLargeEndConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=0]() + %start : int = prim::Constant[value=1]() + %end : int = prim::Constant[value=9223372036854775807]() + %step : int = prim::Constant[value=2]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamicNegStartBatchConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=0]() + %start : int = prim::Constant[value=-15]() + %end : int = prim::Constant[value=15]() + %step : int = prim::Constant[value=2]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamicNegEndBatchConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=0]() + %start : int = prim::Constant[value=1]() + %end : int = prim::Constant[value=-2]() + %step : int = prim::Constant[value=3]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamicNoneBatchConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %dim : int = prim::Constant[value=0]() + %start : None = prim::Constant() + %end : None = prim::Constant() + %step : int = prim::Constant[value=3]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamicConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=1]() + %start : int = prim::Constant[value=3]() + %end : int = prim::Constant[value=32]() + %step : int = prim::Constant[value=3]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in dim 1, slice in dim 1 + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamic2ConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=1]() + %start : int = prim::Constant[value=3]() + %end : int = prim::Constant[value=17]() + %step : int = prim::Constant[value=3]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch, slice in dim 1 + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} \ No newline at end of file diff --git a/tests/core/conversion/converters/test_split.cpp b/tests/core/conversion/converters/test_split.cpp new file mode 100644 index 0000000000..87bd5a16e0 --- /dev/null +++ b/tests/core/conversion/converters/test_split.cpp @@ -0,0 +1,174 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenSplitSizesInScriptingConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[1, 2]]() + %3 : int = prim::Constant[value=1]() + %4 : Tensor[] = aten::split(%x.1, %2, %3) + %x1.1 : Tensor, %x2.1 : Tensor = prim::ListUnpack(%4) + return (%x1.1, %x2.1))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ATenSplitSizesinTracingConvertsCorrectly) { + const auto graph = R"IR( + graph(%argument_1.1 : Tensor): + %2 : int[] = prim::Constant[value=[1, 2]]() + %3 : int = prim::Constant[value=1]() + %4 : Tensor[] = aten::split_with_sizes(%argument_1.1, %2, %3) + %5 : Tensor, %6 : Tensor = prim::ListUnpack(%4) + return (%5, %6))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ATenSplitFixedConvertsCorrectly) { + const auto graph = R"IR( + graph(%argument_1.1 : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : Tensor[] = aten::split(%argument_1.1, %2, %2) + %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3) + return (%4, %5, %6))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ATenSplitFixedHasRemainderConvertsCorrectly) { + const auto graph = R"IR( + graph(%argument_1.1 : Tensor): + %2 : int = prim::Constant[value=2]() + %2.1 : int = prim::Constant[value=1]() + %3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1) + %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3) + return (%4, %5, %6))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(1, 10, {1, 5, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ATenSplitAndAddConvertsCorrectly) { + const auto graph = R"IR( + graph(%argument_1.1 : Tensor): + %2 : int = prim::Constant[value=2]() + %2.1 : int = prim::Constant[value=1]() + %3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1) + %4 : Tensor, %5 : Tensor = prim::ListUnpack(%3) + %6 : Tensor = aten::add(%4, %5, %2.1) + return (%6))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ATenSplitNegativeDimsConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=1]() + %n1 : int = prim::Constant[value=-1]() + %3 : Tensor[] = aten::split(%x.1, %2, %n1) + %4 : Tensor, %5 : Tensor, %6 : Tensor, %7 : Tensor = prim::ListUnpack(%3) + return (%4, %5, %6, %7))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} \ No newline at end of file diff --git a/tests/core/conversion/converters/test_unbind.cpp b/tests/core/conversion/converters/test_unbind.cpp new file mode 100644 index 0000000000..0062a055bb --- /dev/null +++ b/tests/core/conversion/converters/test_unbind.cpp @@ -0,0 +1,88 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenUnbindConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : Tensor[] = aten::unbind(%x.1, %2) + %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3) + return (%o1.1, %o2.1))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {2, 3, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i]; + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=-1]() + %3 : Tensor[] = aten::unbind(%x.1, %2) + %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3) + return (%o1.1, %o2.1))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {5, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i]; + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ATenUnbindEvaluatedTensor) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %3 : int[] = aten::size(%x.1) + %z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2) + %5 : int = prim::Constant[value=-1]() + %6 : Tensor[] = aten::unbind(%z.1, %5) + %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%6) + return (%o1.1, %o2.1))IR"; + + auto in = at::randint(1, 10, {2}, {at::kCUDA}); + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i]; + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i].cuda(), trt, 2e-6)); + } +} \ No newline at end of file diff --git a/tests/core/conversion/converters/test_unpack.cpp b/tests/core/conversion/converters/test_unpack.cpp new file mode 100644 index 0000000000..858462b003 --- /dev/null +++ b/tests/core/conversion/converters/test_unpack.cpp @@ -0,0 +1,245 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/passes/common_subexpression_elimination.h" +#include "torch/torch.h" + +TEST(Converters, UnpackVarLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %5, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + torch_tensorrt::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackVarKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %5, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + torch_tensorrt::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackVarUnbiasedLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %4, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + torch_tensorrt::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackVarUnbiasedKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::var(%x.1, %6, %4, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + torch_tensorrt::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackStdLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %5, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + torch_tensorrt::core::lowering::passes::UnpackStd(g); + torch_tensorrt::core::lowering::passes::UnpackVar(g); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackStdKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %5, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + torch_tensorrt::core::lowering::passes::UnpackStd(g); + torch_tensorrt::core::lowering::passes::UnpackVar(g); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackStdUnbiasedLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %6 : int[] = prim::ListConstruct(%3) + %7 : Tensor = aten::std(%x.1, %6, %4, %4) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + torch_tensorrt::core::lowering::passes::UnpackStd(g); + torch_tensorrt::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackStdUnbiasedKeepDimsLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 + %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 + %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 + %one : int = prim::Constant[value=1]() + %6 : int[] = prim::ListConstruct(%3, %one) + %7 : Tensor = aten::std(%x.1, %6, %4, %5) # test_zeros.py:10:26 + return (%7))IR"; + + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + torch_tensorrt::core::lowering::passes::UnpackStd(g); + torch_tensorrt::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, UnpackVarUnbiasedNegAxisLowersCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %37 : bool = prim::Constant[value=1]() + %53 : int[] = prim::Constant[value=[-1]]() + %69 : Tensor = aten::var(%x.1, %53, %37, %37) + return (%69))IR"; + + auto in = at::randint(-5, 5, {2, 20, 768}, at::kCUDA).to(at::kFloat); + + auto jit_in = at::clone(in); + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + in = at::clone(in); + torch_tensorrt::core::lowering::passes::UnpackVar(g); + torch::jit::EliminateCommonSubexpression(g); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {jit_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} \ No newline at end of file diff --git a/tests/core/conversion/converters/test_where.cpp b/tests/core/conversion/converters/test_where.cpp new file mode 100644 index 0000000000..34b3696582 --- /dev/null +++ b/tests/core/conversion/converters/test_where.cpp @@ -0,0 +1,69 @@ +#include +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, WhereConvertsCorrectly) { + const auto graph = R"IR( + graph(%condition : Tensor, + %x : Tensor, + %y : Tensor): + %out : Tensor = aten::where(%condition, %x, %y) + return (%out))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto condition = at::randint(0, 2, {5, 5}, {at::kCUDA}).to(torch::kBool); + auto x = at::randn({5, 5}, {at::kCUDA}); + auto y = at::randn({5, 5}, {at::kCUDA}); + + auto jit_condition = at::clone(condition); + auto jit_x = at::clone(x); + auto jit_y = at::clone(y); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y}); + + auto trt_condition = at::clone(condition); + auto trt_x = at::clone(x); + auto trt_y = at::clone(y); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, WhereConvertsMismatchedShapesCorrectly) { + const auto graph = R"IR( + graph(%condition : Tensor, + %x : Tensor, + %y : Tensor): + %out : Tensor = aten::where(%condition, %x, %y) + return (%out))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + // As per Torch behavior, the input Tensors are expected to be broadcasted + // along their respective dimension in the largest-rank Tensor provided + auto condition = at::randint(0, 2, {7, 5}, {at::kCUDA}).to(torch::kBool); + auto x = at::randn({2, 7, 5}, {at::kCUDA}); + auto y = at::randn({5}, {at::kCUDA}); + + auto jit_condition = at::clone(condition); + auto jit_x = at::clone(x); + auto jit_y = at::clone(y); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y}); + + auto trt_condition = at::clone(condition); + auto trt_x = at::clone(x); + auto trt_y = at::clone(y); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} \ No newline at end of file