diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 7c116b6c2c..3dcd2e9d80 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -85,10 +85,10 @@ nvinfer1::ILayer* add_elementwise( const std::string& name) { if (self->getType() == nvinfer1::DataType::kFLOAT && other->getType() == nvinfer1::DataType::kINT32) { LOG_DEBUG("Type mismatch, casting other to " << self->getType()); - other = castITensor(ctx, other, self->getType()); + other = castITensor(ctx, other, self->getType(), name); } else if (self->getType() == nvinfer1::DataType::kINT32 && other->getType() == nvinfer1::DataType::kFLOAT) { LOG_DEBUG("Type mismatch, casting self to " << other->getType()); - self = castITensor(ctx, self, other->getType()); + self = castITensor(ctx, self, other->getType(), name); } // ensure self to have larger number of dimension bool swapSelfOther = false; @@ -106,13 +106,13 @@ nvinfer1::ILayer* add_elementwise( LOG_DEBUG( "Element-wise op type promotion adding cast from " << self->getType() << " to " << promo_type << " for layer " << name); - self = castITensor(ctx, self, promo_type); + self = castITensor(ctx, self, promo_type, name); } if (other->getType() != promo_type) { LOG_DEBUG( "Element-wise op type promotion adding cast from " << other->getType() << " to " << promo_type << " for layer " << name); - other = castITensor(ctx, other, promo_type); + other = castITensor(ctx, other, promo_type, name); } } diff --git a/tests/core/conversion/converters/test_add_sub_mul.cpp b/tests/core/conversion/converters/test_add_sub_mul.cpp index 3631a92a44..997b2b20ab 100644 --- a/tests/core/conversion/converters/test_add_sub_mul.cpp +++ b/tests/core/conversion/converters/test_add_sub_mul.cpp @@ -178,3 +178,14 @@ TEST(Converters, ATenPowScalarConvertsCorrectly) { return (%3))IR"; pointwise_test_helper(graph, true); } + +TEST(Converters, ElementWiseTypePromotionDisambiguatesCastNames) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : Tensor = aten::add(%0, %1, %2) + %4 : Tensor = aten::add(%0, %1, %2) + %5 : Tensor = aten::add(%3, %4, %2) + return (%5))IR"; + pointwise_test_helper(graph, false, false, {4, 3, 3, 3}, {4, 3, 3, 3}, false, at::kInt, at::kFloat); +}