diff --git a/core/conversion/converters/impl/unary.cpp b/core/conversion/converters/impl/unary.cpp index c78602963c..acac34cd7f 100644 --- a/core/conversion/converters/impl/unary.cpp +++ b/core/conversion/converters/impl/unary.cpp @@ -95,6 +95,7 @@ convert(sqrt, kSQRT); convert(exp, kEXP); convert(neg, kNEG); convert(erf, kERF); +convert(sign, kSIGN); convert(asinh, kASINH); convert(acosh, kACOSH); convert(atanh, kATANH); diff --git a/tests/core/conversion/converters/test_unary.cpp b/tests/core/conversion/converters/test_unary.cpp index 06f092ff36..cc6aa0420b 100644 --- a/tests/core/conversion/converters/test_unary.cpp +++ b/tests/core/conversion/converters/test_unary.cpp @@ -47,6 +47,40 @@ TEST(Converters, ATenReciprocalIntConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0])); } +TEST(Converters, ATenSignConvertsCorrectly) { + const auto graph = gen_test_graph("sign"); + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Resize range to [-10, 10] to span negative values + auto in = -20 * at::rand({2, 3, 5, 5}, {at::kCUDA}) + 10; + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenSignConvertsZerosCorrectly) { + const auto graph = gen_test_graph("sign"); + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Resize range to [-1, 1] to span negative values, cast to int to include zero + auto in = (-2 * at::rand({7, 3, 1, 5}, {at::kCUDA}) + 1).to(torch::kInt32); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + #define test_unary(unary, name) \ TEST(Converters, ATen##name##ConvertsCorrectly) { \ const auto graph = gen_test_graph(#unary); \