Closed
Description
Bug Description
This is a bug in a tutorial below.
- Web: https://nvidia.github.io/TRTorch/tutorials/use_from_pytorch.html
- Original ReST doc: https://github.com/NVIDIA/TRTorch/blob/51a2043217a3f6c93f393169961750733c0d26ec/docsrc/tutorials/use_from_pytorch.rst
To compile a model, torch._C._jit_to_tensorrt()
is called with two parameters, script_model._c
and spec
.
But, a following error happened.
Traceback (most recent call last):
File "simple.py", line 34, in <module>
main()
File "simple.py", line 27, in main
trt_model = torch._C._jit_to_tensorrt(script_model._c, spec)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "<string>", line 6, in __setstate__
self.__processed_module = state[1]
self.__create_backend()
self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: isGenericDict() INTERNAL ASSERT FAILED at "/opt/python/cp36-cp36m/lib/python3.6/site-packages/torch/include/ATen/core/ivalue_inl.h":830, please report a bug to PyTorch. Expected GenericDict but got Object
This error looks like caused by wrong parameter. In test code (https://github.com/NVIDIA/TRTorch/blob/b93627ecc7bb7123d2e32aff9d762dd0e6bc3166/tests/py/test_to_backend_api.py), this API is called with a different parameter like below. It looks like making additional dict
with forward
key is necessary.
trt_model = torch._C._jit_to_tensorrt(script_model._c, {'forward': spec})
To Reproduce
Steps to reproduce the behavior:
- Launch NGC TRT container,
nvcr.io/nvidia/tensorrt:20.03-py3
. - Install required libraries
pip install https://github.com/NVIDIA/TRTorch/releases/download/v0.1.0/trtorch-0.1.0-cp36-cp36m-linux_x86_64.whl torchvision==0.7.0
- Run each step described in the tutorial.
My repro code is below. Note that I modified the tutorial code to use FP32 for test.
import torch
import trtorch
import torchvision.models as models
def main():
model = models.mobilenet_v2(pretrained=True)
script_model = torch.jit.script(model)
spec = {
"forward": trtorch.TensorRTCompileSpec({
"input_shapes": [[1, 3, 300, 300]],
"op_precision": torch.float32,
"refit": False,
"debug": False,
"strict_types": False,
"allow_gpu_fallback": True,
"device_type": "gpu",
"capability": trtorch.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
"max_batch_size": 0,
})
}
# trt_model = torch._C._jit_to_tensorrt(script_model._c, {'forward': spec})
trt_model = torch._C._jit_to_tensorrt(script_model._c, spec)
x = torch.randn((1, 3, 300, 300)).to('cuda').to(torch.float32)
print(trt_model.forward(x))
if __name__ == '__main__':
main()
Expected behavior
No exception.
Environment
Build information about the TRTorch compiler can be found by turning on debug messages
- PyTorch Version (e.g., 1.0): 1.6.0
- CPU Architecture: Intel(R) Core(TM) i7-6850K
- OS (e.g., Linux): Ubuntu 18.04.5 LTS (host) / Ubuntu 18.04.4 LTS (container)
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source): N/A
- Are you using local sources or building from archives: N/A
- Python version: 3.6.9
- CUDA version: 10.2.89
- GPU models and configuration: TitanX (Pascal)
- Any other relevant information: N/A
Additional context
N/A