Skip to content

Fused RMSNorm incompatible with PP tracing (dynamic stride) #217

@wconstab

Description

@wconstab

The incompatibility is that during backwards, fused_rmsnorm does dynamic control flow over strides, which isn't safe for export tracing used by PP.

        dy = dy.view(-1, dy.shape[-1])
        if dy.stride(-1) != 1:
            dy = dy.contiguous()

Which leads to a stacktrace ending in

    File "/data/users/whc/pytorch/torch/_dynamo/variables/tensor.py", line 326, in var_getattr
      unimplemented(f"Illegal getattr invocation {name} in strict mode")     
    File "/data/users/whc/pytorch/torch/_dynamo/exc.py", line 204, in unimplemented
      raise Unsupported(msg)                                                                                                                                                                      
  torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode

Would it be possible to refactor this in a more export friendly way, or is that difficult?

cc @lessw2020, @kwen2501

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions