diff --git a/core/conversion/converters/impl/max.cpp b/core/conversion/converters/impl/max.cpp index adc8d06ed0..1f03b97260 100644 --- a/core/conversion/converters/impl/max.cpp +++ b/core/conversion/converters/impl/max.cpp @@ -18,17 +18,36 @@ auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patter [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto dim = args[1].unwrapToInt(); + auto keep_dims = args[2].unwrapToBool(); auto selfDim = util::toVec(self->getDimensions()); if (dim < 0) { dim = selfDim.size() + dim; } uint32_t shiftDim = 1 << dim; auto TopKOperation = nvinfer1::TopKOperation::kMAX; - auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim); - TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n); + auto topk_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim); + TORCHTRT_CHECK(topk_layer, "Unable to create max layer from node: " << *n); + auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions()); - auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); - auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1)); + nvinfer1::ITensor* out0; + nvinfer1::ITensor* out1; + if (!keep_dims) { + if (topk_dims[dim] == 1) { + auto squeeze_layer = ctx->net->addShuffle(*topk_layer->getOutput(0)); + squeeze_layer->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(0)->getDimensions(), dim)); + TORCHTRT_CHECK(squeeze_layer, "Unable to create squeeze_layer layer from node: " << *n); + out0 = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer->getOutput(0)); + + auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1)); + squeeze_layer_indices->setReshapeDimensions( + util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim)); + TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n); + out1 = ctx->AssociateValueAndTensor(n->outputs()[1], squeeze_layer_indices->getOutput(0)); + } + } else { + out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0)); + out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1)); + } LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions()); LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions()); diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index 199a31bfb4..bf5234dc19 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -212,6 +212,18 @@ TEST(Converters, ATenProdKeepDimsConvertsCorrectly) { test_body(graph, in); } +TEST(Converters, ATenMaxKeepDimsConvertsCorrectly) { + const auto graph = R"IR( + graph(%x : Tensor): + %2 : int = prim::Constant[value=-1]() + %3 : bool = prim::Constant[value=1]() + %keep.1 : Tensor, %6 : Tensor = aten::max(%x, %2, %3) + return (%keep.1, %6))IR"; + + auto in = at::randint(-5, 5, {4, 4}, at::kCUDA); + test_body(graph, in); +} + TEST(Converters, ATenMeanDimNegOneIndexConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor):