You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[pytorch] Disable fast path in MultiheadAttention in Export (pytorch#106824)
Summary:
We are seeing `aten._native_multi_head_attention` op (not in core Aten op set) is left in the exported graph and causes problems in the downstream at runtime.
Two proposed solutions:
1. Disable fast path while tracing to leverage the non-optimized path to get decomp, that way, the blamed op won't show up in the exported graph
2. Add a decomp rule for `aten._native_multi_head_attention`
After discussing with kimishpatel and bdhirsh, ROCm#1 is preferred and verified it could immediately unblock the critical model enablement work for PP.
Test Plan: CI
Differential Revision: D48169806
Pull Request resolved: pytorch#106824
Approved by: https://github.com/kimishpatel
0 commit comments