Skip to content

🐛 [Bug] Runtime Error on BART #1532

Closed
@gs-olive

Description

@gs-olive

Bug Description

When performing inference with a Torch-TRT converted BART network (https://huggingface.co/facebook/bart-base), the following error is encountered:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(243): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(325): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(850): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(1233): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(976): trace_module
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(759): trace
RuntimeError: The size of tensor a (1536) must match the size of tensor b (128) at non-singleton dimension 3

Note that compilation of the model succeeds.

To Reproduce

Steps to reproduce the behavior:

  1. Run torch_tensorrt.compile with BART model as input, using fp32 precision.
  2. Choose two fixed-size inputs of shape [1, 128] and [1, 128] and enable truncate_long_and_double with 12 GB workspace.
  3. Pass in model keyword args to disable attention and hidden state outputs
  4. Run inference using the compiled model on two sample inputs.

Expected behavior

Model should successfully perform inference with Torch-TRT. Specifically, internal shape issues should either be caught at compile time, or should otherwise not cause errors.

Environment

  • Torch-TensorRT Version: 1.4.0.dev0+81f2dabb
  • PyTorch Version: 1.14.0.dev20221114+cu116
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.6

Additional context

The problem currently seems to be related to Torch-TensorRT flattening input tensors in a way which is inconsistent with the analogous PyTorch behavior. Two potential operations which could be relevant are aten::mul and aten::add which are used often in the BART code as replacements for the linear layer, inserted in the LinearToAddMM lowering pass:

void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
std::string flatten_linear_pattern = R"IR(
graph(%input, %weight, %bias):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
std::string fused_linear = R"IR(
graph(%input, %weight_t, %bias):
%1: int = prim::Constant[value=1]()
%weight = aten::t(%weight_t)
%mm: Tensor = aten::matmul(%input, %weight)
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add(%b_f, %mm, %1)
return (%out))IR";

Temporary Solution

A temporary fix to this problem is to add the following to the compilation arguments in torch_tensorrt.compile:

torch_tensorrt.compile( ..., torch_executed_ops=["aten::mul"], ...)

This solution works as it happens to exclude the problematic code, which could potentially be related to the aten::mul operator itself.

Related Issues

Potentially related to Issue #1455, as a similar error appears under certain compilation configurations for that model as well.

Additional Note

The bug appears to be nondeterministic, as, after recompiling and running inference using the model many times, inference ultimately completes successfully.

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions