Closed
Description
Bug Description
Trying to run the following code but got bug.
To Reproduce
import torchaudio.transforms as T
resample = T.Resample(16000, 32000, "kaiser_window", 12, 0.5, 4.663800127934911, dtype=torch.float32).cuda()
resample(torch.randn(1, 1, 32000).cuda())
resample = torch.jit.script(resample)
trt_ts_module = torch_tensorrt.compile(
resample,
inputs=[
torch_tensorrt.Input(
min_shape=[1, 1, 16000],
opt_shape=[1, 1, 32000],
max_shape=[1, 1, 64000],
dtype=torch.float32)
],
enabled_precisions={torch.float32},
truncate_long_and_double=True
)
Expected behavior
Able to convert to TensorRT
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 1.3.0a0
- PyTorch Version (e.g. 1.0): 1.14.0a0+410ce96
- CPU Architecture:
- OS (e.g., Linux): Ubuntu
- How you installed PyTorch (
conda
,pip
,libtorch
, source): docker pull nvcr.io/nvidia/pytorch:22.12-py3 - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.10
- CUDA version: 11.8
- GPU models and configuration: T4 Google Cloud
- Any other relevant information:
Additional context
RuntimeError Traceback (most recent call last)
Cell In[55], line 2
1 # with torch_tensorrt.logging.debug():
----> 2 trt_ts_module = torch_tensorrt.compile(
3 resample,
4 inputs=[
5 torch_tensorrt.Input( # Specify input object with shape and dtype
6 min_shape=[1, 1, 16000],
7 opt_shape=[1, 1, 32000],
8 max_shape=[1, 1, 64000],
9 dtype=torch.float32)
10 ],
11 enabled_precisions={torch.float32},
12 truncate_long_and_double=True
13 )
File /usr/local/lib/python3.8/dist-packages/torch_tensorrt/_compile.py:125, in compile(module, ir, inputs, enabled_precisions, **kwargs)
120 logging.log(
121 logging.Level.Info,
122 "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
123 )
124 ts_mod = torch.jit.script(module)
--> 125 return torch_tensorrt.ts.compile(
126 ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
127 )
128 elif target_ir == _IRType.fx:
129 if (
130 torch.float16 in enabled_precisions
131 or torch_tensorrt.dtype.half in enabled_precisions
132 ):
File /usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py:136, in compile(module, inputs, input_signature, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, calibrator, truncate_long_and_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules)
110 raise ValueError(
111 f"require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: {torch_executed_ops}, torch_executed_modules: {torch_executed_modules}"
112 )
114 spec = {
115 "inputs": inputs,
116 "input_signature": input_signature,
(...)
133 },
134 }
--> 136 compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
137 compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
138 return compiled_module
RuntimeError: [Error thrown at core/conversion/converters/impl/conv_deconv.cpp:115] Expected orig_dims.nbDims > 2 to be true but got false
Unable to create convolution layer from node: %6 : Tensor = aten::_convolution(%5, %7, %3, %8, %11, %8, %10, %11, %4, %10, %10, %10, %10)