diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index f89cd59d73..da582c9548 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -359,8 +359,22 @@ auto element_wise_registrations TRTORCH_UNUSED = // Should implement self * other auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto mul = + nvinfer1::ILayer* mul = nullptr; + if (self->getType() ==nvinfer1::DataType::kBOOL || other->getType() == nvinfer1::DataType::kBOOL) { + auto self_id = ctx->net->addIdentity(*self); + auto other_id = ctx->net->addIdentity(*other); + if (self->getType() == nvinfer1::DataType::kBOOL) { + self_id->getOutput(0)->setType(nvinfer1::DataType::kINT32); + } + if (other->getType() == nvinfer1::DataType::kBOOL) { + other_id->getOutput(0)->setType(nvinfer1::DataType::kINT32); + } + mul = + add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self_id->getOutput(0), other_id->getOutput(0), util::node_info(n)); + } else { + mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); + } TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n); mul->setName(util::node_info(n).c_str());