diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index 32c7050289..8b08a5505a 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -390,6 +390,18 @@ auto element_wise_registrations TORCHTRT_UNUSED = LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}) + .pattern( + {"aten::square(Tensor self) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, self, util::node_info(n)); + TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n); + + mul->setName(util::node_info(n).c_str()); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + return true; + }}) .pattern( {"aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { diff --git a/tests/core/conversion/converters/test_element_wise.cpp b/tests/core/conversion/converters/test_element_wise.cpp index 3ecfdb2019..6b1c26bbab 100644 --- a/tests/core/conversion/converters/test_element_wise.cpp +++ b/tests/core/conversion/converters/test_element_wise.cpp @@ -145,6 +145,14 @@ TEST(Converters, ATenMulConvertsCorrectly) { pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat); } +TEST(Converters, ATenSquareConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : Tensor = aten::square(%0) + return (%1))IR"; + pointwise_test_helper(graph, true); +} + TEST(Converters, ATenMulWithScalarConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor):