diff --git a/core/conversion/converters/impl/max.cpp b/core/conversion/converters/impl/max.cpp index 3ccf165bbe..b07da3e311 100644 --- a/core/conversion/converters/impl/max.cpp +++ b/core/conversion/converters/impl/max.cpp @@ -22,6 +22,11 @@ bool min_max_dim(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvin if (dim < 0) { dim = selfDim.size() + dim; } + bool int_input = self->getType() == nvinfer1::DataType::kINT32; + if (int_input) { + LOG_DEBUG("topk layer does not support int32 inputs, adding cast to float"); + self = castITensor(ctx, self, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_input"); + } uint32_t reduce_axes_mask = 1 << dim; auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask); TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n); @@ -44,7 +49,10 @@ bool min_max_dim(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvin out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0)); out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1)); } - + if (int_input) { + LOG_DEBUG("Adding cast of topK layer output back to int32"); + out0 = castITensor(ctx, out0, nvinfer1::DataType::kINT32, util::node_info(n) + "_output"); + } LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions()); LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions()); @@ -59,6 +67,10 @@ bool arg_min_max(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvin if (dim < 0) { dim = selfDim.size() + dim; } + if (self->getType() == nvinfer1::DataType::kINT32) { + LOG_DEBUG("topk layer does not support int32 inputs, adding cast to float"); + self = castITensor(ctx, self, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_input"); + } uint32_t reduce_axes_mask = 1 << dim; auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask); TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n); diff --git a/tests/core/conversion/converters/test_max.cpp b/tests/core/conversion/converters/test_max.cpp index dfc2432c24..30111edf36 100644 --- a/tests/core/conversion/converters/test_max.cpp +++ b/tests/core/conversion/converters/test_max.cpp @@ -29,6 +29,29 @@ TEST(Converters, ATenMaxDimConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6)); } +TEST(Converters, ATenMaxDimIntInputConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=0]() + %4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3) + return (%4, %5))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(-5, 5, {5, 5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6)); +} + TEST(Converters, ATenMinDimConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): @@ -77,6 +100,28 @@ TEST(Converters, ATenArgMaxConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenArgMaxIntInputConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=0]() + %4 : Tensor = aten::argmax(%x.1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(-5, 5, {5, 5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenArgMaxKeepdimConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):