diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index b3db09ffd7..68d4c41fa5 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -72,6 +72,13 @@ auto reduce_registrations TORCHTRT_UNUSED = auto in_dims = util::toVec(in_tensor->getDimensions()); LOG_WARNING("Sum Converter disregards dtype"); + if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { + LOG_DEBUG( + "Found type " << in_tensor->getType() << " in aten::sum, casting to " + << nvinfer1::DataType::kINT32 << " for compatibility."); + in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32); + } + uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1); auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, false); diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index 4699427d5e..3bcef3db77 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -137,6 +137,16 @@ converts_keepdims_correctly(mean, Mean); #undef converts_keepdims_correctly +TEST(Converters, ATenSumBoolConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %4 : None = prim::Constant() + %5 : Tensor = aten::sum(%0, %4) + return (%5))IR"; + auto in = at::randint(-1, 2, {4, 4, 4}, at::kCUDA).to(at::kBool); + test_body(graph, in); +} + TEST(Converters, ATenSumDimNegOneIndexConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor):