diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index 991d11fc58..8834f6ca6c 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -9,6 +9,14 @@ namespace converters { namespace impl { namespace { +nvinfer1::IIdentityLayer* convert_to_op_input_type(nvinfer1::ILayer* op, ConversionCtx* ctx, const torch::jit::Node* n) { + auto identity = ctx->net->addIdentity(*op->getOutput(0)); + identity->setOutputType(0, op->getInput(0)->getType()); + TRTORCH_CHECK(identity, "Unable to create Identity layer from node: " << *n); + identity->setName(util::node_info(n).c_str()); + return identity; +} + nvinfer1::ILayer* add_elementwise( ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, @@ -340,7 +348,7 @@ auto element_wise_registrations TRTORCH_UNUSED = LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) - .pattern({"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", + .pattern({"aten::div_.Scalar(Tensor self, Scalar other) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto otherScalar = args[1].unwrapToScalar().to(); @@ -535,12 +543,12 @@ auto element_wise_registrations TRTORCH_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto gt = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n)); + auto gt = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n) + "_gt"); TRTORCH_CHECK(gt, "Unable to create greater layer from node: " << *n); - gt->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gt->getOutput(0)); + auto identity = convert_to_op_input_type(gt, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -549,12 +557,12 @@ auto element_wise_registrations TRTORCH_UNUSED = auto self = args[0].ITensorOrFreeze(ctx); auto otherScalar = args[1].unwrapToScalar().to(); auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); - auto gt = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n)); + auto gt = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n) + "_gt"); TRTORCH_CHECK(gt, "Unable to create greater layer from node: " << *n); - gt->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gt->getOutput(0)); + auto identity = convert_to_op_input_type(gt, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -562,12 +570,12 @@ auto element_wise_registrations TRTORCH_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto lt = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n)); + auto lt = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n) + "_lt"); TRTORCH_CHECK(lt, "Unable to create less layer from node: " << *n); - lt->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], lt->getOutput(0)); + auto identity = convert_to_op_input_type(lt, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -576,12 +584,12 @@ auto element_wise_registrations TRTORCH_UNUSED = auto self = args[0].ITensorOrFreeze(ctx); auto otherScalar = args[1].unwrapToScalar().to(); auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); - auto lt = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n)); + auto lt = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n) + "_lt"); TRTORCH_CHECK(lt, "Unable to create less layer from node: " << *n); - lt->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], lt->getOutput(0)); + auto identity = convert_to_op_input_type(lt, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -589,12 +597,12 @@ auto element_wise_registrations TRTORCH_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto eq = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n)); + auto eq = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n) + "_eq"); TRTORCH_CHECK(eq, "Unable to create equal layer from node: " << *n); - eq->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], eq->getOutput(0)); + auto identity = convert_to_op_input_type(eq, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -603,12 +611,12 @@ auto element_wise_registrations TRTORCH_UNUSED = auto self = args[0].ITensorOrFreeze(ctx); auto otherScalar = args[1].unwrapToScalar().to(); auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); - auto eq = - add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n)); + auto eq = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n) + "_eq"); TRTORCH_CHECK(eq, "Unable to create equal layer from node: " << *n); - eq->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], eq->getOutput(0)); + auto identity = convert_to_op_input_type(eq, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -629,9 +637,8 @@ auto element_wise_registrations TRTORCH_UNUSED = *greater->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR); TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n); - or_op->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0)); - + auto identity = convert_to_op_input_type(or_op, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -653,9 +660,8 @@ auto element_wise_registrations TRTORCH_UNUSED = *greater->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR); TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n); - or_op->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0)); - + auto identity = convert_to_op_input_type(or_op, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -676,9 +682,8 @@ auto element_wise_registrations TRTORCH_UNUSED = *less->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR); TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n); - or_op->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0)); - + auto identity = convert_to_op_input_type(or_op, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) @@ -700,9 +705,8 @@ auto element_wise_registrations TRTORCH_UNUSED = *less->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR); TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n); - or_op->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0)); - + auto identity = convert_to_op_input_type(or_op, ctx, n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }});