Description
Bug Description
For model types which are fully-supported in Torch-TRT, but which use complex collections of outputs, for example Tuple[Tensor, Tensor]
, or Tuple[Tensor, Tuple[Tensor, Tensor]]
, the compilation behavior of Torch-TRT is different when the flag require_full_compilation
is enabled versus not.
When require_full_compilation=False
, the model compiles successfully, and the only operations executed in Torch are collections-processing operations. When require_full_compilation=True
, however, the model fails with:
RuntimeError: [Error thrown at core/conversion/conversion.cpp:230] Tuple type. Only a single tensor or a TensorList type is supported.
To Reproduce
Steps to reproduce the behavior:
- Define a fully convertible Torch model with forward function having the following form:
def forward(self, input : Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- Define the Torch-TRT Inputs and compilation settings using
require_full_compilation=True
, then compile the scripted model:
compile_settings = {
"inputs": [torch_tensorrt.Input((5, 5), dtype=torch.float)],
"enabled_precisions": {torch.float},
"truncate_long_and_double": True,
"require_full_compilation": True,
}
trt_ts_module = torch_tensorrt.ts.compile(scripted_model, **compile_settings)
Expected behavior
Model should compile with require_full_compilation=True
, when containing nested Tuple collection.
Environment
- Torch-TensorRT Version: 1.4.0.dev0+f43be5b6
- PyTorch Version: 1.14.0.dev20221114+cu116
- CPU Architecture: Intel Xeon CPU
- OS: Ubuntu 20.04
- How you installed PyTorch: pip
- Build command you used:
python setup.py develop
- Are you using local sources or building from archives: local
- Python version: 3.8.13
- CUDA version: 11.6
Additional context
This bug is related to, but not the same as #1595, as this bug relates to nested collection outputs and not inputs. The resolution to both of these bugs should resolve the overall issue of having fully-compiled models with complex collections of inputs and outputs.
This particular bug also appears when compiling the HuggingFace BERT uncased model with require_full_compilation=True
.