Skip to content

🐛 [Bug] Unable to compile Resample layer when using Torch-TensorRT #1596

Closed
@dathudeptrai

Description

@dathudeptrai

Bug Description

Trying to run the following code but got bug.

To Reproduce

import torchaudio.transforms as T

resample = T.Resample(16000, 32000, "kaiser_window", 12, 0.5, 4.663800127934911, dtype=torch.float32).cuda()
resample(torch.randn(1, 1, 32000).cuda())
resample = torch.jit.script(resample)
trt_ts_module = torch_tensorrt.compile(
 resample, 
 inputs=[
     torch_tensorrt.Input(
          min_shape=[1, 1, 16000],
          opt_shape=[1, 1, 32000],
          max_shape=[1, 1, 64000],
          dtype=torch.float32)
  ],
  enabled_precisions={torch.float32},
  truncate_long_and_double=True
)

Expected behavior

Able to convert to TensorRT

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.3.0a0
  • PyTorch Version (e.g. 1.0): 1.14.0a0+410ce96
  • CPU Architecture:
  • OS (e.g., Linux): Ubuntu
  • How you installed PyTorch (conda, pip, libtorch, source): docker pull nvcr.io/nvidia/pytorch:22.12-py3
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10
  • CUDA version: 11.8
  • GPU models and configuration: T4 Google Cloud
  • Any other relevant information:

Additional context

RuntimeError                              Traceback (most recent call last)
Cell In[55], line 2
      1 # with torch_tensorrt.logging.debug():
----> 2 trt_ts_module = torch_tensorrt.compile(
      3    resample, 
      4    inputs=[
      5        torch_tensorrt.Input( # Specify input object with shape and dtype
      6             min_shape=[1, 1, 16000],
      7             opt_shape=[1, 1, 32000],
      8             max_shape=[1, 1, 64000],
      9             dtype=torch.float32)
     10     ],
     11     enabled_precisions={torch.float32},
     12     truncate_long_and_double=True
     13 )

File /usr/local/lib/python3.8/dist-packages/torch_tensorrt/_compile.py:125, in compile(module, ir, inputs, enabled_precisions, **kwargs)
    120         logging.log(
    121             logging.Level.Info,
    122             "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
    123         )
    124         ts_mod = torch.jit.script(module)
--> 125     return torch_tensorrt.ts.compile(
    126         ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
    127     )
    128 elif target_ir == _IRType.fx:
    129     if (
    130         torch.float16 in enabled_precisions
    131         or torch_tensorrt.dtype.half in enabled_precisions
    132     ):

File /usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py:136, in compile(module, inputs, input_signature, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, calibrator, truncate_long_and_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules)
    110     raise ValueError(
    111         f"require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: {torch_executed_ops}, torch_executed_modules: {torch_executed_modules}"
    112     )
    114 spec = {
    115     "inputs": inputs,
    116     "input_signature": input_signature,
   (...)
    133     },
    134 }
--> 136 compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
    137 compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
    138 return compiled_module

RuntimeError: [Error thrown at core/conversion/converters/impl/conv_deconv.cpp:115] Expected orig_dims.nbDims > 2 to be true but got false
Unable to create convolution layer from node: %6 : Tensor = aten::_convolution(%5, %7, %3, %8, %11, %8, %10, %11, %4, %10, %10, %10, %10)

@narendasan

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions