Skip to content

[Bug] pytorch relax frontend failed to import models with torch.permute #17183

@mshr-h

Description

@mshr-h

Environment

  • TVM@3c7adf
  • PyTorch v2.3.1

Steps to reproduce

repro

import torch
import torch.fx
from tvm.relax.frontend.torch import from_fx

def main():
  class PermuteTest(torch.nn.Module):
    def __init__(self):
      super().__init__()

    def forward(self, x):
      x = torch.permute(x, (0, 2, 3, 1))
      return x

  model = PermuteTest()
  x = torch.randn(2, 3, 5, 7)

  graph_module = torch.fx.symbolic_trace(model)

  with torch.no_grad():
    mod = from_fx(graph_module, [(x.shape, "float32")])

if __name__ == "__main__":
  main()

Error message

TVMError: In function relax.op.permute_dims(0: RelayExpr, 1: Array) -> RelayExpr: error while converting argument 1: [12:21:18] /home/ubuntu/data/project/torch-fx-to-tvm-relax/3rdparty/tvm/include/tvm/runtime/packed_func.h:2056: InternalError: Check failed: (!checked_type.defined()) is false: Expected Array[IntImm], but got Array[index 0: Array]

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions