Open
Description
Bug Description
when I use dynamic shape in trt, will raise error,
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: Internal Error (Tensor [SLICE]-[aten_ops.expand.default]-[__/expand]_output has axis 0 with inherently negative length. Proven upper bound is -1. Network must have an instance where axis has non-negative length.)
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: Internal Error (Output shape can not be computed for node [SLICE]-[aten_ops.expand.default]-[__/expand].)
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: Internal Error (Output shape can not be computed for node [SLICE]-[aten_ops.expand.default]-[__/expand].)
Traceback (most recent call last):
File "/larec/tzrec/tests/test3.py", line 73, in <module>
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 230, in compile
trt_gm = compile_module(gm, inputs, settings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 418, in compile_module
trt_module = convert_module(
^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 106, in convert_module
interpreter_result = interpret_module_to_result(module, inputs, settings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 87, in interpret_module_to_result
interpreter_result = interpreter.run()
^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 327, in run
super().run()
File "/opt/conda/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 372, in run_node
trt_node: torch.fx.Node = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 487, in call_function
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1937, in aten_ops_sub
return impl.elementwise.sub(
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py", line 492, in sub
return convert_binary_elementwise(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py", line 154, in convert_binary_elementwise
lhs_val, rhs_val = broadcast(
^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/fx/converters/converter_utils.py", line 404, in broadcast
a_shape = tuple(a.shape)
^^^^^^^^^^^^^^
ValueError: __len__() should return >= 0
While executing %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%expand, %args1_1), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7fe317191230>: ((s0, 41), torch.float32, False, (41, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7fe3170105b0>: ((s0, 1, 41), torch.float32, False, (41, 41, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7fe3174f3c70>: ((s0, 50, 41), torch.float32, False, (41, 0, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7fe317026cb0>: ((s0, 50, 41), torch.float32, False, (2050, 41, 1), torch.contiguous_format, False, {})}})
Original traceback:
File "<eval_with_key>.0 from /larec/tzrec/tests/test3.py:32 in forward", line 22, in forward
sub = expand - getitem_1
the static shape is ok.just delete these
torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
To Reproduce
Steps to reproduce the behavior:
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
self.mlp = MLP(in_features=41 * 4, **attn_mlp)
self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)
def forward(self, *args1: List[torch.Tensor]):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
return self.predict(args1)
def predict(self, args: List[torch.Tensor]):
grouped_features= _get_dict(self.keys, args)
query = grouped_features["query"]
sequence = grouped_features["sequence"]
sequence_length = grouped_features["sequence_length"]
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
return attn_input
model = MatMul().eval().cuda()
a=torch.randn(8196, 41).cuda()
b=torch.randn(8196, 50,41).cuda()
c=torch.randn(8196).cuda()
torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
inputs = [a, b,c]
print(model(*inputs)[0][0][0])
# seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torchrec.fx import symbolic_trace
model = symbolic_trace(model)
exp_program = torch.export.export(model, (*inputs,))
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference
print(trt_gm(*inputs)[0][0][0])
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
example_inputs=(a,b,c),
strict=False)
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")
model_gpu = torch.jit.load(
"./scripted_model_trt.pt", map_location="cuda:0"
)
print("load:",model_gpu(*inputs)[0][0][0])
the env:
CPU(s): 104
On-line CPU(s) list: 0-103
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8269CY CPU @ 2.50GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 26
Socket(s): 2
Stepping: 7
CPU max MHz: 3800.0000
CPU min MHz: 1200.0000
BogoMIPS: 5000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 1.6 MiB (52 instances)
L1i cache: 1.6 MiB (52 instances)
L2 cache: 52 MiB (52 instances)
L3 cache: 71.5 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-103
Vulnerability Itlb multihit: KVM: Mitigation: Split huge pages
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] optree==0.12.1
[pip3] torch==2.4.0
[pip3] torch_tensorrt==2.4.0
[pip3] torchaudio==2.4.0
[pip3] torchelastic==0.2.2
[pip3] torchmetrics==1.0.3
[pip3] torchrec==0.8.0+cu121
[pip3] torchvision==0.19.0
[pip3] triton==3.0.0
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py311h5eee18b_1
[conda] mkl_fft 1.3.8 py311h5eee18b_0
[conda] mkl_random 1.2.4 py311hdb19cb5_0
[conda] numpy 1.26.4 py311h08b1b3b_0
[conda] numpy-base 1.26.4 py311hf175353_0
[conda] optree 0.12.1 pypi_0 pypi
[conda] pytorch 2.4.0 py3.11_cuda12.1_cudnn9.1.0_0 pytorch
[conda] pytorch-cuda 12.1 ha16c6d3_5 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch-tensorrt 2.4.0 pypi_0 pypi
[conda] torchaudio 2.4.0 py311_cu121 pytorch
[conda] torchelastic 0.2.2 pypi_0 pypi
[conda] torchmetrics 1.0.3 pypi_0 pypi
[conda] torchrec 0.8.0+cu121 pypi_0 pypi
[conda] torchtriton 3.0.0 py311 pytorch
[conda] torchvision 0.19.0 py311_cu121 pytorch