-
Notifications
You must be signed in to change notification settings - Fork 371
Closed
Labels
Description
Bug Description
AssertionError: to_numpy can only be called on None or a torch.Tensor, got: <tensorrt_bindings.tensorrt.ITensor object at 0x7f72c6108d30> While executing %batch_norm
This is using new export
workflow from https://github.com/pytorch/TensorRT/tree/dynamo_export_refactor branch.
The issue seems to be coming from partitioning (using from the torch.compile
) workflow where all the constants are being registered as placeholders when a graph copy happens. Hence, constants like weights and biases are now treated as ITensors while the batch norm converter expects them to be constants.
To Reproduce
import torch
import torch_tensorrt
import torchvision.models as models
import timm
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=False)
self.bn = torch.nn.BatchNorm2d(16)
self.relu = torch.nn.ReLU()
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
return out
model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")
compile_spec = {
"inputs": [
torch_tensorrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"enabled_precisions": {torch.float},
"debug": True,
"is_aten": True,
"min_block_size": 1,
}
trt_mod = torch_tensorrt.dynamo.export.compile(model, **compile_spec)
Expected behavior
It should pass
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information: