diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index 3381e34def..014ac207bd 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -132,12 +132,34 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) nvinfer1::ILayer* new_layer; if (transposed) { + // Refer to + // https://github.com/onnx/onnx-tensorrt/blob/c3cfcbc8248c6bd007e6630af2085df5e4834b42/builtin_op_importers.cpp#L734 + nvinfer1::Dims begPadding = padding; + bool hasOutputPadding = false; + int nbSpatialDims = out_padding.nbDims; + // When there is out_padding, if padding is larger than out_padding, just adjust padding Or reduce out_padding as + // minimum as possible. + for (int i = 0; i < nbSpatialDims; ++i) { + if (padding.d[i] - out_padding.d[i] >= 0) { + padding.d[i] -= out_padding.d[i]; + out_padding.d[i] = 0; + } else { + // Reduce out_padding as possible. + out_padding.d[i] -= padding.d[i]; + padding.d[i] = 0; + hasOutputPadding = true; + } + } + // shape of deconvolution's weight: [in, out/groups, ...] - auto deconv = ctx->net->addDeconvolutionNd(*in, w.shape.d[1] * groups, w.kernel_shape, w.data, bias.data); + // If there is still output padding, remove the bias. Bias will be added below. + auto deconv = ctx->net->addDeconvolutionNd( + *in, w.shape.d[1] * groups, w.kernel_shape, w.data, hasOutputPadding ? nvinfer1::Weights{} : bias.data); TORCHTRT_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n); deconv->setStrideNd(stride); - deconv->setPaddingNd(padding); + deconv->setPrePadding(begPadding); + deconv->setPostPadding(padding); #if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR >= 1) deconv->setDilationNd(dilation); deconv->setNbGroups(groups); @@ -147,7 +169,56 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) TORCHTRT_CHECK(dilation.d[idx] == 1, "for deconv with dilation > 1, require TensorRT version >= 7.1"); } #endif - new_layer = deconv; + if (hasOutputPadding) { + LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); + + // Add padding layer + nvinfer1::ITensor* start; + nvinfer1::ITensor* totalPadding; + auto in_nbDims = orig_dims.nbDims; + std::vector startVec(in_nbDims, 0); + std::vector totalPaddingVec(in_nbDims, 0); + int32_t diff = in_nbDims - out_padding.nbDims; + for (int32_t i = diff; i < in_nbDims; i++) { + int32_t idx = i - diff; + startVec[i] = 0; // Don't need begin padding, only post padding + totalPaddingVec[i] = out_padding.d[idx]; + } + start = tensor_to_const(ctx, torch::tensor(startVec, torch::kInt32)); + totalPadding = tensor_to_const(ctx, torch::tensor(totalPaddingVec, torch::kInt32)); + + nvinfer1::ITensor* tensorPtr = deconv->getOutput(0); + nvinfer1::ITensor* deconvOutShape = ctx->net->addShape(*tensorPtr)->getOutput(0); + const auto size = + ctx->net->addElementWise(*deconvOutShape, *totalPadding, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); + + nvinfer1::Dims stride; + stride.nbDims = in_nbDims; + for (size_t i = 0; i < in_nbDims; i++) { + stride.d[i] = 1; + } + const auto& dummy = stride; + auto* sliceLayer = ctx->net->addSlice(*tensorPtr, dummy, dummy, stride); + sliceLayer->setInput(1, *start); + sliceLayer->setInput(2, *size); + sliceLayer->setMode(nvinfer1::SliceMode::kFILL); + tensorPtr = sliceLayer->getOutput(0); + + nvinfer1::Dims constantDims; + constantDims.nbDims = in_nbDims; + for (size_t i = 0; i < in_nbDims; i++) { + constantDims.d[i] = 1; + } + constantDims.d[diff - 1] = + bias.shape.d[0]; // Set C dimension to bias dim and other dimensions to 1 to enable broadcast + auto const_layer = ctx->net->addConstant(constantDims, bias.data); + auto add_bias_layer = + ctx->net->addElementWise(*tensorPtr, *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUM); + + new_layer = add_bias_layer; + } else { + new_layer = deconv; + } } else { // shape of convolution's weight: [out, in/groups, ...] auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data); diff --git a/tests/core/conversion/converters/test_conv_deconv.cpp b/tests/core/conversion/converters/test_conv_deconv.cpp index 676e7978e1..c77681cd80 100644 --- a/tests/core/conversion/converters/test_conv_deconv.cpp +++ b/tests/core/conversion/converters/test_conv_deconv.cpp @@ -570,6 +570,131 @@ TEST(Converters, ATenConvTransposeWithPaddingConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenConv1dTransposeWithPaddingOutPaddingConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(4, 3, 3, strides=[9, 3, 1])): + %2 : None = prim::Constant() + %3 : int = prim::Constant[value=2]() + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %7 : bool = prim::Constant[value=1]() + %8 : int[] = prim::ListConstruct(%3) + %9 : int[] = prim::ListConstruct(%4) + %10 : int[] = prim::ListConstruct(%5) + %11 : int[] = prim::ListConstruct(%6) + %12 : int = prim::Constant[value=1]() + %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7, %7) + return (%13))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA}); + auto w = at::randint(1, 2, {3, 4, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto jit_w = at::clone(w); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_w}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_w}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenConvTransposeWithPaddingOutPaddingConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(4, 3, 4, 4, strides=[48, 16, 4, 1]), + %2 : Float(4)): + %3 : int = prim::Constant[value=2]() + %4 : int = prim::Constant[value=2]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %7 : bool = prim::Constant[value=1]() + %8 : int[] = prim::ListConstruct(%3, %3) + %9 : int[] = prim::ListConstruct(%4, %4) + %10 : int[] = prim::ListConstruct(%5, %5) + %11 : int[] = prim::ListConstruct(%6, %6) + %12 : int = prim::Constant[value=1]() + %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7, %7) + return (%13))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA}); + auto w = at::randint(1, 10, {4, 3, 2, 2}, {at::kCUDA}); + auto b = at::randint(1, 10, {3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto jit_w = at::clone(w); + auto jit_b = at::clone(b); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_w, jit_b}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_w, trt_b}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0]; + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenConvTransposeOutPaddingBiggerThanPaddingConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(4, 3, 4, 4, strides=[48, 16, 4, 1]), + %2 : Float(4)): + %3 : int = prim::Constant[value=4]() + %4 : int = prim::Constant[value=2]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=3]() + %7 : bool = prim::Constant[value=1]() + %8 : int[] = prim::ListConstruct(%3, %3) + %9 : int[] = prim::ListConstruct(%4, %4) + %10 : int[] = prim::ListConstruct(%5, %5) + %11 : int[] = prim::ListConstruct(%6, %6) + %12 : int = prim::Constant[value=1]() + %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7, %7) + return (%13))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA}); + auto w = at::randint(1, 10, {4, 3, 2, 2}, {at::kCUDA}); + auto b = at::randint(1, 10, {3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto jit_w = at::clone(w); + auto jit_b = at::clone(b); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_w, jit_b}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_w, trt_b}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + TEST(Converters, ATenConvolutionWithGroupConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor,