-
Notifications
You must be signed in to change notification settings - Fork 372
Description
Bug Description
Passing a boolean value inside a dict to kwarg_inputs
parameter of the torch_tensorrt.compile method results in
ValueError: Invalid input type <class 'bool'> encountered in the dynamo_compile input parsing. Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}
It seems that apart from collection types (list, tuple, dict), at leaf level only torch.Tensor values are allowed. This contradicts the documentation https://pytorch.org/TensorRT/py_api/torch_tensorrt.html?highlight=compile which states:
kwarg_inputs: Optional[dict[Any, Any]] = None
To Reproduce
Steps to reproduce the behavior:
- Execute the following minimal example:
import torch
import torch_tensorrt
class TestModel(torch.nn.Module):
def forward(self, param1, additional_param = bool | None):
pass
compiled_model = torch_tensorrt.compile(
TestModel(),
ir="dynamo",
inputs=[torch.rand(1)],
kwarg_inputs={
"additional_param": True
},
)
- The result is
Traceback (most recent call last):
File "...\test_bug.py", line 8, in <module>
compiled_model = torch_tensorrt.compile(
File "...\lib\site-packages\torch_tensorrt\_compile.py", line 284, in compile
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
File "...\lib\site-packages\torch_tensorrt\dynamo\utils.py", line 272, in prepare_inputs
torchtrt_input = prepare_inputs(
File "...\lib\site-packages\torch_tensorrt\dynamo\utils.py", line 280, in prepare_inputs
raise ValueError(
ValueError: Invalid input type <class 'bool'> encountered in the dynamo_compile input parsing. Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}
Expected behavior
The minimal example should compile fine. Any values in addition to torch tensors in both - inputs and kwarg_inputs - should IMHO be accepted. It would additionally be nice if the documentation would be a bit more verbose about this IMHO important topic of how inputs will be treated by the compiler and what will happen at runtime of the compiled model.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
I am sorry, I do not know a canonical way of "turning on debug messages" in python. I do not know how this translates into something actionable.
- Torch-TensorRT Version (e.g. 1.0.0) / PyTorch Version (e.g. 1.0):
tensorrt==10.7.0
tensorrt_cu12==10.7.0
tensorrt_cu12_bindings==10.7.0
tensorrt_cu12_libs==10.7.0
torch==2.6.0+cu124
torch_tensorrt==2.6.0+cu124
- CPU Architecture: Intel x86_64
- OS (e.g., Linux): Windows 10
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Python version: Python 3.10.16
- CUDA version: 12.4
Activity
cehongwang commentedon Feb 4, 2025
Please make sure to include everything the model takes in the forward function in the kwarg_inputs. If a kwarg input is a tensor, you give it a tensor of the same shape. If the kwarg has a bool, you can try
torch.tensor(True)
You can also try MutableTorchTensorRTModule. That handles the sample input for you and supports extra functionalities such as weight updates.
TensorRT/examples/dynamo/mutable_torchtrt_module_example.py
Line 76 in 54e36db
This includes an example of compiling the unet of stable diffusion pipeline.
derMart commentedon Feb 6, 2025
@cehongwang thank you for guiding me to MutableTorchTensorRTModule. I will probably need to check that out. But apart from that, I am really sorry. I do not understand how your comment has anything to do with the issue I opened. I did not pass torch.tensor(True) as kwarg_inputs but a boolean value True which raises an exception except it should not as stated by the doc.
I gave a minimal reproducible example which is the topic of this issue and which raises an exception. It would be really nice to stick to this topic here.
cehongwang commentedon Feb 6, 2025
Use
torch.tensor(True)
derMart commentedon Feb 6, 2025
@cehongwang The bug is that it does not work with a python boolean as kwarg_inputs. It should according to the docs. Changing the reproducible example which causes the bug is not the solution to fixing the bug :D Fixing the code is.
cehongwang commentedon Feb 6, 2025
Sure. We will fix the documentation in our next release. You can use this to bypass the error for now :)
derMart commentedon Feb 6, 2025
The problem is not a documentation issue. For models where non tensor inputs are constants I can workaround the issue by generating a wrapper class which passes the python typed inputs, sth like this:
where self.forward_kwargs are the python typed constants and self.model is the original model I want to compile.
But this approach does not work if I need to change a python typed input at runtime of the model. Wrapping it as you suggest into a tensor does not work here. My current problematic example is indeed a StableDiffusion Controlnet which has a conditioning_scale as python typed float value.
If I try to wrap that into a tensor:
and try to compile this, I get the following error:
I can give you a reproducible example for this as well if you need that.
I also had a quick look at MutableTorchTensorRTModule. If this would work (havent tested yet), it seems still not a solution as the docs currently state that saving and loading the compiled models with it is not possible for the python version of tensorrt. I could use regular torch.compile if I dont want to save the model. Also MutableTorchTensorRTModule is a compile at first use compilation and not ahead of time which is not a very nice api.
And also ... I am sure not all python types can be wrapped by tensors ;-)
cehongwang commentedon Feb 6, 2025
Currently, we have limited support to some python-typed objects due to the inflexibility of Ahead-of-Time compilation. Please give us a reproducible example and we can look into that.