From 3db25f8aa34b466f4160e7ba6aedfc08a1d8a264 Mon Sep 17 00:00:00 2001 From: inocsin Date: Fri, 5 Feb 2021 20:07:44 +0800 Subject: [PATCH] convert the output type of bool operation to int Signed-off-by: inocsin --- .../converters/impl/element_wise.cpp | 76 ++++++++++--------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index f89cd59d73..9495cd7e89 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -9,6 +9,14 @@ namespace converters { namespace impl { namespace { +nvinfer1::IIdentityLayer* convert_to_int(nvinfer1::ILayer* op, ConversionCtx* ctx, const torch::jit::Node* n) { + auto identity = ctx->net->addIdentity(*op->getOutput(0)); + identity->setOutputType(0, nvinfer1::DataType::kINT32); + TRTORCH_CHECK(identity, "Unable to create Identity layer from node: " << *n); + identity->setName(util::node_info(n).c_str()); + return identity; +} + nvinfer1::ILayer* add_elementwise( ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, @@ -535,12 +543,12 @@ auto element_wise_registrations TRTORCH_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto gt = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n)); + auto gt = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n) + "_gt"); TRTORCH_CHECK(gt, "Unable to create greater layer from node: " << *n); - gt->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gt->getOutput(0)); + auto identity = convert_to_int(gt, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -549,12 +557,12 @@ auto element_wise_registrations TRTORCH_UNUSED = auto self = args[0].ITensorOrFreeze(ctx); auto otherScalar = args[1].unwrapToScalar().to(); auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); - auto gt = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n)); + auto gt = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n) + "_gt"); TRTORCH_CHECK(gt, "Unable to create greater layer from node: " << *n); - gt->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gt->getOutput(0)); + auto identity = convert_to_int(gt, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -562,12 +570,12 @@ auto element_wise_registrations TRTORCH_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto lt = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n)); + auto lt = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n) + "_lt"); TRTORCH_CHECK(lt, "Unable to create less layer from node: " << *n); - lt->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], lt->getOutput(0)); + auto identity = convert_to_int(lt, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -576,12 +584,12 @@ auto element_wise_registrations TRTORCH_UNUSED = auto self = args[0].ITensorOrFreeze(ctx); auto otherScalar = args[1].unwrapToScalar().to(); auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); - auto lt = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n)); + auto lt = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n) + "_lt"); TRTORCH_CHECK(lt, "Unable to create less layer from node: " << *n); - lt->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], lt->getOutput(0)); + auto identity = convert_to_int(lt, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -589,12 +597,12 @@ auto element_wise_registrations TRTORCH_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto eq = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n)); + auto eq = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n) + "_eq"); TRTORCH_CHECK(eq, "Unable to create equal layer from node: " << *n); - eq->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], eq->getOutput(0)); + auto identity = convert_to_int(eq, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -603,12 +611,12 @@ auto element_wise_registrations TRTORCH_UNUSED = auto self = args[0].ITensorOrFreeze(ctx); auto otherScalar = args[1].unwrapToScalar().to(); auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); - auto eq = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n)); + auto eq = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n) + "_eq"); TRTORCH_CHECK(eq, "Unable to create equal layer from node: " << *n); - eq->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], eq->getOutput(0)); + auto identity = convert_to_int(eq, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -629,9 +637,8 @@ auto element_wise_registrations TRTORCH_UNUSED = *greater->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR); TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n); - or_op->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0)); - + auto identity = convert_to_int(or_op, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -653,9 +660,8 @@ auto element_wise_registrations TRTORCH_UNUSED = *greater->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR); TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n); - or_op->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0)); - + auto identity = convert_to_int(or_op, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -676,9 +682,8 @@ auto element_wise_registrations TRTORCH_UNUSED = *less->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR); TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n); - or_op->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0)); - + auto identity = convert_to_int(or_op, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -700,9 +705,8 @@ auto element_wise_registrations TRTORCH_UNUSED = *less->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR); TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n); - or_op->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0)); - + auto identity = convert_to_int(or_op, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }});