From 14870e022068ae6a2118c82e24267e9249201ccd Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 11 Dec 2023 16:03:17 -0800 Subject: [PATCH 1/6] chore: fix deconv padding Signed-off-by: Dheeraj Peri --- .../converters/impl/conv_deconv.cpp | 137 ++++++++++-------- 1 file changed, 78 insertions(+), 59 deletions(-) diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index fc0e97b7ee..8d8e9a3c88 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -10,6 +10,74 @@ namespace converters { namespace impl { namespace { +void add_output_padding(nvinfer1::Dims& padding, nvinfer1::Dims& out_padding, bool& has_output_padding) { + int nbSpatialDims = out_padding.nbDims; + // When there is out_padding, if padding is larger than out_padding, just adjust padding Or reduce out_padding as + // minimum as possible. + for (int i = 0; i < nbSpatialDims; ++i) { + if (padding.d[i] - out_padding.d[i] >= 0) { + padding.d[i] -= out_padding.d[i]; + out_padding.d[i] = 0; + } else { + // Reduce out_padding as possible. + out_padding.d[i] -= padding.d[i]; + padding.d[i] = 0; + has_output_padding = true; + } + } +} + +nvinfer1::ILayer* add_bias_layer( + ConversionCtx* ctx, + nvinfer1::ITensor* input_tensor, + nvinfer1::Dims& input_dims, + nvinfer1::Dims& output_padding, + Weights& bias) { + nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0); + // Add padding layer + nvinfer1::ITensor* start; + nvinfer1::ITensor* totalPadding; + auto in_nbDims = input_dims.nbDims; + std::vector startVec(in_nbDims, 0); + std::vector totalPaddingVec(in_nbDims, 0); + int32_t diff = in_nbDims - output_padding.nbDims; + for (int32_t i = diff; i < in_nbDims; i++) { + int32_t idx = i - diff; + startVec[i] = 0; // Don't need begin padding, only post padding + totalPaddingVec[i] = output_padding.d[idx]; + } + start = tensor_to_const(ctx, torch::tensor(startVec, torch::kInt32)); + totalPadding = tensor_to_const(ctx, torch::tensor(totalPaddingVec, torch::kInt32)); + + const auto size = + ctx->net->addElementWise(*input_shape, *totalPadding, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); + + nvinfer1::Dims stride; + stride.nbDims = in_nbDims; + for (int64_t i = 0; i < in_nbDims; i++) { + stride.d[i] = 1; + } + const auto& dummy = stride; + auto* sliceLayer = ctx->net->addSlice(*input_tensor, dummy, dummy, stride); + sliceLayer->setInput(1, *start); + sliceLayer->setInput(2, *size); + sliceLayer->setMode(nvinfer1::SliceMode::kFILL); + nvinfer1::ITensor* slide_output = sliceLayer->getOutput(0); + + nvinfer1::Dims constantDims; + constantDims.nbDims = in_nbDims; + for (int64_t i = 0; i < in_nbDims; i++) { + constantDims.d[i] = 1; + } + constantDims.d[diff - 1] = + bias.shape.d[0]; // Set C dimension to bias dim and other dimensions to 1 to enable broadcast + auto const_layer = ctx->net->addConstant(constantDims, bias.data); + auto bias_layer = + ctx->net->addElementWise(*slide_output, *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUM); + + return bias_layer; +} + bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { // Input to conv/deconv auto in = args[0].ITensor(); @@ -76,12 +144,19 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) nvinfer1::ILayer* layer = nullptr; if (transposed) { + // Fix padding based on output_padding provided + nvinfer1::Dims begPadding = padding; + bool hasOutputPadding = false; + add_output_padding(padding, out_padding, hasOutputPadding); + nvinfer1::IDeconvolutionLayer* deconvLayer = ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data); deconvLayer->setStrideNd(stride); deconvLayer->setDilationNd(dilation); deconvLayer->setNbGroups(groups); - deconvLayer->setPaddingNd(padding); + deconvLayer->setPrePadding(begPadding); + deconvLayer->setPostPadding(padding); + // Set deconv kernel weights deconvLayer->setInput(1, *kernel); TORCHTRT_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n); @@ -155,20 +230,7 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) // https://github.com/onnx/onnx-tensorrt/blob/c3cfcbc8248c6bd007e6630af2085df5e4834b42/builtin_op_importers.cpp#L734 nvinfer1::Dims begPadding = padding; bool hasOutputPadding = false; - int nbSpatialDims = out_padding.nbDims; - // When there is out_padding, if padding is larger than out_padding, just adjust padding Or reduce out_padding as - // minimum as possible. - for (int i = 0; i < nbSpatialDims; ++i) { - if (padding.d[i] - out_padding.d[i] >= 0) { - padding.d[i] -= out_padding.d[i]; - out_padding.d[i] = 0; - } else { - // Reduce out_padding as possible. - out_padding.d[i] -= padding.d[i]; - padding.d[i] = 0; - hasOutputPadding = true; - } - } + add_output_padding(padding, out_padding, hasOutputPadding); // shape of deconvolution's weight: [in, out/groups, ...] // If there is still output padding, remove the bias. Bias will be added below. @@ -190,51 +252,8 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) #endif if (hasOutputPadding) { LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); - - // Add padding layer - nvinfer1::ITensor* start; - nvinfer1::ITensor* totalPadding; - auto in_nbDims = orig_dims.nbDims; - std::vector startVec(in_nbDims, 0); - std::vector totalPaddingVec(in_nbDims, 0); - int32_t diff = in_nbDims - out_padding.nbDims; - for (int32_t i = diff; i < in_nbDims; i++) { - int32_t idx = i - diff; - startVec[i] = 0; // Don't need begin padding, only post padding - totalPaddingVec[i] = out_padding.d[idx]; - } - start = tensor_to_const(ctx, torch::tensor(startVec, torch::kInt32)); - totalPadding = tensor_to_const(ctx, torch::tensor(totalPaddingVec, torch::kInt32)); - nvinfer1::ITensor* tensorPtr = deconv->getOutput(0); - nvinfer1::ITensor* deconvOutShape = ctx->net->addShape(*tensorPtr)->getOutput(0); - const auto size = - ctx->net->addElementWise(*deconvOutShape, *totalPadding, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - - nvinfer1::Dims stride; - stride.nbDims = in_nbDims; - for (int64_t i = 0; i < in_nbDims; i++) { - stride.d[i] = 1; - } - const auto& dummy = stride; - auto* sliceLayer = ctx->net->addSlice(*tensorPtr, dummy, dummy, stride); - sliceLayer->setInput(1, *start); - sliceLayer->setInput(2, *size); - sliceLayer->setMode(nvinfer1::SliceMode::kFILL); - tensorPtr = sliceLayer->getOutput(0); - - nvinfer1::Dims constantDims; - constantDims.nbDims = in_nbDims; - for (int64_t i = 0; i < in_nbDims; i++) { - constantDims.d[i] = 1; - } - constantDims.d[diff - 1] = - bias.shape.d[0]; // Set C dimension to bias dim and other dimensions to 1 to enable broadcast - auto const_layer = ctx->net->addConstant(constantDims, bias.data); - auto add_bias_layer = - ctx->net->addElementWise(*tensorPtr, *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUM); - - new_layer = add_bias_layer; + new_layer = add_bias_layer(ctx, tensorPtr, orig_dims, out_padding, bias); } else { new_layer = deconv; } From 3375e9e3d4612b345fb5d4484804a4d6a096b6fc Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 11 Dec 2023 16:09:12 -0800 Subject: [PATCH 2/6] chore: add bias layer for QAT deconv Signed-off-by: Dheeraj Peri --- .../converters/impl/conv_deconv.cpp | 238 +++++++++--------- 1 file changed, 121 insertions(+), 117 deletions(-) diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index 8d8e9a3c88..fe8c49e86f 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -149,8 +149,8 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) bool hasOutputPadding = false; add_output_padding(padding, out_padding, hasOutputPadding); - nvinfer1::IDeconvolutionLayer* deconvLayer = - ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data); + nvinfer1::IDeconvolutionLayer* deconvLayer = ctx->net->addDeconvolutionNd( + *in, kernel_dims.d[0], filter_dim, kernel_weights, hasOutputPadding ? nvinfer1::Weights{} : bias.data); deconvLayer->setStrideNd(stride); deconvLayer->setDilationNd(dilation); deconvLayer->setNbGroups(groups); @@ -161,151 +161,155 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) deconvLayer->setInput(1, *kernel); TORCHTRT_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n); layer = deconvLayer; - } else { - nvinfer1::IConvolutionLayer* convLayer = - ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data); - convLayer->setStrideNd(stride); - convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); - convLayer->setPaddingNd(padding); - convLayer->setPostPadding(out_padding); - convLayer->setDilationNd(dilation); - convLayer->setNbGroups(groups); + if (hasOutputPadding) { + LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); + nvinfer1::ITensor* tensorPtr = deconvLayer->getOutput(0); + layer = add_bias_layer(ctx, tensorPtr, orig_dims, out_padding, bias); + } else { + nvinfer1::IConvolutionLayer* convLayer = + ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data); + convLayer->setStrideNd(stride); + convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); + convLayer->setPaddingNd(padding); + convLayer->setPostPadding(out_padding); + convLayer->setDilationNd(dilation); + convLayer->setNbGroups(groups); - // Set conv kernel weights - convLayer->setInput(1, *kernel); - layer = convLayer; - } + // Set conv kernel weights + convLayer->setInput(1, *kernel); + layer = convLayer; + } - ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << layer->getOutput(0)->getDimensions()); - return true; - } + ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << layer->getOutput(0)->getDimensions()); + return true; + } - auto w = Weights(ctx, args[1].unwrapToTensor()); - // TODO: Remove this when conv3d with kernel size=1 bug is fixed. - // Github issue: https://github.com/pytorch/TensorRT/issues/1445 - bool is_kernel_size_one = true; - bool is_3d_kernel = w.kernel_shape.nbDims == 3; - for (int64_t i = 0; i < w.kernel_shape.nbDims; i++) { - if (w.kernel_shape.d[i] != 1.0f) { - is_kernel_size_one = false; + auto w = Weights(ctx, args[1].unwrapToTensor()); + // TODO: Remove this when conv3d with kernel size=1 bug is fixed. + // Github issue: https://github.com/pytorch/TensorRT/issues/1445 + bool is_kernel_size_one = true; + bool is_3d_kernel = w.kernel_shape.nbDims == 3; + for (int64_t i = 0; i < w.kernel_shape.nbDims; i++) { + if (w.kernel_shape.d[i] != 1.0f) { + is_kernel_size_one = false; + } } - } - if (is_kernel_size_one && is_3d_kernel) { - LOG_WARNING( - "Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \ + if (is_kernel_size_one && is_3d_kernel) { + LOG_WARNING( + "Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \ Github issue: https://github.com/pytorch/TensorRT/issues/1445. Other conv variants do not have this issue."); - } - auto dims = in->getDimensions(); - auto orig_dims = dims; - LOG_DEBUG("Input dims: " << orig_dims); - LOG_DEBUG("Weights: " << w); - LOG_DEBUG("stride: " << stride); - LOG_DEBUG("padding: " << padding); - LOG_DEBUG("dilation: " << dilation); - LOG_DEBUG("out_padding: " << out_padding); - LOG_DEBUG("groups: " << groups); + } + auto dims = in->getDimensions(); + auto orig_dims = dims; + LOG_DEBUG("Input dims: " << orig_dims); + LOG_DEBUG("Weights: " << w); + LOG_DEBUG("stride: " << stride); + LOG_DEBUG("padding: " << padding); + LOG_DEBUG("dilation: " << dilation); + LOG_DEBUG("out_padding: " << out_padding); + LOG_DEBUG("groups: " << groups); - TORCHTRT_CHECK(orig_dims.nbDims > 2, "Unable to create convolution layer from node: " << *n); + TORCHTRT_CHECK(orig_dims.nbDims > 2, "Unable to create convolution layer from node: " << *n); - bool expandDims = (orig_dims.nbDims < 4); - if (expandDims) { - in = addPadding(ctx, n, in, 4); - dims = in->getDimensions(); - LOG_DEBUG("Reshaped Input dims: " << dims); - } - if (w.shape.nbDims < 4) { - for (int i = w.shape.nbDims; i < 4; ++i) { - w.shape.d[i] = 1; + bool expandDims = (orig_dims.nbDims < 4); + if (expandDims) { + in = addPadding(ctx, n, in, 4); + dims = in->getDimensions(); + LOG_DEBUG("Reshaped Input dims: " << dims); + } + if (w.shape.nbDims < 4) { + for (int i = w.shape.nbDims; i < 4; ++i) { + w.shape.d[i] = 1; + } + w.shape.nbDims = 4; + w.kernel_shape.nbDims = 2; + w.kernel_shape.d[1] = 1; + LOG_DEBUG("Reshaped Weights: " << w); } - w.shape.nbDims = 4; - w.kernel_shape.nbDims = 2; - w.kernel_shape.d[1] = 1; - LOG_DEBUG("Reshaped Weights: " << w); - } - nvinfer1::ILayer* new_layer; - if (transposed) { - // Refer to - // https://github.com/onnx/onnx-tensorrt/blob/c3cfcbc8248c6bd007e6630af2085df5e4834b42/builtin_op_importers.cpp#L734 - nvinfer1::Dims begPadding = padding; - bool hasOutputPadding = false; - add_output_padding(padding, out_padding, hasOutputPadding); + nvinfer1::ILayer* new_layer; + if (transposed) { + // Refer to + // https://github.com/onnx/onnx-tensorrt/blob/c3cfcbc8248c6bd007e6630af2085df5e4834b42/builtin_op_importers.cpp#L734 + nvinfer1::Dims begPadding = padding; + bool hasOutputPadding = false; + add_output_padding(padding, out_padding, hasOutputPadding); - // shape of deconvolution's weight: [in, out/groups, ...] - // If there is still output padding, remove the bias. Bias will be added below. - auto deconv = ctx->net->addDeconvolutionNd( - *in, w.shape.d[1] * groups, w.kernel_shape, w.data, hasOutputPadding ? nvinfer1::Weights{} : bias.data); - TORCHTRT_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n); + // shape of deconvolution's weight: [in, out/groups, ...] + // If there is still output padding, remove the bias. Bias will be added below. + auto deconv = ctx->net->addDeconvolutionNd( + *in, w.shape.d[1] * groups, w.kernel_shape, w.data, hasOutputPadding ? nvinfer1::Weights{} : bias.data); + TORCHTRT_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n); - deconv->setStrideNd(stride); - deconv->setPrePadding(begPadding); - deconv->setPostPadding(padding); + deconv->setStrideNd(stride); + deconv->setPrePadding(begPadding); + deconv->setPostPadding(padding); #if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR >= 1) - deconv->setDilationNd(dilation); - deconv->setNbGroups(groups); + deconv->setDilationNd(dilation); + deconv->setNbGroups(groups); #else - TORCHTRT_CHECK(groups == 1, "for deconv with groups > 1, require TensorRT version >= 7.1"); - for (int idx = 0; idx < dilation.nbDims; idx++) { - TORCHTRT_CHECK(dilation.d[idx] == 1, "for deconv with dilation > 1, require TensorRT version >= 7.1"); - } + TORCHTRT_CHECK(groups == 1, "for deconv with groups > 1, require TensorRT version >= 7.1"); + for (int idx = 0; idx < dilation.nbDims; idx++) { + TORCHTRT_CHECK(dilation.d[idx] == 1, "for deconv with dilation > 1, require TensorRT version >= 7.1"); + } #endif - if (hasOutputPadding) { - LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); - nvinfer1::ITensor* tensorPtr = deconv->getOutput(0); - new_layer = add_bias_layer(ctx, tensorPtr, orig_dims, out_padding, bias); + if (hasOutputPadding) { + LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); + nvinfer1::ITensor* tensorPtr = deconv->getOutput(0); + new_layer = add_bias_layer(ctx, tensorPtr, orig_dims, out_padding, bias); + } else { + new_layer = deconv; + } } else { - new_layer = deconv; - } - } else { - // shape of convolution's weight: [out, in/groups, ...] - auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data); - TORCHTRT_CHECK(conv, "Unable to create convolution layer from node: " << *n); + // shape of convolution's weight: [out, in/groups, ...] + auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data); + TORCHTRT_CHECK(conv, "Unable to create convolution layer from node: " << *n); - conv->setStrideNd(stride); - conv->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); - conv->setPaddingNd(padding); - conv->setPostPadding(out_padding); - conv->setDilationNd(dilation); - conv->setNbGroups(groups); - new_layer = conv; - } + conv->setStrideNd(stride); + conv->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); + conv->setPaddingNd(padding); + conv->setPostPadding(out_padding); + conv->setDilationNd(dilation); + conv->setNbGroups(groups); + new_layer = conv; + } - new_layer->setName(util::node_info(n).c_str()); + new_layer->setName(util::node_info(n).c_str()); - // Un-expand spatial dims back to 1D if needed - auto out = addUnpadding(ctx, n, new_layer->getOutput(0), orig_dims.nbDims); + // Un-expand spatial dims back to 1D if needed + auto out = addUnpadding(ctx, n, new_layer->getOutput(0), orig_dims.nbDims); - ctx->AssociateValueAndTensor(n->outputs()[0], out); + ctx->AssociateValueAndTensor(n->outputs()[0], out); - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); - return true; -} + return true; + } -auto conv_registrations TORCHTRT_UNUSED = - RegisterNodeConversionPatterns() - .pattern({ - R"SIG(aten::_convolution(Tensor input, Tensor weight, + auto conv_registrations TORCHTRT_UNUSED = + RegisterNodeConversionPatterns() + .pattern({ + R"SIG(aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor))SIG", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - return add_conv_deconv(ctx, n, args); - }}) - .pattern({ - R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight, + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + return add_conv_deconv(ctx, n, args); + }}) + .pattern({ + R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor))SIG", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - // This pattern is only matched for traced JIT models which do not - // have allow_tf32 bool in the function signature. The TRT conversion - // code is exactly same as the above call. - return add_conv_deconv(ctx, n, args); - }}); + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // This pattern is only matched for traced JIT models which do not + // have allow_tf32 bool in the function signature. The TRT conversion + // code is exactly same as the above call. + return add_conv_deconv(ctx, n, args); + }}); } // namespace } // namespace impl } // namespace converters From 05b6318d981a08841470afc84e8951784be7e5b6 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 11 Dec 2023 16:12:10 -0800 Subject: [PATCH 3/6] chore: fix formatting Signed-off-by: Dheeraj Peri --- .../converters/impl/conv_deconv.cpp | 229 +++++++++--------- 1 file changed, 115 insertions(+), 114 deletions(-) diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index fe8c49e86f..c7b867b539 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -165,151 +165,152 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); nvinfer1::ITensor* tensorPtr = deconvLayer->getOutput(0); layer = add_bias_layer(ctx, tensorPtr, orig_dims, out_padding, bias); - } else { - nvinfer1::IConvolutionLayer* convLayer = - ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data); - convLayer->setStrideNd(stride); - convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); - convLayer->setPaddingNd(padding); - convLayer->setPostPadding(out_padding); - convLayer->setDilationNd(dilation); - convLayer->setNbGroups(groups); - - // Set conv kernel weights - convLayer->setInput(1, *kernel); - layer = convLayer; } + } else { + nvinfer1::IConvolutionLayer* convLayer = + ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data); + convLayer->setStrideNd(stride); + convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); + convLayer->setPaddingNd(padding); + convLayer->setPostPadding(out_padding); + convLayer->setDilationNd(dilation); + convLayer->setNbGroups(groups); - ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << layer->getOutput(0)->getDimensions()); - return true; + // Set conv kernel weights + convLayer->setInput(1, *kernel); + layer = convLayer; } - auto w = Weights(ctx, args[1].unwrapToTensor()); - // TODO: Remove this when conv3d with kernel size=1 bug is fixed. - // Github issue: https://github.com/pytorch/TensorRT/issues/1445 - bool is_kernel_size_one = true; - bool is_3d_kernel = w.kernel_shape.nbDims == 3; - for (int64_t i = 0; i < w.kernel_shape.nbDims; i++) { - if (w.kernel_shape.d[i] != 1.0f) { - is_kernel_size_one = false; - } + ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << layer->getOutput(0)->getDimensions()); + return true; + } + + auto w = Weights(ctx, args[1].unwrapToTensor()); + // TODO: Remove this when conv3d with kernel size=1 bug is fixed. + // Github issue: https://github.com/pytorch/TensorRT/issues/1445 + bool is_kernel_size_one = true; + bool is_3d_kernel = w.kernel_shape.nbDims == 3; + for (int64_t i = 0; i < w.kernel_shape.nbDims; i++) { + if (w.kernel_shape.d[i] != 1.0f) { + is_kernel_size_one = false; } - if (is_kernel_size_one && is_3d_kernel) { - LOG_WARNING( - "Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \ + } + if (is_kernel_size_one && is_3d_kernel) { + LOG_WARNING( + "Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \ Github issue: https://github.com/pytorch/TensorRT/issues/1445. Other conv variants do not have this issue."); - } - auto dims = in->getDimensions(); - auto orig_dims = dims; - LOG_DEBUG("Input dims: " << orig_dims); - LOG_DEBUG("Weights: " << w); - LOG_DEBUG("stride: " << stride); - LOG_DEBUG("padding: " << padding); - LOG_DEBUG("dilation: " << dilation); - LOG_DEBUG("out_padding: " << out_padding); - LOG_DEBUG("groups: " << groups); + } + auto dims = in->getDimensions(); + auto orig_dims = dims; + LOG_DEBUG("Input dims: " << orig_dims); + LOG_DEBUG("Weights: " << w); + LOG_DEBUG("stride: " << stride); + LOG_DEBUG("padding: " << padding); + LOG_DEBUG("dilation: " << dilation); + LOG_DEBUG("out_padding: " << out_padding); + LOG_DEBUG("groups: " << groups); - TORCHTRT_CHECK(orig_dims.nbDims > 2, "Unable to create convolution layer from node: " << *n); + TORCHTRT_CHECK(orig_dims.nbDims > 2, "Unable to create convolution layer from node: " << *n); - bool expandDims = (orig_dims.nbDims < 4); - if (expandDims) { - in = addPadding(ctx, n, in, 4); - dims = in->getDimensions(); - LOG_DEBUG("Reshaped Input dims: " << dims); - } - if (w.shape.nbDims < 4) { - for (int i = w.shape.nbDims; i < 4; ++i) { - w.shape.d[i] = 1; - } - w.shape.nbDims = 4; - w.kernel_shape.nbDims = 2; - w.kernel_shape.d[1] = 1; - LOG_DEBUG("Reshaped Weights: " << w); + bool expandDims = (orig_dims.nbDims < 4); + if (expandDims) { + in = addPadding(ctx, n, in, 4); + dims = in->getDimensions(); + LOG_DEBUG("Reshaped Input dims: " << dims); + } + if (w.shape.nbDims < 4) { + for (int i = w.shape.nbDims; i < 4; ++i) { + w.shape.d[i] = 1; } + w.shape.nbDims = 4; + w.kernel_shape.nbDims = 2; + w.kernel_shape.d[1] = 1; + LOG_DEBUG("Reshaped Weights: " << w); + } - nvinfer1::ILayer* new_layer; - if (transposed) { - // Refer to - // https://github.com/onnx/onnx-tensorrt/blob/c3cfcbc8248c6bd007e6630af2085df5e4834b42/builtin_op_importers.cpp#L734 - nvinfer1::Dims begPadding = padding; - bool hasOutputPadding = false; - add_output_padding(padding, out_padding, hasOutputPadding); + nvinfer1::ILayer* new_layer; + if (transposed) { + // Refer to + // https://github.com/onnx/onnx-tensorrt/blob/c3cfcbc8248c6bd007e6630af2085df5e4834b42/builtin_op_importers.cpp#L734 + nvinfer1::Dims begPadding = padding; + bool hasOutputPadding = false; + add_output_padding(padding, out_padding, hasOutputPadding); - // shape of deconvolution's weight: [in, out/groups, ...] - // If there is still output padding, remove the bias. Bias will be added below. - auto deconv = ctx->net->addDeconvolutionNd( - *in, w.shape.d[1] * groups, w.kernel_shape, w.data, hasOutputPadding ? nvinfer1::Weights{} : bias.data); - TORCHTRT_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n); + // shape of deconvolution's weight: [in, out/groups, ...] + // If there is still output padding, remove the bias. Bias will be added below. + auto deconv = ctx->net->addDeconvolutionNd( + *in, w.shape.d[1] * groups, w.kernel_shape, w.data, hasOutputPadding ? nvinfer1::Weights{} : bias.data); + TORCHTRT_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n); - deconv->setStrideNd(stride); - deconv->setPrePadding(begPadding); - deconv->setPostPadding(padding); + deconv->setStrideNd(stride); + deconv->setPrePadding(begPadding); + deconv->setPostPadding(padding); #if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR >= 1) - deconv->setDilationNd(dilation); - deconv->setNbGroups(groups); + deconv->setDilationNd(dilation); + deconv->setNbGroups(groups); #else - TORCHTRT_CHECK(groups == 1, "for deconv with groups > 1, require TensorRT version >= 7.1"); - for (int idx = 0; idx < dilation.nbDims; idx++) { - TORCHTRT_CHECK(dilation.d[idx] == 1, "for deconv with dilation > 1, require TensorRT version >= 7.1"); - } + TORCHTRT_CHECK(groups == 1, "for deconv with groups > 1, require TensorRT version >= 7.1"); + for (int idx = 0; idx < dilation.nbDims; idx++) { + TORCHTRT_CHECK(dilation.d[idx] == 1, "for deconv with dilation > 1, require TensorRT version >= 7.1"); + } #endif - if (hasOutputPadding) { - LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); - nvinfer1::ITensor* tensorPtr = deconv->getOutput(0); - new_layer = add_bias_layer(ctx, tensorPtr, orig_dims, out_padding, bias); - } else { - new_layer = deconv; - } + if (hasOutputPadding) { + LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); + nvinfer1::ITensor* tensorPtr = deconv->getOutput(0); + new_layer = add_bias_layer(ctx, tensorPtr, orig_dims, out_padding, bias); } else { - // shape of convolution's weight: [out, in/groups, ...] - auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data); - TORCHTRT_CHECK(conv, "Unable to create convolution layer from node: " << *n); - - conv->setStrideNd(stride); - conv->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); - conv->setPaddingNd(padding); - conv->setPostPadding(out_padding); - conv->setDilationNd(dilation); - conv->setNbGroups(groups); - new_layer = conv; + new_layer = deconv; } + } else { + // shape of convolution's weight: [out, in/groups, ...] + auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data); + TORCHTRT_CHECK(conv, "Unable to create convolution layer from node: " << *n); - new_layer->setName(util::node_info(n).c_str()); + conv->setStrideNd(stride); + conv->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); + conv->setPaddingNd(padding); + conv->setPostPadding(out_padding); + conv->setDilationNd(dilation); + conv->setNbGroups(groups); + new_layer = conv; + } - // Un-expand spatial dims back to 1D if needed - auto out = addUnpadding(ctx, n, new_layer->getOutput(0), orig_dims.nbDims); + new_layer->setName(util::node_info(n).c_str()); - ctx->AssociateValueAndTensor(n->outputs()[0], out); + // Un-expand spatial dims back to 1D if needed + auto out = addUnpadding(ctx, n, new_layer->getOutput(0), orig_dims.nbDims); - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + ctx->AssociateValueAndTensor(n->outputs()[0], out); - return true; - } + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + + return true; +} - auto conv_registrations TORCHTRT_UNUSED = - RegisterNodeConversionPatterns() - .pattern({ - R"SIG(aten::_convolution(Tensor input, Tensor weight, +auto conv_registrations TORCHTRT_UNUSED = + RegisterNodeConversionPatterns() + .pattern({ + R"SIG(aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor))SIG", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - return add_conv_deconv(ctx, n, args); - }}) - .pattern({ - R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight, + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + return add_conv_deconv(ctx, n, args); + }}) + .pattern({ + R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor))SIG", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - // This pattern is only matched for traced JIT models which do not - // have allow_tf32 bool in the function signature. The TRT conversion - // code is exactly same as the above call. - return add_conv_deconv(ctx, n, args); - }}); + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // This pattern is only matched for traced JIT models which do not + // have allow_tf32 bool in the function signature. The TRT conversion + // code is exactly same as the above call. + return add_conv_deconv(ctx, n, args); + }}); } // namespace } // namespace impl } // namespace converters From cf9681e0cf6fab4071daf15c17c6e68e7ebd6bba Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 11 Dec 2023 16:40:06 -0800 Subject: [PATCH 4/6] chore: fix input dims Signed-off-by: Dheeraj Peri --- core/conversion/converters/impl/conv_deconv.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index c7b867b539..1c0b308a49 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -164,7 +164,7 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) if (hasOutputPadding) { LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); nvinfer1::ITensor* tensorPtr = deconvLayer->getOutput(0); - layer = add_bias_layer(ctx, tensorPtr, orig_dims, out_padding, bias); + layer = add_bias_layer(ctx, tensorPtr, in->getDimensions(), out_padding, bias); } } else { nvinfer1::IConvolutionLayer* convLayer = From 29a0670d9e4eb0c5ed1bba7ff73103ef36acb7a3 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 11 Dec 2023 17:03:22 -0800 Subject: [PATCH 5/6] chore: fixes Signed-off-by: Dheeraj Peri --- core/conversion/converters/impl/conv_deconv.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index 1c0b308a49..8e73eef130 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -164,7 +164,8 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) if (hasOutputPadding) { LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); nvinfer1::ITensor* tensorPtr = deconvLayer->getOutput(0); - layer = add_bias_layer(ctx, tensorPtr, in->getDimensions(), out_padding, bias); + auto dims = in->getDimensions(); + layer = add_bias_layer(ctx, tensorPtr, dims, out_padding, bias); } } else { nvinfer1::IConvolutionLayer* convLayer = From 6c139080fb477f5e647464a8ea3ffb05bbaaf6f5 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 18 Dec 2023 13:24:44 -0800 Subject: [PATCH 6/6] chore: address review comment Signed-off-by: Dheeraj Peri --- core/conversion/converters/impl/conv_deconv.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index 8e73eef130..083a4ecc2f 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -62,7 +62,7 @@ nvinfer1::ILayer* add_bias_layer( sliceLayer->setInput(1, *start); sliceLayer->setInput(2, *size); sliceLayer->setMode(nvinfer1::SliceMode::kFILL); - nvinfer1::ITensor* slide_output = sliceLayer->getOutput(0); + nvinfer1::ITensor* slice_output = sliceLayer->getOutput(0); nvinfer1::Dims constantDims; constantDims.nbDims = in_nbDims; @@ -73,7 +73,7 @@ nvinfer1::ILayer* add_bias_layer( bias.shape.d[0]; // Set C dimension to bias dim and other dimensions to 1 to enable broadcast auto const_layer = ctx->net->addConstant(constantDims, bias.data); auto bias_layer = - ctx->net->addElementWise(*slide_output, *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUM); + ctx->net->addElementWise(*slice_output, *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUM); return bias_layer; }