Skip to content

❓ [Question] Output shape error in deconvolution layer when model is quantized with pytorch-quantization and using torch-tensorrt via torchscript #2723

Closed
@oazeybekoglu

Description

@oazeybekoglu

❓ Question

While using a simple model with int8 quantization (pytorch-quantization) when the output layer is deconvolution, torchscript to torch-tensorrt conversion fails with wrong number of output channels. If a conv layer is used instead of deconv, it works without an error.

What you have already tried

import torch_tensorrt
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
from torchvision import transforms
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import quant_modules
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
import torch.nn.functional as F

class customodel(nn.Module):
    def __init__(self):
        super().__init__()
        self.e11 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 
        self.e12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)  
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 
        self.upconv4 = nn.ConvTranspose2d(64,64, kernel_size=2, stride=2)
        self.d41 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.d42 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.outconv = nn.ConvTranspose2d(64,10, kernel_size=1) 
       
    def forward(self, x):
        x1 = F.relu(self.e11(x))
        x2 = F.relu(self.e12(x1))
        pool1 = self.pool1(x2)
        up4 = self.upconv4(pool1)
        merge4 = torch.cat([up4, x2], dim=1)  
        y = F.relu(self.d41(merge4))
        y = F.relu(self.d42(y))  
        y = self.outconv(y)     
        return y

def collect_stats(model, data_loader, num_batches):
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()
    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()

def compute_amax(model, **kwargs):
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)


def main():
  quant_modules.initialize()
  quant_desc_input = QuantDescriptor(calib_method='histogram')
  quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
  quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input)
  quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
  model = customodel().cuda()
  train_dataset = torchvision.datasets.CIFAR10(root = './data',
                                           train = True,
                                           transform = transforms.Compose([
                                                  transforms.Resize((572,572)),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1307,), std = (0.3081,))]),download = True)
  num_samples = int(0.03 * len(train_dataset))
  train_dataset_subset = torch.utils.data.Subset(train_dataset, range(num_samples))
  train_loader = torch.utils.data.DataLoader(dataset=train_dataset_subset,
                                           batch_size = 12,
                                           shuffle = True)
  with torch.no_grad():
      collect_stats(model,train_loader, num_batches=10)
      compute_amax(model, method="percentile", percentile=99.99)

  quant_nn.TensorQuantizer.use_fb_fake_quant = True
  with torch.no_grad():
    data = iter(train_loader)
    images, _ = next(data)
    jit_model = torch.jit.trace(model, images.to("cuda"))
    torch.jit.save(jit_model, "custom.pt")
def main2():
  model = torch.jit.load('/content/custom.pt').eval()
  compile_spec = {"inputs": [torch_tensorrt.Input([2,3,572,572])],
                "enabled_precisions":torch.int8,
                }

  trt_mod = torch_tensorrt.compile(model, **compile_spec,ir='torchscript')
if __name__ == '__main__':
    main()
    main2()
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: (Unnamed Layer* 53) [Deconvolution]: weight input tensor shape not consistent with the nbOutputMaps in addConvolutionNd/addDeconvolutionNd API. Expected output channels 64 kernel spatial dims [1,1]. But got output channels 10 kernel spatial dims [1,1]
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: (Unnamed Layer* 53) [Deconvolution]: weight input tensor shape not consistent with the nbOutputMaps in addConvolutionNd/addDeconvolutionNd API. Expected output channels 64 kernel spatial dims [1,1]. But got output channels 10 kernel spatial dims [1,1]
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::needTypeAndDimensions::2212] Error Code 4: Internal Error ((Unnamed Layer* 53) [Deconvolution]: output shape can not be computed)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: (Unnamed Layer* 53) [Deconvolution]: weight input tensor shape not consistent with the nbOutputMaps in addConvolutionNd/addDeconvolutionNd API. Expected output channels 64 kernel spatial dims [1,1]. But got output channels 10 kernel spatial dims [1,1]
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [network.cpp::validate::3121] Error Code 4: Internal Error (Layer (Unnamed Layer* 53) [Deconvolution] failed validation)

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.2.0
  • PyTorch Version (e.g., 1.0): 2.2.1+cu121
  • Pytorch-quantization: 2.1.3
  • Python version: 3.10.12
  • CUDA version: 12.2
  • OS (e.g., Linux): Ubuntu 22.04.3 LTS
  • GPU models and configuration: T4
  • Any other relevant information:

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions