Description
Bug Description
When performing inference with a Torch-TRT converted BART network (https://huggingface.co/facebook/bart-base), the following error is encountered:
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(243): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(325): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(850): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(1233): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(976): trace_module
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(759): trace
RuntimeError: The size of tensor a (1536) must match the size of tensor b (128) at non-singleton dimension 3
Note that compilation of the model succeeds.
To Reproduce
Steps to reproduce the behavior:
- Run torch_tensorrt.compile with BART model as input, using fp32 precision.
- Choose two fixed-size inputs of shape [1, 128] and [1, 128] and enable truncate_long_and_double with 12 GB workspace.
- Pass in model keyword args to disable attention and hidden state outputs
- Run inference using the compiled model on two sample inputs.
Expected behavior
Model should successfully perform inference with Torch-TRT. Specifically, internal shape issues should either be caught at compile time, or should otherwise not cause errors.
Environment
- Torch-TensorRT Version: 1.4.0.dev0+81f2dabb
- PyTorch Version: 1.14.0.dev20221114+cu116
- CPU Architecture: Intel Xeon CPU
- OS: Ubuntu 20.04
- How you installed PyTorch: pip
- Build command you used:
python setup.py develop
- Are you using local sources or building from archives: local
- Python version: 3.8.13
- CUDA version: 11.6
Additional context
The problem currently seems to be related to Torch-TensorRT flattening input tensors in a way which is inconsistent with the analogous PyTorch behavior. Two potential operations which could be relevant are aten::mul
and aten::add
which are used often in the BART code as replacements for the linear layer, inserted in the LinearToAddMM
lowering pass:
TensorRT/core/lowering/passes/linear_to_addmm.cpp
Lines 47 to 61 in aa93a12
Temporary Solution
A temporary fix to this problem is to add the following to the compilation arguments in torch_tensorrt.compile:
torch_tensorrt.compile( ..., torch_executed_ops=["aten::mul"], ...)
This solution works as it happens to exclude the problematic code, which could potentially be related to the aten::mul
operator itself.
Related Issues
Potentially related to Issue #1455, as a similar error appears under certain compilation configurations for that model as well.
Additional Note
The bug appears to be nondeterministic, as, after recompiling and running inference using the model many times, inference ultimately completes successfully.