diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index fd11246855..0b347c6647 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -166,11 +166,11 @@ auto element_wise_registrations TORCHTRT_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Should implement self - alpha * other auto self = args[0].ITensorOrFreeze(ctx); - auto scalar = args[2].unwrapToScalar().to(); auto other = args[1].ITensorOrFreeze(ctx); + auto scalar = args[2].unwrapToScalar(); - if (1 != scalar) { - auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar})); + if (1 != scalar.to()) { + auto alphaTensor = scalar_to_tensor(ctx, scalar); auto scaleLayer = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kPROD, @@ -214,11 +214,11 @@ auto element_wise_registrations TORCHTRT_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Should implement self - alpha * other auto self = args[0].ITensorOrFreeze(ctx); - auto scalar = args[2].unwrapToScalar().to(); auto other = args[1].ITensorOrFreeze(ctx); + auto scalar = args[2].unwrapToScalar(); - if (1 != scalar) { - auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar})); + if (1 != scalar.to()) { + auto alphaTensor = scalar_to_tensor(ctx, scalar); auto scaleLayer = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kPROD, @@ -351,8 +351,7 @@ auto element_wise_registrations TORCHTRT_UNUSED = {"aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); - auto otherScalar = args[1].unwrapToScalar().to(); - auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); + auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar()); auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n); @@ -381,8 +380,7 @@ auto element_wise_registrations TORCHTRT_UNUSED = {"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); - auto otherScalar = args[1].unwrapToScalar().to(); - auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); + auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar()); auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n); @@ -481,18 +479,12 @@ auto element_wise_registrations TORCHTRT_UNUSED = {"aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); - auto scalar = args[1].unwrapToScalar(); - nvinfer1::ITensor* scalar_tensor; - if (self->getType() == nvinfer1::DataType::kFLOAT || self->getType() == nvinfer1::DataType::kHALF) { - scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to()})); - } else { - scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to()})); - } + auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar()); auto equal = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, - scalar_tensor, + other, util::node_info(n) + std::string("is_equal")); TORCHTRT_CHECK(equal, "Unable to create elementwise equal layer from node: " << *n); // XOR with ones negates and produces not_equal result @@ -534,8 +526,7 @@ auto element_wise_registrations TORCHTRT_UNUSED = {"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); - auto exponentScalar = args[1].unwrapToScalar().to(); - auto exponent = tensor_to_const(ctx, torch::tensor({exponentScalar})); + auto exponent = scalar_to_tensor(ctx, args[1].unwrapToScalar()); auto pow = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPOW, self, exponent, util::node_info(n)); TORCHTRT_CHECK(pow, "Unable to create Power layer from node: " << *n); @@ -681,9 +672,9 @@ auto element_wise_registrations TORCHTRT_UNUSED = {"aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); - auto otherScalar = args[1].unwrapToScalar().to(); - auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); + auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar()); if (self->getType() == nvinfer1::DataType::kBOOL) { + auto otherScalar = args[1].unwrapToScalar().to(); if (otherScalar == 0 || otherScalar == 1) { LOG_DEBUG("Since input tensor is type bool, casting input tensor and scalar to int32"); other = castITensor(ctx, other, nvinfer1::DataType::kINT32);