Skip to content

🐛 [Bug] Compilation of transformer models on non-cuda:0 devices fails #1764

Closed
@gs-olive

Description

@gs-olive

Bug Description

When compiling certain models, such as HuggingFace GPT2 model and BERT base uncased from TorchScript to Torch-TRT, an error is encountered during partitioning where internally-generated tensors are moved to cuda:0 regardless of user specification.

Despite the user specification of cuda:4 as the desired device, cuda:0 is used for some internal tensors. See error below:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py(2210): embedding
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/sparse.py(162): forward
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py(1488): _slow_forward
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py(1501): _call_impl
/usr/local/lib/python3.8/dist-packages/transformers/models/bert/modeling_bert.py(230): forward
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py(1488): _slow_forward
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py(1501): _call_impl
/usr/local/lib/python3.8/dist-packages/transformers/models/bert/modeling_bert.py(1013): forward
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py(1488): _slow_forward
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py(1501): _call_impl
/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py(1056): trace_module
/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py(794): trace
trt_run_distilbert.py(55): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:4 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

To Reproduce

Steps to reproduce the behavior:

  1. Run torch_tensorrt.compile with BERT base-uncased model as input, using fp32 precision, and specify device=torch_tensorrt.Device("cuda:4") (or cuda:X with X nonzero)
  2. Choose fixed input size of [1, 128] and enable truncate_long_and_double with 12 GB workspace
  3. Pass in model keyword args to disable attention and hidden state outputs:
bert_kwargs = {
    "use_cache": False,
    "output_attentions": False,
    "output_hidden_states": False,
    "torchscript": True,
}

BertModel.from_pretrained('bert-base-uncased', **bert_kwargs).eval().to("cuda:4")

Expected behavior

Module should perform inference when compiled with half precision via the TorchScript path

Environment

  • Transformers: 4.27.2
  • Torch-TensorRT Version (e.g. 1.0.0): a1d4af0
  • PyTorch Version (e.g. 1.0): 2.1.0.dev20230219+cu117
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.7

Additional Considerations

Note that the BERT and GPT2 models compile in the TorchScript path successfully otherwise.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions