From d94b4495c3ea63ac627ff33d5dc2a45f08e3baea Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 5 Oct 2022 14:34:21 -0700 Subject: [PATCH 1/2] feat: Add converter for sign unary operator - Add sign operator - Update test cases to test op - Ensure tests cover both int and float cases with negative and positive sign - Ensure tests cover cases where elements equal zero --- core/conversion/converters/impl/unary.cpp | 2 ++ .../core/conversion/converters/test_unary.cpp | 34 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/core/conversion/converters/impl/unary.cpp b/core/conversion/converters/impl/unary.cpp index c78602963c..2d04afa44f 100644 --- a/core/conversion/converters/impl/unary.cpp +++ b/core/conversion/converters/impl/unary.cpp @@ -95,6 +95,8 @@ convert(sqrt, kSQRT); convert(exp, kEXP); convert(neg, kNEG); convert(erf, kERF); +convert(sign, kSIGN); +convert(round, kROUND); convert(asinh, kASINH); convert(acosh, kACOSH); convert(atanh, kATANH); diff --git a/tests/core/conversion/converters/test_unary.cpp b/tests/core/conversion/converters/test_unary.cpp index 06f092ff36..cc6aa0420b 100644 --- a/tests/core/conversion/converters/test_unary.cpp +++ b/tests/core/conversion/converters/test_unary.cpp @@ -47,6 +47,40 @@ TEST(Converters, ATenReciprocalIntConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0])); } +TEST(Converters, ATenSignConvertsCorrectly) { + const auto graph = gen_test_graph("sign"); + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Resize range to [-10, 10] to span negative values + auto in = -20 * at::rand({2, 3, 5, 5}, {at::kCUDA}) + 10; + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenSignConvertsZerosCorrectly) { + const auto graph = gen_test_graph("sign"); + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Resize range to [-1, 1] to span negative values, cast to int to include zero + auto in = (-2 * at::rand({7, 3, 1, 5}, {at::kCUDA}) + 1).to(torch::kInt32); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + #define test_unary(unary, name) \ TEST(Converters, ATen##name##ConvertsCorrectly) { \ const auto graph = gen_test_graph(#unary); \ From 0342c646f563138968dede8f2dcfb8273bf755ea Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 5 Oct 2022 14:42:00 -0700 Subject: [PATCH 2/2] Remove round unary from PR --- core/conversion/converters/impl/unary.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/core/conversion/converters/impl/unary.cpp b/core/conversion/converters/impl/unary.cpp index 2d04afa44f..acac34cd7f 100644 --- a/core/conversion/converters/impl/unary.cpp +++ b/core/conversion/converters/impl/unary.cpp @@ -96,7 +96,6 @@ convert(exp, kEXP); convert(neg, kNEG); convert(erf, kERF); convert(sign, kSIGN); -convert(round, kROUND); convert(asinh, kASINH); convert(acosh, kACOSH); convert(atanh, kATANH);