From e7bd9acac4ec3f3f116ab1908db48ede74f19e6d Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Thu, 1 Dec 2022 10:48:32 -0800 Subject: [PATCH 1/2] Support aten::sum with bool tensor input TensorRT sum layers do not support bool tensor inputs. Add support by casting the input to int32. Fixes # (issue) Please delete options that are not relevant and/or add your own. - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - This change requires a documentation update - [ ] My code follows the style guidelines of this project (You can use the linters) - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas and hacks - [ ] I have made corresponding changes to the documentation - [ ] I have added tests to verify my fix or my feature - [ ] New and existing unit tests pass locally with my changes - [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified --- core/conversion/converters/impl/reduce.cpp | 7 +++++++ tests/core/conversion/converters/test_reduce.cpp | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index b3db09ffd7..871c8a3138 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -71,6 +71,13 @@ auto reduce_registrations TORCHTRT_UNUSED = auto in_tensor = args[0].ITensorOrFreeze(ctx); auto in_dims = util::toVec(in_tensor->getDimensions()); LOG_WARNING("Sum Converter disregards dtype"); + + if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { + LOG_DEBUG( + "Found type " << in_tensor->getType() << " in aten::sum, casting to " + << nvinfer1::DataType::kINT32 << " for compatibility."); + in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32); + } uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1); diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index 4699427d5e..3bcef3db77 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -137,6 +137,16 @@ converts_keepdims_correctly(mean, Mean); #undef converts_keepdims_correctly +TEST(Converters, ATenSumBoolConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %4 : None = prim::Constant() + %5 : Tensor = aten::sum(%0, %4) + return (%5))IR"; + auto in = at::randint(-1, 2, {4, 4, 4}, at::kCUDA).to(at::kBool); + test_body(graph, in); +} + TEST(Converters, ATenSumDimNegOneIndexConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor): From 20a0716ca21817c06b0d7bb4fd24c9c25b8a4bdf Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Thu, 1 Dec 2022 11:42:51 -0800 Subject: [PATCH 2/2] lint --- core/conversion/converters/impl/reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index 871c8a3138..68d4c41fa5 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -71,7 +71,7 @@ auto reduce_registrations TORCHTRT_UNUSED = auto in_tensor = args[0].ITensorOrFreeze(ctx); auto in_dims = util::toVec(in_tensor->getDimensions()); LOG_WARNING("Sum Converter disregards dtype"); - + if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { LOG_DEBUG( "Found type " << in_tensor->getType() << " in aten::sum, casting to "