Skip to content

Commit 8554782

Browse files
authored
fix: aten.index converter (#2487)
1 parent 2f569c3 commit 8554782

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,19 @@ def aten_ops_sigmoid(
392392
)
393393

394394

395-
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
395+
def index_dtype_validator(node: Node) -> bool:
396+
index = node.args[1]
397+
for ind in index:
398+
if ind is not None:
399+
val = ind.meta.get("val")
400+
if val is not None and val.dtype != torch.int32:
401+
return False
402+
return True
403+
404+
405+
@dynamo_tensorrt_converter(
406+
torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator
407+
)
396408
@enforce_tensor_types(
397409
{
398410
0: (TRTTensor,),

0 commit comments

Comments
 (0)