-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Environment
- TVM@3c7adf
- PyTorch v2.3.1
Steps to reproduce
repro
import torch
import torch.fx
from tvm.relax.frontend.torch import from_fx
def main():
class PermuteTest(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.permute(x, (0, 2, 3, 1))
return x
model = PermuteTest()
x = torch.randn(2, 3, 5, 7)
graph_module = torch.fx.symbolic_trace(model)
with torch.no_grad():
mod = from_fx(graph_module, [(x.shape, "float32")])
if __name__ == "__main__":
main()
Error message
TVMError: In function relax.op.permute_dims(0: RelayExpr, 1: Array) -> RelayExpr: error while converting argument 1: [12:21:18] /home/ubuntu/data/project/torch-fx-to-tvm-relax/3rdparty/tvm/include/tvm/runtime/packed_func.h:2056: InternalError: Check failed: (!checked_type.defined()) is false: Expected Array[IntImm], but got Array[index 0: Array]
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- needs-triage
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug