diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index ac781697b7..73c28af1a1 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -26,6 +26,31 @@ nvinfer1::ITensor* clamp_util( return clamp_layer_out; } +void cast_int_int_div_tensors( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor*& a, + nvinfer1::ITensor*& b) { + // Torch automatically produces a float for int/int division + if (a->getType() == nvinfer1::DataType::kINT32 && b->getType() == nvinfer1::DataType::kINT32) { + a = castITensor(ctx, a, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_a_cast"); + b = castITensor(ctx, b, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_b_cast"); + } +} + +bool element_wise_divide_implementation( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* a, + nvinfer1::ITensor* b) { + cast_int_int_div_tensors(ctx, n, a, b); + auto element_wise = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, a, b, util::node_info(n)); + TORCHTRT_CHECK(element_wise, "Unable to create element_wise layer from node: " << *n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], element_wise->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + return true; +} + auto element_wise_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() .pattern( @@ -296,18 +321,9 @@ auto element_wise_registrations TORCHTRT_UNUSED = .pattern( {"aten::div.Tensor(Tensor self, Tensor other) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - // Should implement self / other auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); - - TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n); - - div->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0)); - - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); - return true; + return element_wise_divide_implementation(ctx, n, self, other); }}) .pattern( {"aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> (Tensor)", @@ -349,6 +365,7 @@ auto element_wise_registrations TORCHTRT_UNUSED = div = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kPROD, floor, sign->getOutput(0), util::node_info(n)); } else { + cast_int_int_div_tensors(ctx, n, self, other); div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); } @@ -365,42 +382,21 @@ auto element_wise_registrations TORCHTRT_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar()); - auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); - TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n); - - div->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); - return true; + return element_wise_divide_implementation(ctx, n, self, other); }}) .pattern( {"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - // TODO: Remove with functionalization auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); - auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); - - TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n); - - div->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0)); - - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); - return true; + return element_wise_divide_implementation(ctx, n, self, other); }}) .pattern( {"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar()); - auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); - TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n); - - div->setName(util::node_info(n).c_str()); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); - return true; + return element_wise_divide_implementation(ctx, n, self, other); }}) .pattern( {"aten::square(Tensor self) -> Tensor", diff --git a/tests/core/conversion/converters/test_div.cpp b/tests/core/conversion/converters/test_div.cpp index 7670132be4..0b879c1829 100644 --- a/tests/core/conversion/converters/test_div.cpp +++ b/tests/core/conversion/converters/test_div.cpp @@ -18,6 +18,7 @@ TEST(Converters, ATenDivConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt); } TEST(Converters, ATenDivWithScalarConvertsCorrectly) { @@ -29,6 +30,16 @@ TEST(Converters, ATenDivWithScalarConvertsCorrectly) { pointwise_test_helper(graph, true); } +TEST(Converters, ATenDivWithScalarIntConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %scalar : int = prim::Constant[value=2]() + %1 : Tensor = aten::div(%0, %scalar) + return (%1))IR"; + pointwise_test_helper(graph, true); + pointwise_test_helper(graph, true, false, {5}, {1}, false, at::kInt); +} + TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor, %1 : Tensor): @@ -42,6 +53,7 @@ TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) { pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true); pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt); pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt); } TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) { @@ -57,6 +69,7 @@ TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) { pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true); pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt); pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt); } TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) { @@ -70,6 +83,7 @@ TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}, true); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt); } TEST(Converters, ATenDivRoundingTruncWithIntsConvertsCorrectly) { @@ -107,6 +121,7 @@ TEST(Converters, ATenFloorDivideConvertsCorrectly) { pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt); pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); + pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt); } TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) {